1use std::future::Future;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11#[cfg(not(target_arch = "wasm32"))]
12use tokio::time::sleep as tokio_sleep;
13
14#[cfg(target_arch = "wasm32")]
15use wasm_bindgen::JsCast;
16#[cfg(target_arch = "wasm32")]
17use wasm_bindgen_futures::JsFuture;
18#[cfg(target_arch = "wasm32")]
19use web_sys::window;
20
21pub async fn sleep(duration: Duration) {
35 #[cfg(not(target_arch = "wasm32"))]
36 {
37 tokio_sleep(duration).await;
38 }
39
40 #[cfg(target_arch = "wasm32")]
41 {
42 let millis = duration.as_millis() as i32;
43 let promise = js_sys::Promise::new(&mut |resolve, _reject| {
44 let window = window().expect("no global `window` exists");
45 let closure = wasm_bindgen::closure::Closure::once(move || {
46 resolve.call0(&wasm_bindgen::JsValue::NULL).unwrap();
47 });
48 window
49 .set_timeout_with_callback_and_timeout_and_arguments_0(
50 closure.as_ref().unchecked_ref(),
51 millis,
52 )
53 .expect("failed to set timeout");
54 closure.forget();
55 });
56 JsFuture::from(promise).await.unwrap();
57 }
58}
59
60pub fn spawn<F>(future: F)
76where
77 F: Future<Output = ()> + Send + 'static,
78{
79 #[cfg(not(target_arch = "wasm32"))]
80 {
81 tokio::spawn(future);
82 }
83
84 #[cfg(target_arch = "wasm32")]
85 {
86 wasm_bindgen_futures::spawn_local(future);
87 }
88}
89
90pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
95where
96 F: FnOnce() -> R + Send + 'static,
97 R: Send + 'static,
98{
99 #[cfg(not(target_arch = "wasm32"))]
100 {
101 let handle = tokio::task::spawn_blocking(f);
102 JoinHandle::Native(handle)
103 }
104
105 #[cfg(target_arch = "wasm32")]
106 {
107 let result = f();
109 JoinHandle::Wasm(Some(result))
110 }
111}
112
113#[derive(Debug)]
115pub enum JoinHandle<T> {
116 #[cfg(not(target_arch = "wasm32"))]
118 Native(tokio::task::JoinHandle<T>),
119 #[cfg(target_arch = "wasm32")]
121 Wasm(Option<T>),
122}
123
124impl<T: Unpin> Future for JoinHandle<T> {
125 type Output = Result<T, JoinError>;
126
127 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
128 let this = self.get_mut();
129 match this {
130 #[cfg(not(target_arch = "wasm32"))]
131 Self::Native(handle) => Pin::new(handle)
132 .poll(cx)
133 .map_err(|e| JoinError(e.to_string())),
134 #[cfg(target_arch = "wasm32")]
135 Self::Wasm(result) => {
136 let _ = cx; Poll::Ready(
138 result
139 .take()
140 .ok_or_else(|| JoinError("Already consumed".to_string())),
141 )
142 },
143 }
144 }
145}
146
147#[derive(Debug)]
149pub struct JoinError(String);
150
151impl std::fmt::Display for JoinError {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(f, "Join error: {}", self.0)
154 }
155}
156
157impl std::error::Error for JoinError {}
158
159pub fn timestamp_millis() -> u64 {
170 #[cfg(not(target_arch = "wasm32"))]
171 {
172 use std::time::{SystemTime, UNIX_EPOCH};
173 SystemTime::now()
174 .duration_since(UNIX_EPOCH)
175 .expect("Time went backwards")
176 .as_millis() as u64
177 }
178
179 #[cfg(target_arch = "wasm32")]
180 {
181 js_sys::Date::now() as u64
182 }
183}
184
185#[cfg(not(target_arch = "wasm32"))]
190pub use tokio::sync::Mutex;
191
192#[cfg(target_arch = "wasm32")]
193pub use std::sync::Mutex;
194
195#[cfg(not(target_arch = "wasm32"))]
197pub use tokio::sync::RwLock;
198
199#[cfg(target_arch = "wasm32")]
200pub use std::sync::RwLock;
201
202#[cfg(not(target_arch = "wasm32"))]
204pub mod channel {
205 pub use tokio::sync::mpsc::*;
206}
207
208#[cfg(target_arch = "wasm32")]
209pub mod channel {
210 use std::collections::VecDeque;
211 use std::sync::{Arc, Mutex};
212 use std::task::Waker;
213
214 pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
216 let shared = Arc::new(Mutex::new(ChannelState {
217 queue: VecDeque::with_capacity(buffer),
218 closed: false,
219 waker: None,
220 }));
221
222 (
223 Sender {
224 shared: shared.clone(),
225 },
226 Receiver { shared },
227 )
228 }
229
230 struct ChannelState<T> {
231 queue: VecDeque<T>,
232 closed: bool,
233 waker: Option<Waker>,
234 }
235
236 pub struct Sender<T> {
237 shared: Arc<Mutex<ChannelState<T>>>,
238 }
239
240 impl<T> Sender<T> {
241 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
242 let mut state = self.shared.lock().unwrap();
243 if state.closed {
244 return Err(SendError(value));
245 }
246 state.queue.push_back(value);
247 if let Some(waker) = state.waker.take() {
248 waker.wake();
249 }
250 Ok(())
251 }
252 }
253
254 pub struct Receiver<T> {
255 shared: Arc<Mutex<ChannelState<T>>>,
256 }
257
258 impl<T> Receiver<T> {
259 pub async fn recv(&mut self) -> Option<T> {
260 let mut state = self.shared.lock().unwrap();
262 state.queue.pop_front()
263 }
264 }
265
266 pub struct SendError<T>(pub T);
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[cfg(not(target_arch = "wasm32"))]
274 #[tokio::test]
275 async fn test_sleep() {
276 let start = timestamp_millis();
277 sleep(Duration::from_millis(100)).await;
278 let elapsed = timestamp_millis() - start;
279 assert!((100..200).contains(&elapsed));
280 }
281
282 #[cfg(not(target_arch = "wasm32"))]
283 #[tokio::test]
284 async fn test_spawn() {
285 let (tx, mut rx) = channel::channel(1);
286 spawn(async move {
287 tx.send(42).await.unwrap();
288 });
289 assert_eq!(rx.recv().await, Some(42));
290 }
291
292 #[test]
293 fn test_timestamp() {
294 let ts1 = timestamp_millis();
295 std::thread::sleep(Duration::from_millis(10));
296 let ts2 = timestamp_millis();
297 assert!(ts2 > ts1);
298 }
299}