1use crate::agents::{is_glm_like_agent, JsonParserType};
4use crate::common::{format_argv_for_log, split_command, truncate_text};
5use crate::config::Config;
6use crate::logger::Colors;
7use crate::logger::Logger;
8use crate::logger::{argv_requests_json, format_generic_json_for_display};
9use crate::pipeline::idle_timeout::{
10 monitor_idle_timeout, new_activity_timestamp, ActivityTrackingReader, MonitorResult,
11 SharedActivityTimestamp, IDLE_TIMEOUT_SECS,
12};
13use crate::pipeline::Timer;
14
15use std::io::{self, BufRead, BufReader, Read, Write};
16use std::path::Path;
17use std::sync::Arc;
18
19struct StreamingLineReader<R: Read> {
31 inner: BufReader<R>,
32 buffer: Vec<u8>,
33 consumed: usize,
34}
35
36const MAX_BUFFER_SIZE: usize = 1024 * 1024; #[cfg(test)]
62const MAX_PROMPT_SIZE: usize = 200 * 1024; #[cfg(test)]
73fn truncate_prompt_if_needed(prompt: &str, logger: &Logger) -> String {
74 if prompt.len() <= MAX_PROMPT_SIZE {
75 return prompt.to_string();
76 }
77
78 let excess = prompt.len() - MAX_PROMPT_SIZE;
79 logger.warn(&format!(
80 "Prompt exceeds safe limit ({} bytes > {} bytes), truncating {} bytes",
81 prompt.len(),
82 MAX_PROMPT_SIZE,
83 excess
84 ));
85
86 let truncation_markers = [
93 "\n---\n", "\n```\n", "\n<last-output>", "\nPrevious output:", ];
98
99 for marker in truncation_markers {
100 if let Some(marker_pos) = prompt.find(marker) {
101 let content_start = marker_pos + marker.len();
103 if content_start < prompt.len() {
104 let before_marker = &prompt[..content_start];
105 let after_marker = &prompt[content_start..];
106
107 if after_marker.len() > excess + 100 {
108 let keep_from = excess + 100; let truncated_content = &after_marker[keep_from..];
111
112 let clean_start = truncated_content.find('\n').map(|i| i + 1).unwrap_or(0);
114
115 return format!(
116 "{}\n[... {} bytes truncated to fit CLI argument limit ...]\n{}",
117 before_marker,
118 keep_from + clean_start,
119 &truncated_content[clean_start..]
120 );
121 }
122 }
123 }
124 }
125
126 let keep_start = MAX_PROMPT_SIZE / 3;
128 let keep_end = MAX_PROMPT_SIZE / 3;
129 let start_part = &prompt[..keep_start];
130 let end_part = &prompt[prompt.len() - keep_end..];
131
132 let start_end = start_part.rfind('\n').map(|i| i + 1).unwrap_or(keep_start);
134 let end_start = end_part.find('\n').map(|i| i + 1).unwrap_or(0);
135
136 format!(
137 "{}\n\n[... {} bytes truncated to fit CLI argument limit ...]\n\n{}",
138 &prompt[..start_end],
139 prompt.len() - start_end - (keep_end - end_start),
140 &end_part[end_start..]
141 )
142}
143
144impl<R: Read> StreamingLineReader<R> {
145 fn new(inner: R) -> Self {
147 const BUFFER_SIZE: usize = 1024;
150 Self {
151 inner: BufReader::with_capacity(BUFFER_SIZE, inner),
152 buffer: Vec::new(),
153 consumed: 0,
154 }
155 }
156
157 fn fill_buffer(&mut self) -> io::Result<usize> {
164 let current_size = self.buffer.len() - self.consumed;
166 if current_size >= MAX_BUFFER_SIZE {
167 return Err(io::Error::other(format!(
168 "StreamingLineReader buffer exceeded maximum size of {MAX_BUFFER_SIZE} bytes. \
169 This may indicate malformed input or an agent that is not sending newlines."
170 )));
171 }
172
173 let mut read_buf = [0u8; 256];
174 let n = self.inner.read(&mut read_buf)?;
175 if n > 0 {
176 let new_size = current_size + n;
178 if new_size > MAX_BUFFER_SIZE {
179 return Err(io::Error::other(format!(
180 "StreamingLineReader buffer would exceed maximum size of {MAX_BUFFER_SIZE} bytes. \
181 This may indicate malformed input or an agent that is not sending newlines."
182 )));
183 }
184 self.buffer.extend_from_slice(&read_buf[..n]);
185 }
186 Ok(n)
187 }
188}
189
190impl<R: Read> Read for StreamingLineReader<R> {
191 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
192 let available = self.buffer.len() - self.consumed;
194 if available > 0 {
195 let to_copy = available.min(buf.len());
196 buf[..to_copy].copy_from_slice(&self.buffer[self.consumed..self.consumed + to_copy]);
197 self.consumed += to_copy;
198
199 if self.consumed == self.buffer.len() {
201 self.buffer.clear();
202 self.consumed = 0;
203 }
204 return Ok(to_copy);
205 }
206
207 self.inner.read(buf)
209 }
210}
211
212impl<R: Read> BufRead for StreamingLineReader<R> {
213 fn fill_buf(&mut self) -> io::Result<&[u8]> {
214 const MAX_ATTEMPTS: usize = 8; if self.consumed < self.buffer.len() {
218 return Ok(&self.buffer[self.consumed..]);
219 }
220
221 self.buffer.clear();
223 self.consumed = 0;
224
225 let mut total_read = 0;
227 for _ in 0..MAX_ATTEMPTS {
228 match self.fill_buffer()? {
229 0 if total_read == 0 => return Ok(&[]), 0 => break, n => {
232 total_read += n;
233 if self.buffer.contains(&b'\n') {
235 break;
236 }
237 }
238 }
239 }
240
241 Ok(&self.buffer[self.consumed..])
242 }
243
244 fn consume(&mut self, amt: usize) {
245 self.consumed = (self.consumed + amt).min(self.buffer.len());
246
247 if self.consumed == self.buffer.len() {
249 self.buffer.clear();
250 self.consumed = 0;
251 }
252 }
253}
254
255use super::clipboard::get_platform_clipboard_command;
256use super::types::CommandResult;
257
258pub struct PromptCommand<'a> {
260 pub label: &'a str,
261 pub display_name: &'a str,
262 pub cmd_str: &'a str,
263 pub prompt: &'a str,
264 pub logfile: &'a str,
265 pub parser_type: JsonParserType,
266 pub env_vars: &'a std::collections::HashMap<String, String>,
267}
268
269pub struct PipelineRuntime<'a> {
271 pub timer: &'a mut Timer,
272 pub logger: &'a Logger,
273 pub colors: &'a Colors,
274 pub config: &'a Config,
275 pub executor: &'a dyn crate::executor::ProcessExecutor,
277 pub executor_arc: std::sync::Arc<dyn crate::executor::ProcessExecutor>,
279 pub workspace: &'a dyn crate::workspace::Workspace,
281}
282
283fn save_prompt_to_file_and_clipboard(
285 prompt: &str,
286 prompt_path: &std::path::Path,
287 interactive: bool,
288 logger: &Logger,
289 colors: Colors,
290 executor: &dyn crate::executor::ProcessExecutor,
291 workspace: &dyn crate::workspace::Workspace,
292) -> io::Result<()> {
293 workspace.write(prompt_path, prompt)?;
295 logger.info(&format!(
296 "Prompt saved to {}{}{}",
297 colors.cyan(),
298 prompt_path.display(),
299 colors.reset()
300 ));
301
302 if interactive {
304 if let Some(clipboard_cmd) = get_platform_clipboard_command() {
305 match executor.spawn(clipboard_cmd.binary, clipboard_cmd.args, &[], None) {
306 Ok(mut child) => {
307 if let Some(mut stdin) = child.stdin.take() {
308 let _ = stdin.write_all(prompt.as_bytes());
309 }
310 let _ = child.wait();
311 logger.info(&format!(
312 "Prompt copied to clipboard {}({}){}",
313 colors.dim(),
314 clipboard_cmd.paste_hint,
315 colors.reset()
316 ));
317 }
318 Err(e) => {
319 logger.warn(&format!("Failed to copy to clipboard: {}", e));
320 }
321 }
322 }
323 }
324 Ok(())
325}
326
327pub fn sanitize_command_env(
355 env_vars: &mut std::collections::HashMap<String, String>,
356 agent_env_vars: &std::collections::HashMap<String, String>,
357 vars_to_sanitize: &[&str],
358) {
359 for &var in vars_to_sanitize {
360 if !agent_env_vars.contains_key(var) {
361 env_vars.remove(var);
362 }
363 }
364}
365
366fn wait_for_completion_and_collect_stderr(
368 mut child: Box<dyn crate::executor::AgentChild>,
369 stderr_join_handle: Option<std::thread::JoinHandle<io::Result<String>>>,
370 runtime: &PipelineRuntime<'_>,
371) -> io::Result<(i32, String)> {
372 let status = child.wait()?;
373 let exit_code = status.code().unwrap_or(1);
374
375 if status.code().is_none() && runtime.config.verbosity.is_debug() {
376 runtime
377 .logger
378 .warn("Process terminated by signal (no exit code), treating as failure");
379 }
380
381 let stderr_output = match stderr_join_handle {
382 Some(handle) => match handle.join() {
383 Ok(result) => result?,
384 Err(panic_payload) => {
385 let panic_msg = panic_payload.downcast_ref::<String>().map_or_else(
386 || {
387 panic_payload.downcast_ref::<&str>().map_or_else(
388 || "<unknown panic>".to_string(),
389 std::string::ToString::to_string,
390 )
391 },
392 std::clone::Clone::clone,
393 );
394 runtime.logger.warn(&format!(
395 "Stderr collection thread panicked: {panic_msg}. This may indicate a bug."
396 ));
397 String::new()
398 }
399 },
400 None => String::new(),
401 };
402
403 if !stderr_output.is_empty() && runtime.config.verbosity.is_debug() {
404 runtime.logger.warn(&format!(
405 "Agent stderr output detected ({} bytes):",
406 stderr_output.len()
407 ));
408 for (i, line) in stderr_output.lines().take(5).enumerate() {
409 runtime.logger.info(&format!(" stderr[{i}]: {line}"));
410 }
411 if stderr_output.lines().count() > 5 {
412 runtime.logger.info(&format!(
413 " ... ({} more lines, see log file for full output)",
414 stderr_output.lines().count() - 5
415 ));
416 }
417 }
418
419 Ok((exit_code, stderr_output))
420}
421
422pub fn run_with_prompt(
426 cmd: &PromptCommand<'_>,
427 runtime: &mut PipelineRuntime<'_>,
428) -> io::Result<CommandResult> {
429 const ANTHROPIC_ENV_VARS_TO_SANITIZE: &[&str] = &[
430 "ANTHROPIC_API_KEY",
431 "ANTHROPIC_BASE_URL",
432 "ANTHROPIC_AUTH_TOKEN",
433 "ANTHROPIC_MODEL",
434 "ANTHROPIC_DEFAULT_HAIKU_MODEL",
435 "ANTHROPIC_DEFAULT_OPUS_MODEL",
436 "ANTHROPIC_DEFAULT_SONNET_MODEL",
437 ];
438
439 runtime.timer.start_phase();
440 runtime.logger.step(&format!(
441 "{}{}{}",
442 runtime.colors.bold(),
443 cmd.label,
444 runtime.colors.reset()
445 ));
446
447 save_prompt_to_file_and_clipboard(
448 cmd.prompt,
449 &runtime.config.prompt_path,
450 runtime.config.behavior.interactive,
451 runtime.logger,
452 *runtime.colors,
453 runtime.executor,
454 runtime.workspace,
455 )?;
456
457 run_with_agent_spawn(cmd, runtime, ANTHROPIC_ENV_VARS_TO_SANITIZE)
461}
462
463const SIGTERM_EXIT_CODE: i32 = 143;
465
466fn run_with_agent_spawn(
471 cmd: &PromptCommand<'_>,
472 runtime: &mut PipelineRuntime<'_>,
473 anthropic_env_vars_to_sanitize: &[&str],
474) -> io::Result<CommandResult> {
475 use std::sync::atomic::{AtomicBool, Ordering};
476
477 let argv = split_command(cmd.cmd_str)?;
479 if argv.is_empty() || cmd.cmd_str.trim().is_empty() {
480 return Err(io::Error::new(
481 io::ErrorKind::InvalidInput,
482 "Agent command is empty or contains only whitespace",
483 ));
484 }
485
486 let mut argv_for_log = argv.clone();
487 argv_for_log.push("<PROMPT>".to_string());
488 let display_cmd = truncate_text(&format_argv_for_log(&argv_for_log), 160);
489 runtime.logger.info(&format!(
490 "Executing: {}{}{}",
491 runtime.colors.dim(),
492 display_cmd,
493 runtime.colors.reset()
494 ));
495
496 let is_glm_cmd = is_glm_like_agent(cmd.cmd_str);
498 if is_glm_cmd {
499 runtime
500 .logger
501 .info(&format!("GLM command details: {display_cmd}"));
502 if argv.iter().any(|arg| arg == "-p") {
503 runtime
504 .logger
505 .info("GLM command includes '-p' flag (correct)");
506 } else {
507 runtime.logger.warn("GLM command may be missing '-p' flag");
508 }
509 }
510
511 let _uses_json = cmd.parser_type != JsonParserType::Generic || argv_requests_json(&argv);
512 runtime
513 .logger
514 .info(&format!("Using {} parser...", cmd.parser_type));
515
516 let logfile_path = Path::new(cmd.logfile);
518 runtime.workspace.write(logfile_path, "")?;
519
520 let mut complete_env: std::collections::HashMap<String, String> = std::env::vars().collect();
522 for (key, value) in cmd.env_vars.iter() {
523 complete_env.insert(key.clone(), value.clone());
524 }
525 sanitize_command_env(
526 &mut complete_env,
527 cmd.env_vars,
528 anthropic_env_vars_to_sanitize,
529 );
530
531 let config = crate::executor::AgentSpawnConfig {
533 command: argv[0].clone(),
534 args: argv[1..].to_vec(),
535 env: complete_env,
536 prompt: cmd.prompt.to_string(),
537 logfile: cmd.logfile.to_string(),
538 parser_type: cmd.parser_type,
539 };
540
541 let agent_handle = match runtime.executor.spawn_agent(&config) {
543 Ok(handle) => handle,
544 Err(e) => {
545 let (exit_code, detail) = match e.kind() {
548 io::ErrorKind::NotFound => (127, "command not found"),
549 io::ErrorKind::PermissionDenied => (126, "permission denied"),
550 io::ErrorKind::ArgumentListTooLong => {
551 (7, "argument list too long (prompt exceeds OS limit)")
552 }
553 io::ErrorKind::InvalidInput => (22, "invalid input"),
554 io::ErrorKind::OutOfMemory => (12, "out of memory"),
555 _ => (1, "spawn failed"),
556 };
557
558 return Ok(CommandResult {
559 exit_code,
560 stderr: format!("{}: {} - {}", argv[0], detail, e),
561 });
562 }
563 };
564
565 let child_id = agent_handle.inner.id();
567
568 let activity_timestamp = new_activity_timestamp();
570 let monitor_should_stop = Arc::new(AtomicBool::new(false));
571 let monitor_should_stop_clone = monitor_should_stop.clone();
572 let activity_timestamp_clone = activity_timestamp.clone();
573
574 let monitor_executor: Arc<dyn crate::executor::ProcessExecutor> =
577 std::sync::Arc::clone(&runtime.executor_arc);
578
579 let monitor_handle = std::thread::spawn(move || {
581 monitor_idle_timeout(
582 activity_timestamp_clone,
583 child_id,
584 IDLE_TIMEOUT_SECS,
585 monitor_should_stop_clone,
586 monitor_executor,
587 )
588 });
589
590 let stdout = agent_handle.stdout;
592 let stderr = agent_handle.stderr;
593 let inner = agent_handle.inner;
594
595 let stderr_join_handle = std::thread::spawn(move || -> io::Result<String> {
597 const STDERR_MAX_BYTES: usize = 512 * 1024;
598
599 let mut reader = BufReader::new(stderr);
600 let mut buf = [0u8; 8192];
601 let mut collected = Vec::<u8>::new();
602 let mut truncated = false;
603
604 loop {
605 let n = reader.read(&mut buf)?;
606 if n == 0 {
607 break;
608 }
609
610 let remaining = STDERR_MAX_BYTES.saturating_sub(collected.len());
611 if remaining == 0 {
612 truncated = true;
613 break;
614 }
615
616 let to_take = remaining.min(n);
617 collected.extend_from_slice(&buf[..to_take]);
618 if to_take < n {
619 truncated = true;
620 break;
621 }
622 }
623
624 let mut stderr_output = String::from_utf8_lossy(&collected).into_owned();
625 if truncated {
626 if !stderr_output.ends_with('\n') {
627 stderr_output.push('\n');
628 }
629 stderr_output.push_str("<stderr truncated>");
630 }
631
632 Ok(stderr_output)
633 });
634
635 stream_agent_output_from_handle(stdout, cmd, runtime, activity_timestamp)?;
637
638 monitor_should_stop.store(true, Ordering::Release);
640
641 let (exit_code, stderr_output) =
642 wait_for_completion_and_collect_stderr(inner, Some(stderr_join_handle), runtime)?;
643
644 let monitor_result = monitor_handle
646 .join()
647 .unwrap_or(MonitorResult::ProcessCompleted);
648
649 let final_exit_code = if monitor_result == MonitorResult::TimedOut {
651 runtime.logger.warn(&format!(
652 "Agent killed due to idle timeout (no output for {} seconds)",
653 IDLE_TIMEOUT_SECS
654 ));
655 SIGTERM_EXIT_CODE
656 } else {
657 exit_code
658 };
659
660 if runtime.config.verbosity.is_verbose() {
661 runtime.logger.info(&format!(
662 "Phase elapsed: {}",
663 runtime.timer.phase_elapsed_formatted()
664 ));
665 }
666
667 Ok(CommandResult {
668 exit_code: final_exit_code,
669 stderr: stderr_output,
670 })
671}
672
673fn stream_agent_output_from_handle(
678 stdout: Box<dyn io::Read + Send>,
679 cmd: &PromptCommand<'_>,
680 runtime: &PipelineRuntime<'_>,
681 activity_timestamp: SharedActivityTimestamp,
682) -> io::Result<()> {
683 let tracked_stdout = ActivityTrackingReader::new(stdout, activity_timestamp);
685 let reader = StreamingLineReader::new(tracked_stdout);
689
690 if cmd.parser_type != JsonParserType::Generic
691 || argv_requests_json(&split_command(cmd.cmd_str)?)
692 {
693 let stdout_io = io::stdout();
694 let mut out = stdout_io.lock();
695
696 match cmd.parser_type {
697 JsonParserType::Claude => {
698 let p = crate::json_parser::ClaudeParser::new(
699 *runtime.colors,
700 runtime.config.verbosity,
701 )
702 .with_display_name(cmd.display_name)
703 .with_log_file(cmd.logfile)
704 .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
705 p.parse_stream(reader, runtime.workspace)?;
706 }
707 JsonParserType::Codex => {
708 let p =
709 crate::json_parser::CodexParser::new(*runtime.colors, runtime.config.verbosity)
710 .with_display_name(cmd.display_name)
711 .with_log_file(cmd.logfile)
712 .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
713 p.parse_stream(reader, runtime.workspace)?;
714 }
715 JsonParserType::Gemini => {
716 let p = crate::json_parser::GeminiParser::new(
717 *runtime.colors,
718 runtime.config.verbosity,
719 )
720 .with_display_name(cmd.display_name)
721 .with_log_file(cmd.logfile)
722 .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
723 p.parse_stream(reader, runtime.workspace)?;
724 }
725 JsonParserType::OpenCode => {
726 let p = crate::json_parser::OpenCodeParser::new(
727 *runtime.colors,
728 runtime.config.verbosity,
729 )
730 .with_display_name(cmd.display_name)
731 .with_log_file(cmd.logfile)
732 .with_show_streaming_metrics(runtime.config.show_streaming_metrics);
733 p.parse_stream(reader, runtime.workspace)?;
734 }
735 JsonParserType::Generic => {
736 let logfile_path = Path::new(cmd.logfile);
737 let mut buf = String::new();
738 for line in reader.lines() {
739 let line = line?;
740 runtime
742 .workspace
743 .append_bytes(logfile_path, format!("{line}\n").as_bytes())?;
744 buf.push_str(&line);
745 buf.push('\n');
746 }
747
748 let formatted = format_generic_json_for_display(&buf, runtime.config.verbosity);
749 out.write_all(formatted.as_bytes())?;
750 }
751 }
752 } else {
753 let logfile_path = Path::new(cmd.logfile);
754 let stdout_io = io::stdout();
755 let mut out = stdout_io.lock();
756
757 for line in reader.lines() {
758 let line = line?;
759 writeln!(out, "{line}")?;
760 runtime
761 .workspace
762 .append_bytes(logfile_path, format!("{line}\n").as_bytes())?;
763 }
764 }
765 Ok(())
766}
767
768#[cfg(test)]
769mod tests {
770 use super::*;
771
772 fn test_logger() -> Logger {
773 Logger::new(Colors::new())
774 }
775
776 #[test]
777 fn test_truncate_prompt_small_content() {
778 let logger = test_logger();
779 let content = "This is a small prompt that fits within limits.";
780 let result = truncate_prompt_if_needed(content, &logger);
781 assert_eq!(result, content);
782 }
783
784 #[test]
785 fn test_truncate_prompt_large_content_with_marker() {
786 let logger = test_logger();
787 let prefix = "Task: Do something\n\n---\n";
789 let large_content = "x".repeat(MAX_PROMPT_SIZE + 50000);
790 let content = format!("{}{}", prefix, large_content);
791
792 let result = truncate_prompt_if_needed(&content, &logger);
793
794 assert!(result.len() < content.len());
796 assert!(result.contains("truncated"));
798 assert!(result.starts_with("Task:"));
800 }
801
802 #[test]
803 fn test_truncate_prompt_large_content_fallback() {
804 let logger = test_logger();
805 let content = "a".repeat(MAX_PROMPT_SIZE + 50000);
807
808 let result = truncate_prompt_if_needed(&content, &logger);
809
810 assert!(result.len() < content.len());
812 assert!(result.contains("truncated"));
814 }
815
816 #[test]
817 fn test_truncate_prompt_preserves_end() {
818 let logger = test_logger();
819 let prefix = "Instructions\n\n---\n";
821 let middle = "m".repeat(MAX_PROMPT_SIZE);
822 let suffix = "\nIMPORTANT_END_MARKER";
823 let content = format!("{}{}{}", prefix, middle, suffix);
824
825 let result = truncate_prompt_if_needed(&content, &logger);
826
827 assert!(result.contains("IMPORTANT_END_MARKER"));
829 }
830}
831
832#[cfg(test)]
833mod sanitize_env_tests {
834 use super::*;
835 use std::collections::HashMap;
836
837 const ANTHROPIC_ENV_VARS_TO_SANITIZE: &[&str] = &[
838 "ANTHROPIC_API_KEY",
839 "ANTHROPIC_BASE_URL",
840 "ANTHROPIC_AUTH_TOKEN",
841 "ANTHROPIC_MODEL",
842 "ANTHROPIC_DEFAULT_HAIKU_MODEL",
843 "ANTHROPIC_DEFAULT_OPUS_MODEL",
844 "ANTHROPIC_DEFAULT_SONNET_MODEL",
845 ];
846
847 #[test]
848 fn test_sanitize_command_env_removes_anthropic_vars_when_not_explicitly_set() {
849 let mut env_vars = HashMap::from([
851 ("ANTHROPIC_API_KEY".to_string(), "glm-test-key".to_string()),
852 (
853 "ANTHROPIC_BASE_URL".to_string(),
854 "https://glm.example.com".to_string(),
855 ),
856 ("PATH".to_string(), "/usr/bin:/bin".to_string()),
857 ("HOME".to_string(), "/home/user".to_string()),
858 ]);
859 let agent_env_vars = HashMap::new(); sanitize_command_env(
863 &mut env_vars,
864 &agent_env_vars,
865 ANTHROPIC_ENV_VARS_TO_SANITIZE,
866 );
867
868 assert!(
870 !env_vars.contains_key("ANTHROPIC_API_KEY"),
871 "ANTHROPIC_API_KEY should be removed when not explicitly set by agent"
872 );
873 assert!(
874 !env_vars.contains_key("ANTHROPIC_BASE_URL"),
875 "ANTHROPIC_BASE_URL should be removed when not explicitly set by agent"
876 );
877 assert_eq!(
878 env_vars.get("PATH"),
879 Some(&"/usr/bin:/bin".to_string()),
880 "Non-Anthropic vars should be preserved"
881 );
882 assert_eq!(
883 env_vars.get("HOME"),
884 Some(&"/home/user".to_string()),
885 "Non-Anthropic vars should be preserved"
886 );
887 }
888
889 #[test]
890 fn test_sanitize_command_env_preserves_explicitly_set_anthropic_vars() {
891 let mut env_vars = HashMap::from([
893 ("ANTHROPIC_API_KEY".to_string(), "parent-key".to_string()),
894 (
895 "ANTHROPIC_BASE_URL".to_string(),
896 "https://parent.example.com".to_string(),
897 ),
898 (
899 "ANTHROPIC_AUTH_TOKEN".to_string(),
900 "parent-token".to_string(),
901 ),
902 ("PATH".to_string(), "/usr/bin:/bin".to_string()),
903 ]);
904 let agent_env_vars = HashMap::from([
905 (
906 "ANTHROPIC_API_KEY".to_string(),
907 "agent-specific-key".to_string(),
908 ),
909 (
910 "ANTHROPIC_BASE_URL".to_string(),
911 "https://agent.example.com".to_string(),
912 ),
913 ]);
914
915 for (key, value) in agent_env_vars.iter() {
917 env_vars.insert(key.clone(), value.clone());
918 }
919
920 sanitize_command_env(
922 &mut env_vars,
923 &agent_env_vars,
924 ANTHROPIC_ENV_VARS_TO_SANITIZE,
925 );
926
927 assert_eq!(
929 env_vars.get("ANTHROPIC_API_KEY"),
930 Some(&"agent-specific-key".to_string()),
931 "ANTHROPIC_API_KEY explicitly set by agent should be preserved"
932 );
933 assert_eq!(
934 env_vars.get("ANTHROPIC_BASE_URL"),
935 Some(&"https://agent.example.com".to_string()),
936 "ANTHROPIC_BASE_URL explicitly set by agent should be preserved"
937 );
938 assert!(
939 !env_vars.contains_key("ANTHROPIC_AUTH_TOKEN"),
940 "ANTHROPIC_AUTH_TOKEN not explicitly set by agent should be removed"
941 );
942 assert_eq!(
943 env_vars.get("PATH"),
944 Some(&"/usr/bin:/bin".to_string()),
945 "Non-Anthropic vars should be preserved"
946 );
947 }
948
949 #[test]
950 fn test_sanitize_command_env_handles_empty_env_vars() {
951 let mut env_vars = HashMap::new();
953 let agent_env_vars = HashMap::new();
954
955 sanitize_command_env(
957 &mut env_vars,
958 &agent_env_vars,
959 ANTHROPIC_ENV_VARS_TO_SANITIZE,
960 );
961
962 assert!(env_vars.is_empty(), "Empty environment should remain empty");
964 }
965
966 #[test]
967 fn test_sanitize_command_env_handles_all_anthropic_vars() {
968 let mut env_vars: std::collections::HashMap<String, String> =
970 ANTHROPIC_ENV_VARS_TO_SANITIZE
971 .iter()
972 .map(|&var| (var.to_string(), format!("value-{var}")))
973 .collect();
974 env_vars.insert("OTHER_VAR".to_string(), "other-value".to_string());
975
976 let agent_env_vars = HashMap::new();
977
978 sanitize_command_env(
980 &mut env_vars,
981 &agent_env_vars,
982 ANTHROPIC_ENV_VARS_TO_SANITIZE,
983 );
984
985 for &var in ANTHROPIC_ENV_VARS_TO_SANITIZE {
987 assert!(
988 !env_vars.contains_key(var),
989 "{var} should be removed when not explicitly set"
990 );
991 }
992 assert_eq!(
993 env_vars.get("OTHER_VAR"),
994 Some(&"other-value".to_string()),
995 "Non-Anthropic vars should be preserved"
996 );
997 }
998}