yaar_lock/sync/
wait_node.rs1use crate::ThreadEvent;
2use core::{cell::Cell, mem::MaybeUninit, ptr::null};
3
4#[derive(Debug, Copy, Clone, PartialEq)]
6pub(crate) enum WaitNodeState {
7 Uninit,
9 Waiting,
11 Notified,
13 DirectNotified,
15}
16
17pub(crate) struct WaitNode<E> {
19 state: Cell<WaitNodeState>,
20 event: Cell<MaybeUninit<E>>,
21 prev: Cell<MaybeUninit<*const Self>>,
22 next: Cell<MaybeUninit<*const Self>>,
23 tail: Cell<MaybeUninit<*const Self>>,
24}
25
26impl<E> Default for WaitNode<E> {
28 fn default() -> Self {
29 Self {
30 state: Cell::new(WaitNodeState::Uninit),
31 event: Cell::new(MaybeUninit::uninit()),
32 prev: Cell::new(MaybeUninit::uninit()),
33 next: Cell::new(MaybeUninit::uninit()),
34 tail: Cell::new(MaybeUninit::uninit()),
35 }
36 }
37}
38
39impl<E: Default> WaitNode<E> {
40 pub fn enqueue(&self, head: *const Self) -> *const Self {
44 match self.state.get() {
45 WaitNodeState::Uninit => {
47 self.state.set(WaitNodeState::Waiting);
48 self.prev.set(MaybeUninit::new(null()));
49 self.event.set(MaybeUninit::new(E::default()));
50 }
51 WaitNodeState::Waiting => {}
53 #[cfg(not(debug_assertions))]
56 _ => unsafe { core::hint::unreachable_unchecked() },
57 #[cfg(debug_assertions)]
59 unexpected => unreachable!(
60 "unexpected WaitNodeState: expected {:?} found {:?}",
61 WaitNodeState::Waiting,
62 unexpected,
63 ),
64 }
65
66 self.next.set(MaybeUninit::new(head));
68 if head.is_null() {
69 self.tail.set(MaybeUninit::new(self));
70 } else {
71 self.tail.set(MaybeUninit::new(null()));
72 }
73
74 self as *const Self
76 }
77
78 pub fn dequeue<'a>(&self) -> (*const Self, &'a Self) {
85 unsafe {
86 let head = self;
88 debug_assert_eq!(head.state.get(), WaitNodeState::Waiting);
89
90 let mut current = head;
92 let mut tail = head.tail.get().assume_init();
93 while tail.is_null() {
94 let next = &*current.next.get().assume_init();
95 debug_assert_eq!((&*next).state.get(), WaitNodeState::Waiting);
96 next.prev.set(MaybeUninit::new(current));
97 tail = next.tail.get().assume_init();
98 current = next;
99 }
100
101 debug_assert_eq!((&*tail).state.get(), WaitNodeState::Waiting);
103 let new_tail = (&*tail).prev.get().assume_init();
104 if (head as *const _) == tail {
105 (null(), &*tail)
106 } else {
107 head.tail.set(MaybeUninit::new(new_tail));
108 (new_tail, &*tail)
109 }
110 }
111 }
112}
113
114impl<E: ThreadEvent> WaitNode<E> {
115 #[inline]
118 fn get_event(&self) -> &E {
119 unsafe { &*(&*self.event.as_ptr()).as_ptr() }
120 }
121
122 pub fn reset(&self) {
126 self.get_event().reset();
127 self.state.set(WaitNodeState::Waiting);
128 self.prev.set(MaybeUninit::new(null()));
129 }
130
131 pub fn notify(&self, is_direct: bool) {
134 let event = self.get_event();
135 debug_assert_eq!(self.state.get(), WaitNodeState::Waiting);
136 self.state.set(if is_direct {
137 WaitNodeState::DirectNotified
138 } else {
139 WaitNodeState::Notified
140 });
141 event.set();
142 }
143
144 pub fn wait(&self) -> bool {
148 self.get_event().wait();
149 match self.state.get() {
150 WaitNodeState::Notified => false,
151 WaitNodeState::DirectNotified => true,
152 #[cfg(not(debug_assertions))]
154 _ => unsafe { core::hint::unreachable_unchecked() },
155 #[cfg(debug_assertions)]
157 unexpected => unreachable!(
158 "unexpected WaitNodeState: expected {:?} or {:?} found {:?}",
159 WaitNodeState::Notified,
160 WaitNodeState::DirectNotified,
161 unexpected,
162 ),
163 }
164 }
165}