use std::io;
use std::sync::mpsc::{SyncSender, sync_channel};
use std::thread::{self, JoinHandle};
use crate::error::Error;
pub const DEFAULT_PIPELINE_DEPTH: usize = 4;
#[allow(dead_code)]
pub const WRITE_THROUGH_DEPTH: usize = 1;
pub enum Flow {
Continue,
#[allow(dead_code)]
Stop,
}
pub trait Sink<I>: Send + 'static {
type Output: Send + 'static;
fn apply(&mut self, item: I) -> Result<Flow, Error>;
fn close(self) -> Result<Self::Output, Error>;
}
pub struct Pipeline<I: Send + 'static, R: Send + 'static> {
tx: SyncSender<I>,
handle: JoinHandle<Result<R, Error>>,
}
impl<I: Send + 'static, R: Send + 'static> Pipeline<I, R> {
#[allow(dead_code)]
pub fn spawn<S: Sink<I, Output = R>>(depth: usize, sink: S) -> Result<Self, Error> {
Self::spawn_named("freemkv-pipeline-consumer", depth, sink)
}
pub fn spawn_named<S: Sink<I, Output = R>>(
name: &str,
depth: usize,
sink: S,
) -> Result<Self, Error> {
let (tx, rx) = sync_channel::<I>(depth);
let handle = thread::Builder::new()
.name(name.into())
.spawn(move || -> Result<R, Error> {
let mut sink = sink;
let mut first_err: Option<Error> = None;
let mut stopped = false;
while let Ok(item) = rx.recv() {
if first_err.is_some() || stopped {
continue;
}
match sink.apply(item) {
Ok(Flow::Continue) => {}
Ok(Flow::Stop) => {
stopped = true;
}
Err(e) => {
first_err = Some(e);
}
}
}
match first_err {
Some(e) => Err(e),
None => sink.close(),
}
})
.map_err(|e| Error::IoError { source: e })?;
Ok(Pipeline { tx, handle })
}
pub fn send(&self, item: I) -> Result<(), I> {
self.tx.send(item).map_err(|e| e.0)
}
pub fn try_send(&self, item: I) -> Result<(), std::sync::mpsc::TrySendError<I>> {
self.tx.try_send(item)
}
pub fn finish(self) -> Result<R, Error> {
let Pipeline { tx, handle } = self;
drop(tx);
match handle.join() {
Ok(result) => result,
Err(payload) => {
let msg = payload
.downcast_ref::<&'static str>()
.copied()
.or_else(|| payload.downcast_ref::<String>().map(|s| s.as_str()))
.unwrap_or("(no message)");
Err(Error::IoError {
source: io::Error::other(format!("pipeline consumer panicked: {msg}")),
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
struct SumSink {
total: u64,
}
impl Sink<u64> for SumSink {
type Output = u64;
fn apply(&mut self, item: u64) -> Result<Flow, Error> {
self.total += item;
Ok(Flow::Continue)
}
fn close(self) -> Result<u64, Error> {
Ok(self.total)
}
}
#[test]
fn happy_path_sums_items() {
let pipe = Pipeline::spawn(DEFAULT_PIPELINE_DEPTH, SumSink { total: 0 })
.expect("spawn should succeed");
let mut expected = 0u64;
for i in 0..100u64 {
expected += i;
pipe.send(i).expect("send should succeed");
}
let total = pipe.finish().expect("finish should succeed");
assert_eq!(total, expected);
assert_eq!(total, (0..100u64).sum::<u64>());
}
struct SlowSink {
delay: Duration,
count: Arc<AtomicUsize>,
}
impl Sink<()> for SlowSink {
type Output = usize;
fn apply(&mut self, _item: ()) -> Result<Flow, Error> {
std::thread::sleep(self.delay);
self.count.fetch_add(1, Ordering::SeqCst);
Ok(Flow::Continue)
}
fn close(self) -> Result<usize, Error> {
Ok(self.count.load(Ordering::SeqCst))
}
}
#[test]
fn back_pressure_blocks_sender() {
let count = Arc::new(AtomicUsize::new(0));
let sink = SlowSink {
delay: Duration::from_millis(50),
count: count.clone(),
};
let pipe = Pipeline::spawn(2, sink).expect("spawn should succeed");
let start = Instant::now();
for _ in 0..5 {
pipe.send(()).expect("send should succeed");
}
let elapsed_send = start.elapsed();
let total = pipe.finish().expect("finish should succeed");
assert_eq!(total, 5);
assert!(
elapsed_send >= Duration::from_millis(80),
"back-pressure not observed: 5 sends with depth=2 and 50ms/apply \
took {elapsed_send:?}, expected ≥ ~100ms (one or more sends \
should have blocked behind the consumer)"
);
}
struct FailOnNthSink {
n: usize,
seen: Arc<AtomicUsize>,
close_called: Arc<AtomicUsize>,
}
impl Sink<u64> for FailOnNthSink {
type Output = ();
fn apply(&mut self, _item: u64) -> Result<Flow, Error> {
let i = self.seen.fetch_add(1, Ordering::SeqCst) + 1;
if i == self.n {
Err(Error::DecryptFailed)
} else {
Ok(Flow::Continue)
}
}
fn close(self) -> Result<(), Error> {
self.close_called.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[test]
fn apply_error_drains_then_propagates() {
let seen = Arc::new(AtomicUsize::new(0));
let close_called = Arc::new(AtomicUsize::new(0));
let pipe = Pipeline::spawn(
DEFAULT_PIPELINE_DEPTH,
FailOnNthSink {
n: 3,
seen: seen.clone(),
close_called: close_called.clone(),
},
)
.expect("spawn should succeed");
for i in 0..10u64 {
pipe.send(i).expect("send should succeed even after error");
}
let res = pipe.finish();
assert!(matches!(res, Err(Error::DecryptFailed)));
assert_eq!(
close_called.load(Ordering::SeqCst),
0,
"close() must not be called when apply returned Err"
);
assert_eq!(seen.load(Ordering::SeqCst), 3);
}
struct StopOnNthSink {
n: usize,
seen: Arc<AtomicUsize>,
close_called: Arc<AtomicUsize>,
}
impl Sink<u64> for StopOnNthSink {
type Output = usize;
fn apply(&mut self, _item: u64) -> Result<Flow, Error> {
let i = self.seen.fetch_add(1, Ordering::SeqCst) + 1;
if i >= self.n {
Ok(Flow::Stop)
} else {
Ok(Flow::Continue)
}
}
fn close(self) -> Result<usize, Error> {
self.close_called.fetch_add(1, Ordering::SeqCst);
Ok(self.seen.load(Ordering::SeqCst))
}
}
#[test]
fn apply_stop_calls_close_and_returns_output() {
let seen = Arc::new(AtomicUsize::new(0));
let close_called = Arc::new(AtomicUsize::new(0));
let pipe = Pipeline::spawn(
DEFAULT_PIPELINE_DEPTH,
StopOnNthSink {
n: 3,
seen: seen.clone(),
close_called: close_called.clone(),
},
)
.expect("spawn should succeed");
for i in 0..10u64 {
let _ = pipe.send(i);
}
let out = pipe.finish().expect("finish should succeed after Stop");
assert_eq!(close_called.load(Ordering::SeqCst), 1);
assert!(
out >= 3,
"expected ≥ 3 applies before Stop took effect, got {out}"
);
}
struct PanickingSink;
impl Sink<u64> for PanickingSink {
type Output = ();
fn apply(&mut self, _item: u64) -> Result<Flow, Error> {
panic!("synthetic test panic");
}
fn close(self) -> Result<(), Error> {
Ok(())
}
}
#[test]
fn consumer_panic_becomes_io_error() {
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let pipe =
Pipeline::spawn(DEFAULT_PIPELINE_DEPTH, PanickingSink).expect("spawn should succeed");
let _ = pipe.send(1);
for i in 0..5u64 {
let _ = pipe.send(i);
}
let res = pipe.finish();
std::panic::set_hook(prev);
match res {
Err(Error::IoError { source }) => {
let msg = source.to_string();
assert!(
msg.contains("pipeline consumer panicked"),
"expected constant panic prefix, got: {msg}"
);
assert!(
msg.contains("synthetic test panic"),
"expected original panic payload, got: {msg}"
);
}
other => panic!("expected Err(IoError), got {other:?}"),
}
}
}