async_winit/
sync.rs

1/*
2
3`async-winit` is free software: you can redistribute it and/or modify it under the terms of one of
4the following licenses:
5
6* GNU Lesser General Public License as published by the Free Software Foundation, either
7  version 3 of the License, or (at your option) any later version.
8* Mozilla Public License as published by the Mozilla Foundation, version 2.
9
10`async-winit` is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even
11the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
12Public License and the Patron License for more details.
13
14You should have received a copy of the GNU Lesser General Public License and the Mozilla
15Public License along with `async-winit`. If not, see <https://www.gnu.org/licenses/>.
16
17*/
18
19use crate::reactor::Reactor;
20pub(crate) use __private::__ThreadSafety;
21
22use core::cell::{Cell, RefCell, RefMut};
23use core::convert::Infallible;
24use core::future::Future;
25use core::ops::Add;
26
27use std::collections::VecDeque;
28use std::rc::Rc;
29use std::sync::atomic;
30use std::thread;
31
32use unsend::channel as us_channel;
33
34pub(crate) mod prelude {
35    pub use super::__private::{Atomic, Mutex, OnceLock};
36}
37
38#[cfg(feature = "thread_safe")]
39pub use thread_safe::ThreadSafe;
40
41#[cfg(feature = "thread_safe")]
42type _DefaultTS = ThreadSafe;
43#[cfg(not(feature = "thread_safe"))]
44type _DefaultTS = ThreadUnsafe;
45
46/// The default thread safe type to use.
47pub type DefaultThreadSafety = _DefaultTS;
48
49/// A token that can be used to indicate whether the current implementation should be thread-safe or
50/// not.
51pub trait ThreadSafety: __ThreadSafety {}
52
53/// Use thread-unsafe primitives.
54#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
55pub struct ThreadUnsafe {
56    _private: (),
57}
58
59impl ThreadSafety for ThreadUnsafe {}
60
61impl __ThreadSafety for ThreadUnsafe {
62    type Error = Infallible;
63
64    type AtomicUsize = Cell<usize>;
65    type AtomicU64 = Cell<u64>;
66    type AtomicI64 = Cell<i64>;
67
68    type Receiver<T> = us_channel::Receiver<T>;
69    type Sender<T> = us_channel::Sender<T>;
70    type Rc<T> = Rc<T>;
71
72    type ConcurrentQueue<T> = RefCell<VecDeque<T>>;
73    type Mutex<T> = RefCell<T>;
74    type OnceLock<T> = once_cell::unsync::OnceCell<T>;
75
76    fn channel_bounded<T>(_capacity: usize) -> (Self::Sender<T>, Self::Receiver<T>) {
77        us_channel::channel()
78    }
79
80    fn get_reactor() -> Self::Rc<Reactor<Self>> {
81        use once_cell::sync::OnceCell;
82
83        /// The thread ID of the thread that created the reactor.
84        static REACTOR_THREAD_ID: OnceCell<thread::ThreadId> = OnceCell::new();
85
86        std::thread_local! {
87            static REACTOR: RefCell<Option<std::rc::Rc<Reactor<ThreadUnsafe>>>> = RefCell::new(None);
88        }
89
90        // Try to set the thread ID.
91        let thread_id = thread_id();
92        let reactor_thread_id = REACTOR_THREAD_ID.get_or_init(|| thread_id);
93
94        if thread_id != *reactor_thread_id {
95            panic!("The reactor must be created on the main thread");
96        }
97
98        REACTOR.with(|reactor| {
99            reactor
100                .borrow_mut()
101                .get_or_insert_with(|| std::rc::Rc::new(Reactor::new()))
102                .clone()
103        })
104    }
105}
106
107pub(crate) type MutexGuard<'a, T, TS> =
108    <<TS as __ThreadSafety>::Mutex<T> as __private::Mutex<T>>::Lock<'a>;
109
110fn thread_id() -> thread::ThreadId {
111    // Get the address of a thread-local variable.
112    std::thread_local! {
113        static THREAD_ID: Cell<Option<thread::ThreadId>> = Cell::new(None);
114    }
115
116    THREAD_ID
117        .try_with(|thread_id| {
118            thread_id.get().unwrap_or_else(|| {
119                let id = thread::current().id();
120                thread_id.set(Some(id));
121                id
122            })
123        })
124        .unwrap_or_else(|_| {
125            // We're in a destructor
126            thread::current().id()
127        })
128}
129
130impl<T: Copy> __private::Atomic<T> for Cell<T> {
131    fn new(value: T) -> Self {
132        Self::new(value)
133    }
134
135    fn load(&self, _order: atomic::Ordering) -> T {
136        self.get()
137    }
138
139    fn store(&self, value: T, _order: atomic::Ordering) {
140        self.set(value);
141    }
142
143    fn fetch_add(&self, value: T, _order: atomic::Ordering) -> T
144    where
145        T: Add<Output = T>,
146    {
147        let old = self.get();
148        self.set(old + value);
149        old
150    }
151}
152
153impl<T> __private::Sender<T> for us_channel::Sender<T> {
154    type Error = Infallible;
155    type Send<'a> = core::future::Ready<Result<(), Self::Error>> where Self: 'a;
156
157    fn send(&self, value: T) -> Self::Send<'_> {
158        self.send(value).ok();
159        core::future::ready(Ok(()))
160    }
161
162    fn try_send(&self, value: T) -> Result<(), Self::Error> {
163        self.send(value).ok();
164        Ok(())
165    }
166}
167
168impl<T> __private::Receiver<T> for us_channel::Receiver<T> {
169    type Error = ();
170    type Recv<'a> = std::pin::Pin<Box<dyn Future<Output = Result<T, Self::Error>> + 'a>> where Self: 'a;
171
172    fn recv(&self) -> Self::Recv<'_> {
173        Box::pin(async move { self.recv().await.map_err(|_| ()) })
174    }
175
176    fn capacity(&self) -> usize {
177        usize::MAX
178    }
179
180    fn try_recv(&self) -> Option<T> {
181        self.try_recv().ok()
182    }
183
184    fn len(&self) -> usize {
185        todo!()
186    }
187}
188
189impl<T> __private::ConcurrentQueue<T> for RefCell<VecDeque<T>> {
190    type TryIter<'a> = TryIter<'a, T> where Self: 'a;
191
192    fn bounded(capacity: usize) -> Self {
193        Self::new(VecDeque::with_capacity(capacity))
194    }
195
196    fn push(&self, value: T) -> Result<(), T> {
197        self.borrow_mut().push_back(value);
198        Ok(())
199    }
200
201    fn pop(&self) -> Option<T> {
202        self.borrow_mut().pop_front()
203    }
204
205    fn capacity(&self) -> usize {
206        usize::MAX
207    }
208
209    fn try_iter(&self) -> Self::TryIter<'_> {
210        TryIter { queue: self }
211    }
212}
213
214#[doc(hidden)]
215pub struct TryIter<'a, T> {
216    queue: &'a RefCell<VecDeque<T>>,
217}
218
219impl<'a, T> Iterator for TryIter<'a, T> {
220    type Item = T;
221
222    fn next(&mut self) -> Option<Self::Item> {
223        self.queue.borrow_mut().pop_front()
224    }
225
226    fn size_hint(&self) -> (usize, Option<usize>) {
227        let len = self.queue.borrow().len();
228        (len, Some(len))
229    }
230}
231
232impl<T> __private::Mutex<T> for RefCell<T> {
233    type Error = Infallible;
234    type Lock<'a> = RefMut<'a, T> where Self: 'a;
235
236    fn new(value: T) -> Self {
237        Self::new(value)
238    }
239
240    fn lock(&self) -> Result<Self::Lock<'_>, Self::Error> {
241        Ok(self.borrow_mut())
242    }
243}
244
245impl<T> __private::OnceLock<T> for once_cell::unsync::OnceCell<T> {
246    fn new() -> Self {
247        Self::new()
248    }
249
250    fn get(&self) -> Option<&T> {
251        self.get()
252    }
253
254    fn set(&self, value: T) -> Result<(), T> {
255        self.set(value)
256    }
257
258    fn get_or_init<F>(&self, f: F) -> &T
259    where
260        F: FnOnce() -> T,
261    {
262        self.get_or_init(f)
263    }
264}
265
266impl<T> __private::Rc<T> for std::rc::Rc<T> {
267    fn new(value: T) -> Self {
268        Self::new(value)
269    }
270}
271
272#[cfg(feature = "thread_safe")]
273pub(crate) mod thread_safe {
274    use super::*;
275
276    use concurrent_queue::ConcurrentQueue;
277    use std::sync::atomic;
278    use std::sync::{Arc, Mutex};
279
280    /// Use thread-safe primitives.
281    #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
282    pub struct ThreadSafe {
283        _private: (),
284    }
285
286    impl ThreadSafety for ThreadSafe {}
287
288    impl __ThreadSafety for ThreadSafe {
289        type Error = Box<dyn std::error::Error + Send + Sync>;
290
291        type AtomicI64 = atomic::AtomicI64;
292        type AtomicUsize = atomic::AtomicUsize;
293        type AtomicU64 = atomic::AtomicU64;
294
295        type Sender<T> = async_channel::Sender<T>;
296        type Receiver<T> = async_channel::Receiver<T>;
297
298        type ConcurrentQueue<T> = ConcurrentQueue<T>;
299        type Mutex<T> = Mutex<T>;
300        type OnceLock<T> = once_cell::sync::OnceCell<T>;
301        type Rc<T> = Arc<T>;
302
303        fn channel_bounded<T>(capacity: usize) -> (Self::Sender<T>, Self::Receiver<T>) {
304            async_channel::bounded(capacity)
305        }
306        fn get_reactor() -> Self::Rc<crate::reactor::Reactor<Self>>
307        where
308            Self: super::ThreadSafety,
309        {
310            use once_cell::sync::OnceCell;
311
312            static REACTOR: OnceCell<Arc<Reactor<ThreadSafe>>> = OnceCell::new();
313
314            REACTOR.get_or_init(|| Arc::new(Reactor::new())).clone()
315        }
316    }
317
318    impl __private::Atomic<i64> for atomic::AtomicI64 {
319        fn new(value: i64) -> Self {
320            Self::new(value)
321        }
322
323        fn fetch_add(&self, value: i64, order: atomic::Ordering) -> i64 {
324            self.fetch_add(value, order)
325        }
326
327        fn load(&self, order: atomic::Ordering) -> i64 {
328            self.load(order)
329        }
330
331        fn store(&self, value: i64, order: atomic::Ordering) {
332            self.store(value, order)
333        }
334    }
335
336    impl __private::Atomic<usize> for atomic::AtomicUsize {
337        fn new(value: usize) -> Self {
338            Self::new(value)
339        }
340
341        fn fetch_add(&self, value: usize, order: atomic::Ordering) -> usize {
342            self.fetch_add(value, order)
343        }
344
345        fn load(&self, order: atomic::Ordering) -> usize {
346            self.load(order)
347        }
348
349        fn store(&self, value: usize, order: atomic::Ordering) {
350            self.store(value, order)
351        }
352    }
353
354    impl __private::Atomic<u64> for atomic::AtomicU64 {
355        fn new(value: u64) -> Self {
356            Self::new(value)
357        }
358
359        fn fetch_add(&self, value: u64, order: atomic::Ordering) -> u64 {
360            self.fetch_add(value, order)
361        }
362
363        fn load(&self, order: atomic::Ordering) -> u64 {
364            self.load(order)
365        }
366
367        fn store(&self, value: u64, order: atomic::Ordering) {
368            self.store(value, order)
369        }
370    }
371
372    impl<T> __private::Sender<T> for async_channel::Sender<T> {
373        type Error = async_channel::SendError<T>;
374        type Send<'a> = async_channel::Send<'a, T> where Self: 'a;
375
376        fn send(&self, value: T) -> Self::Send<'_> {
377            self.send(value)
378        }
379
380        fn try_send(&self, value: T) -> Result<(), Self::Error> {
381            self.try_send(value).map_err(|_e| todo!())
382        }
383    }
384
385    impl<T> __private::Receiver<T> for async_channel::Receiver<T> {
386        type Error = async_channel::RecvError;
387        type Recv<'a> = async_channel::Recv<'a, T> where Self: 'a;
388
389        fn recv(&self) -> Self::Recv<'_> {
390            self.recv()
391        }
392
393        fn capacity(&self) -> usize {
394            self.capacity().unwrap()
395        }
396
397        fn try_recv(&self) -> Option<T> {
398            self.try_recv().ok()
399        }
400
401        fn len(&self) -> usize {
402            self.len()
403        }
404    }
405
406    impl<T> __private::ConcurrentQueue<T> for ConcurrentQueue<T> {
407        type TryIter<'a> = concurrent_queue::TryIter<'a, T> where Self: 'a;
408
409        fn bounded(capacity: usize) -> Self {
410            Self::bounded(capacity)
411        }
412
413        fn push(&self, value: T) -> Result<(), T> {
414            self.push(value).map_err(|e| e.into_inner())
415        }
416
417        fn pop(&self) -> Option<T> {
418            self.pop().ok()
419        }
420
421        fn capacity(&self) -> usize {
422            self.capacity().unwrap()
423        }
424
425        fn try_iter(&self) -> Self::TryIter<'_> {
426            self.try_iter()
427        }
428    }
429
430    impl<T> __private::Mutex<T> for Mutex<T> {
431        type Error = Infallible;
432        type Lock<'a> = std::sync::MutexGuard<'a, T> where Self: 'a;
433
434        fn new(value: T) -> Self {
435            Self::new(value)
436        }
437
438        fn lock(&self) -> Result<Self::Lock<'_>, Self::Error> {
439            Ok(self.lock().unwrap_or_else(|e| e.into_inner()))
440        }
441    }
442
443    impl<T> __private::OnceLock<T> for once_cell::sync::OnceCell<T> {
444        fn new() -> Self {
445            Self::new()
446        }
447
448        fn get(&self) -> Option<&T> {
449            self.get()
450        }
451
452        fn set(&self, value: T) -> Result<(), T> {
453            self.set(value)
454        }
455
456        fn get_or_init<F>(&self, f: F) -> &T
457        where
458            F: FnOnce() -> T,
459        {
460            self.get_or_init(f)
461        }
462    }
463
464    impl<T> __private::Rc<T> for Arc<T> {
465        fn new(value: T) -> Self {
466            Self::new(value)
467        }
468    }
469}
470
471pub(crate) mod __private {
472    use core::fmt::{Debug, Display};
473    use core::future::Future;
474    use core::ops::{Add, Deref, DerefMut};
475    use core::sync::atomic;
476
477    #[doc(hidden)]
478    pub trait __ThreadSafety: Sized {
479        type Error: Display + Debug;
480
481        type AtomicI64: Atomic<i64>;
482        type AtomicUsize: Atomic<usize>;
483        type AtomicU64: Atomic<u64>;
484
485        type Sender<T>: Sender<T>;
486        type Receiver<T>: Receiver<T>;
487
488        type ConcurrentQueue<T>: ConcurrentQueue<T>;
489        type Mutex<T>: Mutex<T>;
490        type OnceLock<T>: OnceLock<T>;
491        type Rc<T>: Rc<T>;
492
493        fn channel_bounded<T>(capacity: usize) -> (Self::Sender<T>, Self::Receiver<T>);
494        fn get_reactor() -> Self::Rc<crate::reactor::Reactor<Self>>
495        where
496            Self: super::ThreadSafety;
497    }
498
499    #[doc(hidden)]
500    pub trait Atomic<T> {
501        fn new(value: T) -> Self;
502        fn load(&self, order: atomic::Ordering) -> T;
503        fn store(&self, value: T, order: atomic::Ordering);
504        fn fetch_add(&self, value: T, order: atomic::Ordering) -> T
505        where
506            T: Add<Output = T>;
507    }
508
509    #[doc(hidden)]
510    pub trait Sender<T> {
511        type Error;
512        type Send<'a>: Future<Output = Result<(), Self::Error>> + 'a
513        where
514            Self: 'a;
515        fn send(&self, value: T) -> Self::Send<'_>;
516        fn try_send(&self, value: T) -> Result<(), Self::Error>;
517    }
518
519    #[doc(hidden)]
520    pub trait Receiver<T> {
521        type Error: std::fmt::Debug;
522        type Recv<'a>: Future<Output = Result<T, Self::Error>> + 'a
523        where
524            Self: 'a;
525
526        fn recv(&self) -> Self::Recv<'_>;
527        fn capacity(&self) -> usize;
528        fn try_recv(&self) -> Option<T>;
529        fn len(&self) -> usize;
530    }
531
532    #[doc(hidden)]
533    pub trait OnceLock<T> {
534        fn new() -> Self;
535        fn get(&self) -> Option<&T>;
536        fn get_or_init<F>(&self, f: F) -> &T
537        where
538            F: FnOnce() -> T;
539        fn set(&self, value: T) -> Result<(), T>;
540    }
541
542    #[doc(hidden)]
543    pub trait Mutex<T> {
544        type Error: Debug + Display;
545        type Lock<'a>: DerefMut<Target = T> + 'a
546        where
547            Self: 'a;
548
549        fn new(value: T) -> Self;
550        fn lock(&self) -> Result<Self::Lock<'_>, Self::Error>;
551    }
552
553    #[doc(hidden)]
554    pub trait ConcurrentQueue<T> {
555        type TryIter<'a>: Iterator<Item = T> + 'a
556        where
557            Self: 'a;
558
559        fn bounded(capacity: usize) -> Self;
560        fn push(&self, value: T) -> Result<(), T>;
561        fn pop(&self) -> Option<T>;
562        fn capacity(&self) -> usize;
563        fn try_iter(&self) -> Self::TryIter<'_>;
564    }
565
566    #[doc(hidden)]
567    pub trait Rc<T>: Clone + Deref<Target = T> {
568        fn new(value: T) -> Self;
569    }
570}