Skip to main content

neco_server_runtime/
lib.rs

1#![warn(missing_docs)]
2
3//! Runtime primitives for `neco-server`.
4
5use core::future::Future;
6use std::collections::VecDeque;
7use std::io;
8use std::net::{SocketAddr, TcpListener};
9use std::pin::Pin;
10use std::sync::{Arc, Condvar, Mutex};
11use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
12use std::thread;
13
14/// Fire-and-forget task spawning.
15pub trait Spawn: Send + Sync + 'static {
16    /// Spawns a detached future.
17    fn spawn<F>(&self, future: F)
18    where
19        F: Future<Output = ()> + Send + 'static;
20}
21
22/// Multi-subscriber event channel.
23pub trait EventChannel<T>: Send + Sync
24where
25    T: Clone + Send + 'static,
26{
27    /// Receiver type returned by `subscribe`.
28    type Receiver: EventReceiver<T>;
29
30    /// Sends a value to all subscribers.
31    fn send(&self, value: T) -> Result<(), EventChannelError>;
32
33    /// Creates a new receiver.
34    fn subscribe(&self) -> Self::Receiver;
35
36    /// Returns the current subscriber count.
37    fn subscriber_count(&self) -> usize;
38}
39
40/// Receiver side of an event channel.
41pub trait EventReceiver<T>: Send
42where
43    T: Send,
44{
45    /// Receives the next value.
46    fn recv(&mut self) -> impl Future<Output = Result<T, EventChannelError>> + Send;
47}
48
49/// Event channel error.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub enum EventChannelError {
52    /// The channel is closed.
53    Closed,
54    /// The receiver lagged behind and lost `u64` messages.
55    Lagged(u64),
56}
57
58impl core::fmt::Display for EventChannelError {
59    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60        match self {
61            Self::Closed => write!(f, "event channel closed"),
62            Self::Lagged(count) => write!(f, "event channel lagged by {count} messages"),
63        }
64    }
65}
66
67impl std::error::Error for EventChannelError {}
68
69/// Binds a TCP listener for serving incoming connections.
70pub fn bind_listener(addr: SocketAddr) -> io::Result<TcpListener> {
71    TcpListener::bind(addr)
72}
73
74/// Drives a future to completion on the current thread without an async runtime.
75///
76/// This executor only makes progress for futures that can complete by repeated polling
77/// on the current thread. Futures that rely on an external reactor, timer wheel, or
78/// wake-driven async I/O must not be passed here.
79fn block_on<F>(future: F) -> F::Output
80where
81    F: Future,
82{
83    fn raw_waker() -> RawWaker {
84        fn clone(_: *const ()) -> RawWaker {
85            raw_waker()
86        }
87        fn wake(_: *const ()) {}
88        fn wake_by_ref(_: *const ()) {}
89        fn drop(_: *const ()) {}
90
91        RawWaker::new(
92            std::ptr::null(),
93            &RawWakerVTable::new(clone, wake, wake_by_ref, drop),
94        )
95    }
96
97    let waker = unsafe { Waker::from_raw(raw_waker()) };
98    let mut future = Box::pin(future);
99    let mut context = Context::from_waker(&waker);
100
101    loop {
102        match Pin::as_mut(&mut future).poll(&mut context) {
103            Poll::Ready(value) => return value,
104            Poll::Pending => thread::yield_now(),
105        }
106    }
107}
108
109/// Detached task runner backed by OS threads.
110#[derive(Debug, Clone, Default)]
111pub struct DetachedTasks;
112
113impl DetachedTasks {
114    /// Creates a detached task runner.
115    pub fn current() -> Self {
116        Self
117    }
118}
119
120impl Spawn for DetachedTasks {
121    /// Spawns a future on a dedicated OS thread using the local busy-loop executor.
122    ///
123    /// The spawned future must not depend on an external async runtime.
124    fn spawn<F>(&self, future: F)
125    where
126        F: Future<Output = ()> + Send + 'static,
127    {
128        thread::spawn(move || {
129            block_on(future);
130        });
131    }
132}
133
134#[derive(Debug)]
135struct FanoutState<T> {
136    buffer: VecDeque<(u64, T)>,
137    next_seq: u64,
138    receiver_count: usize,
139}
140
141#[derive(Debug)]
142struct FanoutShared<T> {
143    capacity: usize,
144    state: Mutex<FanoutState<T>>,
145    condvar: Condvar,
146}
147
148/// Multi-subscriber fanout channel with ring-buffer semantics.
149#[derive(Debug, Clone)]
150pub struct FanoutChannel<T: Clone + Send + 'static> {
151    shared: Arc<FanoutShared<T>>,
152}
153
154impl<T: Clone + Send + 'static> FanoutChannel<T> {
155    /// Creates a channel with the given ring buffer capacity.
156    pub fn new(capacity: usize) -> Self {
157        let capacity = capacity.max(1);
158        Self {
159            shared: Arc::new(FanoutShared {
160                capacity,
161                state: Mutex::new(FanoutState {
162                    buffer: VecDeque::new(),
163                    next_seq: 0,
164                    receiver_count: 0,
165                }),
166                condvar: Condvar::new(),
167            }),
168        }
169    }
170}
171
172impl<T: Clone + Send + 'static> EventChannel<T> for FanoutChannel<T> {
173    type Receiver = FanoutReceiver<T>;
174
175    fn send(&self, value: T) -> Result<(), EventChannelError> {
176        let mut state = self.shared.state.lock().expect("fanout channel poisoned");
177        let seq = state.next_seq;
178        state.next_seq += 1;
179        state.buffer.push_back((seq, value));
180        while state.buffer.len() > self.shared.capacity {
181            state.buffer.pop_front();
182        }
183        self.shared.condvar.notify_all();
184        Ok(())
185    }
186
187    fn subscribe(&self) -> Self::Receiver {
188        let mut state = self.shared.state.lock().expect("fanout channel poisoned");
189        state.receiver_count += 1;
190        let next_seq = state.next_seq;
191        drop(state);
192        FanoutReceiver {
193            shared: self.shared.clone(),
194            next_seq,
195        }
196    }
197
198    fn subscriber_count(&self) -> usize {
199        self.shared
200            .state
201            .lock()
202            .expect("fanout channel poisoned")
203            .receiver_count
204    }
205}
206
207/// Receiver returned by [`FanoutChannel::subscribe`].
208pub struct FanoutReceiver<T: Clone + Send + 'static> {
209    shared: Arc<FanoutShared<T>>,
210    next_seq: u64,
211}
212
213impl<T: Clone + Send + 'static> Drop for FanoutReceiver<T> {
214    fn drop(&mut self) {
215        let mut state = self.shared.state.lock().expect("fanout channel poisoned");
216        state.receiver_count = state.receiver_count.saturating_sub(1);
217    }
218}
219
220impl<T: Clone + Send + 'static> EventReceiver<T> for FanoutReceiver<T> {
221    async fn recv(&mut self) -> Result<T, EventChannelError> {
222        loop {
223            let mut state = self.shared.state.lock().expect("fanout channel poisoned");
224
225            if let Some((oldest_seq, _)) = state.buffer.front() {
226                if self.next_seq < *oldest_seq {
227                    let lagged = *oldest_seq - self.next_seq;
228                    self.next_seq = *oldest_seq;
229                    return Err(EventChannelError::Lagged(lagged));
230                }
231            }
232
233            if let Some((_, value)) = state
234                .buffer
235                .iter()
236                .find(|(seq, _)| *seq == self.next_seq)
237                .cloned()
238            {
239                self.next_seq += 1;
240                return Ok(value);
241            }
242
243            state = self
244                .shared
245                .condvar
246                .wait(state)
247                .expect("fanout channel poisoned");
248            drop(state);
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use std::sync::mpsc;
257
258    #[test]
259    fn fanout_channel_send_receive_smoke() {
260        let bus: FanoutChannel<u64> = FanoutChannel::new(16);
261        let mut rx = bus.subscribe();
262        bus.send(42).expect("send");
263        let value = block_on(rx.recv()).expect("recv");
264        assert_eq!(value, 42);
265    }
266
267    #[test]
268    fn fanout_channel_send_with_no_subscriber_is_ok() {
269        let bus: FanoutChannel<u64> = FanoutChannel::new(16);
270        bus.send(1).expect("send must be ok");
271        bus.send(2).expect("send must be ok");
272        assert_eq!(bus.subscriber_count(), 0);
273    }
274
275    #[test]
276    fn fanout_channel_lag_returns_error() {
277        let bus: FanoutChannel<u64> = FanoutChannel::new(2);
278        let mut rx = bus.subscribe();
279        for value in 0..5 {
280            bus.send(value).expect("send");
281        }
282        let result = block_on(rx.recv());
283        assert!(matches!(result, Err(EventChannelError::Lagged(_))));
284    }
285
286    #[test]
287    fn detached_tasks_runs_future() {
288        let runtime = DetachedTasks::current();
289        let (tx, rx) = mpsc::channel();
290        runtime.spawn(async move {
291            let _ = tx.send(7u64);
292        });
293        let value = rx.recv().expect("oneshot");
294        assert_eq!(value, 7);
295    }
296}