1use crate::error::Result;
7use async_trait::async_trait;
8use std::process::Stdio;
9use tokio::io::{AsyncBufReadExt, BufReader};
10use tokio::process::Command as TokioCommand;
11use tokio::sync::mpsc;
12
13#[derive(Debug, Clone)]
15pub enum OutputLine {
16 Stdout(String),
18 Stderr(String),
20}
21
22#[derive(Debug, Clone)]
24pub struct StreamResult {
25 pub exit_code: i32,
27 pub success: bool,
29 pub stdout: Option<String>,
31 pub stderr: Option<String>,
33}
34
35impl StreamResult {
36 #[must_use]
38 pub fn is_success(&self) -> bool {
39 self.success
40 }
41}
42
43#[async_trait]
45pub trait StreamableCommand: Send + Sync {
46 async fn stream<F>(&self, handler: F) -> Result<StreamResult>
52 where
53 F: FnMut(OutputLine) + Send + 'static;
54
55 async fn stream_channel(&self) -> Result<(mpsc::Receiver<OutputLine>, StreamResult)>;
61}
62
63pub struct StreamHandler;
65
66impl StreamHandler {
67 pub fn print() -> impl FnMut(OutputLine) {
69 move |line| match line {
70 OutputLine::Stdout(s) => println!("{s}"),
71 OutputLine::Stderr(s) => eprintln!("{s}"),
72 }
73 }
74
75 pub fn tee<F>(mut handler: F) -> impl FnMut(OutputLine) -> (Vec<String>, Vec<String>)
77 where
78 F: FnMut(&OutputLine),
79 {
80 let mut stdout_lines = Vec::new();
81 let mut stderr_lines = Vec::new();
82
83 move |line| {
84 handler(&line);
85 match line {
86 OutputLine::Stdout(s) => stdout_lines.push(s),
87 OutputLine::Stderr(s) => stderr_lines.push(s),
88 }
89 (stdout_lines.clone(), stderr_lines.clone())
90 }
91 }
92
93 pub fn filter(pattern: String) -> impl FnMut(OutputLine) -> Option<String> {
95 move |line| {
96 let text = match &line {
97 OutputLine::Stdout(s) | OutputLine::Stderr(s) => s,
98 };
99 if text.contains(&pattern) {
100 Some(text.clone())
101 } else {
102 None
103 }
104 }
105 }
106
107 pub fn with_prefix(prefix: String) -> impl FnMut(OutputLine) {
109 move |line| match line {
110 OutputLine::Stdout(s) => println!("{prefix}: {s}"),
111 OutputLine::Stderr(s) => eprintln!("{prefix} (error): {s}"),
112 }
113 }
114}
115
116pub(crate) async fn stream_command(
118 mut cmd: TokioCommand,
119 mut handler: impl FnMut(OutputLine) + Send + 'static,
120) -> Result<StreamResult> {
121 cmd.stdout(Stdio::piped());
122 cmd.stderr(Stdio::piped());
123
124 let mut child = cmd
125 .spawn()
126 .map_err(|e| crate::error::Error::custom(format!("Failed to spawn command: {e}")))?;
127
128 let stdout = child
129 .stdout
130 .take()
131 .ok_or_else(|| crate::error::Error::custom("Failed to capture stdout"))?;
132 let stderr = child
133 .stderr
134 .take()
135 .ok_or_else(|| crate::error::Error::custom("Failed to capture stderr"))?;
136
137 let stdout_reader = BufReader::new(stdout);
138 let stderr_reader = BufReader::new(stderr);
139 let mut stdout_lines = stdout_reader.lines();
140 let mut stderr_lines = stderr_reader.lines();
141
142 let mut stdout_accumulator = Vec::new();
143 let mut stderr_accumulator = Vec::new();
144
145 loop {
146 tokio::select! {
147 line = stdout_lines.next_line() => {
148 match line {
149 Ok(Some(text)) => {
150 stdout_accumulator.push(text.clone());
151 handler(OutputLine::Stdout(text));
152 }
153 Ok(None) => break,
154 Err(e) => {
155 return Err(crate::error::Error::custom(
156 format!("Error reading stdout: {e}")
157 ));
158 }
159 }
160 }
161 line = stderr_lines.next_line() => {
162 match line {
163 Ok(Some(text)) => {
164 stderr_accumulator.push(text.clone());
165 handler(OutputLine::Stderr(text));
166 }
167 Ok(None) => break,
168 Err(e) => {
169 return Err(crate::error::Error::custom(
170 format!("Error reading stderr: {e}")
171 ));
172 }
173 }
174 }
175 }
176 }
177
178 let status = child
179 .wait()
180 .await
181 .map_err(|e| crate::error::Error::custom(format!("Failed to wait for command: {e}")))?;
182
183 Ok(StreamResult {
184 exit_code: status.code().unwrap_or(-1),
185 success: status.success(),
186 stdout: Some(stdout_accumulator.join("\n")),
187 stderr: Some(stderr_accumulator.join("\n")),
188 })
189}
190
191pub(crate) async fn stream_command_channel(
193 mut cmd: TokioCommand,
194) -> Result<(mpsc::Receiver<OutputLine>, StreamResult)> {
195 let (tx, rx) = mpsc::channel(100);
196
197 cmd.stdout(Stdio::piped());
198 cmd.stderr(Stdio::piped());
199
200 let mut child = cmd
201 .spawn()
202 .map_err(|e| crate::error::Error::custom(format!("Failed to spawn command: {e}")))?;
203
204 let stdout = child
205 .stdout
206 .take()
207 .ok_or_else(|| crate::error::Error::custom("Failed to capture stdout"))?;
208 let stderr = child
209 .stderr
210 .take()
211 .ok_or_else(|| crate::error::Error::custom("Failed to capture stderr"))?;
212
213 let tx_clone = tx.clone();
214
215 let stdout_task = tokio::spawn(async move {
217 let reader = BufReader::new(stdout);
218 let mut reader_lines = reader.lines();
219 let mut lines = Vec::new();
220 while let Ok(Some(line)) = reader_lines.next_line().await {
221 lines.push(line.clone());
222 let _ = tx.send(OutputLine::Stdout(line)).await;
223 }
224 lines
225 });
226
227 let stderr_task = tokio::spawn(async move {
229 let reader = BufReader::new(stderr);
230 let mut reader_lines = reader.lines();
231 let mut lines = Vec::new();
232 while let Ok(Some(line)) = reader_lines.next_line().await {
233 lines.push(line.clone());
234 let _ = tx_clone.send(OutputLine::Stderr(line)).await;
235 }
236 lines
237 });
238
239 let status_future = child.wait();
241 let (stdout_lines, stderr_lines, status) =
242 tokio::join!(stdout_task, stderr_task, status_future);
243
244 let stdout_lines = stdout_lines.unwrap_or_default();
245 let stderr_lines = stderr_lines.unwrap_or_default();
246 let status = status
247 .map_err(|e| crate::error::Error::custom(format!("Failed to wait for command: {e}")))?;
248
249 Ok((
250 rx,
251 StreamResult {
252 exit_code: status.code().unwrap_or(-1),
253 success: status.success(),
254 stdout: Some(stdout_lines.join("\n")),
255 stderr: Some(stderr_lines.join("\n")),
256 },
257 ))
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_output_line() {
266 let stdout = OutputLine::Stdout("test".to_string());
267 let stderr = OutputLine::Stderr("error".to_string());
268
269 match stdout {
270 OutputLine::Stdout(s) => assert_eq!(s, "test"),
271 OutputLine::Stderr(_) => panic!("Wrong variant"),
272 }
273
274 match stderr {
275 OutputLine::Stderr(s) => assert_eq!(s, "error"),
276 OutputLine::Stdout(_) => panic!("Wrong variant"),
277 }
278 }
279
280 #[test]
281 fn test_stream_result() {
282 let result = StreamResult {
283 exit_code: 0,
284 success: true,
285 stdout: Some("output".to_string()),
286 stderr: None,
287 };
288
289 assert!(result.is_success());
290 assert_eq!(result.exit_code, 0);
291 assert_eq!(result.stdout, Some("output".to_string()));
292 assert!(result.stderr.is_none());
293 }
294
295 #[test]
296 fn test_stream_handler_filter() {
297 let mut filter = StreamHandler::filter("error".to_string());
298
299 let result1 = filter(OutputLine::Stdout(
300 "this contains error message".to_string(),
301 ));
302 assert_eq!(result1, Some("this contains error message".to_string()));
303
304 let result2 = filter(OutputLine::Stdout("normal message".to_string()));
305 assert!(result2.is_none());
306 }
307}