neco-server-runtime 0.1.0

runtime primitives for neco-server
Documentation
#![warn(missing_docs)]

//! Runtime primitives for `neco-server`.

use core::future::Future;
use std::collections::VecDeque;
use std::io;
use std::net::{SocketAddr, TcpListener};
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::thread;

/// Fire-and-forget task spawning.
pub trait Spawn: Send + Sync + 'static {
    /// Spawns a detached future.
    fn spawn<F>(&self, future: F)
    where
        F: Future<Output = ()> + Send + 'static;
}

/// Multi-subscriber event channel.
pub trait EventChannel<T>: Send + Sync
where
    T: Clone + Send + 'static,
{
    /// Receiver type returned by `subscribe`.
    type Receiver: EventReceiver<T>;

    /// Sends a value to all subscribers.
    fn send(&self, value: T) -> Result<(), EventChannelError>;

    /// Creates a new receiver.
    fn subscribe(&self) -> Self::Receiver;

    /// Returns the current subscriber count.
    fn subscriber_count(&self) -> usize;
}

/// Receiver side of an event channel.
pub trait EventReceiver<T>: Send
where
    T: Send,
{
    /// Receives the next value.
    fn recv(&mut self) -> impl Future<Output = Result<T, EventChannelError>> + Send;
}

/// Event channel error.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EventChannelError {
    /// The channel is closed.
    Closed,
    /// The receiver lagged behind and lost `u64` messages.
    Lagged(u64),
}

impl core::fmt::Display for EventChannelError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            Self::Closed => write!(f, "event channel closed"),
            Self::Lagged(count) => write!(f, "event channel lagged by {count} messages"),
        }
    }
}

impl std::error::Error for EventChannelError {}

/// Binds a TCP listener for serving incoming connections.
pub fn bind_listener(addr: SocketAddr) -> io::Result<TcpListener> {
    TcpListener::bind(addr)
}

/// Drives a future to completion on the current thread without an async runtime.
///
/// This executor only makes progress for futures that can complete by repeated polling
/// on the current thread. Futures that rely on an external reactor, timer wheel, or
/// wake-driven async I/O must not be passed here.
fn block_on<F>(future: F) -> F::Output
where
    F: Future,
{
    fn raw_waker() -> RawWaker {
        fn clone(_: *const ()) -> RawWaker {
            raw_waker()
        }
        fn wake(_: *const ()) {}
        fn wake_by_ref(_: *const ()) {}
        fn drop(_: *const ()) {}

        RawWaker::new(
            std::ptr::null(),
            &RawWakerVTable::new(clone, wake, wake_by_ref, drop),
        )
    }

    let waker = unsafe { Waker::from_raw(raw_waker()) };
    let mut future = Box::pin(future);
    let mut context = Context::from_waker(&waker);

    loop {
        match Pin::as_mut(&mut future).poll(&mut context) {
            Poll::Ready(value) => return value,
            Poll::Pending => thread::yield_now(),
        }
    }
}

/// Detached task runner backed by OS threads.
#[derive(Debug, Clone, Default)]
pub struct DetachedTasks;

impl DetachedTasks {
    /// Creates a detached task runner.
    pub fn current() -> Self {
        Self
    }
}

impl Spawn for DetachedTasks {
    /// Spawns a future on a dedicated OS thread using the local busy-loop executor.
    ///
    /// The spawned future must not depend on an external async runtime.
    fn spawn<F>(&self, future: F)
    where
        F: Future<Output = ()> + Send + 'static,
    {
        thread::spawn(move || {
            block_on(future);
        });
    }
}

#[derive(Debug)]
struct FanoutState<T> {
    buffer: VecDeque<(u64, T)>,
    next_seq: u64,
    receiver_count: usize,
}

#[derive(Debug)]
struct FanoutShared<T> {
    capacity: usize,
    state: Mutex<FanoutState<T>>,
    condvar: Condvar,
}

/// Multi-subscriber fanout channel with ring-buffer semantics.
#[derive(Debug, Clone)]
pub struct FanoutChannel<T: Clone + Send + 'static> {
    shared: Arc<FanoutShared<T>>,
}

