Skip to main content

aster_test/mcp/stdio/
record.rs

1use std::fs::OpenOptions;
2use std::io::{self, BufRead, BufReader, Write};
3use std::process::{ChildStdin, Command, Stdio};
4use std::sync::mpsc;
5use std::thread::{self, JoinHandle};
6
7#[derive(Debug, Clone)]
8enum StreamType {
9    Stdin,
10    Stdout,
11    Stderr,
12}
13
14fn handle_output_stream<R: BufRead + Send + 'static>(
15    reader: R,
16    sender: mpsc::Sender<(StreamType, String)>,
17    stream_type: StreamType,
18    mut output_writer: Box<dyn Write + Send>,
19) -> JoinHandle<()> {
20    thread::spawn(move || {
21        for line in reader.lines() {
22            match line {
23                Ok(line) => {
24                    let _ = sender.send((stream_type.clone(), line.clone()));
25
26                    if writeln!(output_writer, "{}", line).is_err() {
27                        break;
28                    }
29                }
30                Err(_) => break,
31            }
32        }
33    })
34}
35
36fn handle_stdin_stream(
37    mut child_stdin: ChildStdin,
38    sender: mpsc::Sender<(StreamType, String)>,
39) -> JoinHandle<()> {
40    thread::spawn(move || {
41        let stdin = io::stdin();
42
43        for line in stdin.lock().lines() {
44            match line {
45                Ok(line) => {
46                    let _ = sender.send((StreamType::Stdin, line.clone()));
47
48                    if writeln!(child_stdin, "{}", line).is_err() {
49                        break;
50                    }
51                }
52                Err(_) => break,
53            }
54        }
55    })
56}
57
58pub fn record(log_file_path: &String, cmd: &String, cmd_args: &[String]) -> io::Result<()> {
59    let (tx, rx) = mpsc::channel();
60
61    let log_file = OpenOptions::new()
62        .create(true)
63        .write(true)
64        .truncate(true)
65        .open(log_file_path)?;
66
67    let mut child = Command::new(cmd)
68        .args(cmd_args.iter())
69        .stdin(Stdio::piped())
70        .stdout(Stdio::piped())
71        .stderr(Stdio::piped())
72        .spawn()
73        .inspect_err(|e| eprintln!("Failed to execute command '{}': {}", &cmd, e))?;
74
75    let child_stdin = child.stdin.take().unwrap();
76    let child_stdout = child.stdout.take().unwrap();
77    let child_stderr = child.stderr.take().unwrap();
78
79    let stdin_handle = handle_stdin_stream(child_stdin, tx.clone());
80    let stdout_handle = handle_output_stream(
81        BufReader::new(child_stdout),
82        tx.clone(),
83        StreamType::Stdout,
84        Box::new(io::stdout()),
85    );
86    let stderr_handle = handle_output_stream(
87        BufReader::new(child_stderr),
88        tx.clone(),
89        StreamType::Stderr,
90        Box::new(io::stderr()),
91    );
92
93    thread::spawn(move || {
94        let mut log_file = log_file;
95        for (stream_type, line) in rx {
96            let prefix = match stream_type {
97                StreamType::Stdin => "STDIN",
98                StreamType::Stdout => "STDOUT",
99                StreamType::Stderr => "STDERR",
100            };
101            if let Err(e) = writeln!(log_file, "{}: {}", prefix, line) {
102                eprintln!("Error writing to log file: {}", e);
103            }
104            log_file.flush().ok();
105        }
106    });
107
108    child.wait()?;
109
110    stdin_handle.join().ok();
111    stdout_handle.join().ok();
112    stderr_handle.join().ok();
113
114    Ok(())
115}