Skip to main content

yaar_lock/sync/
wait_node.rs

1use crate::ThreadEvent;
2use core::{cell::Cell, mem::MaybeUninit, ptr::null};
3
4/// The state of a WaitNode in relation to the wait queue.
5#[derive(Debug, Copy, Clone, PartialEq)]
6pub(crate) enum WaitNodeState {
7    /// The node is uninitialized so reading from the fields is UB
8    Uninit,
9    /// The node is initialized and probably waiting in the wait queue.
10    Waiting,
11    /// The node was dequeued and should check a resource again.
12    Notified,
13    /// The node was dequeued and was given ownership of a resource.
14    DirectNotified,
15}
16
17/// An intrusive, doubly linked list node used to track blocked threads.
18pub(crate) struct WaitNode<E> {
19    state: Cell<WaitNodeState>,
20    event: Cell<MaybeUninit<E>>,
21    prev: Cell<MaybeUninit<*const Self>>,
22    next: Cell<MaybeUninit<*const Self>>,
23    tail: Cell<MaybeUninit<*const Self>>,
24}
25
26/// Lazy initialize WaitNode as it improves performance in the fast path.
27impl<E> Default for WaitNode<E> {
28    fn default() -> Self {
29        Self {
30            state: Cell::new(WaitNodeState::Uninit),
31            event: Cell::new(MaybeUninit::uninit()),
32            prev: Cell::new(MaybeUninit::uninit()),
33            next: Cell::new(MaybeUninit::uninit()),
34            tail: Cell::new(MaybeUninit::uninit()),
35        }
36    }
37}
38
39impl<E: Default> WaitNode<E> {
40    /// Given the head of the queue, prepend this WaitNode
41    /// to the queue by initializing it and returning the
42    /// new head of the queue.
43    pub fn enqueue(&self, head: *const Self) -> *const Self {
44        match self.state.get() {
45            // lazy initialize a node before prepending to the head
46            WaitNodeState::Uninit => {
47                self.state.set(WaitNodeState::Waiting);
48                self.prev.set(MaybeUninit::new(null()));
49                self.event.set(MaybeUninit::new(E::default()));
50            }
51            // node is already initialized, only change the links volatile to the head below.
52            WaitNodeState::Waiting => {}
53            // node is in an unknown state, unchecked in release for performance (less so than in
54            // notify())
55            #[cfg(not(debug_assertions))]
56            _ => unsafe { core::hint::unreachable_unchecked() },
57            // In debug mode, this fault should still be caught and reported
58            #[cfg(debug_assertions)]
59            unexpected => unreachable!(
60                "unexpected WaitNodeState: expected {:?} found {:?}",
61                WaitNodeState::Waiting,
62                unexpected,
63            ),
64        }
65
66        // prepare a node to be the new head of the queue
67        self.next.set(MaybeUninit::new(head));
68        if head.is_null() {
69            self.tail.set(MaybeUninit::new(self));
70        } else {
71            self.tail.set(MaybeUninit::new(null()));
72        }
73
74        // return ourselves as the new head
75        self as *const Self
76    }
77
78    /// Given the head of the queeu as ourselves,
79    /// dequeue a node from the queue returning the new tail
80    /// of the queue and the removed tail that was dequeued.
81    ///
82    /// This function is not pure like `enqueue()` and it modifies
83    /// the internal queue tail for tracking the tail node.
84    pub fn dequeue<'a>(&self) -> (*const Self, &'a Self) {
85        unsafe {
86            // Given the head of the queue
87            let head = self;
88            debug_assert_eq!(head.state.get(), WaitNodeState::Waiting);
89
90            // Find the tail, updating the links along the way
91            let mut current = head;
92            let mut tail = head.tail.get().assume_init();
93            while tail.is_null() {
94                let next = &*current.next.get().assume_init();
95                debug_assert_eq!((&*next).state.get(), WaitNodeState::Waiting);
96                next.prev.set(MaybeUninit::new(current));
97                tail = next.tail.get().assume_init();
98                current = next;
99            }
100
101            // Dequeue the tail, returning the new_tail and it.
102            debug_assert_eq!((&*tail).state.get(), WaitNodeState::Waiting);
103            let new_tail = (&*tail).prev.get().assume_init();
104            if (head as *const _) == tail {
105                (null(), &*tail)
106            } else {
107                head.tail.set(MaybeUninit::new(new_tail));
108                (new_tail, &*tail)
109            }
110        }
111    }
112}
113
114impl<E: ThreadEvent> WaitNode<E> {
115    /// Get a reference to the thread event, assuming the WaitNode is
116    /// initialized.
117    #[inline]
118    fn get_event(&self) -> &E {
119        unsafe { &*(&*self.event.as_ptr()).as_ptr() }
120    }
121
122    /// Reset the wait node without uninitializing it.
123    /// Less expensive than re-initialization, especially for larger
124    /// ThreadEvent's.
125    pub fn reset(&self) {
126        self.get_event().reset();
127        self.state.set(WaitNodeState::Waiting);
128        self.prev.set(MaybeUninit::new(null()));
129    }
130
131    /// Unblock this node, waking it up with either normal or direct notify.
132    /// This assumes that this WaitNode is in a waiting state.
133    pub fn notify(&self, is_direct: bool) {
134        let event = self.get_event();
135        debug_assert_eq!(self.state.get(), WaitNodeState::Waiting);
136        self.state.set(if is_direct {
137            WaitNodeState::DirectNotified
138        } else {
139            WaitNodeState::Notified
140        });
141        event.set();
142    }
143
144    /// Block this node, waiting to be notified by another WaitNode.
145    /// Returns whether the notification was direct.
146    /// This assumes that this node is initialized.
147    pub fn wait(&self) -> bool {
148        self.get_event().wait();
149        match self.state.get() {
150            WaitNodeState::Notified => false,
151            WaitNodeState::DirectNotified => true,
152            // Using unreachable_unchecked improves performance during benchmarks.
153            #[cfg(not(debug_assertions))]
154            _ => unsafe { core::hint::unreachable_unchecked() },
155            // In debug mode, this fault should still be caught and reported
156            #[cfg(debug_assertions)]
157            unexpected => unreachable!(
158                "unexpected WaitNodeState: expected {:?} or {:?} found {:?}",
159                WaitNodeState::Notified,
160                WaitNodeState::DirectNotified,
161                unexpected,
162            ),
163        }
164    }
165}