pg_embed/command_executor.rs
1//! Generic async process runner used by the pg_ctl and initdb wrappers.
2//!
3//! The core abstraction is [`AsyncCommand`], a trait with two methods:
4//! [`AsyncCommand::new`] spawns the OS process and [`AsyncCommand::execute`]
5//! waits for it to finish (with an optional timeout).
6//!
7//! [`ProcessStatus`] is a companion trait that maps a process type (initdb,
8//! start, stop) to the status values and errors it should produce.
9//!
10//! The only concrete implementation is [`AsyncCommandExecutor`].
11
12use std::error::Error;
13use std::ffi::OsStr;
14use std::marker;
15use std::process::Stdio;
16
17use log;
18use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
19use tokio::process::Child;
20use tokio::sync::mpsc::{Receiver, Sender};
21use tokio::time::Duration;
22
23/// Indicates whether a log line came from stdout or stderr.
24#[derive(Debug)]
25pub enum LogType {
26 /// Standard output line.
27 Info,
28 /// Standard error line.
29 Error,
30}
31
32/// Maps a process type to the status values and errors it should produce.
33///
34/// Implement this trait on an enum whose variants represent the distinct
35/// processes you want to run (e.g. `InitDb`, `StartDb`, `StopDb`). The
36/// executor calls these methods to produce typed status/error values without
37/// knowing the concrete error type.
38pub trait ProcessStatus<T, E>
39where
40 E: Error + Send,
41 Self: Send,
42{
43 /// Returns the status value that signals the process has *entered* execution.
44 fn status_entry(&self) -> T;
45
46 /// Returns the status value that signals the process *exited successfully*.
47 fn status_exit(&self) -> T;
48
49 /// Returns the error value for a generic process failure (non-zero exit,
50 /// spawn error, etc.).
51 fn error_type(&self) -> E;
52
53 /// Returns the error value to use when the process exceeds its timeout.
54 ///
55 /// Defaults to [`Self::error_type`]. Override to return a distinct
56 /// timeout-specific error variant.
57 fn timeout_error(&self) -> E {
58 self.error_type()
59 }
60
61 /// Wraps a foreign error `F` (e.g. an OS I/O error) into `E`, optionally
62 /// attaching a context `message`.
63 fn wrap_error<F: Error + Sync + Send + 'static>(&self, error: F, message: Option<String>) -> E;
64}
65
66/// A single log line captured from a child process stream.
67#[derive(Debug)]
68pub struct LogOutputData {
69 line: String,
70 log_type: LogType,
71}
72
73/// Trait for types that can spawn and execute an OS process asynchronously.
74///
75/// The type parameter `S` is the success-status type (e.g. [`crate::pg_enums::PgServerStatus`]),
76/// `E` is the error type, and `P` is the [`ProcessStatus`] implementation that
77/// provides status/error mappings.
78///
79/// `Self: Sized` means this trait cannot be used as a `dyn` trait object. Use
80/// the concrete [`AsyncCommandExecutor`] directly.
81#[allow(async_fn_in_trait)]
82pub trait AsyncCommand<S, E, P>
83where
84 E: Error + Send,
85 P: ProcessStatus<S, E> + Send,
86 Self: Sized,
87{
88 /// Creates and spawns a new OS process.
89 ///
90 /// # Arguments
91 ///
92 /// * `executable_path` — Path to the executable (e.g. `initdb`, `pg_ctl`).
93 /// * `args` — Command-line arguments to pass to the executable.
94 /// * `process_type` — The [`ProcessStatus`] value describing this process.
95 ///
96 /// # Errors
97 ///
98 /// Returns `E::error_type()` if the process cannot be spawned.
99 fn new<A, B>(executable_path: &OsStr, args: A, process_type: P) -> Result<Self, E>
100 where
101 A: IntoIterator<Item = B>,
102 B: AsRef<OsStr>;
103
104 /// Waits for the process to finish, optionally enforcing a deadline.
105 ///
106 /// Stdout and stderr are captured and forwarded to the [`log`] crate
107 /// (at `info` level) in background tasks.
108 ///
109 /// # Arguments
110 ///
111 /// * `timeout` — If `Some(duration)`, the process is killed and an error
112 /// is returned if it does not finish within `duration`. `None` waits
113 /// indefinitely.
114 ///
115 /// # Returns
116 ///
117 /// The [`ProcessStatus::status_exit`] value on success.
118 ///
119 /// # Errors
120 ///
121 /// Returns [`ProcessStatus::timeout_error`] if the deadline is exceeded.
122 /// Returns [`ProcessStatus::error_type`] if the process exits with a
123 /// non-zero status.
124 /// Returns a wrapped error from [`ProcessStatus::wrap_error`] if waiting
125 /// on the process fails.
126 async fn execute(&mut self, timeout: Option<Duration>) -> Result<S, E>;
127}
128
129/// Concrete implementation of [`AsyncCommand`] built on [`tokio::process`].
130///
131/// Created through [`AsyncCommand::new`]; the process is spawned immediately
132/// and stdout/stderr are piped. Call [`AsyncCommand::execute`] to wait for
133/// completion.
134pub struct AsyncCommandExecutor<S, E, P>
135where
136 S: Send,
137 E: Error + Send,
138 P: ProcessStatus<S, E>,
139 Self: Send,
140{
141 /// The Tokio command handle (kept alive so the process is not killed on drop).
142 _command: tokio::process::Command,
143 /// The spawned child process.
144 process: Child,
145 /// Determines status/error values for this specific process type.
146 process_type: P,
147 _marker_s: marker::PhantomData<S>,
148 _marker_e: marker::PhantomData<E>,
149}
150
151impl<S, E, P> AsyncCommandExecutor<S, E, P>
152where
153 S: Send,
154 E: Error + Send,
155 P: ProcessStatus<S, E> + Send,
156{
157 /// Spawns `command` with piped stdout/stderr.
158 fn init(command: &mut tokio::process::Command, process_type: &P) -> Result<Child, E> {
159 command
160 .stdout(Stdio::piped())
161 .stderr(Stdio::piped())
162 .spawn()
163 .map_err(|_| process_type.error_type())
164 }
165
166 /// Builds a [`tokio::process::Command`] from `executable_path` and `args`.
167 fn generate_command<A, B>(executable_path: &OsStr, args: A) -> tokio::process::Command
168 where
169 A: IntoIterator<Item = B>,
170 B: AsRef<OsStr>,
171 {
172 let mut command = tokio::process::Command::new(executable_path);
173 command.args(args);
174 command
175 }
176
177 /// Reads lines from `data` and forwards them to `sender` until EOF or error.
178 async fn handle_output<R: AsyncRead + Unpin>(data: R, sender: Sender<LogOutputData>) {
179 let mut lines = BufReader::new(data).lines();
180 loop {
181 match lines.next_line().await {
182 Ok(Some(line)) => {
183 let io_data = LogOutputData {
184 line,
185 log_type: LogType::Info,
186 };
187 if sender.send(io_data).await.is_err() {
188 log::warn!("process output channel closed before stream ended");
189 break;
190 }
191 }
192 Ok(None) => break,
193 Err(e) => {
194 log::error!("Error reading process output: {}", e);
195 break;
196 }
197 }
198 }
199 }
200
201 /// Drains `receiver` and writes each line to the [`log`] crate.
202 async fn log_output(mut receiver: Receiver<LogOutputData>) {
203 while let Some(data) = receiver.recv().await {
204 match data.log_type {
205 LogType::Info => {
206 log::info!("{}", data.line);
207 }
208 LogType::Error => {
209 log::error!("{}", data.line);
210 }
211 }
212 }
213 }
214
215 /// Awaits the child process exit status.
216 async fn run_process(&mut self) -> Result<S, E> {
217 let exit_status = self
218 .process
219 .wait()
220 .await
221 .map_err(|e| self.process_type.wrap_error(e, None))?;
222 if exit_status.success() {
223 Ok(self.process_type.status_exit())
224 } else {
225 Err(self.process_type.error_type())
226 }
227 }
228
229 /// Waits for the process and drains its output in background tasks.
230 async fn command_execution(&mut self) -> Result<S, E> {
231 let (sender, receiver) = tokio::sync::mpsc::channel::<LogOutputData>(1000);
232 let res = self.run_process().await;
233 if let Some(stdout) = self.process.stdout.take() {
234 let tx = sender.clone();
235 drop(tokio::task::spawn(async move {
236 Self::handle_output(stdout, tx).await;
237 }));
238 }
239 if let Some(stderr) = self.process.stderr.take() {
240 let tx = sender.clone();
241 drop(tokio::task::spawn(async move {
242 Self::handle_output(stderr, tx).await;
243 }));
244 }
245 drop(sender);
246 drop(tokio::task::spawn(async {
247 Self::log_output(receiver).await;
248 }));
249 res
250 }
251}
252
253impl<S, E, P> AsyncCommand<S, E, P> for AsyncCommandExecutor<S, E, P>
254where
255 S: Send,
256 E: Error + Send,
257 P: ProcessStatus<S, E> + Send,
258{
259 fn new<A, B>(executable_path: &OsStr, args: A, process_type: P) -> Result<Self, E>
260 where
261 A: IntoIterator<Item = B>,
262 B: AsRef<OsStr>,
263 {
264 let mut _command = Self::generate_command(executable_path, args);
265 let process = Self::init(&mut _command, &process_type)?;
266 Ok(AsyncCommandExecutor {
267 _command,
268 process,
269 process_type,
270 _marker_s: Default::default(),
271 _marker_e: Default::default(),
272 })
273 }
274
275 async fn execute(&mut self, timeout: Option<Duration>) -> Result<S, E> {
276 match timeout {
277 None => self.command_execution().await,
278 Some(duration) => tokio::time::timeout(duration, self.command_execution())
279 .await
280 .map_err(|_| self.process_type.timeout_error())?,
281 }
282 }
283}