use std::io;
use std::sync::mpsc::{SyncSender, TrySendError, sync_channel};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use crate::error::Error;
use crate::halt::Halt;
pub const JOIN_TIMEOUT_SECS: u64 = 600;
const POLL_INTERVAL: Duration = Duration::from_millis(250);
const SEND_POLL_INTERVAL: Duration = Duration::from_millis(50);
pub fn debug_enabled() -> bool {
std::env::var("FREEMKV_DEBUG")
.ok()
.map(|v| v == "1" || v == "true" || v == "yes")
.unwrap_or(false)
}
pub const DEFAULT_PIPELINE_DEPTH: usize = 4;
pub const READ_PIPELINE_DEPTH: usize = 32;
pub const WRITE_PIPELINE_DEPTH: usize = 16;
#[allow(dead_code)]
pub const WRITE_THROUGH_DEPTH: usize = 1;
pub enum Flow {
Continue,
#[allow(dead_code)]
Stop,
}
pub trait Sink<I>: Send + 'static {
type Output: Send + 'static;
fn apply(&mut self, item: I) -> Result<Flow, Error>;
fn close(self) -> Result<Self::Output, Error>;
}
pub struct Pipeline<I: Send + 'static, R: Send + 'static> {
tx: SyncSender<I>,
handle: JoinHandle<Result<R, Error>>,
}
impl<I: Send + 'static, R: Send + 'static> Pipeline<I, R> {
#[allow(dead_code)]
pub fn spawn<S: Sink<I, Output = R>>(depth: usize, sink: S) -> Result<Self, Error> {
Self::spawn_named("freemkv-pipeline-consumer", depth, sink)
}
pub fn spawn_named<S: Sink<I, Output = R>>(
name: &str,
depth: usize,
sink: S,
) -> Result<Self, Error> {
let (tx, rx) = sync_channel::<I>(depth);
let handle = thread::Builder::new()
.name(name.into())
.spawn(move || -> Result<R, Error> {
let mut sink = sink;
let mut first_err: Option<Error> = None;
let mut stopped = false;
while let Ok(item) = rx.recv() {
if debug_enabled() {
tracing::debug!("Pipeline receive: item={}", std::any::type_name::<I>());
}
let apply_start = std::time::Instant::now();
if first_err.is_some() || stopped {
continue;
}
match sink.apply(item) {
Ok(Flow::Continue) => {}
Ok(Flow::Stop) => {
stopped = true;
if debug_enabled() {
tracing::debug!("Pipeline: consumer returned Flow::Stop");
}
}
Err(e) => {
if debug_enabled() {
tracing::debug!("Pipeline: apply error, stopping, err={:?}", e);
}
first_err = Some(e);
}
}
let apply_elapsed = apply_start.elapsed();
if debug_enabled() && apply_elapsed > std::time::Duration::from_millis(100) {
tracing::debug!(
"Pipeline apply: took {:.2}s, item={}",
apply_elapsed.as_secs_f64(),
std::any::type_name::<I>()
);
} else if debug_enabled() {
tracing::debug!(
"Pipeline apply: OK in {:.3}ms, item={}",
apply_elapsed.as_micros(),
std::any::type_name::<I>()
);
}
}
match first_err {
Some(e) => Err(e),
None => sink.close(),
}
})
.map_err(|e| Error::IoError { source: e })?;
Ok(Pipeline { tx, handle })
}
pub fn send(&self, item: I) -> Result<(), I> {
let start = std::time::Instant::now();
match self.tx.send(item) {
Ok(()) => {
let elapsed = start.elapsed();
if debug_enabled() && elapsed > std::time::Duration::from_millis(10) {
tracing::debug!(
"Pipeline send: blocked {:.2}s, item={}",
elapsed.as_secs_f64(),
std::any::type_name::<I>()
);
} else if debug_enabled() {
tracing::debug!("Pipeline send: OK in {:.3}ms", elapsed.as_micros());
}
Ok(())
}
Err(e) => {
let elapsed = start.elapsed();
if debug_enabled() && elapsed > std::time::Duration::from_millis(10) {
tracing::debug!(
"Pipeline send: blocked {:.2}s before channel closed, item={}",
elapsed.as_secs_f64(),
std::any::type_name::<I>()
);
} else if debug_enabled() {
tracing::debug!("Pipeline send: failed after {:.3}ms", elapsed.as_micros());
}
Err(e.0)
}
}
}
pub fn try_send(&self, item: I) -> Result<(), std::sync::mpsc::TrySendError<I>> {
self.tx.try_send(item)
}
pub fn send_with_halt(&self, item: I, halt: &Halt, deadline: Duration) -> Result<(), I> {
let end = Instant::now() + deadline;
let mut pending = item;
loop {
match self.tx.try_send(pending) {
Ok(()) => return Ok(()),
Err(TrySendError::Full(returned)) => {
pending = returned;
if halt.is_cancelled() {
if debug_enabled() {
tracing::debug!(
"Pipeline send_with_halt: halt observed, returning item={}",
std::any::type_name::<I>()
);
}
return Err(pending);
}
if Instant::now() >= end {
if debug_enabled() {
tracing::debug!(
"Pipeline send_with_halt: deadline elapsed, returning item={}",
std::any::type_name::<I>()
);
}
return Err(pending);
}
thread::sleep(SEND_POLL_INTERVAL);
}
Err(TrySendError::Disconnected(returned)) => {
if debug_enabled() {
tracing::debug!(
"Pipeline send_with_halt: consumer disconnected, item={}",
std::any::type_name::<I>()
);
}
return Err(returned);
}
}
}
}
pub fn finish(self) -> Result<R, Error> {
let Pipeline { tx, handle } = self;
drop(tx);
match handle.join() {
Ok(result) => result,
Err(payload) => {
let msg = payload
.downcast_ref::<&'static str>()
.copied()
.or_else(|| payload.downcast_ref::<String>().map(|s| s.as_str()))
.unwrap_or("(no message)");
Err(Error::IoError {
source: io::Error::other(format!("pipeline consumer panicked: {msg}")),
})
}
}
}
pub fn finish_with_halt(self, halt: Option<&Halt>) -> Result<R, Error> {
let Pipeline { tx, handle } = self;
drop(tx);
let deadline = Instant::now() + Duration::from_secs(JOIN_TIMEOUT_SECS);
loop {
if handle.is_finished() {
return match handle.join() {
Ok(result) => result,
Err(payload) => {
let msg = payload
.downcast_ref::<&'static str>()
.copied()
.or_else(|| payload.downcast_ref::<String>().map(|s| s.as_str()))
.unwrap_or("(no message)");
Err(Error::IoError {
source: io::Error::other(format!("pipeline consumer panicked: {msg}")),
})
}
};
}
if let Some(h) = halt {
if h.is_cancelled() {
return Err(Error::IoError {
source: io::Error::other("pipeline join halted"),
});
}
}
if Instant::now() >= deadline {
return Err(Error::IoError {
source: io::Error::other("pipeline join timed out"),
});
}
thread::sleep(POLL_INTERVAL);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
struct SumSink {
total: u64,
}
impl Sink<u64> for SumSink {
type Output = u64;
fn apply(&mut self, item: u64) -> Result<Flow, Error> {
self.total += item;
Ok(Flow::Continue)
}
fn close(self) -> Result<u64, Error> {
Ok(self.total)
}
}
#[test]
fn happy_path_sums_items() {
let pipe = Pipeline::spawn(DEFAULT_PIPELINE_DEPTH, SumSink { total: 0 })
.expect("spawn should succeed");
let mut expected = 0u64;
for i in 0..100u64 {
expected += i;
pipe.send(i).expect("send should succeed");
}
let total = pipe.finish().expect("finish should succeed");
assert_eq!(total, expected);
assert_eq!(total, (0..100u64).sum::<u64>());
}
struct SlowSink {
delay: Duration,
count: Arc<AtomicUsize>,
}
impl Sink<()> for SlowSink {
type Output = usize;
fn apply(&mut self, _item: ()) -> Result<Flow, Error> {
std::thread::sleep(self.delay);
self.count.fetch_add(1, Ordering::SeqCst);
Ok(Flow::Continue)
}
fn close(self) -> Result<usize, Error> {
Ok(self.count.load(Ordering::SeqCst))
}
}
#[test]
fn back_pressure_blocks_sender() {
let count = Arc::new(AtomicUsize::new(0));
let sink = SlowSink {
delay: Duration::from_millis(50),
count: count.clone(),
};
let pipe = Pipeline::spawn(2, sink).expect("spawn should succeed");
let start = Instant::now();
for _ in 0..5 {
pipe.send(()).expect("send should succeed");
}
let elapsed_send = start.elapsed();
let total = pipe.finish().expect("finish should succeed");
assert_eq!(total, 5);
assert!(
elapsed_send >= Duration::from_millis(80),
"back-pressure not observed: 5 sends with depth=2 and 50ms/apply \
took {elapsed_send:?}, expected ≥ ~100ms (one or more sends \
should have blocked behind the consumer)"
);
}
struct FailOnNthSink {
n: usize,
seen: Arc<AtomicUsize>,
close_called: Arc<AtomicUsize>,
}
impl Sink<u64> for FailOnNthSink {
type Output = ();
fn apply(&mut self, _item: u64) -> Result<Flow, Error> {
let i = self.seen.fetch_add(1, Ordering::SeqCst) + 1;
if i == self.n {
Err(Error::DecryptFailed)
} else {
Ok(Flow::Continue)
}
}
fn close(self) -> Result<(), Error> {
self.close_called.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[test]
fn apply_error_drains_then_propagates() {
let seen = Arc::new(AtomicUsize::new(0));
let close_called = Arc::new(AtomicUsize::new(0));
let pipe = Pipeline::spawn(
DEFAULT_PIPELINE_DEPTH,
FailOnNthSink {
n: 3,
seen: seen.clone(),
close_called: close_called.clone(),
},
)
.expect("spawn should succeed");
for i in 0..10u64 {
pipe.send(i).expect("send should succeed even after error");
}
let res = pipe.finish();
assert!(matches!(res, Err(Error::DecryptFailed)));
assert_eq!(
close_called.load(Ordering::SeqCst),
0,
"close() must not be called when apply returned Err"
);
assert_eq!(seen.load(Ordering::SeqCst), 3);
}
struct StopOnNthSink {
n: usize,
seen: Arc<AtomicUsize>,
close_called: Arc<AtomicUsize>,
}
impl Sink<u64> for StopOnNthSink {
type Output = usize;
fn apply(&mut self, _item: u64) -> Result<Flow, Error> {
let i = self.seen.fetch_add(1, Ordering::SeqCst) + 1;
if i >= self.n {
Ok(Flow::Stop)
} else {
Ok(Flow::Continue)
}
}
fn close(self) -> Result<usize, Error> {
self.close_called.fetch_add(1, Ordering::SeqCst);
Ok(self.seen.load(Ordering::SeqCst))
}
}
#[test]
fn apply_stop_calls_close_and_returns_output() {
let seen = Arc::new(AtomicUsize::new(0));
let close_called = Arc::new(AtomicUsize::new(0));
let pipe = Pipeline::spawn(
DEFAULT_PIPELINE_DEPTH,
StopOnNthSink {
n: 3,
seen: seen.clone(),
close_called: close_called.clone(),
},
)
.expect("spawn should succeed");
for i in 0..10u64 {
let _ = pipe.send(i);
}
let out = pipe.finish().expect("finish should succeed after Stop");
assert_eq!(close_called.load(Ordering::SeqCst), 1);
assert!(
out >= 3,
"expected ≥ 3 applies before Stop took effect, got {out}"
);
}
struct PanickingSink;
impl Sink<u64> for PanickingSink {
type Output = ();
fn apply(&mut self, _item: u64) -> Result<Flow, Error> {
panic!("synthetic test panic");
}
fn close(self) -> Result<(), Error> {
Ok(())
}
}
#[test]
fn consumer_panic_becomes_io_error() {
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let pipe =
Pipeline::spawn(DEFAULT_PIPELINE_DEPTH, PanickingSink).expect("spawn should succeed");
let _ = pipe.send(1);
for i in 0..5u64 {
let _ = pipe.send(i);
}
let res = pipe.finish();
std::panic::set_hook(prev);
match res {
Err(Error::IoError { source }) => {
let msg = source.to_string();
assert!(
msg.contains("pipeline consumer panicked"),
"expected constant panic prefix, got: {msg}"
);
assert!(
msg.contains("synthetic test panic"),
"expected original panic payload, got: {msg}"
);
}
other => panic!("expected Err(IoError), got {other:?}"),
}
}
struct NeverDrainsSink {
cancel: Arc<std::sync::atomic::AtomicBool>,
started: Arc<std::sync::atomic::AtomicBool>,
}
impl Sink<u64> for NeverDrainsSink {
type Output = ();
fn apply(&mut self, _item: u64) -> Result<Flow, Error> {
self.started.store(true, Ordering::SeqCst);
while !self.cancel.load(Ordering::SeqCst) {
std::thread::sleep(Duration::from_millis(20));
}
Ok(Flow::Continue)
}
fn close(self) -> Result<(), Error> {
Ok(())
}
}
fn wait_for_started(started: &Arc<std::sync::atomic::AtomicBool>, bail: Duration) {
let end = Instant::now() + bail;
while !started.load(Ordering::SeqCst) {
assert!(Instant::now() < end, "consumer never started apply()");
std::thread::sleep(Duration::from_millis(10));
}
}
#[test]
fn send_with_halt_returns_item_on_deadline() {
let cancel = Arc::new(std::sync::atomic::AtomicBool::new(false));
let started = Arc::new(std::sync::atomic::AtomicBool::new(false));
let pipe = Pipeline::spawn(
1,
NeverDrainsSink {
cancel: cancel.clone(),
started: started.clone(),
},
)
.expect("spawn should succeed");
pipe.send(0u64).expect("first send hands off to consumer");
wait_for_started(&started, Duration::from_secs(2));
pipe.send(1u64).expect("second send fills the buffer");
let halt = crate::halt::Halt::new();
let start = Instant::now();
let res = pipe.send_with_halt(99u64, &halt, Duration::from_millis(200));
let elapsed = start.elapsed();
cancel.store(true, Ordering::SeqCst);
let _ = pipe.finish();
assert!(matches!(res, Err(99)), "expected item returned on deadline");
assert!(
elapsed >= Duration::from_millis(150),
"deadline returned too early: {elapsed:?}"
);
assert!(
elapsed < Duration::from_secs(2),
"deadline blew past tolerance: {elapsed:?}"
);
}
#[test]
fn send_with_halt_returns_item_on_halt() {
let cancel = Arc::new(std::sync::atomic::AtomicBool::new(false));
let started = Arc::new(std::sync::atomic::AtomicBool::new(false));
let pipe = Pipeline::spawn(
1,
NeverDrainsSink {
cancel: cancel.clone(),
started: started.clone(),
},
)
.expect("spawn should succeed");
pipe.send(0u64).expect("first send hands off to consumer");
wait_for_started(&started, Duration::from_secs(2));
pipe.send(1u64).expect("second send fills the buffer");
let halt = crate::halt::Halt::new();
let halt2 = halt.clone();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(100));
halt2.cancel();
});
let start = Instant::now();
let res = pipe.send_with_halt(7u64, &halt, Duration::from_secs(10));
let elapsed = start.elapsed();
cancel.store(true, Ordering::SeqCst);
let _ = pipe.finish();
assert!(matches!(res, Err(7)), "expected item returned on halt");
assert!(
elapsed < Duration::from_secs(2),
"halt observation took too long: {elapsed:?}"
);
}
#[test]
fn finish_with_halt_returns_halted_when_consumer_wedged() {
let cancel = Arc::new(std::sync::atomic::AtomicBool::new(false));
let started = Arc::new(std::sync::atomic::AtomicBool::new(false));
let pipe = Pipeline::spawn(
DEFAULT_PIPELINE_DEPTH,
NeverDrainsSink {
cancel: cancel.clone(),
started: started.clone(),
},
)
.expect("spawn should succeed");
pipe.send(0u64).expect("seed item the consumer wedges on");
wait_for_started(&started, Duration::from_secs(2));
let halt = crate::halt::Halt::new();
let halt2 = halt.clone();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(400));
halt2.cancel();
});
let start = Instant::now();
let res = pipe.finish_with_halt(Some(&halt));
let elapsed = start.elapsed();
cancel.store(true, Ordering::SeqCst);
match res {
Err(Error::IoError { source }) => {
assert!(
source.to_string().contains("pipeline join halted"),
"expected halt-prefix error, got: {source}"
);
}
other => panic!("expected Err(IoError) halted, got {other:?}"),
}
assert!(
elapsed < Duration::from_secs(2),
"halt observation took too long: {elapsed:?}"
);
}
#[test]
fn finish_with_halt_happy_path_returns_output() {
let pipe = Pipeline::spawn(DEFAULT_PIPELINE_DEPTH, SumSink { total: 0 })
.expect("spawn should succeed");
for i in 0..10u64 {
pipe.send(i).expect("send should succeed");
}
let total = pipe
.finish_with_halt(None)
.expect("happy-path finish_with_halt should succeed");
assert_eq!(total, (0..10u64).sum::<u64>());
}
}