ralph_workflow/pipeline/
prompt.rs

1//! Prompt-based command execution.
2
3use crate::agents::{is_glm_like_agent, JsonParserType};
4use crate::common::{format_argv_for_log, split_command, truncate_text};
5use crate::config::Config;
6use crate::logger::Colors;
7use crate::logger::Logger;
8use crate::logger::{argv_requests_json, format_generic_json_for_display};
9use crate::pipeline::Timer;
10use std::fs::{self, File, OpenOptions};
11use std::io::{self, BufRead, BufReader, Read, Write};
12use std::path::Path;
13use std::process::{Child, ChildStdout, Command, Stdio};
14
15/// A line-oriented reader that processes data as it arrives.
16///
17/// Unlike `BufReader::lines()`, this reader yields lines immediately
18/// when newlines are encountered, without waiting for the buffer to fill.
19/// This enables real-time streaming for agents that output NDJSON gradually.
20///
21/// # Buffer Size Limit
22///
23/// The reader enforces a maximum buffer size to prevent memory exhaustion
24/// from malicious or malformed input that never contains newlines.
25/// If the buffer exceeds this limit, subsequent reads will fail with an error.
26struct StreamingLineReader<R: Read> {
27    inner: BufReader<R>,
28    buffer: Vec<u8>,
29    consumed: usize,
30}
31
32/// Maximum buffer size in bytes to prevent unbounded memory growth.
33///
34/// This limits the impact of agents that output continuous data without newlines.
35/// The value of 1 MiB was chosen to:
36/// - Handle most legitimate JSON documents (typically < 100KB)
37/// - Allow for reasonably long single-line JSON outputs
38/// - Prevent memory exhaustion from malicious input
39/// - Keep the buffer size manageable for most systems
40///
41/// If your use case requires larger single-line JSON, consider:
42/// - Modifying your agent to output NDJSON (newline-delimited JSON)
43/// - Adjusting this constant and rebuilding
44const MAX_BUFFER_SIZE: usize = 1024 * 1024; // 1 MiB
45
46impl<R: Read> StreamingLineReader<R> {
47    /// Create a new streaming line reader with a small buffer for low latency.
48    fn new(inner: R) -> Self {
49        // Use a smaller buffer (1KB) than default (8KB) for lower latency.
50        // This trades slightly more syscalls for faster response to newlines.
51        const BUFFER_SIZE: usize = 1024;
52        Self {
53            inner: BufReader::with_capacity(BUFFER_SIZE, inner),
54            buffer: Vec::new(),
55            consumed: 0,
56        }
57    }
58
59    /// Fill the internal buffer from the underlying reader.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if the buffer would exceed `MAX_BUFFER_SIZE`.
64    /// This prevents memory exhaustion from malicious input that never contains newlines.
65    fn fill_buffer(&mut self) -> io::Result<usize> {
66        // Check if we're approaching the limit before reading more
67        let current_size = self.buffer.len() - self.consumed;
68        if current_size >= MAX_BUFFER_SIZE {
69            return Err(io::Error::other(format!(
70                "StreamingLineReader buffer exceeded maximum size of {MAX_BUFFER_SIZE} bytes. \
71                This may indicate malformed input or an agent that is not sending newlines."
72            )));
73        }
74
75        let mut read_buf = [0u8; 256];
76        let n = self.inner.read(&mut read_buf)?;
77        if n > 0 {
78            // Check if adding this data would exceed the limit
79            let new_size = current_size + n;
80            if new_size > MAX_BUFFER_SIZE {
81                return Err(io::Error::other(format!(
82                    "StreamingLineReader buffer would exceed maximum size of {MAX_BUFFER_SIZE} bytes. \
83                    This may indicate malformed input or an agent that is not sending newlines."
84                )));
85            }
86            self.buffer.extend_from_slice(&read_buf[..n]);
87        }
88        Ok(n)
89    }
90}
91
92impl<R: Read> Read for StreamingLineReader<R> {
93    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
94        // First, consume from the buffer
95        let available = self.buffer.len() - self.consumed;
96        if available > 0 {
97            let to_copy = available.min(buf.len());
98            buf[..to_copy].copy_from_slice(&self.buffer[self.consumed..self.consumed + to_copy]);
99            self.consumed += to_copy;
100
101            // Compact the buffer if we've consumed everything
102            if self.consumed == self.buffer.len() {
103                self.buffer.clear();
104                self.consumed = 0;
105            }
106            return Ok(to_copy);
107        }
108
109        // Buffer empty - read directly from underlying reader
110        self.inner.read(buf)
111    }
112}
113
114impl<R: Read> BufRead for StreamingLineReader<R> {
115    fn fill_buf(&mut self) -> io::Result<&[u8]> {
116        const MAX_ATTEMPTS: usize = 8; // Prevent infinite loop
117
118        // If we have unconsumed data, return it
119        if self.consumed < self.buffer.len() {
120            return Ok(&self.buffer[self.consumed..]);
121        }
122
123        // Buffer was fully consumed - clear and try to read more
124        self.buffer.clear();
125        self.consumed = 0;
126
127        // Try to fill the buffer with at least some data
128        let mut total_read = 0;
129        for _ in 0..MAX_ATTEMPTS {
130            match self.fill_buffer()? {
131                0 if total_read == 0 => return Ok(&[]), // EOF
132                0 => break,                             // No more data available right now
133                n => {
134                    total_read += n;
135                    // Check if we have a newline
136                    if self.buffer.contains(&b'\n') {
137                        break;
138                    }
139                }
140            }
141        }
142
143        Ok(&self.buffer[self.consumed..])
144    }
145
146    fn consume(&mut self, amt: usize) {
147        self.consumed = (self.consumed + amt).min(self.buffer.len());
148
149        // Compact the buffer if we've consumed everything
150        if self.consumed == self.buffer.len() {
151            self.buffer.clear();
152            self.consumed = 0;
153        }
154    }
155}
156
157use super::clipboard::get_platform_clipboard_command;
158use super::types::CommandResult;
159
160/// A single prompt-based agent invocation.
161pub struct PromptCommand<'a> {
162    pub label: &'a str,
163    pub display_name: &'a str,
164    pub cmd_str: &'a str,
165    pub prompt: &'a str,
166    pub logfile: &'a str,
167    pub parser_type: JsonParserType,
168    pub env_vars: &'a std::collections::HashMap<String, String>,
169}
170
171/// Runtime services required for running agent commands.
172pub struct PipelineRuntime<'a> {
173    pub timer: &'a mut Timer,
174    pub logger: &'a Logger,
175    pub colors: &'a Colors,
176    pub config: &'a Config,
177}
178
179/// Command configuration for building an agent command.
180struct CommandConfig<'a> {
181    cmd_str: &'a str,
182    prompt: &'a str,
183    env_vars: &'a std::collections::HashMap<String, String>,
184    logfile: &'a str,
185    parser_type: JsonParserType,
186}
187
188/// Saves the prompt to a file and optionally copies it to the clipboard.
189fn save_prompt_to_file_and_clipboard(
190    prompt: &str,
191    prompt_path: &std::path::PathBuf,
192    interactive: bool,
193    logger: &Logger,
194    colors: Colors,
195) -> io::Result<()> {
196    // Save prompt to file
197    if let Some(parent) = prompt_path.parent() {
198        fs::create_dir_all(parent)?;
199    }
200    fs::write(prompt_path, prompt)?;
201    logger.info(&format!(
202        "Prompt saved to {}{}{}",
203        colors.cyan(),
204        prompt_path.display(),
205        colors.reset()
206    ));
207
208    // Copy to clipboard if interactive
209    if interactive {
210        if let Some(clipboard_cmd) = get_platform_clipboard_command() {
211            if let Ok(mut child) = Command::new(clipboard_cmd.binary)
212                .args(clipboard_cmd.args)
213                .stdin(Stdio::piped())
214                .spawn()
215            {
216                if let Some(mut stdin) = child.stdin.take() {
217                    let _ = stdin.write_all(prompt.as_bytes());
218                }
219                let _ = child.wait();
220                logger.info(&format!(
221                    "Prompt copied to clipboard {}({}){}",
222                    colors.dim(),
223                    clipboard_cmd.paste_hint,
224                    colors.reset()
225                ));
226            }
227        }
228    }
229    Ok(())
230}
231
232/// Builds and configures the agent command with environment variables.
233fn build_agent_command(
234    config: &CommandConfig<'_>,
235    anthropic_env_vars_to_sanitize: &[&str],
236    logger: &Logger,
237    colors: Colors,
238) -> io::Result<(Vec<String>, Command)> {
239    let argv = split_command(config.cmd_str)?;
240    if argv.is_empty() || config.cmd_str.trim().is_empty() {
241        return Err(io::Error::new(
242            io::ErrorKind::InvalidInput,
243            "Agent command is empty or contains only whitespace",
244        ));
245    }
246
247    let mut argv_for_log = argv.clone();
248    argv_for_log.push("<PROMPT>".to_string());
249    let display_cmd = truncate_text(&format_argv_for_log(&argv_for_log), 160);
250    logger.info(&format!(
251        "Executing: {}{}{}",
252        colors.dim(),
253        display_cmd,
254        colors.reset()
255    ));
256
257    // GLM-specific debug logging
258    let is_glm_cmd = is_glm_like_agent(config.cmd_str);
259    if is_glm_cmd {
260        logger.info(&format!("GLM command details: {display_cmd}"));
261        if argv.iter().any(|arg| arg == "-p") {
262            logger.info("GLM command includes '-p' flag (correct)");
263        } else {
264            logger.warn("GLM command may be missing '-p' flag");
265        }
266    }
267
268    let _uses_json = config.parser_type != JsonParserType::Generic || argv_requests_json(&argv);
269    logger.info(&format!("Using {} parser...", config.parser_type));
270
271    if let Some(parent) = Path::new(config.logfile).parent() {
272        fs::create_dir_all(parent)?;
273    }
274    File::create(config.logfile)?;
275
276    let mut command = Command::new(&argv[0]);
277    command.args(&argv[1..]);
278    command.arg(config.prompt);
279
280    // Inject environment variables from agent config
281    if !config.env_vars.is_empty() {
282        logger.info(&format!(
283            "Injecting {} environment variable(s) into subprocess",
284            config.env_vars.len()
285        ));
286        for key in config.env_vars.keys() {
287            logger.info(&format!("  - {key}"));
288        }
289        for (key, value) in config.env_vars {
290            command.env(key, value);
291        }
292    }
293
294    // Set agent-side buffering disabling environment variables for real-time streaming.
295    // These are only set if not already explicitly configured by the user's env_vars.
296    // This mitigates the issue where AI agents buffer their stdout instead of streaming.
297    //
298    // Note: NODE_ENV is set to "production" (not "development") because production mode
299    // disables buffering in Node.js applications. This is necessary for real-time streaming
300    // but may affect error stack traces and logging levels in Node.js agents.
301    let buffering_vars = [("PYTHONUNBUFFERED", "1"), ("NODE_ENV", "production")];
302    for (key, value) in buffering_vars {
303        if !config.env_vars.contains_key(key) {
304            command.env(key, value);
305        }
306    }
307
308    // Clear problematic Anthropic env vars that weren't explicitly set by the agent.
309    for &var in anthropic_env_vars_to_sanitize {
310        if !config.env_vars.contains_key(var) {
311            command.env_remove(var);
312        }
313    }
314
315    Ok((argv, command))
316}
317
318/// Spawns the agent process with special error handling for `NotFound` and `PermissionDenied`.
319fn spawn_agent_process(
320    mut command: Command,
321    argv: &[String],
322) -> io::Result<Result<Child, CommandResult>> {
323    match command
324        .stdout(Stdio::piped())
325        .stderr(Stdio::piped())
326        .spawn()
327    {
328        Ok(child) => Ok(Ok(child)),
329        Err(e)
330            if matches!(
331                e.kind(),
332                io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied
333            ) =>
334        {
335            let exit_code = if e.kind() == io::ErrorKind::NotFound {
336                127
337            } else {
338                126
339            };
340            Ok(Err(CommandResult {
341                exit_code,
342                stderr: format!("{}: {}", argv[0], e),
343            }))
344        }
345        Err(e) => Err(e),
346    }
347}
348
349/// Streams agent output based on parser type.
350fn stream_agent_output(
351    stdout: ChildStdout,
352    cmd: &PromptCommand<'_>,
353    runtime: &PipelineRuntime<'_>,
354) -> io::Result<()> {
355    // Use StreamingLineReader for real-time streaming instead of BufReader::lines().
356    // StreamingLineReader yields lines immediately when newlines are found,
357    // enabling character-by-character streaming for agents that output NDJSON gradually.
358    let reader = StreamingLineReader::new(stdout);
359
360    if cmd.parser_type != JsonParserType::Generic
361        || argv_requests_json(&split_command(cmd.cmd_str)?)
362    {
363        let stdout_io = io::stdout();
364        let mut out = stdout_io.lock();
365
366        match cmd.parser_type {
367            JsonParserType::Claude => {
368                let p = crate::json_parser::ClaudeParser::new(
369                    *runtime.colors,
370                    runtime.config.verbosity,
371                )
372                .with_display_name(cmd.display_name)
373                .with_log_file(cmd.logfile)
374                .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
375                p.parse_stream(reader)?;
376            }
377            JsonParserType::Codex => {
378                let p =
379                    crate::json_parser::CodexParser::new(*runtime.colors, runtime.config.verbosity)
380                        .with_display_name(cmd.display_name)
381                        .with_log_file(cmd.logfile)
382                        .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
383                p.parse_stream(reader)?;
384            }
385            JsonParserType::Gemini => {
386                let p = crate::json_parser::GeminiParser::new(
387                    *runtime.colors,
388                    runtime.config.verbosity,
389                )
390                .with_display_name(cmd.display_name)
391                .with_log_file(cmd.logfile)
392                .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
393                p.parse_stream(reader)?;
394            }
395            JsonParserType::OpenCode => {
396                let p = crate::json_parser::OpenCodeParser::new(
397                    *runtime.colors,
398                    runtime.config.verbosity,
399                )
400                .with_display_name(cmd.display_name)
401                .with_log_file(cmd.logfile)
402                .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
403                p.parse_stream(reader)?;
404            }
405            JsonParserType::Generic => {
406                let mut buf = String::new();
407                for line in reader.lines() {
408                    let line = line?;
409                    buf.push_str(&line);
410                    buf.push('\n');
411                }
412                let formatted = format_generic_json_for_display(&buf, runtime.config.verbosity);
413                out.write_all(formatted.as_bytes())?;
414            }
415        }
416    } else {
417        let mut logfile = OpenOptions::new()
418            .create(true)
419            .append(true)
420            .open(cmd.logfile)?;
421
422        let stdout_io = io::stdout();
423        let mut out = stdout_io.lock();
424
425        for line in reader.lines() {
426            let line = line?;
427            writeln!(out, "{line}")?;
428            writeln!(logfile, "{line}")?;
429        }
430    }
431    Ok(())
432}
433
434/// Waits for process completion and collects stderr output.
435fn wait_for_completion_and_collect_stderr(
436    mut child: Child,
437    stderr_join_handle: Option<std::thread::JoinHandle<io::Result<String>>>,
438    runtime: &PipelineRuntime<'_>,
439) -> io::Result<(i32, String)> {
440    let status = child.wait()?;
441    let exit_code = status.code().unwrap_or(1);
442
443    if status.code().is_none() && runtime.config.verbosity.is_debug() {
444        runtime
445            .logger
446            .warn("Process terminated by signal (no exit code), treating as failure");
447    }
448
449    let stderr_output = match stderr_join_handle {
450        Some(handle) => match handle.join() {
451            Ok(result) => result?,
452            Err(panic_payload) => {
453                let panic_msg = panic_payload.downcast_ref::<String>().map_or_else(
454                    || {
455                        panic_payload.downcast_ref::<&str>().map_or_else(
456                            || "<unknown panic>".to_string(),
457                            std::string::ToString::to_string,
458                        )
459                    },
460                    std::clone::Clone::clone,
461                );
462                runtime.logger.warn(&format!(
463                    "Stderr collection thread panicked: {panic_msg}. This may indicate a bug."
464                ));
465                String::new()
466            }
467        },
468        None => String::new(),
469    };
470
471    if !stderr_output.is_empty() && runtime.config.verbosity.is_debug() {
472        runtime.logger.warn(&format!(
473            "Agent stderr output detected ({} bytes):",
474            stderr_output.len()
475        ));
476        for (i, line) in stderr_output.lines().take(5).enumerate() {
477            runtime.logger.info(&format!("  stderr[{i}]: {line}"));
478        }
479        if stderr_output.lines().count() > 5 {
480            runtime.logger.info(&format!(
481                "  ... ({} more lines, see log file for full output)",
482                stderr_output.lines().count() - 5
483            ));
484        }
485    }
486
487    Ok((exit_code, stderr_output))
488}
489
490/// Run a command with a prompt argument.
491///
492/// This is an internal helper for `run_with_fallback`.
493pub fn run_with_prompt(
494    cmd: &PromptCommand<'_>,
495    runtime: &mut PipelineRuntime<'_>,
496) -> io::Result<CommandResult> {
497    const ANTHROPIC_ENV_VARS_TO_SANITIZE: &[&str] = &[
498        "ANTHROPIC_API_KEY",
499        "ANTHROPIC_BASE_URL",
500        "ANTHROPIC_AUTH_TOKEN",
501        "ANTHROPIC_MODEL",
502        "ANTHROPIC_DEFAULT_HAIKU_MODEL",
503        "ANTHROPIC_DEFAULT_OPUS_MODEL",
504        "ANTHROPIC_DEFAULT_SONNET_MODEL",
505    ];
506
507    runtime.timer.start_phase();
508    runtime.logger.step(&format!(
509        "{}{}{}",
510        runtime.colors.bold(),
511        cmd.label,
512        runtime.colors.reset()
513    ));
514
515    save_prompt_to_file_and_clipboard(
516        cmd.prompt,
517        &runtime.config.prompt_path,
518        runtime.config.behavior.interactive,
519        runtime.logger,
520        *runtime.colors,
521    )?;
522
523    let (argv, command) = build_agent_command(
524        &CommandConfig {
525            cmd_str: cmd.cmd_str,
526            prompt: cmd.prompt,
527            env_vars: cmd.env_vars,
528            logfile: cmd.logfile,
529            parser_type: cmd.parser_type,
530        },
531        ANTHROPIC_ENV_VARS_TO_SANITIZE,
532        runtime.logger,
533        *runtime.colors,
534    )?;
535
536    let mut child = match spawn_agent_process(command, &argv)? {
537        Ok(child) => child,
538        Err(result) => return Ok(result),
539    };
540
541    let stdout = child
542        .stdout
543        .take()
544        .ok_or_else(|| io::Error::other("Failed to capture stdout"))?;
545
546    let stderr_join_handle = child.stderr.take().map(|stderr| {
547        std::thread::spawn(move || -> io::Result<String> {
548            const STDERR_MAX_BYTES: usize = 512 * 1024;
549
550            let mut reader = BufReader::new(stderr);
551            let mut buf = [0u8; 8192];
552            let mut collected = Vec::<u8>::new();
553            let mut truncated = false;
554
555            loop {
556                let n = reader.read(&mut buf)?;
557                if n == 0 {
558                    break;
559                }
560
561                let remaining = STDERR_MAX_BYTES.saturating_sub(collected.len());
562                if remaining == 0 {
563                    truncated = true;
564                    break;
565                }
566
567                let to_take = remaining.min(n);
568                collected.extend_from_slice(&buf[..to_take]);
569                if to_take < n {
570                    truncated = true;
571                    break;
572                }
573            }
574
575            let mut stderr_output = String::from_utf8_lossy(&collected).into_owned();
576            if truncated {
577                if !stderr_output.ends_with('\n') {
578                    stderr_output.push('\n');
579                }
580                stderr_output.push_str("<stderr truncated>");
581            }
582
583            Ok(stderr_output)
584        })
585    });
586
587    stream_agent_output(stdout, cmd, runtime)?;
588
589    let (exit_code, stderr_output) =
590        wait_for_completion_and_collect_stderr(child, stderr_join_handle, runtime)?;
591
592    if runtime.config.verbosity.is_verbose() {
593        runtime.logger.info(&format!(
594            "Phase elapsed: {}",
595            runtime.timer.phase_elapsed_formatted()
596        ));
597    }
598
599    Ok(CommandResult {
600        exit_code,
601        stderr: stderr_output,
602    })
603}