use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio_process_tools::{
BroadcastOutputStream, Consumable, Consumer, Delivery, LineParsingOptions, Next, ParseLines,
ProcessHandle, ReliableWithBackpressure, Replay, ReplayEnabled,
};
use unwrap_infallible::UnwrapInfallible;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DriverOutputSource {
Stdout,
Stderr,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DriverOutputLine {
pub source: DriverOutputSource,
pub sequence: u64,
pub line: String,
}
#[derive(Clone)]
pub struct DriverOutputListener {
on_line: Arc<dyn Fn(DriverOutputLine) + Send + Sync + 'static>,
}
impl fmt::Debug for DriverOutputListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DriverOutputListener")
.field("on_line", &"<callback>")
.finish()
}
}
impl DriverOutputListener {
#[must_use]
pub fn new(on_line: impl Fn(DriverOutputLine) + Send + Sync + 'static) -> Self {
Self {
on_line: Arc::new(on_line),
}
}
pub(crate) fn emit(&self, line: DriverOutputLine) {
(self.on_line)(line);
}
}
pub struct DriverOutputInspectors {
stdout: Consumer<()>,
stderr: Consumer<()>,
}
impl fmt::Debug for DriverOutputInspectors {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DriverOutputInspectors")
.field("stdout_finished", &self.stdout.is_finished())
.field("stderr_finished", &self.stderr.is_finished())
.finish()
}
}
impl DriverOutputInspectors {
pub(crate) fn start(
process: &ProcessHandle<BroadcastOutputStream<ReliableWithBackpressure, ReplayEnabled>>,
listener: Option<DriverOutputListener>,
) -> Self {
let sequence = Arc::new(AtomicU64::new(0));
Self {
stdout: inspect_output(
process.stdout(),
DriverOutputSource::Stdout,
Arc::clone(&sequence),
listener.clone(),
),
stderr: inspect_output(
process.stderr(),
DriverOutputSource::Stderr,
sequence,
listener,
),
}
}
}
fn inspect_output<D, R>(
stream: &BroadcastOutputStream<D, R>,
source: DriverOutputSource,
sequence: Arc<AtomicU64>,
listener: Option<DriverOutputListener>,
) -> Consumer<()>
where
D: Delivery,
R: Replay,
{
stream
.consume(ParseLines::inspect(
LineParsingOptions::default(),
move |line| {
let line_ref: &str = &line;
tracing::debug!(source = ?source, driver_output = line_ref, "driver log");
if let Some(listener) = &listener {
listener.emit(DriverOutputLine {
source,
sequence: sequence.fetch_add(1, Ordering::SeqCst),
line: line.into_owned(),
});
}
Next::Continue
},
))
.unwrap_infallible()
}
#[cfg(test)]
mod tests {
use super::*;
use assertr::prelude::*;
use std::sync::Mutex;
#[test]
fn driver_output_listener_invokes_callback() {
let lines = Arc::new(Mutex::new(Vec::new()));
let listener = {
let lines = Arc::clone(&lines);
DriverOutputListener::new(move |line| {
lines
.lock()
.expect("lines mutex should not be poisoned")
.push(line);
})
};
listener.emit(DriverOutputLine {
source: DriverOutputSource::Stdout,
sequence: 0,
line: "ready".to_owned(),
});
let lines = lines.lock().expect("lines mutex should not be poisoned");
assert_that!(lines.as_slice()).contains_exactly([DriverOutputLine {
source: DriverOutputSource::Stdout,
sequence: 0,
line: "ready".to_owned(),
}]);
}
}