Skip to main content

secure_exec_kernel/
poll.rs

1use crate::socket_table::SocketId;
2use std::ops::{BitOr, BitOrAssign};
3use std::sync::{Arc, Condvar, Mutex, MutexGuard};
4use std::time::{Duration, Instant};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub struct PollEvents(u16);
8
9impl PollEvents {
10    pub const fn empty() -> Self {
11        Self(0)
12    }
13
14    pub const fn from_bits(bits: u16) -> Self {
15        Self(bits)
16    }
17
18    pub const fn bits(self) -> u16 {
19        self.0
20    }
21
22    pub const fn is_empty(self) -> bool {
23        self.0 == 0
24    }
25
26    pub const fn contains(self, other: Self) -> bool {
27        self.0 & other.0 == other.0
28    }
29
30    pub const fn intersects(self, other: Self) -> bool {
31        self.0 & other.0 != 0
32    }
33}
34
35impl BitOr for PollEvents {
36    type Output = Self;
37
38    fn bitor(self, rhs: Self) -> Self::Output {
39        Self(self.0 | rhs.0)
40    }
41}
42
43impl BitOrAssign for PollEvents {
44    fn bitor_assign(&mut self, rhs: Self) {
45        self.0 |= rhs.0;
46    }
47}
48
49pub const POLLIN: PollEvents = PollEvents(0x0001);
50pub const POLLOUT: PollEvents = PollEvents(0x0004);
51pub const POLLERR: PollEvents = PollEvents(0x0008);
52pub const POLLHUP: PollEvents = PollEvents(0x0010);
53pub const POLLNVAL: PollEvents = PollEvents(0x0020);
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct PollFd {
57    pub fd: u32,
58    pub events: PollEvents,
59    pub revents: PollEvents,
60}
61
62impl PollFd {
63    pub const fn new(fd: u32, events: PollEvents) -> Self {
64        Self {
65            fd,
66            events,
67            revents: PollEvents::empty(),
68        }
69    }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct PollResult {
74    pub ready_count: usize,
75    pub fds: Vec<PollFd>,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum PollTarget {
80    Fd(u32),
81    Socket(SocketId),
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub struct PollTargetEntry {
86    pub target: PollTarget,
87    pub events: PollEvents,
88    pub revents: PollEvents,
89}
90
91impl PollTargetEntry {
92    pub const fn new(target: PollTarget, events: PollEvents) -> Self {
93        Self {
94            target,
95            events,
96            revents: PollEvents::empty(),
97        }
98    }
99
100    pub const fn fd(fd: u32, events: PollEvents) -> Self {
101        Self::new(PollTarget::Fd(fd), events)
102    }
103
104    pub const fn socket(socket_id: SocketId, events: PollEvents) -> Self {
105        Self::new(PollTarget::Socket(socket_id), events)
106    }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub struct PollTargetResult {
111    pub ready_count: usize,
112    pub targets: Vec<PollTargetEntry>,
113}
114
115#[derive(Debug, Clone, Default)]
116pub(crate) struct PollNotifier {
117    inner: Arc<PollNotifierInner>,
118}
119
120#[derive(Debug, Default)]
121struct PollNotifierInner {
122    generation: Mutex<u64>,
123    waiters: Condvar,
124}
125
126impl PollNotifier {
127    pub(crate) fn notify(&self) {
128        let mut generation = lock_or_recover(&self.inner.generation);
129        *generation = generation.wrapping_add(1);
130        self.inner.waiters.notify_all();
131    }
132
133    pub(crate) fn snapshot(&self) -> u64 {
134        *lock_or_recover(&self.inner.generation)
135    }
136
137    pub(crate) fn wait_for_change(&self, observed: u64, timeout: Option<Duration>) -> bool {
138        let mut generation = lock_or_recover(&self.inner.generation);
139        if *generation != observed {
140            return true;
141        }
142
143        let Some(timeout) = timeout else {
144            while *generation == observed {
145                generation = wait_or_recover(&self.inner.waiters, generation);
146            }
147            return true;
148        };
149
150        if timeout.is_zero() {
151            return *generation != observed;
152        }
153
154        let deadline = Instant::now() + timeout;
155        loop {
156            let now = Instant::now();
157            if now >= deadline {
158                return *generation != observed;
159            }
160
161            let remaining = deadline.saturating_duration_since(now);
162            let (next_generation, wait_result) =
163                wait_timeout_or_recover(&self.inner.waiters, generation, remaining);
164            generation = next_generation;
165            if *generation != observed {
166                return true;
167            }
168            if wait_result.timed_out() {
169                return false;
170            }
171        }
172    }
173}
174
175fn lock_or_recover<'a, T>(mutex: &'a Mutex<T>) -> MutexGuard<'a, T> {
176    match mutex.lock() {
177        Ok(guard) => guard,
178        Err(poisoned) => poisoned.into_inner(),
179    }
180}
181
182fn wait_or_recover<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
183    match condvar.wait(guard) {
184        Ok(guard) => guard,
185        Err(poisoned) => poisoned.into_inner(),
186    }
187}
188
189fn wait_timeout_or_recover<'a, T>(
190    condvar: &Condvar,
191    guard: MutexGuard<'a, T>,
192    timeout: Duration,
193) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
194    match condvar.wait_timeout(guard, timeout) {
195        Ok(result) => result,
196        Err(poisoned) => poisoned.into_inner(),
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::PollNotifier;
203    use std::sync::mpsc;
204    use std::thread;
205    use std::time::Duration;
206
207    #[test]
208    fn infinite_wait_returns_after_notification_without_waiter_storage() {
209        let notifier = PollNotifier::default();
210        let observed = notifier.snapshot();
211        let waiter = notifier.clone();
212        let (started_tx, started_rx) = mpsc::channel();
213        let (done_tx, done_rx) = mpsc::channel();
214
215        let handle = thread::spawn(move || {
216            started_tx.send(()).expect("signal waiter start");
217            let changed = waiter.wait_for_change(observed, None);
218            done_tx.send(changed).expect("signal waiter result");
219        });
220
221        started_rx.recv().expect("waiter should start");
222        assert!(
223            done_rx.recv_timeout(Duration::from_millis(25)).is_err(),
224            "waiter should stay blocked before notification"
225        );
226
227        notifier.notify();
228        assert!(done_rx
229            .recv_timeout(Duration::from_secs(1))
230            .expect("waiter should wake after notification"));
231        handle.join().expect("waiter thread should finish");
232    }
233
234    #[test]
235    fn saturated_generation_still_notifies_waiters() {
236        let notifier = PollNotifier::default();
237        {
238            let mut generation = super::lock_or_recover(&notifier.inner.generation);
239            *generation = u64::MAX;
240        }
241
242        let observed = notifier.snapshot();
243        let waiter = notifier.clone();
244        let (started_tx, started_rx) = mpsc::channel();
245        let (done_tx, done_rx) = mpsc::channel();
246
247        let handle = thread::spawn(move || {
248            started_tx.send(()).expect("signal waiter start");
249            let changed = waiter.wait_for_change(observed, Some(Duration::from_secs(1)));
250            done_tx.send(changed).expect("signal waiter result");
251        });
252
253        started_rx.recv().expect("waiter should start");
254        notifier.notify();
255
256        assert!(
257            done_rx
258                .recv_timeout(Duration::from_secs(2))
259                .expect("waiter should return after saturated notify"),
260            "saturated notify should still wake the waiter"
261        );
262        handle.join().expect("waiter thread should finish");
263    }
264}