hybridfutex/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs, missing_debug_implementations)]
3
4use std::{
5    future::Future,
6    process::abort,
7    sync::atomic::{AtomicBool, AtomicIsize, Ordering},
8    task::{Poll, Waker},
9    thread::{self, current, Thread},
10};
11
12use crossbeam_queue::SegQueue;
13
14/// A HybridFutex is a synchronization primitive that allows threads to wait for
15/// a notification from another thread. The HybridFutex maintains a counter that
16/// represents the number of waiters, and a queue of waiters. The counter is
17/// incremented when a thread calls `wait_sync` or `wait_async` methods, and
18/// decremented when a thread calls `notify_one` or `notify_many` methods.
19/// A thread calling `wait_sync` or `wait_async` is blocked until it is notified
20/// by another thread calling `notify_one` or `notify_many`.
21///
22/// # Examples
23///
24/// ```
25/// use std::sync::Arc;
26/// use std::thread;
27/// use std::time::Duration;
28/// use hybridfutex::HybridFutex;
29///
30/// let wait_queue = Arc::new(HybridFutex::new());
31/// let wait_queue_clone = wait_queue.clone();
32///
33/// // Spawn a thread that waits for a notification from another thread
34/// let handle = thread::spawn(move || {
35///     println!("Thread 1 is waiting");
36///     wait_queue_clone.wait_sync();
37///     println!("Thread 1 is notified");
38/// });
39///
40/// // Wait for a short time before notifying the other thread
41/// thread::sleep(Duration::from_millis(100));
42///
43/// // Notify the other thread
44/// wait_queue.notify_one();
45///
46/// // Wait for the other thread to finish
47/// handle.join().unwrap();
48/// ```
49#[derive(Debug)]
50pub struct HybridFutex {
51    counter: AtomicIsize,
52    queue: SegQueue<Waiter>,
53}
54
55impl Default for HybridFutex {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl HybridFutex {
62    /// Creates a new HybridFutex with an initial counter of 0 and an empty
63    /// queue of waiters.
64    ///
65    /// # Examples
66    ///
67    /// ```
68    /// use hybridfutex::HybridFutex;
69    ///
70    /// let wait_queue = HybridFutex::new();
71    /// ```
72    pub fn new() -> Self {
73        Self {
74            counter: AtomicIsize::new(0),
75            queue: SegQueue::new(),
76        }
77    }
78    /// Returns the current value of the counter of this HybridFutex.
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use hybridfutex::HybridFutex;
84    ///
85    /// let wait_queue = HybridFutex::new();
86    ///
87    /// assert_eq!(wait_queue.get_counter(), 0);
88    /// ```
89    pub fn get_counter(&self) -> isize {
90        self.counter.load(Ordering::Relaxed)
91    }
92    /// Blocks the current thread until it is notified by another thread using
93    /// the `notify_one` or `notify_many` method. The method increments the
94    /// counter of the HybridFutex to indicate that the current thread is
95    /// waiting. If the counter is already negative, the method does not
96    /// block the thread and immediately returns.
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// use std::sync::Arc;
102    /// use std::thread;
103    /// use std::time::Duration;
104    /// use hybridfutex::HybridFutex;
105    ///
106    /// let wait_queue = Arc::new(HybridFutex::new());
107    /// let wait_queue_clone = wait_queue.clone();
108    ///
109    /// // Spawn a thread that waits for a notification from another thread
110    /// let handle = thread::spawn(move || {
111    ///     println!("Thread 1 is waiting");
112    ///     wait_queue_clone.wait_sync();
113    ///     println!("Thread 1 is notified");
114    /// });
115    ///
116    /// // Wait for a short time before notifying the other thread
117    /// thread::sleep(Duration::from_millis(100));
118    ///
119    /// // Notify the other thread
120    /// wait_queue.notify_one();
121    ///
122    /// // Wait for the other thread to finish
123    /// handle.join().unwrap();
124    /// ```
125    pub fn wait_sync(&self) {
126        let old_counter = self.counter.fetch_add(1, Ordering::SeqCst);
127        if old_counter >= 0 {
128            let awaken = AtomicBool::new(false);
129            self.queue.push(Waiter::Sync(SyncWaiter {
130                awaken: &awaken,
131                thread: current(),
132            }));
133            while {
134                thread::park();
135                !awaken.load(Ordering::Acquire)
136            } {}
137        }
138    }
139    /// Returns a `WaitFuture` that represents a future that resolves when the
140    /// current thread is notified by another thread using the `notify_one` or
141    /// `notify_many` method. The method increments the counter of the
142    /// HybridFutex to indicate that the current thread is waiting.
143    /// If the counter is already negative, the future immediately resolves.
144    ///
145    /// # Examples
146    ///
147    /// ```
148    /// use std::sync::Arc;
149    /// use std::thread;
150    /// use std::time::Duration;
151    /// use hybridfutex::HybridFutex;
152    /// use futures::executor::block_on;
153    ///
154    /// let wait_queue = Arc::new(HybridFutex::new());
155    ///
156    /// // Spawn a thread that waits for a notification from another thread
157    /// let wqc = wait_queue.clone();
158    /// let handle = thread::spawn(move || {
159    ///     let fut = wqc.wait_async();
160    ///     let _ = block_on(fut);
161    ///     println!("Thread 1 is notified");
162    /// });
163    ///
164    /// // Wait for a short time before notifying the other thread
165    /// thread::sleep(Duration::from_millis(100));
166    ///
167    /// // Notify the other thread
168    /// wait_queue.notify_one();
169    ///
170    /// // Wait for the other thread to finish
171    /// handle.join().unwrap();
172    /// ```
173    pub fn wait_async(&self) -> WaitFuture {
174        WaitFuture {
175            state: 0.into(),
176            wq: self,
177        }
178    }
179
180    /// Notifies one waiting thread that is waiting on this HybridFutex using
181    /// the `wait_sync` or `wait_async` method. If there is no current waiting
182    /// threads, this function call indirectly notifies future call to
183    /// `wait_sync` or `wait_async` using the internal counter.
184    ///
185    /// # Examples
186    ///
187    /// ```
188    /// use std::sync::Arc;
189    /// use std::thread;
190    /// use std::time::Duration;
191    /// use hybridfutex::HybridFutex;
192    ///
193    /// let wait_queue = Arc::new(HybridFutex::new());
194    /// let wait_queue_clone = wait_queue.clone();
195    ///
196    /// // Spawn a thread that waits for a notification from another thread
197    /// let handle = thread::spawn(move || {
198    ///     println!("Thread 1 is waiting");
199    ///     wait_queue_clone.wait_sync();
200    ///     println!("Thread 1 is notified");
201    /// });
202    ///
203    /// // Wait for a short time before notifying the other thread
204    /// thread::sleep(Duration::from_millis(100));
205    ///
206    /// // Notify the other thread
207    /// wait_queue.notify_one();
208    ///
209    /// // Wait for the other thread to finish
210    /// handle.join().unwrap();
211    /// ```
212    pub fn notify_one(&self) {
213        let old_counter = self.counter.fetch_sub(1, Ordering::SeqCst);
214        if old_counter > 0 {
215            loop {
216                if let Some(waker) = self.queue.pop() {
217                    waker.wake();
218                    break;
219                }
220            }
221        }
222    }
223
224    /// Notifies a specified number of waiting threads that are waiting on this
225    /// HybridFutex using the `wait_sync` or `wait_async` method. If there are
226    /// less waiting threads than provided count, it indirectly notifies
227    /// futures calls to to `wait_sync` and `wait_async` using the internal
228    /// counter.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// use std::sync::Arc;
234    /// use std::thread;
235    /// use std::time::Duration;
236    /// use hybridfutex::HybridFutex;
237    ///
238    /// let wait_queue = Arc::new(HybridFutex::new());
239    ///
240    /// // Spawn multiple threads that wait for a notification from another thread
241    /// let handles: Vec<_> = (0..3).map(|i| {
242    ///     let wait_queue_clone = wait_queue.clone();
243    ///     thread::spawn(move || {
244    ///         println!("Thread {} is waiting", i);
245    ///         wait_queue_clone.wait_sync();
246    ///         println!("Thread {} is notified", i);
247    ///     })
248    /// }).collect();
249    ///
250    /// // Wait for a short time before notifying the threads
251    /// thread::sleep(Duration::from_millis(100));
252    ///
253    /// // Notify two threads
254    /// wait_queue.notify_many(2);
255    ///
256    /// // Notify single thread
257    /// wait_queue.notify_one();
258    ///
259    /// // Wait for the other threads to finish
260    /// for handle in handles {
261    ///     handle.join().unwrap();
262    /// }
263    /// ```
264    pub fn notify_many(&self, count: usize) {
265        let count = count as isize;
266        let old_counter = self.counter.fetch_sub(count, Ordering::SeqCst);
267        if old_counter > 0 {
268            for _ in 0..old_counter.min(count) {
269                loop {
270                    if let Some(waker) = self.queue.pop() {
271                        waker.wake();
272                        break;
273                    }
274                }
275            }
276        }
277    }
278}
279
280enum Waiter {
281    Sync(SyncWaiter),
282    Async(AsyncWaiter),
283}
284
285unsafe impl Send for Waiter {}
286
287impl Waiter {
288    fn wake(self) {
289        match self {
290            Waiter::Sync(w) => w.wake(),
291            Waiter::Async(w) => w.wake(),
292        }
293    }
294}
295
296struct SyncWaiter {
297    awaken: *const AtomicBool,
298    thread: Thread,
299}
300
301impl SyncWaiter {
302    fn wake(self) {
303        unsafe {
304            (*self.awaken).store(true, Ordering::Release);
305        }
306        self.thread.unpark();
307    }
308}
309
310struct AsyncWaiter {
311    state: *const AtomicIsize,
312    waker: Waker,
313}
314
315impl AsyncWaiter {
316    fn wake(self) {
317        unsafe {
318            (*self.state).store(!0, Ordering::Release);
319        }
320        self.waker.wake();
321    }
322}
323#[derive(Debug)]
324/// A future representing a thread that is waiting for a notification from
325/// another thread using a `HybridFutex` synchronization primitive.
326pub struct WaitFuture<'a> {
327    /// The current state of the future, represented as an `AtomicIsize`.
328    /// The value of this field is `0` if the future has not yet been polled,
329    /// `1` if the future is waiting for a notification, and `!0` if the future
330    /// has been notified.
331    state: AtomicIsize,
332
333    /// A reference to the `HybridFutex` that this future is waiting on.
334    wq: &'a HybridFutex,
335}
336
337impl<'a> Future for WaitFuture<'a> {
338    type Output = ();
339    /// Polls the future, returning `Poll::Pending` if the future is still waiting
340    /// for a notification, and `Poll::Ready(())` if the future has been notified.
341    ///
342    /// If the future has not yet been polled, this method increments the counter
343    /// of the `HybridFutex` that the future is waiting on to indicate that the
344    /// current thread is waiting. If the counter is already negative, the future
345    /// immediately resolves and returns `Poll::Ready(())`. Otherwise, the method
346    /// pushes a new `AsyncWaiter` onto the queue of waiters for the `HybridFutex`,
347    /// and returns `Poll::Pending`.
348    ///
349    /// If the future has already been polled and the value of the `state` field is
350    /// `1`, this method simply returns `Poll::Pending` without modifying the state
351    /// or the queue of waiters.
352    ///
353    /// If the future has already been notified and the value of the `state` field
354    /// is `!0`, this method returns `Poll::Ready(())` without modifying the state
355    /// or the queue of waiters.
356    fn poll(
357        self: std::pin::Pin<&mut Self>,
358        cx: &mut std::task::Context<'_>,
359    ) -> std::task::Poll<()> {
360        match self.state.load(Ordering::Acquire) {
361            0 => {
362                // If the future has not yet been polled, increment the counter
363                // of the HybridFutex and push a new AsyncWaiter onto the queue.
364                let old_counter = self.wq.counter.fetch_add(1, Ordering::SeqCst);
365                if old_counter >= 0 {
366                    self.state.store(1, Ordering::Relaxed);
367                    self.wq.queue.push(Waiter::Async(AsyncWaiter {
368                        state: &self.state,
369                        waker: cx.waker().clone(),
370                    }));
371                    Poll::Pending
372                } else {
373                    // If the counter is negative, the future has already been
374                    // notified, so set the state to !0 and return Poll::Ready(()).
375                    self.state.store(!0, Ordering::Relaxed);
376                    Poll::Ready(())
377                }
378            }
379            1 => Poll::Pending,
380            _ => Poll::Ready(()),
381        }
382    }
383}
384
385impl<'a> Drop for WaitFuture<'a> {
386    /// Drops the future, checking whether it has been polled before and
387    /// panicking if it has not. This is to prevent potential memory leaks
388    /// if the future is dropped before being polled.
389    fn drop(&mut self) {
390        if self.state.load(Ordering::Relaxed) == 1 {
391            abort();
392        }
393    }
394}