use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use encoding_rs::Encoding;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio::sync::Notify;
use crate::buffer::{OutputBufferPolicy, OverflowMode};
pub(crate) type LineHandler = Arc<dyn Fn(&str) + Send + Sync>;
pub(crate) struct SharedLines {
inner: Mutex<Inner>,
notify: Notify,
count: AtomicUsize,
dropped: AtomicUsize,
}
struct Inner {
lines: VecDeque<String>,
max_lines: Option<usize>,
max_bytes: Option<usize>,
bytes: usize,
seen_bytes: usize,
mode: OverflowMode,
closed: bool,
overflowed: bool,
}
impl Inner {
fn over_backlog(&self) -> bool {
self.max_lines.is_some_and(|n| self.lines.len() > n)
|| self.max_bytes.is_some_and(|b| self.bytes > b)
}
fn would_fit(&self, len: usize) -> bool {
self.max_lines.is_none_or(|n| self.lines.len() < n)
&& self.max_bytes.is_none_or(|b| self.bytes + len <= b)
}
}
pub(crate) enum Popped {
Line(String),
Empty,
Closed,
}
impl SharedLines {
pub(crate) fn new(policy: &OutputBufferPolicy) -> Arc<Self> {
Arc::new(Self {
inner: Mutex::new(Inner {
lines: VecDeque::new(),
max_lines: policy.max_lines,
max_bytes: policy.max_bytes,
bytes: 0,
seen_bytes: 0,
mode: policy.overflow,
closed: false,
overflowed: false,
}),
notify: Notify::new(),
count: AtomicUsize::new(0),
dropped: AtomicUsize::new(0),
})
}
pub(crate) fn push(&self, line: String) {
let total_lines = self.count.fetch_add(1, Ordering::Relaxed) + 1;
let mut policy_dropped = false;
{
let mut inner = self.inner.lock().expect("SharedLines poisoned");
inner.seen_bytes = inner.seen_bytes.saturating_add(line.len());
match inner.mode {
OverflowMode::Error => {
let over = match (inner.max_lines, inner.max_bytes) {
(None, None) => true,
(lines_cap, bytes_cap) => {
lines_cap.is_some_and(|n| total_lines > n)
|| bytes_cap.is_some_and(|b| inner.seen_bytes > b)
}
};
if over {
inner.overflowed = true;
policy_dropped = true;
} else {
inner.bytes += line.len();
inner.lines.push_back(line);
}
}
OverflowMode::DropOldest => {
inner.bytes += line.len();
inner.lines.push_back(line);
while inner.over_backlog() {
match inner.lines.pop_front() {
Some(old) => {
inner.bytes = inner.bytes.saturating_sub(old.len());
policy_dropped = true;
}
None => break,
}
}
}
OverflowMode::DropNewest => {
if inner.would_fit(line.len()) {
inner.bytes += line.len();
inner.lines.push_back(line);
} else {
policy_dropped = true;
}
}
}
}
if policy_dropped {
self.dropped.fetch_add(1, Ordering::Relaxed);
}
self.notify.notify_one();
}
fn close(&self) {
self.inner
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.closed = true;
self.notify.notify_one();
}
pub(crate) fn close_now(&self) {
self.close();
}
pub(crate) fn count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
pub(crate) fn seen_bytes(&self) -> usize {
self.inner
.lock()
.unwrap_or_else(|p| p.into_inner())
.seen_bytes
}
pub(crate) fn dropped(&self) -> usize {
self.dropped.load(Ordering::Relaxed)
}
pub(crate) fn overflowed(&self) -> bool {
self.inner
.lock()
.unwrap_or_else(|p| p.into_inner())
.overflowed
}
pub(crate) fn drain(&self) -> Vec<String> {
let mut inner = self.inner.lock().expect("SharedLines poisoned");
inner.bytes = 0;
inner.lines.drain(..).collect()
}
pub(crate) fn try_pop(&self) -> Popped {
let mut inner = self.inner.lock().expect("SharedLines poisoned");
if let Some(line) = inner.lines.pop_front() {
inner.bytes = inner.bytes.saturating_sub(line.len());
Popped::Line(line)
} else if inner.closed {
Popped::Closed
} else {
Popped::Empty
}
}
pub(crate) async fn changed(self: Arc<Self>) {
self.notify.notified().await;
}
}
pub(crate) type TeeSink = Arc<tokio::sync::Mutex<Box<dyn tokio::io::AsyncWrite + Send + Unpin>>>;
#[cfg(test)]
pub(crate) async fn pump_lines<R>(
reader: R,
encoding: &'static Encoding,
handler: Option<LineHandler>,
sink: Arc<SharedLines>,
) where
R: AsyncRead + Unpin,
{
pump_lines_core(reader, encoding, handler, None, sink).await
}
pub(crate) async fn pump_lines_core<R>(
mut reader: R,
encoding: &'static Encoding,
handler: Option<LineHandler>,
tee: Option<TeeSink>,
sink: Arc<SharedLines>,
) where
R: AsyncRead + Unpin,
{
struct CloseOnDrop(Arc<SharedLines>);
impl Drop for CloseOnDrop {
fn drop(&mut self) {
self.0.close();
}
}
let sink = CloseOnDrop(sink);
let mut handler = handler;
let mut tee = tee;
async fn emit(
handler: &mut Option<LineHandler>,
tee: &mut Option<TeeSink>,
sink: &SharedLines,
line: String,
) {
if let Some(h) = handler {
let invoked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| h(&line)));
if invoked.is_err() {
*handler = None;
#[cfg(feature = "tracing")]
tracing::warn!(
target: "processkit",
"line handler panicked; disabled for the rest of the run"
);
}
}
if let Some(t) = tee {
use tokio::io::AsyncWriteExt;
let mut w = t.lock().await;
let wrote = async {
w.write_all(line.as_bytes()).await?;
w.write_all(b"\n").await
}
.await;
drop(w);
if wrote.is_err() {
*tee = None;
#[cfg(feature = "tracing")]
tracing::warn!(
target: "processkit",
"tee writer errored; disabled for the rest of the run"
);
}
}
sink.push(line);
}
let mut decoder = encoding.new_decoder_with_bom_removal();
let mut pending = String::new(); let mut chunk = [0u8; 8192];
loop {
let (n, last) = match reader.read(&mut chunk).await {
Ok(0) => (0, true),
Ok(n) => (n, false),
Err(_) => (0, true),
};
if let Some(need) = decoder.max_utf8_buffer_length(n) {
pending.reserve(need);
}
let _ = decoder.decode_to_string(&chunk[..n], &mut pending, last);
while let Some(nl) = pending.find('\n') {
let mut line: String = pending.drain(..=nl).collect();
line.pop(); if line.ends_with('\r') {
line.pop(); }
emit(&mut handler, &mut tee, &sink.0, line).await;
}
if last {
if !pending.is_empty() {
emit(
&mut handler,
&mut tee,
&sink.0,
std::mem::take(&mut pending),
)
.await;
}
if let Some(t) = &tee {
use tokio::io::AsyncWriteExt;
let _ = t.lock().await.flush().await;
}
break;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::buffer::OutputBufferPolicy;
#[tokio::test]
async fn pumps_utf8_lines_and_counts() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(
&b"one\ntwo\nthree\n"[..],
encoding_rs::UTF_8,
None,
sink.clone(),
)
.await;
assert_eq!(sink.count(), 3);
assert_eq!(sink.drain(), vec!["one", "two", "three"]);
}
#[tokio::test]
async fn decodes_shift_jis() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(
&[0x82, 0xA0, b'\n'][..],
encoding_rs::SHIFT_JIS,
None,
sink.clone(),
)
.await;
assert_eq!(sink.drain(), vec!["\u{3042}"]);
}
#[tokio::test]
async fn drop_oldest_keeps_tail_but_counts_all() {
let sink = SharedLines::new(&OutputBufferPolicy::bounded(2));
pump_lines(&b"a\nb\nc\nd\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.count(), 4, "every line is counted");
assert_eq!(sink.drain(), vec!["c", "d"], "only the newest two retained");
}
#[tokio::test]
async fn drop_newest_keeps_head() {
let policy = OutputBufferPolicy::bounded(2).with_overflow(OverflowMode::DropNewest);
let sink = SharedLines::new(&policy);
pump_lines(&b"a\nb\nc\nd\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["a", "b"]);
}
#[tokio::test]
async fn fail_loud_sets_overflow_once_full_but_retains_the_cap() {
let sink = SharedLines::new(&OutputBufferPolicy::fail_loud(2));
pump_lines(&b"a\nb\nc\nd\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert!(sink.overflowed(), "third line must trip the fail-loud flag");
assert_eq!(sink.count(), 4, "every line is still counted");
assert_eq!(sink.drain(), vec!["a", "b"], "retains up to the cap");
}
#[tokio::test]
async fn fail_loud_under_the_cap_does_not_overflow() {
let sink = SharedLines::new(&OutputBufferPolicy::fail_loud(5));
pump_lines(&b"a\nb\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert!(!sink.overflowed(), "two lines under a 5-line cap is fine");
}
#[tokio::test]
async fn fail_loud_zero_errors_on_the_first_line() {
let sink = SharedLines::new(&OutputBufferPolicy::fail_loud(0));
pump_lines(&b"oops\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert!(sink.overflowed(), "any line is over a 0-line ceiling");
assert!(sink.drain().is_empty(), "still retains nothing");
}
#[tokio::test]
async fn unbounded_with_error_mode_is_zero_tolerance_not_inert() {
let sink =
SharedLines::new(&OutputBufferPolicy::unbounded().with_overflow(OverflowMode::Error));
pump_lines(&b"anything\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert!(
sink.overflowed(),
"unbounded + Error must fail loud on any output, not be inert"
);
assert!(sink.drain().is_empty(), "zero-tolerance retains nothing");
}
#[tokio::test]
async fn unbounded_without_error_mode_retains_everything() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&b"a\nb\nc\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert!(!sink.overflowed());
assert_eq!(sink.drain(), ["a", "b", "c"]);
}
#[tokio::test]
async fn dropped_counts_policy_drops_not_consumer_pops() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&b"a\nb\nc\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.count(), 3);
assert_eq!(sink.dropped(), 0, "unbounded policy discards nothing");
assert!(matches!(sink.try_pop(), Popped::Line(_)));
assert!(matches!(sink.try_pop(), Popped::Line(_)));
assert_eq!(
sink.dropped(),
0,
"a streaming consumer's pops are not truncation"
);
let bounded = SharedLines::new(&OutputBufferPolicy::bounded(2));
pump_lines(
&b"a\nb\nc\nd\n"[..],
encoding_rs::UTF_8,
None,
bounded.clone(),
)
.await;
assert_eq!(
bounded.dropped(),
2,
"DropOldest discarded the two oldest lines"
);
let newest = SharedLines::new(
&OutputBufferPolicy::bounded(2).with_overflow(OverflowMode::DropNewest),
);
pump_lines(
&b"a\nb\nc\nd\n"[..],
encoding_rs::UTF_8,
None,
newest.clone(),
)
.await;
assert_eq!(
newest.dropped(),
2,
"DropNewest discarded the two newest lines"
);
}
#[tokio::test]
async fn bounded_zero_without_error_mode_never_overflows() {
let sink = SharedLines::new(&OutputBufferPolicy::bounded(0));
pump_lines(&b"a\nb\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert!(!sink.overflowed());
}
#[tokio::test]
async fn handler_sees_every_line_even_when_nothing_retained() {
let seen = Arc::new(Mutex::new(Vec::new()));
let captured = seen.clone();
let handler: LineHandler =
Arc::new(move |line: &str| captured.lock().unwrap().push(line.to_owned()));
let sink = SharedLines::new(&OutputBufferPolicy::bounded(0));
pump_lines(
&b"x\ny\n"[..],
encoding_rs::UTF_8,
Some(handler),
sink.clone(),
)
.await;
assert_eq!(sink.count(), 2);
assert!(
sink.drain().is_empty(),
"retain-nothing policy keeps no lines"
);
assert_eq!(*seen.lock().unwrap(), vec!["x", "y"]);
}
#[tokio::test]
async fn crlf_only_line_is_one_empty_line() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&b"\r\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.count(), 1);
assert_eq!(sink.drain(), vec![""]);
}
#[tokio::test]
async fn final_line_without_a_trailing_newline_is_emitted() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&b"alpha\nomega"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.count(), 2, "the un-terminated tail still counts");
assert_eq!(sink.drain(), vec!["alpha", "omega"]);
}
#[tokio::test]
async fn empty_reader_closes_with_no_lines() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&b""[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.count(), 0);
assert!(sink.drain().is_empty());
assert!(
matches!(sink.try_pop(), Popped::Closed),
"the sink must close on EOF so a streaming consumer ends"
);
}
#[tokio::test]
async fn invalid_multibyte_decodes_lossily_not_fatally() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(
&[0x82, b'\n'][..],
encoding_rs::SHIFT_JIS,
None,
sink.clone(),
)
.await;
let lines = sink.drain();
assert_eq!(lines.len(), 1);
assert!(
lines[0].contains('\u{FFFD}'),
"invalid bytes decode to the replacement char: {lines:?}"
);
}
#[tokio::test]
async fn panicking_handler_is_isolated_and_capture_completes() {
use std::sync::atomic::{AtomicUsize, Ordering};
let calls = Arc::new(AtomicUsize::new(0));
let handler: LineHandler = {
let calls = calls.clone();
Arc::new(move |_: &str| {
if calls.fetch_add(1, Ordering::SeqCst) == 1 {
panic!("boom on the second line");
}
})
};
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
let task = tokio::spawn(pump_lines(
&b"1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n"[..],
encoding_rs::UTF_8,
Some(handler),
sink.clone(),
));
task.await
.expect("the pump task must survive a handler panic");
assert_eq!(sink.count(), 10, "every line captured despite the panic");
assert_eq!(
sink.drain(),
(1..=10).map(|n| n.to_string()).collect::<Vec<_>>()
);
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"the handler is disabled after its panic (called for lines 1 and 2 only)"
);
assert!(
matches!(sink.try_pop(), Popped::Closed),
"sink closes normally after the drain"
);
}
struct ChunkedReader {
chunks: VecDeque<Vec<u8>>,
err_at_end: bool,
}
impl ChunkedReader {
fn new(chunks: impl IntoIterator<Item = Vec<u8>>) -> Self {
Self {
chunks: chunks.into_iter().collect(),
err_at_end: false,
}
}
fn erroring(chunks: impl IntoIterator<Item = Vec<u8>>) -> Self {
Self {
chunks: chunks.into_iter().collect(),
err_at_end: true,
}
}
}
impl AsyncRead for ChunkedReader {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
if let Some(chunk) = self.chunks.pop_front() {
let n = chunk.len().min(buf.remaining());
buf.put_slice(&chunk[..n]);
if n < chunk.len() {
self.chunks.push_front(chunk[n..].to_vec());
}
std::task::Poll::Ready(Ok(()))
} else if self.err_at_end {
self.err_at_end = false;
std::task::Poll::Ready(Err(std::io::Error::other("boom")))
} else {
std::task::Poll::Ready(Ok(())) }
}
}
#[tokio::test]
async fn utf16le_lines_decode_and_split_correctly() {
let bytes = [
0x41, 0x00, 0x42, 0x00, 0x0A, 0x00, 0x43, 0x00, 0x44, 0x00, 0x0A, 0x00, ];
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&bytes[..], encoding_rs::UTF_16LE, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["AB", "CD"]);
}
#[tokio::test]
async fn utf16le_code_unit_split_across_reads_is_reassembled() {
let reader = ChunkedReader::new([vec![0x41, 0x00, 0x42], vec![0x00, 0x0A, 0x00]]);
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(reader, encoding_rs::UTF_16LE, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["AB"]);
}
#[tokio::test]
async fn utf16le_leading_bom_is_removed_once() {
let bytes = [0xFF, 0xFE, 0x41, 0x00, 0x0A, 0x00];
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&bytes[..], encoding_rs::UTF_16LE, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["A"]);
}
#[tokio::test]
async fn utf8_leading_bom_is_removed_once_not_per_line() {
let bytes = [0xEF, 0xBB, 0xBF, b'h', b'i', b'\n', b'b', b'y', b'e', b'\n'];
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&bytes[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["hi", "bye"]);
}
#[tokio::test]
async fn strips_exactly_one_trailing_cr_not_all() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&b"data\r\r\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["data\r"]);
}
#[tokio::test]
async fn lone_trailing_cr_at_eof_is_kept_as_content() {
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&b"tail\r"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["tail\r"]);
}
#[tokio::test]
async fn mid_stream_read_error_flushes_the_partial_tail() {
let reader = ChunkedReader::erroring([b"done\npart".to_vec()]);
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(reader, encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.count(), 2, "the partial tail still counts");
assert_eq!(sink.drain(), vec!["done", "part"]);
}
#[tokio::test]
async fn legacy_line_starting_with_bom_bytes_is_not_resniffed() {
let bytes = [0xFF, 0xFE, b'x', b'\n'];
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines(&bytes[..], encoding_rs::WINDOWS_1252, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["\u{00FF}\u{00FE}x"]);
}
#[tokio::test]
async fn fail_loud_trips_on_total_even_when_streamed_dry() {
let sink = SharedLines::new(&OutputBufferPolicy::fail_loud(2));
sink.push("a".into());
assert!(matches!(sink.try_pop(), Popped::Line(_))); sink.push("b".into());
assert!(matches!(sink.try_pop(), Popped::Line(_))); assert!(!sink.overflowed(), "two lines is within the cap");
sink.push("c".into()); assert!(
sink.overflowed(),
"the 3rd line trips the ceiling even though the backlog was drained dry"
);
}
#[tokio::test]
async fn max_bytes_drop_oldest_evicts_to_fit_the_byte_cap() {
let policy = OutputBufferPolicy::unbounded().with_max_bytes(5);
let sink = SharedLines::new(&policy);
pump_lines(&b"aa\nbb\ncc\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["bb", "cc"]);
assert_eq!(sink.count(), 3, "every line is still counted");
}
#[tokio::test]
async fn max_bytes_drops_a_single_oversized_line_whole() {
let policy = OutputBufferPolicy::unbounded().with_max_bytes(3);
let sink = SharedLines::new(&policy);
pump_lines(
&b"toolong\nok\n"[..],
encoding_rs::UTF_8,
None,
sink.clone(),
)
.await;
assert_eq!(sink.drain(), vec!["ok"], "the oversized line was dropped");
assert_eq!(sink.count(), 2);
assert!(sink.dropped() >= 1);
}
#[tokio::test]
async fn max_bytes_fail_loud_trips_on_byte_total() {
let policy = OutputBufferPolicy::unbounded()
.with_overflow(OverflowMode::Error)
.with_max_bytes(4);
let sink = SharedLines::new(&policy);
pump_lines(&b"ab\ncd\nef\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert!(
sink.overflowed(),
"6 cumulative bytes over a 4-byte ceiling must trip it"
);
}
#[tokio::test]
async fn max_bytes_under_the_cap_does_not_trip_or_drop() {
let policy = OutputBufferPolicy::fail_loud(10).with_max_bytes(100);
let sink = SharedLines::new(&policy);
pump_lines(&b"ab\ncd\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert!(!sink.overflowed());
assert_eq!(sink.dropped(), 0);
assert_eq!(sink.drain(), vec!["ab", "cd"]);
}
#[tokio::test]
async fn max_bytes_drop_newest_keeps_head_within_byte_cap() {
let policy = OutputBufferPolicy::unbounded()
.with_overflow(OverflowMode::DropNewest)
.with_max_bytes(4);
let sink = SharedLines::new(&policy);
pump_lines(&b"ab\ncd\nef\n"[..], encoding_rs::UTF_8, None, sink.clone()).await;
assert_eq!(sink.drain(), vec!["ab", "cd"]);
}
#[derive(Clone)]
struct VecSink(Arc<Mutex<Vec<u8>>>);
impl tokio::io::AsyncWrite for VecSink {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
self.0.lock().unwrap().extend_from_slice(buf);
std::task::Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}
fn tee_of(sink: impl tokio::io::AsyncWrite + Send + Unpin + 'static) -> TeeSink {
Arc::new(tokio::sync::Mutex::new(Box::new(sink)))
}
#[tokio::test]
async fn tee_writes_each_decoded_line_plus_newline_to_the_async_sink() {
let buf = Arc::new(Mutex::new(Vec::new()));
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines_core(
&b"one\ntwo\n"[..],
encoding_rs::UTF_8,
None,
Some(tee_of(VecSink(buf.clone()))),
sink.clone(),
)
.await;
assert_eq!(sink.drain(), vec!["one", "two"], "capture is unaffected");
let teed = String::from_utf8(buf.lock().unwrap().clone()).unwrap();
assert_eq!(teed, "one\ntwo\n", "the tee got each line + a newline");
}
#[tokio::test]
async fn tee_write_error_is_isolated_and_capture_continues() {
struct ErrSink;
impl tokio::io::AsyncWrite for ErrSink {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
std::task::Poll::Ready(Err(std::io::Error::other("nope")))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines_core(
&b"a\nb\nc\n"[..],
encoding_rs::UTF_8,
None,
Some(tee_of(ErrSink)),
sink.clone(),
)
.await;
assert_eq!(
sink.drain(),
vec!["a", "b", "c"],
"capture survives a tee write error"
);
}
#[tokio::test]
async fn tee_and_line_handler_both_fire_independently() {
let buf = Arc::new(Mutex::new(Vec::new()));
let seen = Arc::new(Mutex::new(Vec::new()));
let captured = seen.clone();
let handler: LineHandler =
Arc::new(move |line: &str| captured.lock().unwrap().push(line.to_owned()));
let sink = SharedLines::new(&OutputBufferPolicy::unbounded());
pump_lines_core(
&b"x\ny\n"[..],
encoding_rs::UTF_8,
Some(handler),
Some(tee_of(VecSink(buf.clone()))),
sink.clone(),
)
.await;
assert_eq!(*seen.lock().unwrap(), vec!["x", "y"], "handler fired");
assert_eq!(
String::from_utf8(buf.lock().unwrap().clone()).unwrap(),
"x\ny\n",
"tee fired"
);
}
}