impl<T: Clone + Send + 'static> FanoutChannel<T> {
    /// Creates a channel with the given ring buffer capacity.
    pub fn new(capacity: usize) -> Self {
        let capacity = capacity.max(1);
        Self {
            shared: Arc::new(FanoutShared {
                capacity,
                state: Mutex::new(FanoutState {
                    buffer: VecDeque::new(),
                    next_seq: 0,
                    receiver_count: 0,
                }),
                condvar: Condvar::new(),
            }),
        }
    }
}

impl<T: Clone + Send + 'static> EventChannel<T> for FanoutChannel<T> {
    type Receiver = FanoutReceiver<T>;

    fn send(&self, value: T) -> Result<(), EventChannelError> {
        let mut state = self.shared.state.lock().expect("fanout channel poisoned");
        let seq = state.next_seq;
        state.next_seq += 1;
        state.buffer.push_back((seq, value));
        while state.buffer.len() > self.shared.capacity {
            state.buffer.pop_front();
        }
        self.shared.condvar.notify_all();
        Ok(())
    }

    fn subscribe(&self) -> Self::Receiver {
        let mut state = self.shared.state.lock().expect("fanout channel poisoned");
        state.receiver_count += 1;
        let next_seq = state.next_seq;
        drop(state);
        FanoutReceiver {
            shared: self.shared.clone(),
            next_seq,
        }
    }

    fn subscriber_count(&self) -> usize {
        self.shared
            .state
            .lock()
            .expect("fanout channel poisoned")
            .receiver_count
    }
}

/// Receiver returned by [`FanoutChannel::subscribe`].
pub struct FanoutReceiver<T: Clone + Send + 'static> {
    shared: Arc<FanoutShared<T>>,
    next_seq: u64,
}

impl<T: Clone + Send + 'static> Drop for FanoutReceiver<T> {
    fn drop(&mut self) {
        let mut state = self.shared.state.lock().expect("fanout channel poisoned");
        state.receiver_count = state.receiver_count.saturating_sub(1);
    }
}

impl<T: Clone + Send + 'static> EventReceiver<T> for FanoutReceiver<T> {
    async fn recv(&mut self) -> Result<T, EventChannelError> {
        loop {
            let mut state = self.shared.state.lock().expect("fanout channel poisoned");

            if let Some((oldest_seq, _)) = state.buffer.front() {
                if self.next_seq < *oldest_seq {
                    let lagged = *oldest_seq - self.next_seq;
                    self.next_seq = *oldest_seq;
                    return Err(EventChannelError::Lagged(lagged));
                }
            }

            if let Some((_, value)) = state
                .buffer
                .iter()
                .find(|(seq, _)| *seq == self.next_seq)
                .cloned()
            {
                self.next_seq += 1;
                return Ok(value);
            }

            state = self
                .shared
                .condvar
                .wait(state)
                .expect("fanout channel poisoned");
            drop(state);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::mpsc;

    #[test]
    fn fanout_channel_send_receive_smoke() {
        let bus: FanoutChannel<u64> = FanoutChannel::new(16);
        let mut rx = bus.subscribe();
        bus.send(42).expect("send");
        let value = block_on(rx.recv()).expect("recv");
        assert_eq!(value, 42);
    }

    #[test]
    fn fanout_channel_send_with_no_subscriber_is_ok() {
        let bus: FanoutChannel<u64> = FanoutChannel::new(16);
        bus.send(1).expect("send must be ok");
        bus.send(2).expect("send must be ok");
        assert_eq!(bus.subscriber_count(), 0);
    }

    #[test]
    fn fanout_channel_lag_returns_error() {
        let bus: FanoutChannel<u64> = FanoutChannel::new(2);
        let mut rx = bus.subscribe();
        for value in 0..5 {
            bus.send(value).expect("send");
        }
        let result = block_on(rx.recv());
        assert!(matches!(result, Err(EventChannelError::Lagged(_))));
    }

    #[test]
    fn detached_tasks_runs_future() {
        let runtime = DetachedTasks::current();
        let (tx, rx) = mpsc::channel();
        runtime.spawn(async move {
            let _ = tx.send(7u64);
        });
        let value = rx.recv().expect("oneshot");
        assert_eq!(value, 7);
    }
}