ralph_workflow/pipeline/
prompt.rs

1//! Prompt-based command execution.
2
3use crate::agents::{is_glm_like_agent, JsonParserType};
4use crate::common::{format_argv_for_log, split_command, truncate_text};
5use crate::config::Config;
6use crate::logger::Colors;
7use crate::logger::Logger;
8use crate::logger::{argv_requests_json, format_generic_json_for_display};
9use crate::pipeline::Timer;
10use std::fs::{self, File, OpenOptions};
11use std::io::{self, BufRead, BufReader, Read, Write};
12use std::path::Path;
13use std::process::{Child, ChildStdout, Command, Stdio};
14
15/// A line-oriented reader that processes data as it arrives.
16///
17/// Unlike `BufReader::lines()`, this reader yields lines immediately
18/// when newlines are encountered, without waiting for the buffer to fill.
19/// This enables real-time streaming for agents that output NDJSON gradually.
20///
21/// # Buffer Size Limit
22///
23/// The reader enforces a maximum buffer size to prevent memory exhaustion
24/// from malicious or malformed input that never contains newlines.
25/// If the buffer exceeds this limit, subsequent reads will fail with an error.
26struct StreamingLineReader<R: Read> {
27    inner: BufReader<R>,
28    buffer: Vec<u8>,
29    consumed: usize,
30}
31
32/// Maximum buffer size in bytes to prevent unbounded memory growth.
33///
34/// This limits the impact of agents that output continuous data without newlines.
35/// The value of 1 MiB was chosen to:
36/// - Handle most legitimate JSON documents (typically < 100KB)
37/// - Allow for reasonably long single-line JSON outputs
38/// - Prevent memory exhaustion from malicious input
39/// - Keep the buffer size manageable for most systems
40///
41/// If your use case requires larger single-line JSON, consider:
42/// - Modifying your agent to output NDJSON (newline-delimited JSON)
43/// - Adjusting this constant and rebuilding
44const MAX_BUFFER_SIZE: usize = 1024 * 1024; // 1 MiB
45
46impl<R: Read> StreamingLineReader<R> {
47    /// Create a new streaming line reader with a small buffer for low latency.
48    fn new(inner: R) -> Self {
49        // Use a smaller buffer (1KB) than default (8KB) for lower latency.
50        // This trades slightly more syscalls for faster response to newlines.
51        const BUFFER_SIZE: usize = 1024;
52        Self {
53            inner: BufReader::with_capacity(BUFFER_SIZE, inner),
54            buffer: Vec::new(),
55            consumed: 0,
56        }
57    }
58
59    /// Fill the internal buffer from the underlying reader.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if the buffer would exceed `MAX_BUFFER_SIZE`.
64    /// This prevents memory exhaustion from malicious input that never contains newlines.
65    fn fill_buffer(&mut self) -> io::Result<usize> {
66        // Check if we're approaching the limit before reading more
67        let current_size = self.buffer.len() - self.consumed;
68        if current_size >= MAX_BUFFER_SIZE {
69            return Err(io::Error::other(format!(
70                "StreamingLineReader buffer exceeded maximum size of {MAX_BUFFER_SIZE} bytes. \
71                This may indicate malformed input or an agent that is not sending newlines."
72            )));
73        }
74
75        let mut read_buf = [0u8; 256];
76        let n = self.inner.read(&mut read_buf)?;
77        if n > 0 {
78            // Check if adding this data would exceed the limit
79            let new_size = current_size + n;
80            if new_size > MAX_BUFFER_SIZE {
81                return Err(io::Error::other(format!(
82                    "StreamingLineReader buffer would exceed maximum size of {MAX_BUFFER_SIZE} bytes. \
83                    This may indicate malformed input or an agent that is not sending newlines."
84                )));
85            }
86            self.buffer.extend_from_slice(&read_buf[..n]);
87        }
88        Ok(n)
89    }
90}
91
92impl<R: Read> Read for StreamingLineReader<R> {
93    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
94        // First, consume from the buffer
95        let available = self.buffer.len() - self.consumed;
96        if available > 0 {
97            let to_copy = available.min(buf.len());
98            buf[..to_copy].copy_from_slice(&self.buffer[self.consumed..self.consumed + to_copy]);
99            self.consumed += to_copy;
100
101            // Compact the buffer if we've consumed everything
102            if self.consumed == self.buffer.len() {
103                self.buffer.clear();
104                self.consumed = 0;
105            }
106            return Ok(to_copy);
107        }
108
109        // Buffer empty - read directly from underlying reader
110        self.inner.read(buf)
111    }
112}
113
114impl<R: Read> BufRead for StreamingLineReader<R> {
115    fn fill_buf(&mut self) -> io::Result<&[u8]> {
116        const MAX_ATTEMPTS: usize = 8; // Prevent infinite loop
117
118        // If we have unconsumed data, return it
119        if self.consumed < self.buffer.len() {
120            return Ok(&self.buffer[self.consumed..]);
121        }
122
123        // Buffer was fully consumed - clear and try to read more
124        self.buffer.clear();
125        self.consumed = 0;
126
127        // Try to fill the buffer with at least some data
128        let mut total_read = 0;
129        for _ in 0..MAX_ATTEMPTS {
130            match self.fill_buffer()? {
131                0 if total_read == 0 => return Ok(&[]), // EOF
132                0 => break,                             // No more data available right now
133                n => {
134                    total_read += n;
135                    // Check if we have a newline
136                    if self.buffer.contains(&b'\n') {
137                        break;
138                    }
139                }
140            }
141        }
142
143        Ok(&self.buffer[self.consumed..])
144    }
145
146    fn consume(&mut self, amt: usize) {
147        self.consumed = (self.consumed + amt).min(self.buffer.len());
148
149        // Compact the buffer if we've consumed everything
150        if self.consumed == self.buffer.len() {
151            self.buffer.clear();
152            self.consumed = 0;
153        }
154    }
155}
156
157use super::clipboard::get_platform_clipboard_command;
158use super::types::CommandResult;
159
160/// A single prompt-based agent invocation.
161pub struct PromptCommand<'a> {
162    pub label: &'a str,
163    pub display_name: &'a str,
164    pub cmd_str: &'a str,
165    pub prompt: &'a str,
166    pub logfile: &'a str,
167    pub parser_type: JsonParserType,
168    pub env_vars: &'a std::collections::HashMap<String, String>,
169}
170
171/// Runtime services required for running agent commands.
172pub struct PipelineRuntime<'a> {
173    pub timer: &'a mut Timer,
174    pub logger: &'a Logger,
175    pub colors: &'a Colors,
176    pub config: &'a Config,
177}
178
179/// Command configuration for building an agent command.
180struct CommandConfig<'a> {
181    cmd_str: &'a str,
182    prompt: &'a str,
183    env_vars: &'a std::collections::HashMap<String, String>,
184    logfile: &'a str,
185    parser_type: JsonParserType,
186}
187
188/// Saves the prompt to a file and optionally copies it to the clipboard.
189fn save_prompt_to_file_and_clipboard(
190    prompt: &str,
191    prompt_path: &std::path::PathBuf,
192    interactive: bool,
193    logger: &Logger,
194    colors: Colors,
195) -> io::Result<()> {
196    // Save prompt to file
197    if let Some(parent) = prompt_path.parent() {
198        fs::create_dir_all(parent)?;
199    }
200    fs::write(prompt_path, prompt)?;
201    logger.info(&format!(
202        "Prompt saved to {}{}{}",
203        colors.cyan(),
204        prompt_path.display(),
205        colors.reset()
206    ));
207
208    // Copy to clipboard if interactive
209    if interactive {
210        if let Some(clipboard_cmd) = get_platform_clipboard_command() {
211            if let Ok(mut child) = Command::new(clipboard_cmd.binary)
212                .args(clipboard_cmd.args)
213                .stdin(Stdio::piped())
214                .spawn()
215            {
216                if let Some(mut stdin) = child.stdin.take() {
217                    let _ = stdin.write_all(prompt.as_bytes());
218                }
219                let _ = child.wait();
220                logger.info(&format!(
221                    "Prompt copied to clipboard {}({}){}",
222                    colors.dim(),
223                    clipboard_cmd.paste_hint,
224                    colors.reset()
225                ));
226            }
227        }
228    }
229    Ok(())
230}
231
232/// Builds and configures the agent command with environment variables.
233fn build_agent_command(
234    config: &CommandConfig<'_>,
235    anthropic_env_vars_to_sanitize: &[&str],
236    logger: &Logger,
237    colors: Colors,
238) -> io::Result<(Vec<String>, Command)> {
239    let argv = split_command(config.cmd_str)?;
240    if argv.is_empty() || config.cmd_str.trim().is_empty() {
241        return Err(io::Error::new(
242            io::ErrorKind::InvalidInput,
243            "Agent command is empty or contains only whitespace",
244        ));
245    }
246
247    let mut argv_for_log = argv.clone();
248    argv_for_log.push("<PROMPT>".to_string());
249    let display_cmd = truncate_text(&format_argv_for_log(&argv_for_log), 160);
250    logger.info(&format!(
251        "Executing: {}{}{}",
252        colors.dim(),
253        display_cmd,
254        colors.reset()
255    ));
256
257    // GLM-specific debug logging
258    let is_glm_cmd = is_glm_like_agent(config.cmd_str);
259    if is_glm_cmd {
260        logger.info(&format!("GLM command details: {display_cmd}"));
261        if argv.iter().any(|arg| arg == "-p") {
262            logger.info("GLM command includes '-p' flag (correct)");
263        } else {
264            logger.warn("GLM command may be missing '-p' flag");
265        }
266    }
267
268    let _uses_json = config.parser_type != JsonParserType::Generic || argv_requests_json(&argv);
269    logger.info(&format!("Using {} parser...", config.parser_type));
270
271    if let Some(parent) = Path::new(config.logfile).parent() {
272        fs::create_dir_all(parent)?;
273    }
274    File::create(config.logfile)?;
275
276    let mut command = Command::new(&argv[0]);
277    command.args(&argv[1..]);
278    command.arg(config.prompt);
279
280    // Inject environment variables from agent config
281    if !config.env_vars.is_empty() {
282        logger.info(&format!(
283            "Injecting {} environment variable(s) into subprocess",
284            config.env_vars.len()
285        ));
286        for key in config.env_vars.keys() {
287            logger.info(&format!("  - {key}"));
288        }
289        for (key, value) in config.env_vars {
290            command.env(key, value);
291        }
292    }
293
294    // Set agent-side buffering disabling environment variables for real-time streaming.
295    // These are only set if not already explicitly configured by the user's env_vars.
296    // This mitigates the issue where AI agents buffer their stdout instead of streaming.
297    //
298    // Note: NODE_ENV is set to "production" (not "development") because production mode
299    // disables buffering in Node.js applications. This is necessary for real-time streaming
300    // but may affect error stack traces and logging levels in Node.js agents.
301    let buffering_vars = [("PYTHONUNBUFFERED", "1"), ("NODE_ENV", "production")];
302    for (key, value) in buffering_vars {
303        if !config.env_vars.contains_key(key) {
304            command.env(key, value);
305        }
306    }
307
308    // Clear problematic Anthropic env vars that weren't explicitly set by the agent.
309    for &var in anthropic_env_vars_to_sanitize {
310        if !config.env_vars.contains_key(var) {
311            command.env_remove(var);
312        }
313    }
314
315    Ok((argv, command))
316}
317
318/// Spawns the agent process with special error handling for `NotFound` and `PermissionDenied`.
319fn spawn_agent_process(
320    mut command: Command,
321    argv: &[String],
322) -> io::Result<Result<Child, CommandResult>> {
323    match command
324        .stdout(Stdio::piped())
325        .stderr(Stdio::piped())
326        .spawn()
327    {
328        Ok(child) => Ok(Ok(child)),
329        Err(e)
330            if matches!(
331                e.kind(),
332                io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied
333            ) =>
334        {
335            let exit_code = if e.kind() == io::ErrorKind::NotFound {
336                127
337            } else {
338                126
339            };
340            Ok(Err(CommandResult {
341                exit_code,
342                stderr: format!("{}: {}", argv[0], e),
343            }))
344        }
345        Err(e) => Err(e),
346    }
347}
348
349/// Streams agent output based on parser type.
350fn stream_agent_output(
351    stdout: ChildStdout,
352    cmd: &PromptCommand<'_>,
353    runtime: &PipelineRuntime<'_>,
354) -> io::Result<()> {
355    // Use StreamingLineReader for real-time streaming instead of BufReader::lines().
356    // StreamingLineReader yields lines immediately when newlines are found,
357    // enabling character-by-character streaming for agents that output NDJSON gradually.
358    let reader = StreamingLineReader::new(stdout);
359
360    if cmd.parser_type != JsonParserType::Generic
361        || argv_requests_json(&split_command(cmd.cmd_str)?)
362    {
363        let stdout_io = io::stdout();
364        let mut out = stdout_io.lock();
365
366        match cmd.parser_type {
367            JsonParserType::Claude => {
368                let p = crate::json_parser::ClaudeParser::new(
369                    *runtime.colors,
370                    runtime.config.verbosity,
371                )
372                .with_display_name(cmd.display_name)
373                .with_log_file(cmd.logfile)
374                .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
375                p.parse_stream(reader)?;
376            }
377            JsonParserType::Codex => {
378                let p =
379                    crate::json_parser::CodexParser::new(*runtime.colors, runtime.config.verbosity)
380                        .with_display_name(cmd.display_name)
381                        .with_log_file(cmd.logfile)
382                        .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
383                p.parse_stream(reader)?;
384            }
385            JsonParserType::Gemini => {
386                let p = crate::json_parser::GeminiParser::new(
387                    *runtime.colors,
388                    runtime.config.verbosity,
389                )
390                .with_display_name(cmd.display_name)
391                .with_log_file(cmd.logfile)
392                .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
393                p.parse_stream(reader)?;
394            }
395            JsonParserType::OpenCode => {
396                let p = crate::json_parser::OpenCodeParser::new(
397                    *runtime.colors,
398                    runtime.config.verbosity,
399                )
400                .with_display_name(cmd.display_name)
401                .with_log_file(cmd.logfile)
402                .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
403                p.parse_stream(reader)?;
404            }
405            JsonParserType::Generic => {
406                let mut logfile = OpenOptions::new()
407                    .create(true)
408                    .append(true)
409                    .open(cmd.logfile)?;
410
411                let mut buf = String::new();
412                for line in reader.lines() {
413                    let line = line?;
414                    // Write raw line to log file for extraction
415                    writeln!(logfile, "{line}")?;
416                    buf.push_str(&line);
417                    buf.push('\n');
418                }
419                logfile.flush()?;
420                // Ensure data is written to disk before continuing
421                // This prevents race conditions where extraction runs before OS commits writes
422                let _ = logfile.sync_all();
423
424                let formatted = format_generic_json_for_display(&buf, runtime.config.verbosity);
425                out.write_all(formatted.as_bytes())?;
426            }
427        }
428    } else {
429        let mut logfile = OpenOptions::new()
430            .create(true)
431            .append(true)
432            .open(cmd.logfile)?;
433
434        let stdout_io = io::stdout();
435        let mut out = stdout_io.lock();
436
437        for line in reader.lines() {
438            let line = line?;
439            writeln!(out, "{line}")?;
440            writeln!(logfile, "{line}")?;
441        }
442        logfile.flush()?;
443        // Ensure data is written to disk before continuing
444        // This prevents race conditions where extraction runs before OS commits writes
445        let _ = logfile.sync_all();
446    }
447    Ok(())
448}
449
450/// Waits for process completion and collects stderr output.
451fn wait_for_completion_and_collect_stderr(
452    mut child: Child,
453    stderr_join_handle: Option<std::thread::JoinHandle<io::Result<String>>>,
454    runtime: &PipelineRuntime<'_>,
455) -> io::Result<(i32, String)> {
456    let status = child.wait()?;
457    let exit_code = status.code().unwrap_or(1);
458
459    if status.code().is_none() && runtime.config.verbosity.is_debug() {
460        runtime
461            .logger
462            .warn("Process terminated by signal (no exit code), treating as failure");
463    }
464
465    let stderr_output = match stderr_join_handle {
466        Some(handle) => match handle.join() {
467            Ok(result) => result?,
468            Err(panic_payload) => {
469                let panic_msg = panic_payload.downcast_ref::<String>().map_or_else(
470                    || {
471                        panic_payload.downcast_ref::<&str>().map_or_else(
472                            || "<unknown panic>".to_string(),
473                            std::string::ToString::to_string,
474                        )
475                    },
476                    std::clone::Clone::clone,
477                );
478                runtime.logger.warn(&format!(
479                    "Stderr collection thread panicked: {panic_msg}. This may indicate a bug."
480                ));
481                String::new()
482            }
483        },
484        None => String::new(),
485    };
486
487    if !stderr_output.is_empty() && runtime.config.verbosity.is_debug() {
488        runtime.logger.warn(&format!(
489            "Agent stderr output detected ({} bytes):",
490            stderr_output.len()
491        ));
492        for (i, line) in stderr_output.lines().take(5).enumerate() {
493            runtime.logger.info(&format!("  stderr[{i}]: {line}"));
494        }
495        if stderr_output.lines().count() > 5 {
496            runtime.logger.info(&format!(
497                "  ... ({} more lines, see log file for full output)",
498                stderr_output.lines().count() - 5
499            ));
500        }
501    }
502
503    Ok((exit_code, stderr_output))
504}
505
506/// Run a command with a prompt argument.
507///
508/// This is an internal helper for `run_with_fallback`.
509pub fn run_with_prompt(
510    cmd: &PromptCommand<'_>,
511    runtime: &mut PipelineRuntime<'_>,
512) -> io::Result<CommandResult> {
513    const ANTHROPIC_ENV_VARS_TO_SANITIZE: &[&str] = &[
514        "ANTHROPIC_API_KEY",
515        "ANTHROPIC_BASE_URL",
516        "ANTHROPIC_AUTH_TOKEN",
517        "ANTHROPIC_MODEL",
518        "ANTHROPIC_DEFAULT_HAIKU_MODEL",
519        "ANTHROPIC_DEFAULT_OPUS_MODEL",
520        "ANTHROPIC_DEFAULT_SONNET_MODEL",
521    ];
522
523    runtime.timer.start_phase();
524    runtime.logger.step(&format!(
525        "{}{}{}",
526        runtime.colors.bold(),
527        cmd.label,
528        runtime.colors.reset()
529    ));
530
531    save_prompt_to_file_and_clipboard(
532        cmd.prompt,
533        &runtime.config.prompt_path,
534        runtime.config.behavior.interactive,
535        runtime.logger,
536        *runtime.colors,
537    )?;
538
539    let (argv, command) = build_agent_command(
540        &CommandConfig {
541            cmd_str: cmd.cmd_str,
542            prompt: cmd.prompt,
543            env_vars: cmd.env_vars,
544            logfile: cmd.logfile,
545            parser_type: cmd.parser_type,
546        },
547        ANTHROPIC_ENV_VARS_TO_SANITIZE,
548        runtime.logger,
549        *runtime.colors,
550    )?;
551
552    let mut child = match spawn_agent_process(command, &argv)? {
553        Ok(child) => child,
554        Err(result) => return Ok(result),
555    };
556
557    let stdout = child
558        .stdout
559        .take()
560        .ok_or_else(|| io::Error::other("Failed to capture stdout"))?;
561
562    let stderr_join_handle = child.stderr.take().map(|stderr| {
563        std::thread::spawn(move || -> io::Result<String> {
564            const STDERR_MAX_BYTES: usize = 512 * 1024;
565
566            let mut reader = BufReader::new(stderr);
567            let mut buf = [0u8; 8192];
568            let mut collected = Vec::<u8>::new();
569            let mut truncated = false;
570
571            loop {
572                let n = reader.read(&mut buf)?;
573                if n == 0 {
574                    break;
575                }
576
577                let remaining = STDERR_MAX_BYTES.saturating_sub(collected.len());
578                if remaining == 0 {
579                    truncated = true;
580                    break;
581                }
582
583                let to_take = remaining.min(n);
584                collected.extend_from_slice(&buf[..to_take]);
585                if to_take < n {
586                    truncated = true;
587                    break;
588                }
589            }
590
591            let mut stderr_output = String::from_utf8_lossy(&collected).into_owned();
592            if truncated {
593                if !stderr_output.ends_with('\n') {
594                    stderr_output.push('\n');
595                }
596                stderr_output.push_str("<stderr truncated>");
597            }
598
599            Ok(stderr_output)
600        })
601    });
602
603    stream_agent_output(stdout, cmd, runtime)?;
604
605    let (exit_code, stderr_output) =
606        wait_for_completion_and_collect_stderr(child, stderr_join_handle, runtime)?;
607
608    if runtime.config.verbosity.is_verbose() {
609        runtime.logger.info(&format!(
610            "Phase elapsed: {}",
611            runtime.timer.phase_elapsed_formatted()
612        ));
613    }
614
615    Ok(CommandResult {
616        exit_code,
617        stderr: stderr_output,
618    })
619}