use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
#[cfg(not(target_arch = "wasm32"))]
use tokio::time::sleep as tokio_sleep;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::JsCast;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::JsFuture;
#[cfg(target_arch = "wasm32")]
use web_sys::window;
pub async fn sleep(duration: Duration) {
#[cfg(not(target_arch = "wasm32"))]
{
tokio_sleep(duration).await;
}
#[cfg(target_arch = "wasm32")]
{
let millis = duration.as_millis() as i32;
let promise = js_sys::Promise::new(&mut |resolve, _reject| {
let window = window().expect("no global `window` exists");
let closure = wasm_bindgen::closure::Closure::once(move || {
resolve.call0(&wasm_bindgen::JsValue::NULL).unwrap();
});
window
.set_timeout_with_callback_and_timeout_and_arguments_0(
closure.as_ref().unchecked_ref(),
millis,
)
.expect("failed to set timeout");
closure.forget();
});
JsFuture::from(promise).await.unwrap();
}
}
pub fn spawn<F>(future: F)
where
F: Future<Output = ()> + Send + 'static,
{
#[cfg(not(target_arch = "wasm32"))]
{
tokio::spawn(future);
}
#[cfg(target_arch = "wasm32")]
{
wasm_bindgen_futures::spawn_local(future);
}
}
pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
#[cfg(not(target_arch = "wasm32"))]
{
let handle = tokio::task::spawn_blocking(f);
JoinHandle::Native(handle)
}
#[cfg(target_arch = "wasm32")]
{
let result = f();
JoinHandle::Wasm(Some(result))
}
}
#[derive(Debug)]
pub enum JoinHandle<T> {
#[cfg(not(target_arch = "wasm32"))]
Native(tokio::task::JoinHandle<T>),
#[cfg(target_arch = "wasm32")]
Wasm(Option<T>),
}
impl<T: Unpin> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match this {
#[cfg(not(target_arch = "wasm32"))]
Self::Native(handle) => Pin::new(handle)
.poll(cx)
.map_err(|e| JoinError(e.to_string())),
#[cfg(target_arch = "wasm32")]
Self::Wasm(result) => {
let _ = cx; Poll::Ready(
result
.take()
.ok_or_else(|| JoinError("Already consumed".to_string())),
)
},
}
}
}
#[derive(Debug)]
pub struct JoinError(String);
impl std::fmt::Display for JoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Join error: {}", self.0)
}
}
impl std::error::Error for JoinError {}
pub fn timestamp_millis() -> u64 {
#[cfg(not(target_arch = "wasm32"))]
{
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis() as u64
}
#[cfg(target_arch = "wasm32")]
{
js_sys::Date::now() as u64
}
}
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::sync::Mutex;
#[cfg(target_arch = "wasm32")]
pub use std::sync::Mutex;
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::sync::RwLock;
#[cfg(target_arch = "wasm32")]
pub use std::sync::RwLock;
#[cfg(not(target_arch = "wasm32"))]
pub mod channel {
pub use tokio::sync::mpsc::*;
}
#[cfg(target_arch = "wasm32")]
pub mod channel {
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::task::Waker;
pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Mutex::new(ChannelState {
queue: VecDeque::with_capacity(buffer),
closed: false,
waker: None,
}));
(
Sender {
shared: shared.clone(),
},
Receiver { shared },
)
}
struct ChannelState<T> {
queue: VecDeque<T>,
closed: bool,
waker: Option<Waker>,
}
pub struct Sender<T> {
shared: Arc<Mutex<ChannelState<T>>>,
}
impl<T> Sender<T> {
pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
let mut state = self.shared.lock().unwrap();
if state.closed {
return Err(SendError(value));
}
state.queue.push_back(value);
if let Some(waker) = state.waker.take() {
waker.wake();
}
Ok(())
}
}
pub struct Receiver<T> {
shared: Arc<Mutex<ChannelState<T>>>,
}
impl<T> Receiver<T> {
pub async fn recv(&mut self) -> Option<T> {
let mut state = self.shared.lock().unwrap();
state.queue.pop_front()
}
}
pub struct SendError<T>(pub T);
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_sleep() {
let start = timestamp_millis();
sleep(Duration::from_millis(100)).await;
let elapsed = timestamp_millis() - start;
assert!((100..200).contains(&elapsed));
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_spawn() {
let (tx, mut rx) = channel::channel(1);
spawn(async move {
tx.send(42).await.unwrap();
});
assert_eq!(rx.recv().await, Some(42));
}
#[test]
fn test_timestamp() {
let ts1 = timestamp_millis();
std::thread::sleep(Duration::from_millis(10));
let ts2 = timestamp_millis();
assert!(ts2 > ts1);
}
}