1use std::{
4 sync::{Arc, Weak},
5 task::{Poll, Waker},
6};
7
8use parking_lot::Mutex;
9use smallvec::SmallVec;
10
11#[derive(Debug, Clone)]
12pub struct Sender(Arc<Mutex<Inner>>);
13
14#[derive(Debug, Clone)]
15pub struct Receiver(usize, Weak<Mutex<Inner>>);
16
17impl Default for Sender {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl Sender {
24 pub fn new() -> Self {
25 Self(Arc::new(Mutex::new(Inner { fence: 1, waiters: Default::default() })))
26 }
27
28 pub fn notify(&self) {
29 let mut inner = self.0.lock();
30 inner.fence = inner.fence.wrapping_add(2); inner.waiters.drain(..).for_each(|x| x.1.wake());
32 }
33
34 pub fn receiver(&self, fresh: bool) -> Receiver {
35 Receiver(if fresh { 0 } else { self.0.lock().fence }, Arc::downgrade(&self.0))
36 }
37}
38
39#[derive(Debug)]
40struct Inner {
41 fence: usize,
42 waiters: SmallVec<[(usize, Waker); 4]>,
43}
44
45impl Receiver {
46 pub fn invalidate(&mut self) {
47 self.0 = 0;
48 }
49
50 pub fn has_update(&self) -> Option<bool> {
51 self.1.upgrade().map(|x| x.lock().fence != self.0)
52 }
53
54 pub fn try_recv(&mut self) -> Result<(), TryWaitError> {
55 let flag = self.1.upgrade().ok_or(TryWaitError::Closed)?.lock().fence;
56 if self.0 != flag {
57 self.0 = flag;
58 Ok(())
59 } else {
60 Err(TryWaitError::Empty)
61 }
62 }
63
64 pub fn recv(&mut self) -> Wait {
65 Wait { rx: self, state: WaitState::Created }
66 }
67}
68
69#[derive(thiserror::Error, Debug)]
70pub enum TryWaitError {
71 #[error("Closed notify channel")]
72 Closed,
73
74 #[error("There's no update")]
75 Empty,
76}
77
78#[derive(thiserror::Error, Debug)]
79pub enum WaitError {
80 #[error("Closed notify channel")]
81 Closed,
82
83 #[error("Expired notify channel")]
84 Expired,
85}
86
87#[derive(Debug)]
88pub struct Wait<'a> {
89 rx: &'a mut Receiver,
90 state: WaitState,
91}
92
93#[derive(Debug, Clone, Copy)]
94enum WaitState {
95 Created,
96 Registered,
97 Expired,
98}
99
100impl<'a> Wait<'a> {
101 fn unregister(&mut self) {
102 let id = self.get_id();
103
104 debug_assert!(matches!(self.state, WaitState::Registered));
106
107 let Some(inner) = self.rx.1.upgrade() else { return };
109 let inner = &mut inner.lock().waiters;
110
111 if let Some(idx) = inner.iter().position(|x| x.0 == id) {
113 inner.swap_remove(idx);
114 } else {
115 }
118 }
119
120 fn get_id(&self) -> usize {
121 self.rx as *const _ as usize
122 }
123}
124
125impl<'a> Drop for Wait<'a> {
126 fn drop(&mut self) {
127 if matches!(self.state, WaitState::Registered) {
128 self.unregister();
129 }
130 }
131}
132
133impl<'a> std::future::Future for Wait<'a> {
134 type Output = Result<(), WaitError>;
135
136 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
137 let this = self.get_mut();
138 let id = this.get_id();
139
140 match this.state {
141 WaitState::Created => {
142 let Some(inner) = this.rx.1.upgrade() else {
143 this.state = WaitState::Expired;
144 return Poll::Ready(Err(WaitError::Closed));
145 };
146
147 let mut inner = inner.lock();
148
149 if inner.fence != this.rx.0 {
150 this.rx.0 = inner.fence;
152 return Poll::Ready(Ok(()));
153 }
154
155 inner.waiters.push((id, cx.waker().clone()));
156 this.state = WaitState::Registered;
157
158 Poll::Pending
159 }
160
161 WaitState::Registered => {
162 let Some(inner) = this.rx.1.upgrade() else {
163 this.state = WaitState::Expired;
164 return Poll::Ready(Err(WaitError::Closed));
165 };
166
167 let mut inner = inner.lock();
168
169 if inner.fence != this.rx.0 {
170 this.rx.0 = inner.fence;
171 this.state = WaitState::Expired;
172
173 Poll::Ready(Ok(()))
174 } else {
175 if inner.waiters.iter().any(|x| x.0 == id) {
176 inner.waiters.push((id, cx.waker().clone()));
179 } else {
180 }
183
184 Poll::Pending
185 }
186 }
187
188 WaitState::Expired => Poll::Ready(Err(WaitError::Expired)),
189 }
190 }
191}