pgtemp/
lib.rs

1#![warn(missing_docs)] // denied in CI
2
3//! pgtemp is a Rust library and cli tool that allows you to easily create temporary PostgreSQL servers for testing without using Docker.
4//!
5//! The pgtemp Rust library allows you to spawn a PostgreSQL server in a temporary directory and get back a full connection URI with the host, port, username, and password.
6//!
7//! The pgtemp cli tool allows you to even more simply make temporary connections, and works with any language: Run pgtemp and then use its connection URI when connecting to the database in your tests. **pgtemp will then spawn a new postgresql process for each connection it receives** and transparently proxy everything over that connection to the temporary database. Note that this means when you make multiple connections in a single test, changes made in one connection will not be visible in the other connections, unless you are using pgtemp's `--single` mode.
8
9use std::collections::HashMap;
10use std::fmt;
11use std::fmt::Debug;
12use std::path::{Path, PathBuf};
13use std::process::Child;
14
15use tempfile::TempDir;
16use tokio::task::spawn_blocking;
17
18mod daemon;
19mod run_db;
20
21pub use daemon::*;
22
23// temp db handle - actual db spawning code is in run_db mod
24
25/// A struct representing a handle to a local PostgreSQL server that is currently running. Upon
26/// drop or calling `shutdown`, the server is shut down and the directory its data is stored in
27/// is deleted. See builder struct [`PgTempDBBuilder`] for options and settings.
28pub struct PgTempDB {
29    dbuser: String,
30    dbpass: String,
31    dbport: u16,
32    dbname: String,
33    /// persist the db data directory after shutdown
34    persist: bool,
35    /// dump the databaset to a script file after shutdown
36    dump_path: Option<PathBuf>,
37    // See shutdown implementation for why these are options
38    temp_dir: Option<TempDir>,
39    postgres_process: Option<Child>,
40}
41
42impl PgTempDB {
43    /// Start a PgTempDB with the parameters configured from a PgTempDBBuilder
44    pub fn from_builder(mut builder: PgTempDBBuilder) -> PgTempDB {
45        let dbuser = builder.get_user();
46        let dbpass = builder.get_password();
47        let dbport = builder.get_port_or_set_random();
48        let dbname = builder.get_dbname();
49        let persist = builder.persist_data_dir;
50        let dump_path = builder.dump_path.clone();
51        let load_path = builder.load_path.clone();
52
53        let temp_dir = run_db::init_db(&mut builder);
54        let postgres_process = Some(run_db::run_db(&temp_dir, builder));
55        let temp_dir = Some(temp_dir);
56
57        let db = PgTempDB {
58            dbuser,
59            dbpass,
60            dbport,
61            dbname,
62            persist,
63            dump_path,
64            temp_dir,
65            postgres_process,
66        };
67
68        if let Some(path) = load_path {
69            db.load_database(path);
70        }
71        db
72    }
73
74    /// Creates a builder that can be used to configure the details of the temporary PostgreSQL
75    /// server
76    pub fn builder() -> PgTempDBBuilder {
77        PgTempDBBuilder::new()
78    }
79
80    /// Creates a new PgTempDB with default configuration and starts a PostgreSQL server.
81    pub fn new() -> PgTempDB {
82        PgTempDBBuilder::new().start()
83    }
84
85    /// Creates a new PgTempDB with default configuration and starts a PostgreSQL server in an
86    /// async context.
87    pub async fn async_new() -> PgTempDB {
88        PgTempDBBuilder::new().start_async().await
89    }
90
91    /// Use [pg_dump](https://www.postgresql.org/docs/current/backup-dump.html) to dump the
92    /// database to the provided path upon drop or [`Self::shutdown`].
93    pub fn dump_database(&self, path: impl AsRef<Path>) {
94        let path_str = path.as_ref().to_str().unwrap();
95
96        let dump_output = std::process::Command::new("pg_dump")
97            .arg(self.connection_uri())
98            .args(["--file", path_str])
99            .output()
100            .expect("failed to start pg_dump. Is it installed and on your path?");
101
102        if !dump_output.status.success() {
103            let stdout = dump_output.stdout;
104            let stderr = dump_output.stderr;
105            panic!(
106                "pg_dump failed! stdout: {}\n\nstderr: {}",
107                String::from_utf8_lossy(&stdout),
108                String::from_utf8_lossy(&stderr)
109            );
110        }
111    }
112
113    /// Use `psql` to load the database from the provided dump file. See [`Self::dump_database`].
114    pub fn load_database(&self, path: impl AsRef<Path>) {
115        let path_str = path.as_ref().to_str().unwrap();
116
117        let load_output = std::process::Command::new("psql")
118            .arg(self.connection_uri())
119            .args([
120                "--file",
121                path_str,
122                "--set",
123                "ON_ERROR_STOP=1",
124                "--single-transaction",
125            ])
126            .output()
127            .expect("failed to start psql. Is it installed and on your path?");
128
129        if !load_output.status.success() {
130            let stdout = load_output.stdout;
131            let stderr = load_output.stderr;
132            panic!(
133                "psql failed! stdout: {}\n\nstderr: {}",
134                String::from_utf8_lossy(&stdout),
135                String::from_utf8_lossy(&stderr)
136            );
137        }
138    }
139
140    /// Send a signal to the database to shutdown the server, then wait for the process to exit.
141    /// Equivalent to calling drop on this struct.
142    ///
143    /// We send SIGINT to the postgres process to initiate a fast shutdown
144    /// (<https://www.postgresql.org/docs/current/server-shutdown.html>), which causes all transactions to be aborted and
145    /// connections to be terminated.
146    ///
147    /// NOTE: This is currently a blocking function. It sends SIGINT to the postgres server, waits
148    /// for the process to exit, and also does IO to remove the temp directory.
149    ///
150    pub fn shutdown(self) {
151        drop(self);
152    }
153
154    /// See description of [`shutdown`]
155    fn shutdown_internal(&mut self) {
156        // if no process (e.g. due to calling `force_shutdown`), just skip the cleanup operations.
157        if self.postgres_process.is_none() {
158            return;
159        }
160
161        // do the dump while the postgres process is still running
162        if let Some(path) = &self.dump_path {
163            self.dump_database(path);
164        }
165
166        let postgres_process = self
167            .postgres_process
168            .take()
169            .expect("shutdown with no postgres process");
170        let temp_dir = self.temp_dir.take().unwrap();
171
172        // fast (not graceful) shutdown via SIGINT
173        // TODO: graceful shutdown via SIGTERM
174        // was having issues with using graceful shutdown by default and some tests/examples using
175        // pg connection pools - likely what was happening was that at the end of the test we hit
176        // drop for the connection pool, it tries to drop asynchronously (e.g. it probably sends a
177        // async signal), then we block indefinitely on the main thread in PgTempDB::shutdown
178        // waiting for the server to shut down and the pooler never gets a chance to shut down, so
179        // the postgres server says "we're still connected to a client, can't shut down yet" and we
180        // have a deadlock.
181        #[allow(clippy::cast_possible_wrap)]
182        let _ret = unsafe { libc::kill(postgres_process.id() as i32, libc::SIGINT) };
183        let _output = postgres_process
184            .wait_with_output()
185            .expect("postgres server failed to exit cleanly");
186
187        if self.persist {
188            // this prevents the dir from being deleted on drop
189            let _path = temp_dir.into_path();
190        } else {
191            // if we just used the default drop impl, errors would not be surfaced
192            temp_dir.close().expect("failed to clean up temp directory");
193        }
194    }
195
196    /// Returns the path to the data directory being used by this databaset.
197    pub fn data_dir(&self) -> PathBuf {
198        self.temp_dir.as_ref().unwrap().path().join("pg_data_dir")
199    }
200
201    /// Returns the database username used when connecting to the postgres server.
202    pub fn db_user(&self) -> &str {
203        &self.dbuser
204    }
205
206    /// Returns the database password used when connecting to the postgres server.
207    pub fn db_pass(&self) -> &str {
208        &self.dbpass
209    }
210
211    /// Returns the port the postgres server is running on.
212    pub fn db_port(&self) -> u16 {
213        self.dbport
214    }
215
216    /// Returns the the name of the database created.
217    pub fn db_name(&self) -> &str {
218        &self.dbname
219    }
220
221    /// Returns a connection string that can be passed to a libpq connection function.
222    ///
223    /// Example output:
224    /// `host=localhost port=15432 user=pgtemp password=pgtemppw-9485 dbname=pgtempdb-324`
225    pub fn connection_string(&self) -> String {
226        format!(
227            "host=localhost port={} user={} password={} dbname={}",
228            self.db_port(),
229            self.db_user(),
230            self.db_pass(),
231            self.db_name()
232        )
233    }
234
235    /// Returns a generic connection URI that can be passed to most SQL libraries' connect
236    /// methods.
237    ///
238    /// Example output:
239    /// `postgresql://pgtemp:pgtemppw-9485@localhost:15432/pgtempdb-324`
240    pub fn connection_uri(&self) -> String {
241        format!(
242            "postgresql://{}:{}@localhost:{}/{}",
243            self.db_user(),
244            self.db_pass(),
245            self.db_port(),
246            self.db_name()
247        )
248    }
249}
250
251impl Debug for PgTempDB {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        f.debug_struct("PgTempDB")
254            .field("base directory", self.temp_dir.as_ref().unwrap())
255            .field("connection string", &self.connection_string())
256            .field("persist data dir", &self.persist)
257            .field("dump path", &self.dump_path)
258            .field(
259                "db process",
260                &self.postgres_process.as_ref().map(Child::id).unwrap(),
261            )
262            .finish_non_exhaustive()
263    }
264}
265
266impl Drop for PgTempDB {
267    fn drop(&mut self) {
268        self.shutdown_internal();
269    }
270}
271
272// db config builder functions
273
274/// Builder struct for PgTempDB.
275#[derive(Default, Debug, Clone)]
276pub struct PgTempDBBuilder {
277    /// The directory in which to store the temporary PostgreSQL data directory.
278    pub temp_dir_prefix: Option<PathBuf>,
279    /// The cluster superuser created with `initdb`. Default: `postgres`
280    pub db_user: Option<String>,
281    /// The password for the cluster superuser. Default: `password`
282    pub password: Option<String>,
283    /// The port the server should run on. Default: random unused port.
284    pub port: Option<u16>,
285    /// The name of the database to create on startup. Default: `postgres`.
286    pub dbname: Option<String>,
287    /// Do not delete the data dir when the `PgTempDB` is dropped.
288    pub persist_data_dir: bool,
289    /// The path to dump the database to (via `pg_dump`) when the `PgTempDB` is dropped.
290    pub dump_path: Option<PathBuf>,
291    /// The path to load the database from (via `psql`) when the `PgTempDB` is started.
292    pub load_path: Option<PathBuf>,
293    /// Other server configuration data to be set in `postgresql.conf` via `initdb -c`
294    pub server_configs: HashMap<String, String>,
295    /// Direct arguments to pass to the `initdb` binary (e.g. --encoding=UTF8), distinct from postgres configs (-c)
296    pub initdb_args: HashMap<String, String>,
297    /// Prefix PostgreSQL binary names (`initdb`, `createdb`, and `postgres`) with this path, instead of searching $PATH
298    pub bin_path: Option<PathBuf>,
299}
300
301impl PgTempDBBuilder {
302    /// Create a new [`PgTempDBBuilder`]
303    pub fn new() -> PgTempDBBuilder {
304        PgTempDBBuilder::default()
305    }
306
307    /// Parses the parameters out of a PostgreSQL connection URI and inserts them into the builder.
308    #[must_use]
309    pub fn from_connection_uri(conn_uri: &str) -> Self {
310        let mut builder = PgTempDBBuilder::new();
311
312        let url = url::Url::parse(conn_uri)
313            .expect(&format!("Could not parse connection URI `{}`", conn_uri));
314
315        // TODO: error types
316        assert!(
317            url.scheme() == "postgresql",
318            "connection URI must start with `postgresql://` scheme: `{}`",
319            conn_uri
320        );
321        assert!(
322            url.host_str() == Some("localhost"),
323            "connection URI's host is not localhost: `{}`",
324            conn_uri,
325        );
326
327        let username = url.username();
328        let password = url.password();
329        let port = url.port();
330        let dbname = url.path().strip_prefix('/').unwrap_or("");
331
332        if !username.is_empty() {
333            builder = builder.with_username(username);
334        }
335        if let Some(password) = password {
336            builder = builder.with_password(password);
337        }
338        if let Some(port) = port {
339            builder = builder.with_port(port);
340        }
341        if !dbname.is_empty() {
342            builder = builder.with_dbname(dbname);
343        }
344
345        builder
346    }
347
348    // TODO: make an error type and `try_start` methods (and maybe similar for above shutdown etc
349    // functions)
350
351    /// Creates the temporary data directory and starts the PostgreSQL server with the configured
352    /// parameters.
353    ///
354    /// If the current user is root, will attempt to run the `initdb` and `postgres` commands as
355    /// the `postgres` user.
356    pub fn start(self) -> PgTempDB {
357        PgTempDB::from_builder(self)
358    }
359
360    /// Convenience function for calling `spawn_blocking(self.start())`
361    pub async fn start_async(self) -> PgTempDB {
362        spawn_blocking(move || self.start())
363            .await
364            .expect("failed to start pgtemp server")
365    }
366
367    /// Set the directory in which to put the (temporary) PostgreSQL data directory. This is not
368    /// the data directory itself: a new temporary directory is created inside this one.
369    #[must_use]
370    pub fn with_data_dir_prefix(mut self, prefix: impl AsRef<Path>) -> Self {
371        self.temp_dir_prefix = Some(PathBuf::from(prefix.as_ref()));
372        self
373    }
374
375    /// Set an arbitrary PostgreSQL server configuration parameter that will passed to the
376    /// postgresql process at runtime.
377    #[must_use]
378    pub fn with_config_param(mut self, key: &str, value: &str) -> Self {
379        let _old = self.server_configs.insert(key.into(), value.into());
380        self
381    }
382
383    /// Set an arbitrary argument that will be passed directly to the initdb binary during database
384    /// initialization. These are direct arguments like --encoding or --locale, not configuration
385    /// parameters that get written to postgresql.conf (use with_config_param for those).
386    #[must_use]
387    pub fn with_initdb_arg(mut self, key: &str, value: &str) -> Self {
388        let _old = self.initdb_args.insert(key.into(), value.into());
389        self
390    }
391
392    /// Set the directory that contains binaries like `initdb`, `createdb`, and `postgres`.
393    #[must_use]
394    pub fn with_bin_path(mut self, path: impl AsRef<Path>) -> Self {
395        self.bin_path = Some(PathBuf::from(path.as_ref()));
396        self
397    }
398
399    #[must_use]
400    /// Set the user name
401    pub fn with_username(mut self, username: &str) -> Self {
402        self.db_user = Some(username.to_string());
403        self
404    }
405
406    #[must_use]
407    /// Set the user password
408    pub fn with_password(mut self, password: &str) -> Self {
409        self.password = Some(password.to_string());
410        self
411    }
412
413    #[must_use]
414    /// Set the port
415    pub fn with_port(mut self, port: u16) -> Self {
416        self.port = Some(port);
417        self
418    }
419
420    #[must_use]
421    /// Set the database name
422    pub fn with_dbname(mut self, dbname: &str) -> Self {
423        self.dbname = Some(dbname.to_string());
424        self
425    }
426
427    /// If set, the postgres data directory will not be deleted when the `PgTempDB` is dropped.
428    #[must_use]
429    pub fn persist_data(mut self, persist: bool) -> Self {
430        self.persist_data_dir = persist;
431        self
432    }
433
434    /// If set, the database will be dumped via the `pg_dump` utility to the given location on drop
435    /// or upon calling [`PgTempDB::shutdown`].
436    #[must_use]
437    pub fn dump_database(mut self, path: &Path) -> Self {
438        self.dump_path = Some(path.into());
439        self
440    }
441
442    /// If set, the database will be loaded via `psql` from the given script on startup.
443    #[must_use]
444    pub fn load_database(mut self, path: &Path) -> Self {
445        self.load_path = Some(path.into());
446        self
447    }
448
449    /// Get user if set or return default
450    pub fn get_user(&self) -> String {
451        self.db_user.clone().unwrap_or(String::from("postgres"))
452    }
453
454    /// Get password if set or return default
455    pub fn get_password(&self) -> String {
456        self.password.clone().unwrap_or(String::from("password"))
457    }
458
459    /// Unlike the other getters, this getter will try to open a new socket to find an unused port,
460    /// and then set it as the current port.
461    pub fn get_port_or_set_random(&mut self) -> u16 {
462        let port = self.port.as_ref().copied().unwrap_or_else(get_unused_port);
463
464        self.port = Some(port);
465        port
466    }
467
468    /// Get dbname if set or return default
469    pub fn get_dbname(&self) -> String {
470        self.dbname.clone().unwrap_or(String::from("postgres"))
471    }
472}
473
474fn get_unused_port() -> u16 {
475    // TODO: relies on Rust's stdlib setting SO_REUSEPORT by default so that postgres can still
476    // bind to the port afterwards. Also there's a race condition/TOCTOU because there's lag
477    // between when the port is checked here and when postgres actually tries to bind to it.
478    let sock = std::net::TcpListener::bind("localhost:0")
479        .expect("failed to bind to local port when getting unused port");
480    sock.local_addr()
481        .expect("failed to get local addr from socket")
482        .port()
483}