1use crate::socket_table::SocketId;
2use std::ops::{BitOr, BitOrAssign};
3use std::sync::{Arc, Condvar, Mutex, MutexGuard};
4use std::time::{Duration, Instant};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub struct PollEvents(u16);
8
9impl PollEvents {
10 pub const fn empty() -> Self {
11 Self(0)
12 }
13
14 pub const fn from_bits(bits: u16) -> Self {
15 Self(bits)
16 }
17
18 pub const fn bits(self) -> u16 {
19 self.0
20 }
21
22 pub const fn is_empty(self) -> bool {
23 self.0 == 0
24 }
25
26 pub const fn contains(self, other: Self) -> bool {
27 self.0 & other.0 == other.0
28 }
29
30 pub const fn intersects(self, other: Self) -> bool {
31 self.0 & other.0 != 0
32 }
33}
34
35impl BitOr for PollEvents {
36 type Output = Self;
37
38 fn bitor(self, rhs: Self) -> Self::Output {
39 Self(self.0 | rhs.0)
40 }
41}
42
43impl BitOrAssign for PollEvents {
44 fn bitor_assign(&mut self, rhs: Self) {
45 self.0 |= rhs.0;
46 }
47}
48
49pub const POLLIN: PollEvents = PollEvents(0x0001);
50pub const POLLOUT: PollEvents = PollEvents(0x0004);
51pub const POLLERR: PollEvents = PollEvents(0x0008);
52pub const POLLHUP: PollEvents = PollEvents(0x0010);
53pub const POLLNVAL: PollEvents = PollEvents(0x0020);
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct PollFd {
57 pub fd: u32,
58 pub events: PollEvents,
59 pub revents: PollEvents,
60}
61
62impl PollFd {
63 pub const fn new(fd: u32, events: PollEvents) -> Self {
64 Self {
65 fd,
66 events,
67 revents: PollEvents::empty(),
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct PollResult {
74 pub ready_count: usize,
75 pub fds: Vec<PollFd>,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum PollTarget {
80 Fd(u32),
81 Socket(SocketId),
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub struct PollTargetEntry {
86 pub target: PollTarget,
87 pub events: PollEvents,
88 pub revents: PollEvents,
89}
90
91impl PollTargetEntry {
92 pub const fn new(target: PollTarget, events: PollEvents) -> Self {
93 Self {
94 target,
95 events,
96 revents: PollEvents::empty(),
97 }
98 }
99
100 pub const fn fd(fd: u32, events: PollEvents) -> Self {
101 Self::new(PollTarget::Fd(fd), events)
102 }
103
104 pub const fn socket(socket_id: SocketId, events: PollEvents) -> Self {
105 Self::new(PollTarget::Socket(socket_id), events)
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub struct PollTargetResult {
111 pub ready_count: usize,
112 pub targets: Vec<PollTargetEntry>,
113}
114
115#[derive(Debug, Clone, Default)]
116pub(crate) struct PollNotifier {
117 inner: Arc<PollNotifierInner>,
118}
119
120#[derive(Debug, Default)]
121struct PollNotifierInner {
122 generation: Mutex<u64>,
123 waiters: Condvar,
124}
125
126impl PollNotifier {
127 pub(crate) fn notify(&self) {
128 let mut generation = lock_or_recover(&self.inner.generation);
129 *generation = generation.wrapping_add(1);
130 self.inner.waiters.notify_all();
131 }
132
133 pub(crate) fn snapshot(&self) -> u64 {
134 *lock_or_recover(&self.inner.generation)
135 }
136
137 pub(crate) fn wait_for_change(&self, observed: u64, timeout: Option<Duration>) -> bool {
138 let mut generation = lock_or_recover(&self.inner.generation);
139 if *generation != observed {
140 return true;
141 }
142
143 let Some(timeout) = timeout else {
144 while *generation == observed {
145 generation = wait_or_recover(&self.inner.waiters, generation);
146 }
147 return true;
148 };
149
150 if timeout.is_zero() {
151 return *generation != observed;
152 }
153
154 let deadline = Instant::now() + timeout;
155 loop {
156 let now = Instant::now();
157 if now >= deadline {
158 return *generation != observed;
159 }
160
161 let remaining = deadline.saturating_duration_since(now);
162 let (next_generation, wait_result) =
163 wait_timeout_or_recover(&self.inner.waiters, generation, remaining);
164 generation = next_generation;
165 if *generation != observed {
166 return true;
167 }
168 if wait_result.timed_out() {
169 return false;
170 }
171 }
172 }
173}
174
175fn lock_or_recover<'a, T>(mutex: &'a Mutex<T>) -> MutexGuard<'a, T> {
176 match mutex.lock() {
177 Ok(guard) => guard,
178 Err(poisoned) => poisoned.into_inner(),
179 }
180}
181
182fn wait_or_recover<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
183 match condvar.wait(guard) {
184 Ok(guard) => guard,
185 Err(poisoned) => poisoned.into_inner(),
186 }
187}
188
189fn wait_timeout_or_recover<'a, T>(
190 condvar: &Condvar,
191 guard: MutexGuard<'a, T>,
192 timeout: Duration,
193) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
194 match condvar.wait_timeout(guard, timeout) {
195 Ok(result) => result,
196 Err(poisoned) => poisoned.into_inner(),
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::PollNotifier;
203 use std::sync::mpsc;
204 use std::thread;
205 use std::time::Duration;
206
207 #[test]
208 fn infinite_wait_returns_after_notification_without_waiter_storage() {
209 let notifier = PollNotifier::default();
210 let observed = notifier.snapshot();
211 let waiter = notifier.clone();
212 let (started_tx, started_rx) = mpsc::channel();
213 let (done_tx, done_rx) = mpsc::channel();
214
215 let handle = thread::spawn(move || {
216 started_tx.send(()).expect("signal waiter start");
217 let changed = waiter.wait_for_change(observed, None);
218 done_tx.send(changed).expect("signal waiter result");
219 });
220
221 started_rx.recv().expect("waiter should start");
222 assert!(
223 done_rx.recv_timeout(Duration::from_millis(25)).is_err(),
224 "waiter should stay blocked before notification"
225 );
226
227 notifier.notify();
228 assert!(done_rx
229 .recv_timeout(Duration::from_secs(1))
230 .expect("waiter should wake after notification"));
231 handle.join().expect("waiter thread should finish");
232 }
233
234 #[test]
235 fn saturated_generation_still_notifies_waiters() {
236 let notifier = PollNotifier::default();
237 {
238 let mut generation = super::lock_or_recover(¬ifier.inner.generation);
239 *generation = u64::MAX;
240 }
241
242 let observed = notifier.snapshot();
243 let waiter = notifier.clone();
244 let (started_tx, started_rx) = mpsc::channel();
245 let (done_tx, done_rx) = mpsc::channel();
246
247 let handle = thread::spawn(move || {
248 started_tx.send(()).expect("signal waiter start");
249 let changed = waiter.wait_for_change(observed, Some(Duration::from_secs(1)));
250 done_tx.send(changed).expect("signal waiter result");
251 });
252
253 started_rx.recv().expect("waiter should start");
254 notifier.notify();
255
256 assert!(
257 done_rx
258 .recv_timeout(Duration::from_secs(2))
259 .expect("waiter should return after saturated notify"),
260 "saturated notify should still wake the waiter"
261 );
262 handle.join().expect("waiter thread should finish");
263 }
264}