pgtemp/lib.rs
1#![warn(missing_docs)] // denied in CI
2
3//! pgtemp is a Rust library and cli tool that allows you to easily create temporary PostgreSQL servers for testing without using Docker.
4//!
5//! The pgtemp Rust library allows you to spawn a PostgreSQL server in a temporary directory and get back a full connection URI with the host, port, username, and password.
6//!
7//! The pgtemp cli tool allows you to even more simply make temporary connections, and works with any language: Run pgtemp and then use its connection URI when connecting to the database in your tests. **pgtemp will then spawn a new postgresql process for each connection it receives** and transparently proxy everything over that connection to the temporary database. Note that this means when you make multiple connections in a single test, changes made in one connection will not be visible in the other connections, unless you are using pgtemp's `--single` mode.
8
9use std::collections::HashMap;
10use std::fmt;
11use std::fmt::Debug;
12use std::path::{Path, PathBuf};
13use std::process::Child;
14
15use tempfile::TempDir;
16use tokio::task::spawn_blocking;
17
18mod daemon;
19mod run_db;
20
21pub use daemon::*;
22
23// temp db handle - actual db spawning code is in run_db mod
24
25/// A struct representing a handle to a local PostgreSQL server that is currently running. Upon
26/// drop or calling `shutdown`, the server is shut down and the directory its data is stored in
27/// is deleted. See builder struct [`PgTempDBBuilder`] for options and settings.
28pub struct PgTempDB {
29 dbuser: String,
30 dbpass: String,
31 dbport: u16,
32 dbname: String,
33 /// persist the db data directory after shutdown
34 persist: bool,
35 /// dump the databaset to a script file after shutdown
36 dump_path: Option<PathBuf>,
37 // See shutdown implementation for why these are options
38 temp_dir: Option<TempDir>,
39 postgres_process: Option<Child>,
40}
41
42impl PgTempDB {
43 /// Start a PgTempDB with the parameters configured from a PgTempDBBuilder
44 pub fn from_builder(mut builder: PgTempDBBuilder) -> PgTempDB {
45 let dbuser = builder.get_user();
46 let dbpass = builder.get_password();
47 let dbport = builder.get_port_or_set_random();
48 let dbname = builder.get_dbname();
49 let persist = builder.persist_data_dir;
50 let dump_path = builder.dump_path.clone();
51 let load_path = builder.load_path.clone();
52
53 let temp_dir = run_db::init_db(&mut builder);
54 let postgres_process = Some(run_db::run_db(&temp_dir, builder));
55 let temp_dir = Some(temp_dir);
56
57 let db = PgTempDB {
58 dbuser,
59 dbpass,
60 dbport,
61 dbname,
62 persist,
63 dump_path,
64 temp_dir,
65 postgres_process,
66 };
67
68 if let Some(path) = load_path {
69 db.load_database(path);
70 }
71 db
72 }
73
74 /// Creates a builder that can be used to configure the details of the temporary PostgreSQL
75 /// server
76 pub fn builder() -> PgTempDBBuilder {
77 PgTempDBBuilder::new()
78 }
79
80 /// Creates a new PgTempDB with default configuration and starts a PostgreSQL server.
81 pub fn new() -> PgTempDB {
82 PgTempDBBuilder::new().start()
83 }
84
85 /// Creates a new PgTempDB with default configuration and starts a PostgreSQL server in an
86 /// async context.
87 pub async fn async_new() -> PgTempDB {
88 PgTempDBBuilder::new().start_async().await
89 }
90
91 /// Use [pg_dump](https://www.postgresql.org/docs/current/backup-dump.html) to dump the
92 /// database to the provided path upon drop or [`Self::shutdown`].
93 pub fn dump_database(&self, path: impl AsRef<Path>) {
94 let path_str = path.as_ref().to_str().unwrap();
95
96 let dump_output = std::process::Command::new("pg_dump")
97 .arg(self.connection_uri())
98 .args(["--file", path_str])
99 .output()
100 .expect("failed to start pg_dump. Is it installed and on your path?");
101
102 if !dump_output.status.success() {
103 let stdout = dump_output.stdout;
104 let stderr = dump_output.stderr;
105 panic!(
106 "pg_dump failed! stdout: {}\n\nstderr: {}",
107 String::from_utf8_lossy(&stdout),
108 String::from_utf8_lossy(&stderr)
109 );
110 }
111 }
112
113 /// Use `psql` to load the database from the provided dump file. See [`Self::dump_database`].
114 pub fn load_database(&self, path: impl AsRef<Path>) {
115 let path_str = path.as_ref().to_str().unwrap();
116
117 let load_output = std::process::Command::new("psql")
118 .arg(self.connection_uri())
119 .args([
120 "--file",
121 path_str,
122 "--set",
123 "ON_ERROR_STOP=1",
124 "--single-transaction",
125 ])
126 .output()
127 .expect("failed to start psql. Is it installed and on your path?");
128
129 if !load_output.status.success() {
130 let stdout = load_output.stdout;
131 let stderr = load_output.stderr;
132 panic!(
133 "psql failed! stdout: {}\n\nstderr: {}",
134 String::from_utf8_lossy(&stdout),
135 String::from_utf8_lossy(&stderr)
136 );
137 }
138 }
139
140 /// Send a signal to the database to shutdown the server, then wait for the process to exit.
141 /// Equivalent to calling drop on this struct.
142 ///
143 /// We send SIGINT to the postgres process to initiate a fast shutdown
144 /// (<https://www.postgresql.org/docs/current/server-shutdown.html>), which causes all transactions to be aborted and
145 /// connections to be terminated.
146 ///
147 /// NOTE: This is currently a blocking function. It sends SIGINT to the postgres server, waits
148 /// for the process to exit, and also does IO to remove the temp directory.
149 ///
150 pub fn shutdown(self) {
151 drop(self);
152 }
153
154 /// See description of [`shutdown`]
155 fn shutdown_internal(&mut self) {
156 // if no process (e.g. due to calling `force_shutdown`), just skip the cleanup operations.
157 if self.postgres_process.is_none() {
158 return;
159 }
160
161 // do the dump while the postgres process is still running
162 if let Some(path) = &self.dump_path {
163 self.dump_database(path);
164 }
165
166 let postgres_process = self
167 .postgres_process
168 .take()
169 .expect("shutdown with no postgres process");
170 let temp_dir = self.temp_dir.take().unwrap();
171
172 // fast (not graceful) shutdown via SIGINT
173 // TODO: graceful shutdown via SIGTERM
174 // was having issues with using graceful shutdown by default and some tests/examples using
175 // pg connection pools - likely what was happening was that at the end of the test we hit
176 // drop for the connection pool, it tries to drop asynchronously (e.g. it probably sends a
177 // async signal), then we block indefinitely on the main thread in PgTempDB::shutdown
178 // waiting for the server to shut down and the pooler never gets a chance to shut down, so
179 // the postgres server says "we're still connected to a client, can't shut down yet" and we
180 // have a deadlock.
181 #[allow(clippy::cast_possible_wrap)]
182 let _ret = unsafe { libc::kill(postgres_process.id() as i32, libc::SIGINT) };
183 let _output = postgres_process
184 .wait_with_output()
185 .expect("postgres server failed to exit cleanly");
186
187 if self.persist {
188 // this prevents the dir from being deleted on drop
189 let _path = temp_dir.into_path();
190 } else {
191 // if we just used the default drop impl, errors would not be surfaced
192 temp_dir.close().expect("failed to clean up temp directory");
193 }
194 }
195
196 /// Returns the path to the data directory being used by this databaset.
197 pub fn data_dir(&self) -> PathBuf {
198 self.temp_dir.as_ref().unwrap().path().join("pg_data_dir")
199 }
200
201 /// Returns the database username used when connecting to the postgres server.
202 pub fn db_user(&self) -> &str {
203 &self.dbuser
204 }
205
206 /// Returns the database password used when connecting to the postgres server.
207 pub fn db_pass(&self) -> &str {
208 &self.dbpass
209 }
210
211 /// Returns the port the postgres server is running on.
212 pub fn db_port(&self) -> u16 {
213 self.dbport
214 }
215
216 /// Returns the the name of the database created.
217 pub fn db_name(&self) -> &str {
218 &self.dbname
219 }
220
221 /// Returns a connection string that can be passed to a libpq connection function.
222 ///
223 /// Example output:
224 /// `host=localhost port=15432 user=pgtemp password=pgtemppw-9485 dbname=pgtempdb-324`
225 pub fn connection_string(&self) -> String {
226 format!(
227 "host=localhost port={} user={} password={} dbname={}",
228 self.db_port(),
229 self.db_user(),
230 self.db_pass(),
231 self.db_name()
232 )
233 }
234
235 /// Returns a generic connection URI that can be passed to most SQL libraries' connect
236 /// methods.
237 ///
238 /// Example output:
239 /// `postgresql://pgtemp:pgtemppw-9485@localhost:15432/pgtempdb-324`
240 pub fn connection_uri(&self) -> String {
241 format!(
242 "postgresql://{}:{}@localhost:{}/{}",
243 self.db_user(),
244 self.db_pass(),
245 self.db_port(),
246 self.db_name()
247 )
248 }
249}
250
251impl Debug for PgTempDB {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 f.debug_struct("PgTempDB")
254 .field("base directory", self.temp_dir.as_ref().unwrap())
255 .field("connection string", &self.connection_string())
256 .field("persist data dir", &self.persist)
257 .field("dump path", &self.dump_path)
258 .field(
259 "db process",
260 &self.postgres_process.as_ref().map(Child::id).unwrap(),
261 )
262 .finish_non_exhaustive()
263 }
264}
265
266impl Drop for PgTempDB {
267 fn drop(&mut self) {
268 self.shutdown_internal();
269 }
270}
271
272// db config builder functions
273
274/// Builder struct for PgTempDB.
275#[derive(Default, Debug, Clone)]
276pub struct PgTempDBBuilder {
277 /// The directory in which to store the temporary PostgreSQL data directory.
278 pub temp_dir_prefix: Option<PathBuf>,
279 /// The cluster superuser created with `initdb`. Default: `postgres`
280 pub db_user: Option<String>,
281 /// The password for the cluster superuser. Default: `password`
282 pub password: Option<String>,
283 /// The port the server should run on. Default: random unused port.
284 pub port: Option<u16>,
285 /// The name of the database to create on startup. Default: `postgres`.
286 pub dbname: Option<String>,
287 /// Do not delete the data dir when the `PgTempDB` is dropped.
288 pub persist_data_dir: bool,
289 /// The path to dump the database to (via `pg_dump`) when the `PgTempDB` is dropped.
290 pub dump_path: Option<PathBuf>,
291 /// The path to load the database from (via `psql`) when the `PgTempDB` is started.
292 pub load_path: Option<PathBuf>,
293 /// Other server configuration data to be set in `postgresql.conf` via `initdb -c`
294 pub server_configs: HashMap<String, String>,
295 /// Direct arguments to pass to the `initdb` binary (e.g. --encoding=UTF8), distinct from postgres configs (-c)
296 pub initdb_args: HashMap<String, String>,
297 /// Prefix PostgreSQL binary names (`initdb`, `createdb`, and `postgres`) with this path, instead of searching $PATH
298 pub bin_path: Option<PathBuf>,
299}
300
301impl PgTempDBBuilder {
302 /// Create a new [`PgTempDBBuilder`]
303 pub fn new() -> PgTempDBBuilder {
304 PgTempDBBuilder::default()
305 }
306
307 /// Parses the parameters out of a PostgreSQL connection URI and inserts them into the builder.
308 #[must_use]
309 pub fn from_connection_uri(conn_uri: &str) -> Self {
310 let mut builder = PgTempDBBuilder::new();
311
312 let url = url::Url::parse(conn_uri)
313 .expect(&format!("Could not parse connection URI `{}`", conn_uri));
314
315 // TODO: error types
316 assert!(
317 url.scheme() == "postgresql",
318 "connection URI must start with `postgresql://` scheme: `{}`",
319 conn_uri
320 );
321 assert!(
322 url.host_str() == Some("localhost"),
323 "connection URI's host is not localhost: `{}`",
324 conn_uri,
325 );
326
327 let username = url.username();
328 let password = url.password();
329 let port = url.port();
330 let dbname = url.path().strip_prefix('/').unwrap_or("");
331
332 if !username.is_empty() {
333 builder = builder.with_username(username);
334 }
335 if let Some(password) = password {
336 builder = builder.with_password(password);
337 }
338 if let Some(port) = port {
339 builder = builder.with_port(port);
340 }
341 if !dbname.is_empty() {
342 builder = builder.with_dbname(dbname);
343 }
344
345 builder
346 }
347
348 // TODO: make an error type and `try_start` methods (and maybe similar for above shutdown etc
349 // functions)
350
351 /// Creates the temporary data directory and starts the PostgreSQL server with the configured
352 /// parameters.
353 ///
354 /// If the current user is root, will attempt to run the `initdb` and `postgres` commands as
355 /// the `postgres` user.
356 pub fn start(self) -> PgTempDB {
357 PgTempDB::from_builder(self)
358 }
359
360 /// Convenience function for calling `spawn_blocking(self.start())`
361 pub async fn start_async(self) -> PgTempDB {
362 spawn_blocking(move || self.start())
363 .await
364 .expect("failed to start pgtemp server")
365 }
366
367 /// Set the directory in which to put the (temporary) PostgreSQL data directory. This is not
368 /// the data directory itself: a new temporary directory is created inside this one.
369 #[must_use]
370 pub fn with_data_dir_prefix(mut self, prefix: impl AsRef<Path>) -> Self {
371 self.temp_dir_prefix = Some(PathBuf::from(prefix.as_ref()));
372 self
373 }
374
375 /// Set an arbitrary PostgreSQL server configuration parameter that will passed to the
376 /// postgresql process at runtime.
377 #[must_use]
378 pub fn with_config_param(mut self, key: &str, value: &str) -> Self {
379 let _old = self.server_configs.insert(key.into(), value.into());
380 self
381 }
382
383 /// Set an arbitrary argument that will be passed directly to the initdb binary during database
384 /// initialization. These are direct arguments like --encoding or --locale, not configuration
385 /// parameters that get written to postgresql.conf (use with_config_param for those).
386 #[must_use]
387 pub fn with_initdb_arg(mut self, key: &str, value: &str) -> Self {
388 let _old = self.initdb_args.insert(key.into(), value.into());
389 self
390 }
391
392 /// Set the directory that contains binaries like `initdb`, `createdb`, and `postgres`.
393 #[must_use]
394 pub fn with_bin_path(mut self, path: impl AsRef<Path>) -> Self {
395 self.bin_path = Some(PathBuf::from(path.as_ref()));
396 self
397 }
398
399 #[must_use]
400 /// Set the user name
401 pub fn with_username(mut self, username: &str) -> Self {
402 self.db_user = Some(username.to_string());
403 self
404 }
405
406 #[must_use]
407 /// Set the user password
408 pub fn with_password(mut self, password: &str) -> Self {
409 self.password = Some(password.to_string());
410 self
411 }
412
413 #[must_use]
414 /// Set the port
415 pub fn with_port(mut self, port: u16) -> Self {
416 self.port = Some(port);
417 self
418 }
419
420 #[must_use]
421 /// Set the database name
422 pub fn with_dbname(mut self, dbname: &str) -> Self {
423 self.dbname = Some(dbname.to_string());
424 self
425 }
426
427 /// If set, the postgres data directory will not be deleted when the `PgTempDB` is dropped.
428 #[must_use]
429 pub fn persist_data(mut self, persist: bool) -> Self {
430 self.persist_data_dir = persist;
431 self
432 }
433
434 /// If set, the database will be dumped via the `pg_dump` utility to the given location on drop
435 /// or upon calling [`PgTempDB::shutdown`].
436 #[must_use]
437 pub fn dump_database(mut self, path: &Path) -> Self {
438 self.dump_path = Some(path.into());
439 self
440 }
441
442 /// If set, the database will be loaded via `psql` from the given script on startup.
443 #[must_use]
444 pub fn load_database(mut self, path: &Path) -> Self {
445 self.load_path = Some(path.into());
446 self
447 }
448
449 /// Get user if set or return default
450 pub fn get_user(&self) -> String {
451 self.db_user.clone().unwrap_or(String::from("postgres"))
452 }
453
454 /// Get password if set or return default
455 pub fn get_password(&self) -> String {
456 self.password.clone().unwrap_or(String::from("password"))
457 }
458
459 /// Unlike the other getters, this getter will try to open a new socket to find an unused port,
460 /// and then set it as the current port.
461 pub fn get_port_or_set_random(&mut self) -> u16 {
462 let port = self.port.as_ref().copied().unwrap_or_else(get_unused_port);
463
464 self.port = Some(port);
465 port
466 }
467
468 /// Get dbname if set or return default
469 pub fn get_dbname(&self) -> String {
470 self.dbname.clone().unwrap_or(String::from("postgres"))
471 }
472}
473
474fn get_unused_port() -> u16 {
475 // TODO: relies on Rust's stdlib setting SO_REUSEPORT by default so that postgres can still
476 // bind to the port afterwards. Also there's a race condition/TOCTOU because there's lag
477 // between when the port is checked here and when postgres actually tries to bind to it.
478 let sock = std::net::TcpListener::bind("localhost:0")
479 .expect("failed to bind to local port when getting unused port");
480 sock.local_addr()
481 .expect("failed to get local addr from socket")
482 .port()
483}