#![deny(missing_docs)]
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
pub trait Worker: Send + 'static {
type Request: Send + 'static;
type Response: Send + 'static;
fn handle(&mut self, req: Self::Request) -> Self::Response;
}
#[derive(Debug, Clone)]
pub struct Output<T> {
pub generation: u64,
pub value: T,
}
enum Msg<R> {
Run { generation: u64, request: R },
Shutdown,
}
pub struct Coalescer<W: Worker> {
tx: Sender<Msg<W::Request>>,
rx: Receiver<Output<W::Response>>,
generation: u64,
_thread: Option<thread::JoinHandle<()>>,
}
impl<W: Worker> Coalescer<W> {
pub fn new(worker: W) -> Self {
Self::spawn_named("coalesce-worker", worker)
}
pub fn spawn_named(name: &str, worker: W) -> Self {
let (req_tx, req_rx) = mpsc::channel::<Msg<W::Request>>();
let (res_tx, res_rx) = mpsc::channel::<Output<W::Response>>();
let thread = thread::Builder::new()
.name(name.to_owned())
.spawn(move || worker_loop(worker, req_rx, res_tx))
.expect("failed to spawn coalescer worker thread");
Self {
tx: req_tx,
rx: res_rx,
generation: 0,
_thread: Some(thread),
}
}
pub fn submit(&mut self, request: W::Request) -> u64 {
self.generation += 1;
let _ = self.tx.send(Msg::Run {
generation: self.generation,
request,
});
self.generation
}
pub fn poll(&mut self) -> Option<Output<W::Response>> {
let mut latest: Option<Output<W::Response>> = None;
while let Ok(out) = self.rx.try_recv() {
match &latest {
Some(cur) if cur.generation >= out.generation => {}
_ => latest = Some(out),
}
}
latest
}
pub fn flush_pending(&mut self) {
while self.rx.try_recv().is_ok() {}
}
pub fn current_generation(&self) -> u64 {
self.generation
}
}
impl<W: Worker> Drop for Coalescer<W> {
fn drop(&mut self) {
let _ = self.tx.send(Msg::Shutdown);
}
}
fn worker_loop<W: Worker>(
mut worker: W,
req_rx: Receiver<Msg<W::Request>>,
res_tx: Sender<Output<W::Response>>,
) {
loop {
let first = match req_rx.recv() {
Ok(m) => m,
Err(_) => return, };
let mut latest: Option<(u64, W::Request)> = None;
let mut shutdown = false;
let process = |m: Msg<W::Request>, latest: &mut Option<(u64, W::Request)>| -> bool {
match m {
Msg::Run {
generation,
request,
} => {
match latest {
Some((g, _)) if *g >= generation => {}
_ => *latest = Some((generation, request)),
}
false
}
Msg::Shutdown => true,
}
};
shutdown = process(first, &mut latest) || shutdown;
loop {
match req_rx.try_recv() {
Ok(m) => {
shutdown = process(m, &mut latest) || shutdown;
}
Err(mpsc::TryRecvError::Empty) => break,
Err(mpsc::TryRecvError::Disconnected) => return,
}
}
if shutdown {
return;
}
if let Some((generation, request)) = latest {
let value = worker.handle(request);
if res_tx.send(Output { generation, value }).is_err() {
return; }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
struct CountingWorker {
calls: Arc<AtomicUsize>,
}
impl Worker for CountingWorker {
type Request = u32;
type Response = u32;
fn handle(&mut self, req: u32) -> u32 {
self.calls.fetch_add(1, Ordering::SeqCst);
thread::sleep(Duration::from_millis(10));
req
}
}
fn wait_until<F: FnMut() -> bool>(mut cond: F, timeout: Duration) -> bool {
let start = Instant::now();
while start.elapsed() < timeout {
if cond() {
return true;
}
thread::sleep(Duration::from_millis(1));
}
false
}
#[test]
fn submit_and_poll_roundtrip() {
let calls = Arc::new(AtomicUsize::new(0));
let mut c = Coalescer::new(CountingWorker {
calls: Arc::clone(&calls),
});
let generation = c.submit(42);
assert_eq!(generation, 1);
let mut got = None;
assert!(
wait_until(
|| {
got = c.poll();
got.is_some()
},
Duration::from_secs(1),
),
"timed out waiting for response",
);
let out = got.unwrap();
assert_eq!(out.generation, 1);
assert_eq!(out.value, 42);
}
#[test]
fn poll_returns_newest_when_multiple_pending() {
let calls = Arc::new(AtomicUsize::new(0));
let mut c = Coalescer::new(CountingWorker {
calls: Arc::clone(&calls),
});
for i in 0..5 {
c.submit(i);
}
thread::sleep(Duration::from_millis(100));
let out = c.poll().expect("should receive at least one response");
assert!(out.generation <= 5);
assert!(c.poll().is_none());
}
#[test]
fn coalescing_drops_intermediate_requests() {
let calls = Arc::new(AtomicUsize::new(0));
let mut c = Coalescer::new(CountingWorker {
calls: Arc::clone(&calls),
});
for i in 0..100 {
c.submit(i);
}
let mut max_gen = 0;
let _ = wait_until(
|| {
if let Some(out) = c.poll() {
max_gen = max_gen.max(out.generation);
}
max_gen == 100
},
Duration::from_secs(3),
);
assert_eq!(max_gen, 100, "final request should eventually complete");
let total_calls = calls.load(Ordering::SeqCst);
assert!(
total_calls < 100,
"expected coalescing to drop work, got {total_calls} calls"
);
}
#[test]
fn flush_pending_drops_unread_responses() {
let calls = Arc::new(AtomicUsize::new(0));
let mut c = Coalescer::new(CountingWorker {
calls: Arc::clone(&calls),
});
c.submit(1);
wait_until(|| calls.load(Ordering::SeqCst) >= 1, Duration::from_secs(1));
thread::sleep(Duration::from_millis(20));
c.flush_pending();
assert!(
c.poll().is_none(),
"flush should have dropped the pending response"
);
}
#[test]
fn generation_monotonic() {
let calls = Arc::new(AtomicUsize::new(0));
let mut c = Coalescer::new(CountingWorker {
calls: Arc::clone(&calls),
});
assert_eq!(c.current_generation(), 0);
for i in 1..=5 {
let g = c.submit(i);
assert_eq!(g, i as u64);
}
assert_eq!(c.current_generation(), 5);
}
#[test]
fn drop_shuts_down_cleanly() {
let calls = Arc::new(AtomicUsize::new(0));
let c = Coalescer::new(CountingWorker {
calls: Arc::clone(&calls),
});
drop(c);
}
}