pgdb/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    fs, io, net,
5    net::TcpListener,
6    path, process,
7    sync::{
8        atomic::{AtomicUsize, Ordering},
9        Arc, Mutex, Weak,
10    },
11    thread,
12    time::{Duration, Instant},
13};
14
15use process_guard::ProcessGuard;
16use rand::{rngs::OsRng, Rng};
17use thiserror::Error;
18
19/// A database URI keeping a database alive.
20///
21/// Contains the output of [`PostgresClient::uri`] and a reference to the database it points to. As
22/// a result, as long as the [`DbUri`] is alive, the database it points to will also be kept
23/// running.
24#[derive(Debug)]
25pub struct DbUri {
26    /// A reference to the running Postgres instance where this URI points.
27    _arc: Arc<Postgres>,
28    /// The actual URI.
29    uri: String,
30}
31
32impl DbUri {
33    /// Returns the
34    pub fn as_str(&self) -> &str {
35        self.uri.as_str()
36    }
37}
38
39impl AsRef<str> for DbUri {
40    fn as_ref(&self) -> &str {
41        self.as_str()
42    }
43}
44
45/// A convenience function for regular applications.
46///
47/// Some applications just need a clean database instance and can afford to share the underlying
48/// database.
49///
50/// Uses a shared database instance if multiple tests are running at the same time (see [`DbUri`]
51/// for details). The database may be shut down and recreated if the last [`DbUri`] is dropped
52/// during testing, e.g. when parallel tests are not spawned quick enough.
53///
54/// This construction is necessary because `static` variables will not have `Drop` called on them,
55/// without this construction, the spawned Postgres server would not be stopped.
56pub fn db_fixture() -> DbUri {
57    static DB: Mutex<Weak<Postgres>> = Mutex::new(Weak::new());
58
59    static FIXTURE_COUNT: AtomicUsize = AtomicUsize::new(1);
60
61    let pg = {
62        let mut guard = DB.lock().expect("lock poisoned");
63        if let Some(arc) = guard.upgrade() {
64            // We still have an instance we can reuse.
65            arc
66        } else {
67            let arc = Arc::new(
68                Postgres::build()
69                    .start()
70                    .expect("failed to start global postgres DB"),
71            );
72            *guard = Arc::downgrade(&arc);
73            arc
74        }
75    };
76
77    let count = FIXTURE_COUNT.fetch_add(1, Ordering::Relaxed);
78    let db_name = format!("fixture_db_{}", count);
79    let db_user = format!("fixture_user_{}", count);
80    let db_pw = format!("fixture_pass_{}", count);
81    pg.as_superuser()
82        .create_user(&db_user, &db_pw)
83        .expect("failed to create user for fixture DB");
84    pg.as_superuser()
85        .create_database(&db_name, &db_user)
86        .expect("failed to create database for fixture DB");
87    let uri = pg.as_user(&db_user, &db_pw).uri(&db_name);
88    DbUri { _arc: pg, uri }
89}
90
91/// Finds an unused port by binding to port 0 and letting the OS assign one.
92///
93/// This function has a race condition, there is no guarantee that the OS won't reassign the port as
94/// soon as it is released again. Sadly this is our only recourse, as Postgres does not allow
95/// passing `0` as the port number.
96fn find_unused_port() -> io::Result<u16> {
97    let listener = TcpListener::bind("127.0.0.1:0")?;
98    let port = listener.local_addr()?.port();
99    Ok(port)
100}
101
102/// A wrapped postgres instance.
103///
104/// Contains a handle to a running Postgres process. Once dropped, the instance will be shut down
105/// and the temporary directory containing all of its data removed.
106#[derive(Debug)]
107pub struct Postgres {
108    /// Host address of the instance.
109    host: String,
110    /// Port the instance is running on.
111    port: u16,
112    /// Instance of the postgres process.
113    #[allow(dead_code)] // Only used for its `Drop` implementation.
114    instance: ProcessGuard,
115    /// Path to the `psql` binary.
116    psql_binary: path::PathBuf,
117    /// Superuser name.
118    superuser: String,
119    /// Superuser's password.
120    superuser_pw: String,
121    /// Directory holding all the temporary data.
122    #[allow(dead_code)] // Only used for its `Drop` implementation.
123    tmp_dir: tempfile::TempDir,
124}
125
126/// A virtual client for a running postgres.
127///
128/// Contains credentials and enough information to connect to its parent instance.
129#[derive(Debug)]
130pub struct PostgresClient<'a> {
131    instance: &'a Postgres,
132    /// Superuser name.
133    username: String,
134    /// Superuser password.
135    password: String,
136}
137
138/// Builder for a postgres instance.
139///
140/// Usually constructed via [`Postgres::build`].
141#[derive(Debug)]
142pub struct PostgresBuilder {
143    /// Data directory.
144    data_dir: Option<path::PathBuf>,
145    /// Listening port.
146    ///
147    /// If not set, [`find_unused_port`] will be used to determine the port.
148    port: Option<u16>,
149    /// Bind host.
150    host: String,
151    /// Name of the superuser.
152    superuser: String,
153    /// Password for the superuser.
154    superuser_pw: String,
155    /// Path to `postgres` binary.
156    postgres_binary: Option<path::PathBuf>,
157    /// Path to `initdb` binary.
158    initdb_binary: Option<path::PathBuf>,
159    /// Path to `psql` binary.
160    psql_binary: Option<path::PathBuf>,
161    /// How long to wait between startup probe attempts.
162    probe_delay: Duration,
163    /// Time until giving up waiting for startup.
164    startup_timeout: Duration,
165}
166
167/// A Postgres server error.
168#[derive(Error, Debug)]
169pub enum Error {
170    #[error("could not find `postgres` binary")]
171    FindPostgres(which::Error),
172    /// Failed to find the `initdb` binary.
173    #[error("could not find `initdb` binary")]
174    FindInitdb(which::Error),
175    /// Failed to find the `postgres` binary.
176    #[error("could not find `psql` binary")]
177    FindPsql(which::Error),
178    /// Could not create the temporary directory.
179    #[error("could not create temporary directory for database")]
180    CreateDatabaseDir(io::Error),
181    /// Could not write the temporary password to a file.
182    #[error("error writing temporary password")]
183    WriteTemporaryPw(io::Error),
184    /// Starting `initdb` failed.
185    #[error("failed to run `initdb`")]
186    RunInitDb(io::Error),
187    /// Running `initdb` was not successful.
188    #[error("`initdb` exited with status {}", 0)]
189    InitDbFailed(process::ExitStatus),
190    /// Postgres could not be launched.
191    #[error("failed to launch `postgres`")]
192    LaunchPostgres(io::Error),
193    /// Postgres was launched but failed to bring up a TCP-connection accepting socket in time.
194    #[error("timeout probing tcp socket")]
195    StartupTimeout,
196    /// `psql` could not be launched.
197    #[error("failed to run `psql`")]
198    RunPsql(io::Error),
199    /// Running `psql` returned an error.
200    #[error("`psql` exited with status {}", 0)]
201    PsqlFailed(process::ExitStatus),
202}
203
204impl Postgres {
205    /// Creates a new Postgres database builder.
206    #[inline]
207    pub fn build() -> PostgresBuilder {
208        PostgresBuilder {
209            data_dir: None,
210            port: None,
211            host: "127.0.0.1".to_string(),
212            superuser: "postgres".to_string(),
213            superuser_pw: generate_random_string(),
214            postgres_binary: None,
215            initdb_binary: None,
216            psql_binary: None,
217            probe_delay: Duration::from_millis(100),
218            startup_timeout: Duration::from_secs(10),
219        }
220    }
221
222    /// Returns a postgres client with superuser credentials.
223    #[inline]
224    pub fn as_superuser(&self) -> PostgresClient<'_> {
225        self.as_user(&self.superuser, &self.superuser_pw)
226    }
227
228    /// Returns a postgres client that uses the given credentials.
229    #[inline]
230    pub fn as_user(&self, username: &str, password: &str) -> PostgresClient<'_> {
231        PostgresClient {
232            instance: self,
233            username: username.to_string(),
234            password: password.to_string(),
235        }
236    }
237
238    /// Returns the hostname the Postgres database can be connected to at.
239    #[inline]
240    pub fn host(&self) -> &str {
241        self.host.as_str()
242    }
243
244    /// Returns the port the Postgres database is bound to.
245    #[inline]
246    pub fn port(&self) -> u16 {
247        self.port
248    }
249}
250
251impl<'a> PostgresClient<'a> {
252    /// Runs a `psql` command against the database.
253    ///
254    /// Creates a command that runs `psql -h (host) -p (port) -U (username) -d (database)` with
255    /// `PGPASSWORD` set.
256    pub fn psql(&self, database: &str) -> process::Command {
257        let mut cmd = process::Command::new(&self.instance.psql_binary);
258
259        cmd.arg("-h")
260            .arg(&self.instance.host)
261            .arg("-p")
262            .arg(self.instance.port.to_string())
263            .arg("-U")
264            .arg(&self.username)
265            .arg("-d")
266            .arg(database)
267            .env("PGPASSWORD", &self.password);
268
269        cmd
270    }
271
272    /// Runs the given SQL commands from an input file via `psql`.
273    pub fn load_sql<P: AsRef<path::Path>>(&self, database: &str, filename: P) -> Result<(), Error> {
274        let status = self
275            .psql(database)
276            .arg("-f")
277            .arg(filename.as_ref())
278            .status()
279            .map_err(Error::RunPsql)?;
280
281        if !status.success() {
282            return Err(Error::PsqlFailed(status));
283        }
284
285        Ok(())
286    }
287
288    /// Runs the given SQL command through `psql`.
289    pub fn run_sql(&self, database: &str, sql: &str) -> Result<(), Error> {
290        let status = self
291            .psql(database)
292            .arg("-c")
293            .arg(sql)
294            .status()
295            .map_err(Error::RunPsql)?;
296
297        if !status.success() {
298            return Err(Error::PsqlFailed(status));
299        }
300
301        Ok(())
302    }
303
304    /// Creates a new database with the given owner.
305    ///
306    /// This typically requires superuser credentials, see [`Postgres::as_superuser`].
307    #[inline]
308    pub fn create_database(&self, database: &str, owner: &str) -> Result<(), Error> {
309        self.run_sql(
310            "postgres",
311            &format!(
312                "CREATE DATABASE {} OWNER {};",
313                escape_ident(database),
314                escape_ident(owner)
315            ),
316        )
317    }
318
319    /// Creates a new user on the system that is allowed to login.
320    ///
321    /// This typically requires superuser credentials, see [`Postgres::as_superuser`].
322    #[inline]
323    pub fn create_user(&self, username: &str, password: &str) -> Result<(), Error> {
324        self.run_sql(
325            "postgres",
326            &format!(
327                "CREATE ROLE {} LOGIN ENCRYPTED PASSWORD {};",
328                escape_ident(username),
329                escape_string(password)
330            ),
331        )
332    }
333
334    /// Returns the `Postgres` instance associated with this client.
335    #[inline]
336    pub fn instance(&self) -> &Postgres {
337        self.instance
338    }
339
340    /// Returns the username used by this client.
341    pub fn username(&self) -> &str {
342        self.username.as_str()
343    }
344
345    /// Returns a libpq-style connection URI.
346    pub fn uri(&self, database: &str) -> String {
347        format!(
348            "postgres://{}:{}@{}:{}/{}",
349            self.username,
350            self.password,
351            self.instance.host(),
352            self.instance.port(),
353            database
354        )
355    }
356
357    /// Returns the password used by this client.
358    #[inline]
359    pub fn password(&self) -> &str {
360        self.password.as_str()
361    }
362}
363
364impl PostgresBuilder {
365    /// Sets the postgres data directory.
366    ///
367    /// If not set, a temporary directory will be used.
368    #[inline]
369    pub fn data_dir<T: Into<path::PathBuf>>(&mut self, data_dir: T) -> &mut Self {
370        self.data_dir = Some(data_dir.into());
371        self
372    }
373
374    /// Sets the location of the `initdb` binary.
375    #[inline]
376    pub fn initdb_binary<T: Into<path::PathBuf>>(&mut self, initdb_binary: T) -> &mut Self {
377        self.initdb_binary = Some(initdb_binary.into());
378        self
379    }
380
381    /// Sets the bind address.
382    #[inline]
383    pub fn host(&mut self, host: String) -> &mut Self {
384        self.host = host;
385        self
386    }
387
388    /// Sets listening port.
389    ///
390    /// If no port is set, the builder will attempt to find an unused through binding port `0`. This
391    /// is somewhat racing, but the only recourse, since Postgres does not support binding to port
392    /// `0`.
393    #[inline]
394    pub fn port(&mut self, port: u16) -> &mut Self {
395        self.port = Some(port);
396        self
397    }
398
399    /// Sets the location of the `postgres` binary.
400    #[inline]
401    pub fn postgres_binary<T: Into<path::PathBuf>>(&mut self, postgres_binary: T) -> &mut Self {
402        self.postgres_binary = Some(postgres_binary.into());
403        self
404    }
405
406    /// Sets the startup probe delay.
407    ///
408    /// Between two startup probes, waits this long.
409    #[inline]
410    pub fn probe_delay(&mut self, probe_delay: Duration) -> &mut Self {
411        self.probe_delay = probe_delay;
412        self
413    }
414
415    /// Sets the location of the `psql` binary.
416    #[inline]
417    pub fn psql_binary<T: Into<path::PathBuf>>(&mut self, psql_binary: T) -> &mut Self {
418        self.psql_binary = Some(psql_binary.into());
419        self
420    }
421
422    /// Sets the maximum time to probe for startup.
423    #[inline]
424    pub fn startup_timeout(&mut self, startup_timeout: Duration) -> &mut Self {
425        self.startup_timeout = startup_timeout;
426        self
427    }
428
429    /// Sets the password for the superuser.
430    #[inline]
431    pub fn superuser_pw<T: Into<String>>(&mut self, superuser_pw: T) -> &mut Self {
432        self.superuser_pw = superuser_pw.into();
433        self
434    }
435
436    /// Starts the Postgres server.
437    ///
438    /// Postgres will start using a newly created temporary directory as its data dir. The function
439    /// will only return once a TCP connection to postgres has been made successfully.
440    pub fn start(&self) -> Result<Postgres, Error> {
441        let port = self
442            .port
443            .unwrap_or_else(|| find_unused_port().expect("failed to find an unused port"));
444
445        let postgres_binary = self
446            .postgres_binary
447            .clone()
448            .map(Ok)
449            .unwrap_or_else(|| which::which("postgres").map_err(Error::FindPostgres))?;
450        let initdb_binary = self
451            .initdb_binary
452            .clone()
453            .map(Ok)
454            .unwrap_or_else(|| which::which("initdb").map_err(Error::FindInitdb))?;
455        let psql_binary = self
456            .psql_binary
457            .clone()
458            .map(Ok)
459            .unwrap_or_else(|| which::which("psql").map_err(Error::FindPsql))?;
460
461        let tmp_dir = tempfile::tempdir().map_err(Error::CreateDatabaseDir)?;
462        let data_dir = self
463            .data_dir
464            .clone()
465            .unwrap_or_else(|| tmp_dir.path().join("db"));
466
467        let superuser_pw_file = tmp_dir.path().join("superuser-pw");
468        fs::write(&superuser_pw_file, self.superuser_pw.as_bytes())
469            .map_err(Error::WriteTemporaryPw)?;
470
471        let initdb_status = process::Command::new(initdb_binary)
472            .args([
473                // No default locale (== 'C').
474                "--no-locale",
475                // Require a password for all users.
476                "--auth=md5",
477                // Set default encoding to UTF8.
478                "--encoding=UTF8",
479                // Do not sync data, which is fine for tests.
480                "--nosync",
481                // Path to data directory.
482                "--pgdata",
483            ])
484            .arg(&data_dir)
485            .arg("--pwfile")
486            .arg(&superuser_pw_file)
487            .arg("--username")
488            .arg(&self.superuser)
489            .status()
490            .map_err(Error::RunInitDb)?;
491
492        if !initdb_status.success() {
493            return Err(Error::InitDbFailed(initdb_status));
494        }
495
496        // Start the database.
497        let mut postgres_command = process::Command::new(postgres_binary);
498        postgres_command
499            .arg("-D")
500            .arg(&data_dir)
501            .arg("-p")
502            .arg(port.to_string())
503            .arg("-k")
504            .arg(tmp_dir.path());
505
506        let instance = ProcessGuard::spawn_graceful(&mut postgres_command, Duration::from_secs(5))
507            .map_err(Error::LaunchPostgres)?;
508
509        // Wait for the server to come up.
510        let socket_addr = format!("{}:{}", self.host, port);
511        let started = Instant::now();
512        loop {
513            match net::TcpStream::connect(socket_addr.as_str()) {
514                Ok(_) => break,
515                Err(_) => {
516                    let now = Instant::now();
517
518                    if now.duration_since(started) >= self.startup_timeout {
519                        return Err(Error::StartupTimeout);
520                    }
521
522                    thread::sleep(self.probe_delay);
523                }
524            }
525        }
526
527        Ok(Postgres {
528            host: self.host.clone(),
529            port,
530            instance,
531            psql_binary,
532            superuser: self.superuser.clone(),
533            superuser_pw: self.superuser_pw.clone(),
534            tmp_dir,
535        })
536    }
537}
538
539/// Generates a random hex string string 32 characters long.
540fn generate_random_string() -> String {
541    let raw: [u8; 16] = OsRng.gen();
542    format!("{:x}", hex_fmt::HexFmt(&raw))
543}
544
545/// Escapes an identifier by wrapping in quote char. Any quote character inside the unescaped string
546/// will be doubled.
547fn quote(quote_char: char, unescaped: &str) -> String {
548    let mut result = String::new();
549
550    result.push(quote_char);
551    for c in unescaped.chars() {
552        if c == quote_char {
553            result.push(quote_char);
554            result.push(quote_char);
555        } else {
556            result.push(c);
557        }
558    }
559    result.push(quote_char);
560
561    result
562}
563
564/// Escapes an identifier.
565fn escape_ident(unescaped: &str) -> String {
566    quote('"', unescaped)
567}
568
569/// Escapes a string.
570fn escape_string(unescaped: &str) -> String {
571    quote('\'', unescaped)
572}
573
574#[cfg(test)]
575mod tests {
576    use super::Postgres;
577
578    #[test]
579    fn can_change_superuser_pw() {
580        let pg = Postgres::build()
581            .superuser_pw("helloworld")
582            .start()
583            .expect("could not build postgres database");
584
585        let su = pg.as_superuser();
586        su.create_user("foo", "bar")
587            .expect("could not create normal user");
588
589        // Command executed successfully, check we used the right password.
590        assert_eq!(su.password, "helloworld");
591    }
592
593    #[test]
594    fn instances_use_different_port_by_default() {
595        let a = Postgres::build()
596            .start()
597            .expect("could not build postgres database");
598        let b = Postgres::build()
599            .start()
600            .expect("could not build postgres database");
601        let c = Postgres::build()
602            .start()
603            .expect("could not build postgres database");
604
605        assert_ne!(a.port(), b.port());
606        assert_ne!(a.port(), c.port());
607        assert_ne!(b.port(), c.port());
608    }
609
610    #[test]
611    fn ensure_proper_db_reuse_when_using_fixtures() {
612        let db_uri = crate::db_fixture();
613        assert_eq!(
614            &db_uri.as_str()[..51],
615            "postgres://fixture_user_1:fixture_pass_1@127.0.0.1:"
616        );
617
618        // Calling `db_fixture` multiple times reuses the postgres process, but creates a fresh database instance and role.
619        let db_uri2 = crate::db_fixture();
620        assert_eq!(
621            &db_uri2.as_str()[..51],
622            "postgres://fixture_user_2:fixture_pass_2@127.0.0.1:"
623        );
624    }
625}