tokio_graceful/
trigger.rs1use std::{
19 future::Future,
20 pin::Pin,
21 task::{Context, Poll, Waker},
22};
23
24use pin_project_lite::pin_project;
25use slab::Slab;
26
27use crate::sync::{Arc, AtomicBool, Mutex, Ordering};
28
29type WakerList = Arc<Mutex<Slab<Option<Waker>>>>;
30type TriggerState = Arc<AtomicBool>;
31
32#[derive(Debug, Clone)]
35struct Subscriber {
36 wakers: WakerList,
37 state: TriggerState,
38}
39
40#[derive(Debug)]
45enum SubscriberState {
46 Waiting(usize),
47 Triggered,
48}
49
50impl Subscriber {
51 pub fn state(&self, cx: &mut Context, key: Option<usize>) -> SubscriberState {
62 if self.state.load(Ordering::SeqCst) {
63 return SubscriberState::Triggered;
64 }
65
66 let mut wakers = self.wakers.lock().unwrap();
67
68 if self.state.load(Ordering::SeqCst) {
73 return SubscriberState::Triggered;
74 }
75
76 let waker = Some(cx.waker().clone());
77
78 SubscriberState::Waiting(if let Some(key) = key {
79 tracing::trace!("trigger::Subscriber: updating waker for key: {}", key);
80 *wakers.get_mut(key).unwrap() = waker;
81 key
82 } else {
83 let key = wakers.insert(waker);
84 tracing::trace!("trigger::Subscriber: insert waker for key: {}", key);
85 key
86 })
87 }
88}
89
90#[derive(Debug)]
96enum ReceiverState {
97 Open { sub: Subscriber, key: Option<usize> },
98 Closed,
99 Pending,
100}
101
102impl Clone for ReceiverState {
103 fn clone(&self) -> Self {
108 match self {
109 ReceiverState::Open { sub, .. } => ReceiverState::Open {
110 sub: sub.clone(),
111 key: None,
112 },
113 ReceiverState::Closed => ReceiverState::Closed,
114 ReceiverState::Pending => ReceiverState::Pending,
115 }
116 }
117}
118
119impl Drop for ReceiverState {
120 fn drop(&mut self) {
123 if let ReceiverState::Open { sub, key } = self {
124 if let Some(key) = key.take() {
125 let mut wakers = sub.wakers.lock().unwrap();
126 tracing::trace!(
127 "trigger::ReceiverState::Drop: remove waker for key: {}",
128 key
129 );
130 wakers.remove(key);
131 }
132 }
133 }
134}
135
136pin_project! {
137 #[derive(Debug, Clone)]
138 pub struct Receiver {
139 state: ReceiverState,
140 }
141}
142
143impl Receiver {
144 fn new(wakers: WakerList, state: TriggerState) -> Self {
145 Self {
146 state: ReceiverState::Open {
147 sub: Subscriber { wakers, state },
148 key: None,
149 },
150 }
151 }
152
153 pub(crate) fn closed() -> Self {
155 Self {
156 state: ReceiverState::Closed,
157 }
158 }
159
160 pub(crate) fn pending() -> Self {
162 Self {
163 state: ReceiverState::Pending,
164 }
165 }
166}
167
168impl Future for Receiver {
169 type Output = ();
170
171 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
176 let this = self.project();
177 match this.state {
178 ReceiverState::Open { sub, key } => {
179 let state = sub.state(cx, *key);
180 match state {
181 SubscriberState::Waiting(new_key) => {
182 *key = Some(new_key);
183 std::task::Poll::Pending
184 }
185 SubscriberState::Triggered => {
186 *this.state = ReceiverState::Closed;
187 std::task::Poll::Ready(())
188 }
189 }
190 }
191 ReceiverState::Closed => std::task::Poll::Ready(()),
192 ReceiverState::Pending => std::task::Poll::Pending,
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
198pub struct Sender {
199 state: TriggerState,
200 wakers: WakerList,
201}
202
203impl Sender {
204 fn new(wakers: WakerList, state: TriggerState) -> Self {
205 Self { wakers, state }
206 }
207
208 pub fn trigger(&self) {
210 if self.state.swap(true, Ordering::SeqCst) {
211 return;
212 }
213
214 let mut wakers = self.wakers.lock().unwrap();
215 for (key, waker) in wakers.iter_mut() {
216 match waker.take() {
217 Some(waker) => {
218 tracing::trace!("trigger::Sender: wake up waker with key: {}", key);
219 waker.wake();
220 }
221 None => {
222 tracing::trace!(
223 "trigger::Sender: nop: waker already triggered with key: {}",
224 key
225 );
226 }
227 }
228 }
229 }
230}
231
232pub fn trigger() -> (Sender, Receiver) {
233 let wakers = Arc::new(Mutex::new(Slab::new()));
234 let state = Arc::new(AtomicBool::new(false));
235
236 let sender = Sender::new(wakers.clone(), state.clone());
237 let receiver = Receiver::new(wakers, state);
238
239 (sender, receiver)
240}
241
242#[cfg(all(test, not(loom)))]
243mod tests {
244 use super::*;
245
246 #[tokio::test]
247 async fn test_sender_trigger() {
248 let (sender, receiver) = trigger();
249
250 let th = tokio::spawn(async move {
251 sender.trigger();
252 });
253
254 receiver.await;
255
256 th.await.unwrap();
257 }
258
259 #[tokio::test]
260 async fn test_sender_never_trigger() {
261 let (_, receiver) = trigger();
262 tokio::time::timeout(std::time::Duration::from_millis(100), receiver)
263 .await
264 .unwrap_err();
265 }
266}
267
268#[cfg(all(test, loom))]
269mod loom_tests {
270 use super::*;
271
272 use loom::{future::block_on, thread};
273
274 #[test]
275 fn test_loom_sender_trigger() {
276 loom::model(|| {
277 let (sender, receiver) = trigger();
278
279 let th = thread::spawn(move || {
280 sender.trigger();
281 });
282
283 block_on(async move {
284 receiver.await;
285 });
286
287 th.join().unwrap();
288 });
289 }
290}