pg_embed/
command_executor.rs

1//!
2//! Process command creation and execution
3//!
4use std::error::Error;
5use std::ffi::OsStr;
6use std::marker;
7use std::process::Stdio;
8
9use async_trait::async_trait;
10use log;
11use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
12use tokio::process::Child;
13use tokio::sync::mpsc::{Receiver, Sender};
14use tokio::time::Duration;
15
16///
17/// Output logging type
18///
19#[derive(Debug)]
20pub enum LogType {
21    Info,
22    Error,
23}
24
25///
26/// Child process status
27///
28pub trait ProcessStatus<T, E>
29where
30    E: Error + Send,
31    Self: Send,
32{
33    /// process entry status
34    fn status_entry(&self) -> T;
35    /// process exit status
36    fn status_exit(&self) -> T;
37    /// process error type
38    fn error_type(&self) -> E;
39    /// wrap error
40    fn wrap_error<F: Error + Sync + Send + 'static>(&self, error: F, message: Option<String>) -> E;
41}
42
43///
44/// Logging data
45///
46#[derive(Debug)]
47pub struct LogOutputData {
48    line: String,
49    log_type: LogType,
50}
51
52///
53/// Async command trait
54///
55#[async_trait]
56pub trait AsyncCommand<S, E, P>
57where
58    E: Error + Send,
59    P: ProcessStatus<S, E> + Send,
60    Self: Sized,
61{
62    ///
63    /// Create a new async command
64    ///
65    fn new<A, B>(executable_path: &OsStr, args: A, process_type: P) -> Result<Self, E>
66    where
67        A: IntoIterator<Item = B>,
68        B: AsRef<OsStr>;
69    ///
70    /// Execute command
71    ///
72    /// When timeout is Some(duration) the process execution will be timed out after duration,
73    /// if set to None the process execution will not be timed out.
74    ///
75    async fn execute(&mut self, timeout: Option<Duration>) -> Result<S, E>;
76}
77
78///
79/// Process command
80///
81pub struct AsyncCommandExecutor<S, E, P>
82where
83    S: Send,
84    E: Error + Send,
85    P: ProcessStatus<S, E>,
86    Self: Send,
87{
88    /// Process command
89    command: tokio::process::Command,
90    /// Process child
91    process: Child,
92    /// Process type
93    process_type: P,
94    _marker_s: marker::PhantomData<S>,
95    _marker_e: marker::PhantomData<E>,
96}
97
98impl<S, E, P> AsyncCommandExecutor<S, E, P>
99where
100    S: Send,
101    E: Error + Send,
102    P: ProcessStatus<S, E> + Send,
103{
104    /// Initialize command
105    fn init(command: &mut tokio::process::Command, process_type: &P) -> Result<Child, E> {
106        command
107            .stdout(Stdio::piped())
108            .stderr(Stdio::piped())
109            .spawn()
110            .map_err(|_| process_type.error_type())
111    }
112
113    /// Generate a command
114    fn generate_command<A, B>(executable_path: &OsStr, args: A) -> tokio::process::Command
115    where
116        A: IntoIterator<Item = B>,
117        B: AsRef<OsStr>,
118    {
119        let mut command = tokio::process::Command::new(executable_path);
120        command.args(args);
121        command
122    }
123
124    /// Handle process output
125    async fn handle_output<R: AsyncRead + Unpin>(data: R, sender: Sender<LogOutputData>) -> () {
126        let mut lines = BufReader::new(data).lines();
127        while let Some(line) = lines.next_line().await.expect("error handling output") {
128            let io_data = LogOutputData {
129                line,
130                log_type: LogType::Info,
131            };
132            sender
133                .send(io_data)
134                .await
135                .expect("error sending log output data");
136        }
137    }
138
139    /// Log process output
140    async fn log_output(mut receiver: Receiver<LogOutputData>) -> () {
141        while let Some(data) = receiver.recv().await {
142            match data.log_type {
143                LogType::Info => {
144                    log::info!("{}", data.line);
145                }
146                LogType::Error => {
147                    log::error!("{}", data.line);
148                }
149            }
150        }
151    }
152
153    /// Run process
154    async fn run_process(&mut self) -> Result<S, E> {
155        let exit_status = self
156            .process
157            .wait()
158            .await
159            .map_err(|e| self.process_type.wrap_error(e, None))?;
160        if exit_status.success() {
161            Ok(self.process_type.status_exit())
162        } else {
163            Err(self.process_type.error_type())
164        }
165    }
166
167    #[cfg(not(target_os = "windows"))]
168    async fn command_execution(&mut self) -> Result<S, E> {
169        let (sender, receiver) = tokio::sync::mpsc::channel::<LogOutputData>(1000);
170        let res = self.run_process().await;
171        let stdout = self.process.stdout.take().unwrap();
172        let stderr = self.process.stderr.take().unwrap();
173        let tx = sender.clone();
174        let _ = tokio::task::spawn(async { Self::handle_output(stdout, tx).await });
175        let _ = tokio::task::spawn(async { Self::handle_output(stderr, sender).await });
176        let _ = tokio::task::spawn(async { Self::log_output(receiver).await });
177        res
178    }
179
180    #[cfg(target_os = "windows")]
181    async fn command_execution(&mut self) -> Result<S, E> {
182        //TODO: find another way to use stderr on windows
183        // let (sender, receiver) = tokio::sync::mpsc::channel::<LogOutputData>(1000);
184        let res = self.run_process().await;
185        // let stdout = self.process.stdout.take().unwrap();
186        // let stderr = self.process.stderr.take().unwrap();
187        // let tx = sender.clone();
188        // let _ = tokio::task::spawn(async { Self::handle_output(stdout, tx).await });
189        // let _ = tokio::task::spawn(async { Self::handle_output(stderr, sender).await });
190        // let _ = tokio::task::spawn(async { Self::log_output(receiver).await });
191        res
192    }
193
194}
195
196#[async_trait]
197impl<S, E, P> AsyncCommand<S, E, P> for AsyncCommandExecutor<S, E, P>
198where
199    S: Send,
200    E: Error + Send,
201    P: ProcessStatus<S, E> + Send,
202{
203    fn new<A, B>(executable_path: &OsStr, args: A, process_type: P) -> Result<Self, E>
204    where
205        A: IntoIterator<Item = B>,
206        B: AsRef<OsStr>,
207    {
208        let mut command = Self::generate_command(executable_path, args);
209        let process = Self::init(&mut command, &process_type)?;
210        Ok(AsyncCommandExecutor {
211            command,
212            process,
213            process_type,
214            _marker_s: Default::default(),
215            _marker_e: Default::default(),
216        })
217    }
218
219    async fn execute(&mut self, timeout: Option<Duration>) -> Result<S, E> {
220        match timeout {
221            None => self.command_execution().await,
222            Some(duration) => tokio::time::timeout(duration, self.command_execution())
223                .await
224                .map_err(|e| {
225                    self.process_type
226                        .wrap_error(e, Some(String::from("timed out")))
227                })?,
228        }
229    }
230}