#![warn(missing_docs)]
use core::future::Future;
use std::collections::VecDeque;
use std::io;
use std::net::{SocketAddr, TcpListener};
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::thread;
pub trait Spawn: Send + Sync + 'static {
fn spawn<F>(&self, future: F)
where
F: Future<Output = ()> + Send + 'static;
}
pub trait EventChannel<T>: Send + Sync
where
T: Clone + Send + 'static,
{
type Receiver: EventReceiver<T>;
fn send(&self, value: T) -> Result<(), EventChannelError>;
fn subscribe(&self) -> Self::Receiver;
fn subscriber_count(&self) -> usize;
}
pub trait EventReceiver<T>: Send
where
T: Send,
{
fn recv(&mut self) -> impl Future<Output = Result<T, EventChannelError>> + Send;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EventChannelError {
Closed,
Lagged(u64),
}
impl core::fmt::Display for EventChannelError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Closed => write!(f, "event channel closed"),
Self::Lagged(count) => write!(f, "event channel lagged by {count} messages"),
}
}
}
impl std::error::Error for EventChannelError {}
pub fn bind_listener(addr: SocketAddr) -> io::Result<TcpListener> {
TcpListener::bind(addr)
}
fn block_on<F>(future: F) -> F::Output
where
F: Future,
{
fn raw_waker() -> RawWaker {
fn clone(_: *const ()) -> RawWaker {
raw_waker()
}
fn wake(_: *const ()) {}
fn wake_by_ref(_: *const ()) {}
fn drop(_: *const ()) {}
RawWaker::new(
std::ptr::null(),
&RawWakerVTable::new(clone, wake, wake_by_ref, drop),
)
}
let waker = unsafe { Waker::from_raw(raw_waker()) };
let mut future = Box::pin(future);
let mut context = Context::from_waker(&waker);
loop {
match Pin::as_mut(&mut future).poll(&mut context) {
Poll::Ready(value) => return value,
Poll::Pending => thread::yield_now(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DetachedTasks;
impl DetachedTasks {
pub fn current() -> Self {
Self
}
}
impl Spawn for DetachedTasks {
fn spawn<F>(&self, future: F)
where
F: Future<Output = ()> + Send + 'static,
{
thread::spawn(move || {
block_on(future);
});
}
}
#[derive(Debug)]
struct FanoutState<T> {
buffer: VecDeque<(u64, T)>,
next_seq: u64,
receiver_count: usize,
}
#[derive(Debug)]
struct FanoutShared<T> {
capacity: usize,
state: Mutex<FanoutState<T>>,
condvar: Condvar,
}
#[derive(Debug, Clone)]
pub struct FanoutChannel<T: Clone + Send + 'static> {
shared: Arc<FanoutShared<T>>,
}
impl<T: Clone + Send + 'static> FanoutChannel<T> {
pub fn new(capacity: usize) -> Self {
let capacity = capacity.max(1);
Self {
shared: Arc::new(FanoutShared {
capacity,
state: Mutex::new(FanoutState {
buffer: VecDeque::new(),
next_seq: 0,
receiver_count: 0,
}),
condvar: Condvar::new(),
}),
}
}
}
impl<T: Clone + Send + 'static> EventChannel<T> for FanoutChannel<T> {
type Receiver = FanoutReceiver<T>;
fn send(&self, value: T) -> Result<(), EventChannelError> {
let mut state = self.shared.state.lock().expect("fanout channel poisoned");
let seq = state.next_seq;
state.next_seq += 1;
state.buffer.push_back((seq, value));
while state.buffer.len() > self.shared.capacity {
state.buffer.pop_front();
}
self.shared.condvar.notify_all();
Ok(())
}
fn subscribe(&self) -> Self::Receiver {
let mut state = self.shared.state.lock().expect("fanout channel poisoned");
state.receiver_count += 1;
let next_seq = state.next_seq;
drop(state);
FanoutReceiver {
shared: self.shared.clone(),
next_seq,
}
}
fn subscriber_count(&self) -> usize {
self.shared
.state
.lock()
.expect("fanout channel poisoned")
.receiver_count
}
}
pub struct FanoutReceiver<T: Clone + Send + 'static> {
shared: Arc<FanoutShared<T>>,
next_seq: u64,
}
impl<T: Clone + Send + 'static> Drop for FanoutReceiver<T> {
fn drop(&mut self) {
let mut state = self.shared.state.lock().expect("fanout channel poisoned");
state.receiver_count = state.receiver_count.saturating_sub(1);
}
}
impl<T: Clone + Send + 'static> EventReceiver<T> for FanoutReceiver<T> {
async fn recv(&mut self) -> Result<T, EventChannelError> {
loop {
let mut state = self.shared.state.lock().expect("fanout channel poisoned");
if let Some((oldest_seq, _)) = state.buffer.front() {
if self.next_seq < *oldest_seq {
let lagged = *oldest_seq - self.next_seq;
self.next_seq = *oldest_seq;
return Err(EventChannelError::Lagged(lagged));
}
}
if let Some((_, value)) = state
.buffer
.iter()
.find(|(seq, _)| *seq == self.next_seq)
.cloned()
{
self.next_seq += 1;
return Ok(value);
}
state = self
.shared
.condvar
.wait(state)
.expect("fanout channel poisoned");
drop(state);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc;
#[test]
fn fanout_channel_send_receive_smoke() {
let bus: FanoutChannel<u64> = FanoutChannel::new(16);
let mut rx = bus.subscribe();
bus.send(42).expect("send");
let value = block_on(rx.recv()).expect("recv");
assert_eq!(value, 42);
}
#[test]
fn fanout_channel_send_with_no_subscriber_is_ok() {
let bus: FanoutChannel<u64> = FanoutChannel::new(16);
bus.send(1).expect("send must be ok");
bus.send(2).expect("send must be ok");
assert_eq!(bus.subscriber_count(), 0);
}
#[test]
fn fanout_channel_lag_returns_error() {
let bus: FanoutChannel<u64> = FanoutChannel::new(2);
let mut rx = bus.subscribe();
for value in 0..5 {
bus.send(value).expect("send");
}
let result = block_on(rx.recv());
assert!(matches!(result, Err(EventChannelError::Lagged(_))));
}
#[test]
fn detached_tasks_runs_future() {
let runtime = DetachedTasks::current();
let (tx, rx) = mpsc::channel();
runtime.spawn(async move {
let _ = tx.send(7u64);
});
let value = rx.recv().expect("oneshot");
assert_eq!(value, 7);
}
}