Skip to main content

pg_embed/
command_executor.rs

1//! Generic async process runner used by the pg_ctl and initdb wrappers.
2//!
3//! The core abstraction is [`AsyncCommand`], a trait with two methods:
4//! [`AsyncCommand::new`] spawns the OS process and [`AsyncCommand::execute`]
5//! waits for it to finish (with an optional timeout).
6//!
7//! [`ProcessStatus`] is a companion trait that maps a process type (initdb,
8//! start, stop) to the status values and errors it should produce.
9//!
10//! The only concrete implementation is [`AsyncCommandExecutor`].
11
12use std::error::Error;
13use std::ffi::OsStr;
14use std::marker;
15use std::process::Stdio;
16
17use log;
18use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
19use tokio::process::Child;
20use tokio::sync::mpsc::{Receiver, Sender};
21use tokio::time::Duration;
22
23/// Indicates whether a log line came from stdout or stderr.
24#[derive(Debug)]
25pub enum LogType {
26    /// Standard output line.
27    Info,
28    /// Standard error line.
29    Error,
30}
31
32/// Maps a process type to the status values and errors it should produce.
33///
34/// Implement this trait on an enum whose variants represent the distinct
35/// processes you want to run (e.g. `InitDb`, `StartDb`, `StopDb`).  The
36/// executor calls these methods to produce typed status/error values without
37/// knowing the concrete error type.
38pub trait ProcessStatus<T, E>
39where
40    E: Error + Send,
41    Self: Send,
42{
43    /// Returns the status value that signals the process has *entered* execution.
44    fn status_entry(&self) -> T;
45
46    /// Returns the status value that signals the process *exited successfully*.
47    fn status_exit(&self) -> T;
48
49    /// Returns the error value for a generic process failure (non-zero exit,
50    /// spawn error, etc.).
51    fn error_type(&self) -> E;
52
53    /// Returns the error value to use when the process exceeds its timeout.
54    ///
55    /// Defaults to [`Self::error_type`].  Override to return a distinct
56    /// timeout-specific error variant.
57    fn timeout_error(&self) -> E {
58        self.error_type()
59    }
60
61    /// Wraps a foreign error `F` (e.g. an OS I/O error) into `E`, optionally
62    /// attaching a context `message`.
63    fn wrap_error<F: Error + Sync + Send + 'static>(&self, error: F, message: Option<String>) -> E;
64}
65
66/// A single log line captured from a child process stream.
67#[derive(Debug)]
68pub struct LogOutputData {
69    line: String,
70    log_type: LogType,
71}
72
73/// Trait for types that can spawn and execute an OS process asynchronously.
74///
75/// The type parameter `S` is the success-status type (e.g. [`crate::pg_enums::PgServerStatus`]),
76/// `E` is the error type, and `P` is the [`ProcessStatus`] implementation that
77/// provides status/error mappings.
78///
79/// `Self: Sized` means this trait cannot be used as a `dyn` trait object.  Use
80/// the concrete [`AsyncCommandExecutor`] directly.
81#[allow(async_fn_in_trait)]
82pub trait AsyncCommand<S, E, P>
83where
84    E: Error + Send,
85    P: ProcessStatus<S, E> + Send,
86    Self: Sized,
87{
88    /// Creates and spawns a new OS process.
89    ///
90    /// # Arguments
91    ///
92    /// * `executable_path` — Path to the executable (e.g. `initdb`, `pg_ctl`).
93    /// * `args` — Command-line arguments to pass to the executable.
94    /// * `process_type` — The [`ProcessStatus`] value describing this process.
95    ///
96    /// # Errors
97    ///
98    /// Returns `E::error_type()` if the process cannot be spawned.
99    fn new<A, B>(executable_path: &OsStr, args: A, process_type: P) -> Result<Self, E>
100    where
101        A: IntoIterator<Item = B>,
102        B: AsRef<OsStr>;
103
104    /// Waits for the process to finish, optionally enforcing a deadline.
105    ///
106    /// Stdout and stderr are captured and forwarded to the [`log`] crate
107    /// (at `info` level) in background tasks.
108    ///
109    /// # Arguments
110    ///
111    /// * `timeout` — If `Some(duration)`, the process is killed and an error
112    ///   is returned if it does not finish within `duration`.  `None` waits
113    ///   indefinitely.
114    ///
115    /// # Returns
116    ///
117    /// The [`ProcessStatus::status_exit`] value on success.
118    ///
119    /// # Errors
120    ///
121    /// Returns [`ProcessStatus::timeout_error`] if the deadline is exceeded.
122    /// Returns [`ProcessStatus::error_type`] if the process exits with a
123    /// non-zero status.
124    /// Returns a wrapped error from [`ProcessStatus::wrap_error`] if waiting
125    /// on the process fails.
126    async fn execute(&mut self, timeout: Option<Duration>) -> Result<S, E>;
127}
128
129/// Concrete implementation of [`AsyncCommand`] built on [`tokio::process`].
130///
131/// Created through [`AsyncCommand::new`]; the process is spawned immediately
132/// and stdout/stderr are piped.  Call [`AsyncCommand::execute`] to wait for
133/// completion.
134pub struct AsyncCommandExecutor<S, E, P>
135where
136    S: Send,
137    E: Error + Send,
138    P: ProcessStatus<S, E>,
139    Self: Send,
140{
141    /// The Tokio command handle (kept alive so the process is not killed on drop).
142    _command: tokio::process::Command,
143    /// The spawned child process.
144    process: Child,
145    /// Determines status/error values for this specific process type.
146    process_type: P,
147    _marker_s: marker::PhantomData<S>,
148    _marker_e: marker::PhantomData<E>,
149}
150
151impl<S, E, P> AsyncCommandExecutor<S, E, P>
152where
153    S: Send,
154    E: Error + Send,
155    P: ProcessStatus<S, E> + Send,
156{
157    /// Spawns `command` with piped stdout/stderr.
158    fn init(command: &mut tokio::process::Command, process_type: &P) -> Result<Child, E> {
159        command
160            .stdout(Stdio::piped())
161            .stderr(Stdio::piped())
162            .spawn()
163            .map_err(|_| process_type.error_type())
164    }
165
166    /// Builds a [`tokio::process::Command`] from `executable_path` and `args`.
167    fn generate_command<A, B>(executable_path: &OsStr, args: A) -> tokio::process::Command
168    where
169        A: IntoIterator<Item = B>,
170        B: AsRef<OsStr>,
171    {
172        let mut command = tokio::process::Command::new(executable_path);
173        command.args(args);
174        command
175    }
176
177    /// Reads lines from `data` and forwards them to `sender` until EOF or error.
178    async fn handle_output<R: AsyncRead + Unpin>(data: R, sender: Sender<LogOutputData>) {
179        let mut lines = BufReader::new(data).lines();
180        loop {
181            match lines.next_line().await {
182                Ok(Some(line)) => {
183                    let io_data = LogOutputData {
184                        line,
185                        log_type: LogType::Info,
186                    };
187                    if sender.send(io_data).await.is_err() {
188                        log::warn!("process output channel closed before stream ended");
189                        break;
190                    }
191                }
192                Ok(None) => break,
193                Err(e) => {
194                    log::error!("Error reading process output: {}", e);
195                    break;
196                }
197            }
198        }
199    }
200
201    /// Drains `receiver` and writes each line to the [`log`] crate.
202    async fn log_output(mut receiver: Receiver<LogOutputData>) {
203        while let Some(data) = receiver.recv().await {
204            match data.log_type {
205                LogType::Info => {
206                    log::info!("{}", data.line);
207                }
208                LogType::Error => {
209                    log::error!("{}", data.line);
210                }
211            }
212        }
213    }
214
215    /// Awaits the child process exit status.
216    async fn run_process(&mut self) -> Result<S, E> {
217        let exit_status = self
218            .process
219            .wait()
220            .await
221            .map_err(|e| self.process_type.wrap_error(e, None))?;
222        if exit_status.success() {
223            Ok(self.process_type.status_exit())
224        } else {
225            Err(self.process_type.error_type())
226        }
227    }
228
229    /// Waits for the process and drains its output in background tasks.
230    async fn command_execution(&mut self) -> Result<S, E> {
231        let (sender, receiver) = tokio::sync::mpsc::channel::<LogOutputData>(1000);
232        let res = self.run_process().await;
233        if let Some(stdout) = self.process.stdout.take() {
234            let tx = sender.clone();
235            drop(tokio::task::spawn(async move {
236                Self::handle_output(stdout, tx).await;
237            }));
238        }
239        if let Some(stderr) = self.process.stderr.take() {
240            let tx = sender.clone();
241            drop(tokio::task::spawn(async move {
242                Self::handle_output(stderr, tx).await;
243            }));
244        }
245        drop(sender);
246        drop(tokio::task::spawn(async {
247            Self::log_output(receiver).await;
248        }));
249        res
250    }
251}
252
253impl<S, E, P> AsyncCommand<S, E, P> for AsyncCommandExecutor<S, E, P>
254where
255    S: Send,
256    E: Error + Send,
257    P: ProcessStatus<S, E> + Send,
258{
259    fn new<A, B>(executable_path: &OsStr, args: A, process_type: P) -> Result<Self, E>
260    where
261        A: IntoIterator<Item = B>,
262        B: AsRef<OsStr>,
263    {
264        let mut _command = Self::generate_command(executable_path, args);
265        let process = Self::init(&mut _command, &process_type)?;
266        Ok(AsyncCommandExecutor {
267            _command,
268            process,
269            process_type,
270            _marker_s: Default::default(),
271            _marker_e: Default::default(),
272        })
273    }
274
275    async fn execute(&mut self, timeout: Option<Duration>) -> Result<S, E> {
276        match timeout {
277            None => self.command_execution().await,
278            Some(duration) => tokio::time::timeout(duration, self.command_execution())
279                .await
280                .map_err(|_| self.process_type.timeout_error())?,
281        }
282    }
283}