Skip to main content

ralph_workflow/runtime/
streaming.rs

1//! Streaming I/O utilities for runtime boundary.
2//!
3//! This module provides streaming readers for process output.
4
5use std::io::{self, BufRead, BufReader, Read};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{mpsc, Arc};
8use std::time::Duration;
9
10use crate::pipeline::idle_timeout::{ActivityTrackingReader, SharedActivityTimestamp};
11
12/// A line-oriented reader that processes data as it arrives.
13///
14/// Unlike `BufReader::lines()`, this reader yields lines immediately when newlines are
15/// encountered, without waiting for the buffer to fill. This enables real-time streaming
16/// for agents that output NDJSON gradually.
17///
18/// # Buffer Size Limit
19///
20/// This reader enforces a hard cap for a single line (bytes since the last '\n') to
21/// prevent memory exhaustion from malicious or malformed input that never contains
22/// newlines.
23pub struct StreamingLineReader<R: Read> {
24    inner: BufReader<R>,
25    buffer: Vec<u8>,
26    consumed: usize,
27}
28
29/// Maximum line size in bytes.
30///
31/// Important: `BufRead::lines()` uses `read_line()` under the hood. Without a per-line
32/// cap, `read_line()` can accumulate arbitrarily large `String`s even if `fill_buf()`
33/// only ever returns small chunks.
34///
35/// The value of 1 MiB was chosen to:
36/// - Handle most legitimate JSON documents (typically < 100KB)
37/// - Allow for reasonably long single-line JSON outputs
38/// - Prevent memory exhaustion from malicious input
39/// - Keep the buffer size manageable for most systems
40///
41/// If your use case requires larger single-line JSON, consider:
42/// - Modifying your agent to output NDJSON (newline-delimited JSON)
43/// - Adjusting this constant and rebuilding
44pub const MAX_BUFFER_SIZE: usize = 1024 * 1024; // 1 MiB
45
46impl<R: Read> StreamingLineReader<R> {
47    /// Create a new streaming line reader with a small buffer for low latency.
48    pub fn new(inner: R) -> Self {
49        const BUFFER_SIZE: usize = 1024;
50        Self {
51            inner: BufReader::with_capacity(BUFFER_SIZE, inner),
52            buffer: Vec::new(),
53            consumed: 0,
54        }
55    }
56
57    fn fill_buffer(&mut self) -> io::Result<usize> {
58        let current_size = self.buffer.len() - self.consumed;
59        check_buffer_size_limit(current_size)?;
60
61        let mut read_buf = [0u8; 256];
62        let n = self.inner.read(&mut read_buf)?;
63        if n > 0 {
64            let new_size = current_size + n;
65            check_buffer_size_limit(new_size)?;
66            self.buffer.extend_from_slice(&read_buf[..n]);
67        }
68        Ok(n)
69    }
70}
71
72fn check_buffer_size_limit(current_size: usize) -> io::Result<()> {
73    if current_size >= MAX_BUFFER_SIZE {
74        return Err(io::Error::other(format!(
75            "StreamingLineReader buffer exceeded maximum size of {MAX_BUFFER_SIZE} bytes. \
76             This may indicate malformed input or an agent that is not sending newlines."
77        )));
78    }
79    Ok(())
80}
81
82fn check_line_size_limit(line_len: usize) -> io::Result<()> {
83    if line_len >= MAX_BUFFER_SIZE {
84        return Err(io::Error::other(format!(
85            "StreamingLineReader line exceeded maximum size of {MAX_BUFFER_SIZE} bytes. \
86             This may indicate malformed input or an agent that is not sending newlines."
87        )));
88    }
89    Ok(())
90}
91
92fn check_chunk_size_limit(line_len: usize, to_take: usize) -> io::Result<()> {
93    let remaining = MAX_BUFFER_SIZE - line_len;
94    if to_take > remaining {
95        return Err(io::Error::other(format!(
96            "StreamingLineReader line would exceed maximum size of {MAX_BUFFER_SIZE} bytes. \
97             This may indicate malformed input or an agent that is not sending newlines."
98        )));
99    }
100    Ok(())
101}
102
103fn parse_utf8_chunk(chunk: &[u8]) -> io::Result<&str> {
104    std::str::from_utf8(chunk).map_err(|e| {
105        io::Error::new(
106            io::ErrorKind::InvalidData,
107            format!("agent output is not valid UTF-8: {e}"),
108        )
109    })
110}
111
112impl<R: Read> Read for StreamingLineReader<R> {
113    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
114        let available = self.buffer.len() - self.consumed;
115        if available > 0 {
116            let to_copy = available.min(buf.len());
117            buf[..to_copy].copy_from_slice(&self.buffer[self.consumed..self.consumed + to_copy]);
118            self.consumed += to_copy;
119
120            if self.consumed == self.buffer.len() {
121                self.buffer.clear();
122                self.consumed = 0;
123            }
124            return Ok(to_copy);
125        }
126
127        self.inner.read(buf)
128    }
129}
130
131impl<R: Read> BufRead for StreamingLineReader<R> {
132    fn fill_buf(&mut self) -> io::Result<&[u8]> {
133        const MAX_ATTEMPTS: usize = 8;
134
135        if self.consumed < self.buffer.len() {
136            return Ok(&self.buffer[self.consumed..]);
137        }
138
139        self.buffer.clear();
140        self.consumed = 0;
141
142        let total_read = fill_buffer_with_retry(self, MAX_ATTEMPTS)?;
143        if total_read == 0 {
144            return Ok(&[]);
145        }
146
147        Ok(&self.buffer[self.consumed..])
148    }
149
150    fn consume(&mut self, amt: usize) {
151        self.consumed = (self.consumed + amt).min(self.buffer.len());
152
153        if self.consumed == self.buffer.len() {
154            self.buffer.clear();
155            self.consumed = 0;
156        }
157    }
158
159    fn read_line(&mut self, buf: &mut String) -> io::Result<usize> {
160        let start_len = buf.len();
161        loop {
162            match read_line_step(self, buf, start_len)? {
163                ReadLineStep::Done => return Ok(buf.len() - start_len),
164                ReadLineStep::Continue => {}
165            }
166        }
167    }
168}
169
170enum ReadLineStep {
171    Done,
172    Continue,
173}
174
175fn read_line_step<R: Read>(
176    reader: &mut StreamingLineReader<R>,
177    buf: &mut String,
178    start_len: usize,
179) -> io::Result<ReadLineStep> {
180    check_line_size_limit(buf.len() - start_len)?;
181    let available = reader.fill_buf()?;
182    if available.is_empty() {
183        return Ok(ReadLineStep::Done);
184    }
185    let newline_pos = available.iter().position(|&b| b == b'\n');
186    let to_take = newline_pos.map_or(available.len(), |i| i + 1);
187    check_chunk_size_limit(buf.len() - start_len, to_take)?;
188    buf.push_str(parse_utf8_chunk(&available[..to_take])?);
189    reader.consume(to_take);
190    Ok(newline_pos.map_or(ReadLineStep::Continue, |_| ReadLineStep::Done))
191}
192
193/// Outcome of a single fill attempt.
194enum FillStepOutcome {
195    /// Loop should stop; return this accumulated total.
196    Stop(usize),
197    /// Loop should continue with this updated total.
198    Continue(usize),
199}
200
201fn classify_fill_step(n: usize, total_read: usize, has_newline: bool) -> FillStepOutcome {
202    match n {
203        0 if total_read == 0 => FillStepOutcome::Stop(0),
204        0 => FillStepOutcome::Stop(total_read),
205        _ if has_newline => FillStepOutcome::Stop(total_read + n),
206        _ => FillStepOutcome::Continue(total_read + n),
207    }
208}
209
210fn fill_buffer_with_retry(
211    reader: &mut StreamingLineReader<impl Read>,
212    max_attempts: usize,
213) -> io::Result<usize> {
214    let mut total_read = 0;
215    for _ in 0..max_attempts {
216        let n = reader.fill_buffer()?;
217        match classify_fill_step(n, total_read, reader.buffer.contains(&b'\n')) {
218            FillStepOutcome::Stop(v) => return Ok(v),
219            FillStepOutcome::Continue(next) => total_read = next,
220        }
221    }
222    Ok(total_read)
223}
224
225/// Result type for stdout channel operations.
226type StdoutChannel = (
227    mpsc::SyncSender<io::Result<Vec<u8>>>,
228    mpsc::Receiver<io::Result<Vec<u8>>>,
229);
230
231// Upper bound on stdout data buffered between the pump thread and the parser.
232// Each pump chunk is up to 4096 bytes.
233pub const STDOUT_PUMP_CHANNEL_CAPACITY: usize = 256; // 256 * 4096B chunks ~= 1MiB worst-case
234
235/// A reader that wraps a channel receiver with cancelation support.
236///
237/// This allows the reader to be stopped promptly when cancellation is requested,
238/// even if the underlying receive is blocking.
239pub struct CancelAwareReceiverBufRead {
240    rx: mpsc::Receiver<io::Result<Vec<u8>>>,
241    cancel: Arc<AtomicBool>,
242    poll_interval: Duration,
243    buffer: Vec<u8>,
244    consumed: usize,
245    eof: bool,
246}
247
248impl CancelAwareReceiverBufRead {
249    /// Create a new cancel-aware reader.
250    pub fn new(
251        rx: mpsc::Receiver<io::Result<Vec<u8>>>,
252        cancel: Arc<AtomicBool>,
253        poll_interval: Duration,
254    ) -> Self {
255        Self {
256            rx,
257            cancel,
258            poll_interval,
259            buffer: Vec::new(),
260            consumed: 0,
261            eof: false,
262        }
263    }
264
265    fn apply_cancel_if_needed(&mut self) {
266        if self.cancel.load(Ordering::Acquire) {
267            self.buffer.clear();
268            self.consumed = 0;
269            self.eof = true;
270        }
271    }
272
273    fn recv_loop(&mut self) -> io::Result<()> {
274        loop {
275            if self.cancel.load(Ordering::Acquire) {
276                self.eof = true;
277                return Ok(());
278            }
279            if apply_recv_step(
280                self.rx.recv_timeout(self.poll_interval),
281                &mut self.buffer,
282                &mut self.eof,
283            )? {
284                return Ok(());
285            }
286        }
287    }
288
289    fn refill_if_needed(&mut self) -> io::Result<()> {
290        if should_cancel_or_eof(
291            self.cancel.load(Ordering::Acquire),
292            self.eof,
293            self.consumed,
294            &self.buffer,
295        ) {
296            self.apply_cancel_if_needed();
297            return Ok(());
298        }
299
300        self.buffer.clear();
301        self.consumed = 0;
302        self.recv_loop()
303    }
304}
305
306fn should_cancel_or_eof(cancelled: bool, eof: bool, consumed: usize, buffer: &[u8]) -> bool {
307    cancelled || eof || consumed < buffer.len()
308}
309
310enum RecvStep {
311    Done(Vec<u8>),
312    Eof,
313    Continue,
314}
315
316fn apply_recv_result(
317    result: Result<io::Result<Vec<u8>>, mpsc::RecvTimeoutError>,
318) -> io::Result<RecvStep> {
319    match result {
320        Ok(Ok(chunk)) if chunk.is_empty() => Ok(RecvStep::Eof),
321        Ok(Ok(chunk)) => Ok(RecvStep::Done(chunk)),
322        Ok(Err(e)) => Err(e),
323        Err(mpsc::RecvTimeoutError::Timeout) => Ok(RecvStep::Continue),
324        Err(mpsc::RecvTimeoutError::Disconnected) => Ok(RecvStep::Eof),
325    }
326}
327
328/// Apply a single receive result to the buffer state.
329///
330/// Returns `Ok(true)` when the loop should stop (done or eof), `Ok(false)` to continue.
331fn apply_recv_step(
332    result: Result<io::Result<Vec<u8>>, mpsc::RecvTimeoutError>,
333    buffer: &mut Vec<u8>,
334    eof: &mut bool,
335) -> io::Result<bool> {
336    match apply_recv_result(result)? {
337        RecvStep::Done(chunk) => {
338            *buffer = chunk;
339            Ok(true)
340        }
341        RecvStep::Eof => {
342            *eof = true;
343            Ok(true)
344        }
345        RecvStep::Continue => Ok(false),
346    }
347}
348
349impl Read for CancelAwareReceiverBufRead {
350    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
351        self.refill_if_needed()?;
352        if self.eof {
353            return Ok(0);
354        }
355
356        let available = self.buffer.len() - self.consumed;
357        if available == 0 {
358            return Ok(0);
359        }
360        let to_copy = available.min(buf.len());
361        buf[..to_copy].copy_from_slice(&self.buffer[self.consumed..self.consumed + to_copy]);
362        self.consumed += to_copy;
363        Ok(to_copy)
364    }
365}
366
367impl BufRead for CancelAwareReceiverBufRead {
368    fn fill_buf(&mut self) -> io::Result<&[u8]> {
369        self.refill_if_needed()?;
370        if self.eof {
371            return Ok(&[]);
372        }
373        Ok(&self.buffer[self.consumed..])
374    }
375
376    fn consume(&mut self, amt: usize) {
377        self.consumed = (self.consumed + amt).min(self.buffer.len());
378        if self.consumed == self.buffer.len() {
379            self.buffer.clear();
380            self.consumed = 0;
381        }
382    }
383}
384
385/// Spawn a thread to pump stdout data from a reader into a channel.
386pub fn spawn_stdout_pump(
387    stdout: Box<dyn io::Read + Send>,
388    activity_timestamp: SharedActivityTimestamp,
389    tx: mpsc::SyncSender<io::Result<Vec<u8>>>,
390    cancel: Arc<AtomicBool>,
391) -> std::thread::JoinHandle<()> {
392    std::thread::spawn(move || {
393        let mut tracked_stdout = ActivityTrackingReader::new(stdout, activity_timestamp);
394        let mut buf = [0u8; 4096];
395
396        loop {
397            if cancel.load(Ordering::Acquire) {
398                return;
399            }
400            match tracked_stdout.read(&mut buf) {
401                Ok(0) => {
402                    if tx.send(Ok(Vec::new())).is_err() {
403                        return;
404                    }
405                    return;
406                }
407                Ok(n) => {
408                    if tx.send(Ok(buf[..n].to_vec())).is_err() {
409                        return;
410                    }
411                }
412                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
413                    if cancel.load(Ordering::Acquire) {
414                        return;
415                    }
416                    std::thread::sleep(Duration::from_millis(10));
417                }
418                Err(e) => {
419                    let _ = tx.send(Err(e));
420                    return;
421                }
422            }
423        }
424    })
425}
426
427fn pump_should_detach(cancelled: bool, parse_err: &io::Result<()>) -> bool {
428    cancelled || parse_err.is_err()
429}
430
431fn detach_message_for_logger(detached: bool) -> Option<&'static str> {
432    detached.then_some("Stdout pump thread did not exit; detaching thread")
433}
434
435fn wait_for_pump_deadline(pump_handle: &std::thread::JoinHandle<()>, deadline: std::time::Instant) {
436    while !pump_handle.is_finished() && std::time::Instant::now() < deadline {
437        std::thread::sleep(Duration::from_millis(10));
438    }
439}
440
441fn finalize_pump(pump_handle: std::thread::JoinHandle<()>, logger: &crate::logger::Logger) {
442    if pump_handle.is_finished() {
443        let _ = pump_handle.join();
444    } else {
445        if let Some(msg) = detach_message_for_logger(true) {
446            logger.warn(msg);
447        }
448        drop(pump_handle);
449    }
450}
451
452/// Clean up the stdout pump thread.
453fn join_or_detach_pump(pump_handle: std::thread::JoinHandle<()>, logger: &crate::logger::Logger) {
454    let deadline = std::time::Instant::now() + Duration::from_secs(2);
455    wait_for_pump_deadline(&pump_handle, deadline);
456    finalize_pump(pump_handle, logger);
457}
458
459pub fn cleanup_stdout_pump(
460    pump_handle: std::thread::JoinHandle<()>,
461    cancel: &Arc<AtomicBool>,
462    logger: &crate::logger::Logger,
463    parse_result: &io::Result<()>,
464) {
465    if parse_result.is_err() {
466        cancel.store(true, Ordering::Release);
467    }
468
469    let should_detach = pump_should_detach(cancel.load(Ordering::Acquire), parse_result);
470    if should_detach {
471        join_or_detach_pump(pump_handle, logger);
472    } else {
473        let _ = pump_handle.join();
474    }
475}
476
477/// Create a bounded channel for stdout pumping.
478pub fn create_stdout_channel() -> StdoutChannel {
479    mpsc::sync_channel(STDOUT_PUMP_CHANNEL_CAPACITY)
480}