neco_server_runtime/
lib.rs1#![warn(missing_docs)]
2
3use core::future::Future;
6use std::collections::VecDeque;
7use std::io;
8use std::net::{SocketAddr, TcpListener};
9use std::pin::Pin;
10use std::sync::{Arc, Condvar, Mutex};
11use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
12use std::thread;
13
14pub trait Spawn: Send + Sync + 'static {
16 fn spawn<F>(&self, future: F)
18 where
19 F: Future<Output = ()> + Send + 'static;
20}
21
22pub trait EventChannel<T>: Send + Sync
24where
25 T: Clone + Send + 'static,
26{
27 type Receiver: EventReceiver<T>;
29
30 fn send(&self, value: T) -> Result<(), EventChannelError>;
32
33 fn subscribe(&self) -> Self::Receiver;
35
36 fn subscriber_count(&self) -> usize;
38}
39
40pub trait EventReceiver<T>: Send
42where
43 T: Send,
44{
45 fn recv(&mut self) -> impl Future<Output = Result<T, EventChannelError>> + Send;
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
51pub enum EventChannelError {
52 Closed,
54 Lagged(u64),
56}
57
58impl core::fmt::Display for EventChannelError {
59 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60 match self {
61 Self::Closed => write!(f, "event channel closed"),
62 Self::Lagged(count) => write!(f, "event channel lagged by {count} messages"),
63 }
64 }
65}
66
67impl std::error::Error for EventChannelError {}
68
69pub fn bind_listener(addr: SocketAddr) -> io::Result<TcpListener> {
71 TcpListener::bind(addr)
72}
73
74fn block_on<F>(future: F) -> F::Output
80where
81 F: Future,
82{
83 fn raw_waker() -> RawWaker {
84 fn clone(_: *const ()) -> RawWaker {
85 raw_waker()
86 }
87 fn wake(_: *const ()) {}
88 fn wake_by_ref(_: *const ()) {}
89 fn drop(_: *const ()) {}
90
91 RawWaker::new(
92 std::ptr::null(),
93 &RawWakerVTable::new(clone, wake, wake_by_ref, drop),
94 )
95 }
96
97 let waker = unsafe { Waker::from_raw(raw_waker()) };
98 let mut future = Box::pin(future);
99 let mut context = Context::from_waker(&waker);
100
101 loop {
102 match Pin::as_mut(&mut future).poll(&mut context) {
103 Poll::Ready(value) => return value,
104 Poll::Pending => thread::yield_now(),
105 }
106 }
107}
108
109#[derive(Debug, Clone, Default)]
111pub struct DetachedTasks;
112
113impl DetachedTasks {
114 pub fn current() -> Self {
116 Self
117 }
118}
119
120impl Spawn for DetachedTasks {
121 fn spawn<F>(&self, future: F)
125 where
126 F: Future<Output = ()> + Send + 'static,
127 {
128 thread::spawn(move || {
129 block_on(future);
130 });
131 }
132}
133
134#[derive(Debug)]
135struct FanoutState<T> {
136 buffer: VecDeque<(u64, T)>,
137 next_seq: u64,
138 receiver_count: usize,
139}
140
141#[derive(Debug)]
142struct FanoutShared<T> {
143 capacity: usize,
144 state: Mutex<FanoutState<T>>,
145 condvar: Condvar,
146}
147
148#[derive(Debug, Clone)]
150pub struct FanoutChannel<T: Clone + Send + 'static> {
151 shared: Arc<FanoutShared<T>>,
152}
153
154impl<T: Clone + Send + 'static> FanoutChannel<T> {
155 pub fn new(capacity: usize) -> Self {
157 let capacity = capacity.max(1);
158 Self {
159 shared: Arc::new(FanoutShared {
160 capacity,
161 state: Mutex::new(FanoutState {
162 buffer: VecDeque::new(),
163 next_seq: 0,
164 receiver_count: 0,
165 }),
166 condvar: Condvar::new(),
167 }),
168 }
169 }
170}
171
172impl<T: Clone + Send + 'static> EventChannel<T> for FanoutChannel<T> {
173 type Receiver = FanoutReceiver<T>;
174
175 fn send(&self, value: T) -> Result<(), EventChannelError> {
176 let mut state = self.shared.state.lock().expect("fanout channel poisoned");
177 let seq = state.next_seq;
178 state.next_seq += 1;
179 state.buffer.push_back((seq, value));
180 while state.buffer.len() > self.shared.capacity {
181 state.buffer.pop_front();
182 }
183 self.shared.condvar.notify_all();
184 Ok(())
185 }
186
187 fn subscribe(&self) -> Self::Receiver {
188 let mut state = self.shared.state.lock().expect("fanout channel poisoned");
189 state.receiver_count += 1;
190 let next_seq = state.next_seq;
191 drop(state);
192 FanoutReceiver {
193 shared: self.shared.clone(),
194 next_seq,
195 }
196 }
197
198 fn subscriber_count(&self) -> usize {
199 self.shared
200 .state
201 .lock()
202 .expect("fanout channel poisoned")
203 .receiver_count
204 }
205}
206
207pub struct FanoutReceiver<T: Clone + Send + 'static> {
209 shared: Arc<FanoutShared<T>>,
210 next_seq: u64,
211}
212
213impl<T: Clone + Send + 'static> Drop for FanoutReceiver<T> {
214 fn drop(&mut self) {
215 let mut state = self.shared.state.lock().expect("fanout channel poisoned");
216 state.receiver_count = state.receiver_count.saturating_sub(1);
217 }
218}
219
220impl<T: Clone + Send + 'static> EventReceiver<T> for FanoutReceiver<T> {
221 async fn recv(&mut self) -> Result<T, EventChannelError> {
222 loop {
223 let mut state = self.shared.state.lock().expect("fanout channel poisoned");
224
225 if let Some((oldest_seq, _)) = state.buffer.front() {
226 if self.next_seq < *oldest_seq {
227 let lagged = *oldest_seq - self.next_seq;
228 self.next_seq = *oldest_seq;
229 return Err(EventChannelError::Lagged(lagged));
230 }
231 }
232
233 if let Some((_, value)) = state
234 .buffer
235 .iter()
236 .find(|(seq, _)| *seq == self.next_seq)
237 .cloned()
238 {
239 self.next_seq += 1;
240 return Ok(value);
241 }
242
243 state = self
244 .shared
245 .condvar
246 .wait(state)
247 .expect("fanout channel poisoned");
248 drop(state);
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use std::sync::mpsc;
257
258 #[test]
259 fn fanout_channel_send_receive_smoke() {
260 let bus: FanoutChannel<u64> = FanoutChannel::new(16);
261 let mut rx = bus.subscribe();
262 bus.send(42).expect("send");
263 let value = block_on(rx.recv()).expect("recv");
264 assert_eq!(value, 42);
265 }
266
267 #[test]
268 fn fanout_channel_send_with_no_subscriber_is_ok() {
269 let bus: FanoutChannel<u64> = FanoutChannel::new(16);
270 bus.send(1).expect("send must be ok");
271 bus.send(2).expect("send must be ok");
272 assert_eq!(bus.subscriber_count(), 0);
273 }
274
275 #[test]
276 fn fanout_channel_lag_returns_error() {
277 let bus: FanoutChannel<u64> = FanoutChannel::new(2);
278 let mut rx = bus.subscribe();
279 for value in 0..5 {
280 bus.send(value).expect("send");
281 }
282 let result = block_on(rx.recv());
283 assert!(matches!(result, Err(EventChannelError::Lagged(_))));
284 }
285
286 #[test]
287 fn detached_tasks_runs_future() {
288 let runtime = DetachedTasks::current();
289 let (tx, rx) = mpsc::channel();
290 runtime.spawn(async move {
291 let _ = tx.send(7u64);
292 });
293 let value = rx.recv().expect("oneshot");
294 assert_eq!(value, 7);
295 }
296}