Skip to main content

pg_embed/
postgres.rs

1//! Public API for embedding and managing a PostgreSQL server.
2//!
3//! The entry point is [`PgEmbed`].  A typical usage sequence is:
4//!
5//! ```rust,no_run
6//! use pg_embed::postgres::{PgEmbed, PgSettings};
7//! use pg_embed::pg_fetch::{PgFetchSettings, PG_V17};
8//! use pg_embed::pg_enums::PgAuthMethod;
9//! use std::path::PathBuf;
10//! use std::time::Duration;
11//!
12//! # #[tokio::main]
13//! # async fn main() -> pg_embed::pg_errors::Result<()> {
14//! let pg_settings = PgSettings {
15//!     database_dir: PathBuf::from("data/db"),
16//!     port: 5432,
17//!     user: "postgres".to_string(),
18//!     password: "password".to_string(),
19//!     auth_method: PgAuthMethod::Plain,
20//!     persistent: false,
21//!     timeout: Some(Duration::from_secs(15)),
22//!     migration_dir: None,
23//! };
24//!
25//! let fetch_settings = PgFetchSettings { version: PG_V17, ..Default::default() };
26//!
27//! let mut pg = PgEmbed::new(pg_settings, fetch_settings).await?;
28//! pg.setup().await?;
29//! pg.start_db().await?;
30//!
31//! let uri = pg.full_db_uri("mydb");   // postgres://postgres:password@localhost:5432/mydb
32//!
33//! pg.stop_db().await?;
34//! # Ok(())
35//! # }
36//! ```
37
38use std::io::BufRead;
39use std::path::{Path, PathBuf};
40use std::process::Stdio;
41use std::sync::Arc;
42use std::time::Duration;
43
44use log::{error, info};
45use tokio::sync::Mutex;
46
47#[cfg(feature = "rt_tokio_migrate")]
48use sqlx::migrate::{MigrateDatabase, Migrator};
49#[cfg(feature = "rt_tokio_migrate")]
50use sqlx::postgres::PgPoolOptions;
51#[cfg(feature = "rt_tokio_migrate")]
52use sqlx::Postgres;
53
54use crate::command_executor::AsyncCommand;
55use crate::pg_access::PgAccess;
56use crate::pg_commands::PgCommand;
57use crate::pg_enums::{PgAuthMethod, PgServerStatus};
58use crate::pg_errors::Error;
59use crate::pg_errors::Result;
60use crate::pg_fetch;
61
62/// Configuration for a single embedded PostgreSQL instance.
63pub struct PgSettings {
64    /// Directory that will hold the PostgreSQL cluster data files.
65    ///
66    /// Created automatically if it does not exist.  When [`Self::persistent`]
67    /// is `false` this directory (and [`Self::database_dir`] with a `.pwfile`
68    /// extension) is removed when [`PgEmbed`] is dropped.
69    pub database_dir: PathBuf,
70
71    /// TCP port PostgreSQL will listen on.
72    pub port: u16,
73
74    /// Name of the initial database superuser.
75    pub user: String,
76
77    /// Password for the superuser, written to a temporary password file and
78    /// passed to `initdb` via `--pwfile`.
79    pub password: String,
80
81    /// Authentication method written to `pg_hba.conf` by `initdb`.
82    pub auth_method: PgAuthMethod,
83
84    /// If `false`, the cluster directory and password file are deleted when
85    /// the [`PgEmbed`] instance is dropped.  Set to `true` to keep the data
86    /// across runs.
87    pub persistent: bool,
88
89    /// Maximum time to wait for `initdb`, `pg_ctl start`, and `pg_ctl stop`
90    /// to complete.
91    ///
92    /// `None` disables the timeout (the process is waited on indefinitely).
93    /// Exceeding the timeout returns [`Error::PgTimedOutError`].
94    pub timeout: Option<Duration>,
95
96    /// Directory containing `.sql` migration files.
97    ///
98    /// When `Some`, [`PgEmbed::migrate`] will run all migrations found in
99    /// this directory via sqlx.  `None` disables migrations.
100    /// Requires the `rt_tokio_migrate` feature.
101    pub migration_dir: Option<PathBuf>,
102}
103
104/// An embedded PostgreSQL server with full lifecycle management.
105///
106/// Dropping a [`PgEmbed`] instance that has not been explicitly stopped will
107/// automatically call `pg_ctl stop` synchronously and, if
108/// [`PgSettings::persistent`] is `false`, remove the cluster directory and
109/// password file.
110pub struct PgEmbed {
111    /// Active configuration for this instance.
112    pub pg_settings: PgSettings,
113    /// Binary download settings used during [`Self::setup`].
114    pub fetch_settings: pg_fetch::PgFetchSettings,
115    /// Base connection URI: `postgres://{user}:{password}@localhost:{port}`.
116    pub db_uri: String,
117    /// Current server lifecycle state, protected by an async mutex so it can
118    /// be observed from concurrent tasks.
119    pub server_status: Arc<Mutex<PgServerStatus>>,
120    /// Set to `true` once a graceful stop has been initiated to prevent the
121    /// `Drop` impl from issuing a duplicate stop.
122    pub shutting_down: bool,
123    /// File-system paths and I/O helpers for this instance.
124    pub pg_access: PgAccess,
125}
126
127impl Drop for PgEmbed {
128    fn drop(&mut self) {
129        if !self.shutting_down {
130            if let Err(e) = self.stop_db_sync() {
131                log::warn!("pg_ctl stop failed during drop: {e}");
132            }
133        }
134        if !self.pg_settings.persistent {
135            if let Err(e) = self.pg_access.clean() {
136                log::warn!("cleanup failed during drop: {e}");
137            }
138        }
139    }
140}
141
142impl PgEmbed {
143    /// Creates a new [`PgEmbed`] instance and prepares the directory structure.
144    ///
145    /// Does **not** download binaries or start the server.  Call
146    /// [`Self::setup`] followed by [`Self::start_db`] to bring the server up.
147    ///
148    /// # Arguments
149    ///
150    /// * `pg_settings` — Server configuration (port, auth, directories, …).
151    /// * `fetch_settings` — Which PostgreSQL version/platform to download.
152    ///
153    /// # Errors
154    ///
155    /// Returns [`Error::DirCreationError`] if the cache or database directories
156    /// cannot be created.
157    /// Returns [`Error::InvalidPgUrl`] if the OS cache directory is unavailable.
158    pub async fn new(
159        pg_settings: PgSettings,
160        fetch_settings: pg_fetch::PgFetchSettings,
161    ) -> Result<Self> {
162        let db_uri = format!(
163            "postgres://{}:{}@localhost:{}",
164            &pg_settings.user, &pg_settings.password, pg_settings.port
165        );
166        let pg_access = PgAccess::new(&fetch_settings, &pg_settings.database_dir).await?;
167        Ok(PgEmbed {
168            pg_settings,
169            fetch_settings,
170            db_uri,
171            server_status: Arc::new(Mutex::new(PgServerStatus::Uninitialized)),
172            shutting_down: false,
173            pg_access,
174        })
175    }
176
177    /// Downloads the binaries (if needed), writes the password file, and runs
178    /// `initdb` (if the cluster does not already exist).
179    ///
180    /// This method is idempotent: if the binaries are already cached and the
181    /// cluster is already initialised it returns immediately after verifying
182    /// both.
183    ///
184    /// # Errors
185    ///
186    /// Returns any error from [`PgAccess::maybe_acquire_postgres`],
187    /// [`PgAccess::create_password_file`], or [`Self::init_db`].
188    pub async fn setup(&mut self) -> Result<()> {
189        self.pg_access.maybe_acquire_postgres().await?;
190        self.pg_access
191            .create_password_file(self.pg_settings.password.as_bytes())
192            .await?;
193        if self.pg_access.db_files_exist().await? {
194            let mut server_status = self.server_status.lock().await;
195            *server_status = PgServerStatus::Initialized;
196        } else {
197            self.init_db().await?;
198        }
199        Ok(())
200    }
201
202    /// Installs a third-party PostgreSQL extension into the binary cache.
203    ///
204    /// Must be called **after** [`Self::setup`] (so the cache directory exists)
205    /// and **before** [`Self::start_db`] (so the server loads the shared
206    /// library on startup).  Once the server is running, activate the extension
207    /// in a specific database with:
208    ///
209    /// ```sql
210    /// CREATE EXTENSION IF NOT EXISTS <extension_name>;
211    /// ```
212    ///
213    /// Delegates to [`PgAccess::install_extension`].  See that method for the
214    /// file-routing rules (`.so`/`.dylib`/`.dll` → `lib/`;
215    /// `.control`/`.sql` → the PostgreSQL share extension directory).
216    ///
217    /// # Arguments
218    ///
219    /// * `extension_dir` — Directory containing the pre-compiled extension
220    ///   files (shared library + control + SQL scripts).
221    ///
222    /// # Errors
223    ///
224    /// Returns [`Error::DirCreationError`] if the target directories cannot be
225    /// created.
226    /// Returns [`Error::ReadFileError`] if `extension_dir` cannot be read.
227    /// Returns [`Error::WriteFileError`] if a file cannot be copied.
228    pub async fn install_extension(&self, extension_dir: &Path) -> Result<()> {
229        self.pg_access.install_extension(extension_dir).await
230    }
231
232    /// Runs `initdb` to create a new database cluster.
233    ///
234    /// Updates [`Self::server_status`] to [`PgServerStatus::Initializing`]
235    /// before the call and to [`PgServerStatus::Initialized`] on success.
236    ///
237    /// # Errors
238    ///
239    /// Returns [`Error::InvalidPgUrl`] if any path cannot be converted to UTF-8.
240    /// Returns [`Error::PgInitFailure`] if `initdb` cannot be spawned.
241    /// Returns [`Error::PgTimedOutError`] if the process exceeds
242    /// [`PgSettings::timeout`].
243    pub async fn init_db(&mut self) -> Result<()> {
244        {
245            let mut server_status = self.server_status.lock().await;
246            *server_status = PgServerStatus::Initializing;
247        }
248
249        let mut executor = PgCommand::init_db_executor(
250            &self.pg_access.init_db_exe,
251            &self.pg_access.database_dir,
252            &self.pg_access.pw_file_path,
253            &self.pg_settings.user,
254            &self.pg_settings.auth_method,
255        )?;
256        let exit_status = executor.execute(self.pg_settings.timeout).await?;
257        let mut server_status = self.server_status.lock().await;
258        *server_status = exit_status;
259        Ok(())
260    }
261
262    /// Starts the PostgreSQL server with `pg_ctl start -w`.
263    ///
264    /// Updates [`Self::server_status`] to [`PgServerStatus::Starting`] before
265    /// the call and to [`PgServerStatus::Started`] on success.
266    ///
267    /// # Errors
268    ///
269    /// Returns [`Error::InvalidPgUrl`] if the cluster path cannot be converted
270    /// to UTF-8.
271    /// Returns [`Error::PgStartFailure`] if the process exits with a non-zero
272    /// status or cannot be spawned.
273    /// Returns [`Error::PgTimedOutError`] if the process exceeds
274    /// [`PgSettings::timeout`].
275    pub async fn start_db(&mut self) -> Result<()> {
276        {
277            let mut server_status = self.server_status.lock().await;
278            *server_status = PgServerStatus::Starting;
279        }
280        self.shutting_down = false;
281        let mut executor = PgCommand::start_db_executor(
282            &self.pg_access.pg_ctl_exe,
283            &self.pg_access.database_dir,
284            &self.pg_settings.port,
285        )?;
286        let exit_status = executor.execute(self.pg_settings.timeout).await?;
287        let mut server_status = self.server_status.lock().await;
288        *server_status = exit_status;
289        Ok(())
290    }
291
292    /// Stops the PostgreSQL server with `pg_ctl stop -w`.
293    ///
294    /// Updates [`Self::server_status`] to [`PgServerStatus::Stopping`] before
295    /// the call and to [`PgServerStatus::Stopped`] on success.  Sets
296    /// [`Self::shutting_down`] to `true` so the `Drop` impl does not issue a
297    /// duplicate stop.
298    ///
299    /// # Errors
300    ///
301    /// Returns [`Error::InvalidPgUrl`] if the cluster path cannot be converted
302    /// to UTF-8.
303    /// Returns [`Error::PgStopFailure`] if `pg_ctl stop` fails.
304    /// Returns [`Error::PgTimedOutError`] if the process exceeds
305    /// [`PgSettings::timeout`].
306    pub async fn stop_db(&mut self) -> Result<()> {
307        {
308            let mut server_status = self.server_status.lock().await;
309            *server_status = PgServerStatus::Stopping;
310        }
311        self.shutting_down = true;
312        let mut executor =
313            PgCommand::stop_db_executor(&self.pg_access.pg_ctl_exe, &self.pg_access.database_dir)?;
314        let exit_status = executor.execute(self.pg_settings.timeout).await?;
315        let mut server_status = self.server_status.lock().await;
316        *server_status = exit_status;
317        Ok(())
318    }
319
320    /// Stops the PostgreSQL server synchronously.
321    ///
322    /// Used by the `Drop` impl where async is unavailable.  Stdout and stderr
323    /// of the `pg_ctl stop` process are forwarded to the [`log`] crate.
324    ///
325    /// # Errors
326    ///
327    /// Returns [`Error::PgError`] if the process cannot be spawned.
328    pub fn stop_db_sync(&mut self) -> Result<()> {
329        self.shutting_down = true;
330        let mut stop_db_command = self
331            .pg_access
332            .stop_db_command_sync(&self.pg_settings.database_dir);
333        let process = stop_db_command
334            .get_mut()
335            .stdout(Stdio::piped())
336            .stderr(Stdio::piped())
337            .spawn()
338            .map_err(|e| Error::PgError(e.to_string(), "".to_string()))?;
339
340        self.handle_process_io_sync(process)
341    }
342
343    /// Drains stdout and stderr of `process`, logging each line.
344    ///
345    /// Lines from stdout are logged at `info` level; lines from stderr at
346    /// `error` level.  Read errors are silently ignored (the line is skipped).
347    ///
348    /// # Arguments
349    ///
350    /// * `process` — A child process with piped stdout/stderr.
351    pub fn handle_process_io_sync(&self, mut process: std::process::Child) -> Result<()> {
352        if let Some(stdout) = process.stdout.take() {
353            std::io::BufReader::new(stdout)
354                .lines()
355                .for_each(|line| {
356                    if let Ok(l) = line {
357                        info!("{}", l);
358                    }
359                });
360        }
361        if let Some(stderr) = process.stderr.take() {
362            std::io::BufReader::new(stderr)
363                .lines()
364                .for_each(|line| {
365                    if let Ok(l) = line {
366                        error!("{}", l);
367                    }
368                });
369        }
370        Ok(())
371    }
372
373    /// Creates a new PostgreSQL database.
374    ///
375    /// Requires the `rt_tokio_migrate` feature.
376    ///
377    /// # Arguments
378    ///
379    /// * `db_name` — Name of the database to create.
380    ///
381    /// # Errors
382    ///
383    /// Returns [`Error::PgTaskJoinError`] if the sqlx operation fails.
384    #[cfg(feature = "rt_tokio_migrate")]
385    pub async fn create_database(&self, db_name: &str) -> Result<()> {
386        Postgres::create_database(&self.full_db_uri(db_name))
387            .await
388            .map_err(|e| Error::PgTaskJoinError(e.to_string()))?;
389        Ok(())
390    }
391
392    /// Drops a PostgreSQL database if it exists.
393    ///
394    /// Uses `DROP DATABASE IF EXISTS` semantics: if the database does not
395    /// exist the call succeeds silently.
396    /// Requires the `rt_tokio_migrate` feature.
397    ///
398    /// # Arguments
399    ///
400    /// * `db_name` — Name of the database to drop.
401    ///
402    /// # Errors
403    ///
404    /// Returns [`Error::PgTaskJoinError`] if the sqlx operation fails.
405    #[cfg(feature = "rt_tokio_migrate")]
406    pub async fn drop_database(&self, db_name: &str) -> Result<()> {
407        Postgres::drop_database(&self.full_db_uri(db_name))
408            .await
409            .map_err(|e| Error::PgTaskJoinError(e.to_string()))?;
410        Ok(())
411    }
412
413    /// Returns `true` if a database named `db_name` exists.
414    ///
415    /// Requires the `rt_tokio_migrate` feature.
416    ///
417    /// # Arguments
418    ///
419    /// * `db_name` — Name of the database to check.
420    ///
421    /// # Errors
422    ///
423    /// Returns [`Error::PgTaskJoinError`] if the sqlx operation fails.
424    #[cfg(feature = "rt_tokio_migrate")]
425    pub async fn database_exists(&self, db_name: &str) -> Result<bool> {
426        Postgres::database_exists(&self.full_db_uri(db_name))
427            .await
428            .map_err(|e| Error::PgTaskJoinError(e.to_string()))
429    }
430
431    /// Returns the full connection URI for a specific database.
432    ///
433    /// Format: `postgres://{user}:{password}@localhost:{port}/{db_name}`.
434    ///
435    /// # Arguments
436    ///
437    /// * `db_name` — Database name to append to the base URI.
438    pub fn full_db_uri(&self, db_name: &str) -> String {
439        format!("{}/{}", &self.db_uri, db_name)
440    }
441
442    /// Runs sqlx migrations from [`PgSettings::migration_dir`] against `db_name`.
443    ///
444    /// Does nothing if [`PgSettings::migration_dir`] is `None`.
445    /// Requires the `rt_tokio_migrate` feature.
446    ///
447    /// # Arguments
448    ///
449    /// * `db_name` — Name of the target database.
450    ///
451    /// # Errors
452    ///
453    /// Returns [`Error::MigrationError`] if the migrator cannot be created or
454    /// if a migration fails.
455    /// Returns [`Error::SqlQueryError`] if the database connection fails.
456    #[cfg(feature = "rt_tokio_migrate")]
457    pub async fn migrate(&self, db_name: &str) -> Result<()> {
458        if let Some(migration_dir) = &self.pg_settings.migration_dir {
459            let m = Migrator::new(migration_dir.as_path())
460                .await
461                .map_err(|e| Error::MigrationError(e.to_string()))?;
462            let pool = PgPoolOptions::new()
463                .connect(&self.full_db_uri(db_name))
464                .await
465                .map_err(|e| Error::SqlQueryError(e.to_string()))?;
466            m.run(&pool)
467                .await
468                .map_err(|e| Error::MigrationError(e.to_string()))?;
469        }
470        Ok(())
471    }
472}