sfo_io/
qa_process.rs

1use tokio::io::{AsyncReadExt, AsyncWriteExt};
2use shlex::Shlex;
3use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout};
4use crate::error::{into_sfoio_err, sfoio_err, SfoIOErrorCode, SfoIOResult};
5
6pub struct QAProcess {
7    stdin: Option<ChildStdin>,
8    stdout: Option<ChildStdout>,
9    stderr: Option<ChildStderr>,
10    child: Child,
11}
12
13impl QAProcess {
14    pub fn new(mut child: Child) -> Self {
15        Self {
16            stdin: child.stdin.take(),
17            stdout: child.stdout.take(),
18            stderr: child.stderr.take(),
19            child,
20        }
21    }
22
23    pub async fn answer(&mut self, question: &str, answer: &str) -> SfoIOResult<()> {
24        // self.stdin.as_mut().unwrap().write_all(answer.as_bytes()).await.map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
25        let mut offset = 0;
26        let mut buf = [0u8; 4096];
27        let mut error_buf = [0u8; 4096];
28        let mut error_offset = 0;
29        loop {
30            if offset == buf.len() || error_offset == error_buf.len() {
31                return Err(sfoio_err!(SfoIOErrorCode::Failed, "Buffer overflow"));
32            }
33
34            tokio::select! {
35                ret = self.stderr.as_mut().unwrap().read(&mut error_buf[error_offset..error_offset+1]) => {
36                    match ret {
37                        Ok(len) => {
38                            if len == 0 {
39                                return Err(sfoio_err!(SfoIOErrorCode::Failed, "EOF"));
40                            }
41                            error_offset += len;
42                            let current = String::from_utf8_lossy(&error_buf[..error_offset]).to_string();
43                            // log::info!("current err:{}", current);
44                            if current.ends_with(question) {
45                                let stdin = self.stdin.as_mut().ok_or_else(||sfoio_err!(SfoIOErrorCode::Failed, "Failed to get stdin"))?;
46                                // log::info!("write:{}", answer);
47                                stdin.write_all(answer.as_bytes()).await.map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
48                                stdin.write_all("\n".as_bytes()).await.map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
49                                // log::info!("write:{} finish", answer);
50                                break;
51                            }
52                        },
53                        Err(e) => {
54                            return Err(into_sfoio_err!(SfoIOErrorCode::Failed)(e))
55                        }
56                    }
57                },
58                ret = self.stdout.as_mut().unwrap().read(&mut buf[offset..offset+1]) => {
59                    match ret {
60                        Ok(len) => {
61                            if len == 0 {
62                                return Err(sfoio_err!(SfoIOErrorCode::Failed, "EOF"));
63                            }
64                            offset += len;
65                            let current = String::from_utf8_lossy(&buf[..offset]).to_string();
66                            // log::info!("current:{}", current);
67                            if current.ends_with(question) {
68                                let stdin = self.stdin.as_mut().ok_or_else(||sfoio_err!(SfoIOErrorCode::Failed, "Failed to get stdin"))?;
69                                // log::info!("write:{}", answer);
70                                stdin.write_all(answer.as_bytes()).await.map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
71                                stdin.write_all("\n".as_bytes()).await.map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
72                                // log::info!("write:{} finish", answer);
73                                break;
74                            }
75                        },
76                        Err(e) => {
77                            return Err(into_sfoio_err!(SfoIOErrorCode::Failed)(e))
78                        }
79                    }
80                }
81                _ = self.child.wait() => {
82                    break;
83                }
84            }
85        }
86        Ok(())
87    }
88
89    pub async fn wait(&mut self) -> SfoIOResult<()> {
90        let status = self.child.wait().await.map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
91        if status.success() {
92            Ok(())
93        } else {
94            let stderr = self.stderr.as_mut().ok_or_else(||sfoio_err!(SfoIOErrorCode::Failed, "Failed to get stderr"))?;
95            let mut error = Vec::new();
96            stderr.read_to_end(&mut error).await.map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
97
98            Err(sfoio_err!(SfoIOErrorCode::Failed, "{}", String::from_utf8_lossy(error.as_slice())))
99        }
100    }
101}
102
103pub async fn execute(cmd: &str) -> SfoIOResult<Vec<u8>> {
104    let mut lexer = Shlex::new(cmd);
105    let args: Vec<String> = lexer.by_ref().collect();
106    let output = tokio::process::Command::new(args[0].as_str())
107        .args(&args[1..])
108        .output()
109        .await
110        .map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
111    if output.status.success() {
112        Ok(output.stdout)
113    } else {
114        Err(sfoio_err!(SfoIOErrorCode::CmdReturnFailed, "{}", String::from_utf8_lossy(output.stderr.as_slice())))
115    }
116}
117
118pub async fn spawn(cmd: &str) -> SfoIOResult<QAProcess> {
119    let mut lexer = Shlex::new(cmd);
120    let args: Vec<String> = lexer.by_ref().collect();
121    let child = tokio::process::Command::new(args[0].as_str())
122        .args(&args[1..])
123        .stdin(std::process::Stdio::piped())
124        .stdout(std::process::Stdio::piped())
125        .stderr(std::process::Stdio::piped())
126        .spawn()
127        .map_err(into_sfoio_err!(SfoIOErrorCode::Failed))?;
128
129    Ok(QAProcess::new(child))
130}