1#![doc = include_str!("../README.md")]
2
3use std::{
4 fs, io, net,
5 net::TcpListener,
6 path, process,
7 sync::{
8 atomic::{AtomicUsize, Ordering},
9 Arc, Mutex, Weak,
10 },
11 thread,
12 time::{Duration, Instant},
13};
14
15use process_guard::ProcessGuard;
16use rand::{rngs::OsRng, Rng};
17use thiserror::Error;
18
19#[derive(Debug)]
25pub struct DbUri {
26 _arc: Arc<Postgres>,
28 uri: String,
30}
31
32impl DbUri {
33 pub fn as_str(&self) -> &str {
35 self.uri.as_str()
36 }
37}
38
39impl AsRef<str> for DbUri {
40 fn as_ref(&self) -> &str {
41 self.as_str()
42 }
43}
44
45pub fn db_fixture() -> DbUri {
57 static DB: Mutex<Weak<Postgres>> = Mutex::new(Weak::new());
58
59 static FIXTURE_COUNT: AtomicUsize = AtomicUsize::new(1);
60
61 let pg = {
62 let mut guard = DB.lock().expect("lock poisoned");
63 if let Some(arc) = guard.upgrade() {
64 arc
66 } else {
67 let arc = Arc::new(
68 Postgres::build()
69 .start()
70 .expect("failed to start global postgres DB"),
71 );
72 *guard = Arc::downgrade(&arc);
73 arc
74 }
75 };
76
77 let count = FIXTURE_COUNT.fetch_add(1, Ordering::Relaxed);
78 let db_name = format!("fixture_db_{}", count);
79 let db_user = format!("fixture_user_{}", count);
80 let db_pw = format!("fixture_pass_{}", count);
81 pg.as_superuser()
82 .create_user(&db_user, &db_pw)
83 .expect("failed to create user for fixture DB");
84 pg.as_superuser()
85 .create_database(&db_name, &db_user)
86 .expect("failed to create database for fixture DB");
87 let uri = pg.as_user(&db_user, &db_pw).uri(&db_name);
88 DbUri { _arc: pg, uri }
89}
90
91fn find_unused_port() -> io::Result<u16> {
97 let listener = TcpListener::bind("127.0.0.1:0")?;
98 let port = listener.local_addr()?.port();
99 Ok(port)
100}
101
102#[derive(Debug)]
107pub struct Postgres {
108 host: String,
110 port: u16,
112 #[allow(dead_code)] instance: ProcessGuard,
115 psql_binary: path::PathBuf,
117 superuser: String,
119 superuser_pw: String,
121 #[allow(dead_code)] tmp_dir: tempfile::TempDir,
124}
125
126#[derive(Debug)]
130pub struct PostgresClient<'a> {
131 instance: &'a Postgres,
132 username: String,
134 password: String,
136}
137
138#[derive(Debug)]
142pub struct PostgresBuilder {
143 data_dir: Option<path::PathBuf>,
145 port: Option<u16>,
149 host: String,
151 superuser: String,
153 superuser_pw: String,
155 postgres_binary: Option<path::PathBuf>,
157 initdb_binary: Option<path::PathBuf>,
159 psql_binary: Option<path::PathBuf>,
161 probe_delay: Duration,
163 startup_timeout: Duration,
165}
166
167#[derive(Error, Debug)]
169pub enum Error {
170 #[error("could not find `postgres` binary")]
171 FindPostgres(which::Error),
172 #[error("could not find `initdb` binary")]
174 FindInitdb(which::Error),
175 #[error("could not find `psql` binary")]
177 FindPsql(which::Error),
178 #[error("could not create temporary directory for database")]
180 CreateDatabaseDir(io::Error),
181 #[error("error writing temporary password")]
183 WriteTemporaryPw(io::Error),
184 #[error("failed to run `initdb`")]
186 RunInitDb(io::Error),
187 #[error("`initdb` exited with status {}", 0)]
189 InitDbFailed(process::ExitStatus),
190 #[error("failed to launch `postgres`")]
192 LaunchPostgres(io::Error),
193 #[error("timeout probing tcp socket")]
195 StartupTimeout,
196 #[error("failed to run `psql`")]
198 RunPsql(io::Error),
199 #[error("`psql` exited with status {}", 0)]
201 PsqlFailed(process::ExitStatus),
202}
203
204impl Postgres {
205 #[inline]
207 pub fn build() -> PostgresBuilder {
208 PostgresBuilder {
209 data_dir: None,
210 port: None,
211 host: "127.0.0.1".to_string(),
212 superuser: "postgres".to_string(),
213 superuser_pw: generate_random_string(),
214 postgres_binary: None,
215 initdb_binary: None,
216 psql_binary: None,
217 probe_delay: Duration::from_millis(100),
218 startup_timeout: Duration::from_secs(10),
219 }
220 }
221
222 #[inline]
224 pub fn as_superuser(&self) -> PostgresClient<'_> {
225 self.as_user(&self.superuser, &self.superuser_pw)
226 }
227
228 #[inline]
230 pub fn as_user(&self, username: &str, password: &str) -> PostgresClient<'_> {
231 PostgresClient {
232 instance: self,
233 username: username.to_string(),
234 password: password.to_string(),
235 }
236 }
237
238 #[inline]
240 pub fn host(&self) -> &str {
241 self.host.as_str()
242 }
243
244 #[inline]
246 pub fn port(&self) -> u16 {
247 self.port
248 }
249}
250
251impl<'a> PostgresClient<'a> {
252 pub fn psql(&self, database: &str) -> process::Command {
257 let mut cmd = process::Command::new(&self.instance.psql_binary);
258
259 cmd.arg("-h")
260 .arg(&self.instance.host)
261 .arg("-p")
262 .arg(self.instance.port.to_string())
263 .arg("-U")
264 .arg(&self.username)
265 .arg("-d")
266 .arg(database)
267 .env("PGPASSWORD", &self.password);
268
269 cmd
270 }
271
272 pub fn load_sql<P: AsRef<path::Path>>(&self, database: &str, filename: P) -> Result<(), Error> {
274 let status = self
275 .psql(database)
276 .arg("-f")
277 .arg(filename.as_ref())
278 .status()
279 .map_err(Error::RunPsql)?;
280
281 if !status.success() {
282 return Err(Error::PsqlFailed(status));
283 }
284
285 Ok(())
286 }
287
288 pub fn run_sql(&self, database: &str, sql: &str) -> Result<(), Error> {
290 let status = self
291 .psql(database)
292 .arg("-c")
293 .arg(sql)
294 .status()
295 .map_err(Error::RunPsql)?;
296
297 if !status.success() {
298 return Err(Error::PsqlFailed(status));
299 }
300
301 Ok(())
302 }
303
304 #[inline]
308 pub fn create_database(&self, database: &str, owner: &str) -> Result<(), Error> {
309 self.run_sql(
310 "postgres",
311 &format!(
312 "CREATE DATABASE {} OWNER {};",
313 escape_ident(database),
314 escape_ident(owner)
315 ),
316 )
317 }
318
319 #[inline]
323 pub fn create_user(&self, username: &str, password: &str) -> Result<(), Error> {
324 self.run_sql(
325 "postgres",
326 &format!(
327 "CREATE ROLE {} LOGIN ENCRYPTED PASSWORD {};",
328 escape_ident(username),
329 escape_string(password)
330 ),
331 )
332 }
333
334 #[inline]
336 pub fn instance(&self) -> &Postgres {
337 self.instance
338 }
339
340 pub fn username(&self) -> &str {
342 self.username.as_str()
343 }
344
345 pub fn uri(&self, database: &str) -> String {
347 format!(
348 "postgres://{}:{}@{}:{}/{}",
349 self.username,
350 self.password,
351 self.instance.host(),
352 self.instance.port(),
353 database
354 )
355 }
356
357 #[inline]
359 pub fn password(&self) -> &str {
360 self.password.as_str()
361 }
362}
363
364impl PostgresBuilder {
365 #[inline]
369 pub fn data_dir<T: Into<path::PathBuf>>(&mut self, data_dir: T) -> &mut Self {
370 self.data_dir = Some(data_dir.into());
371 self
372 }
373
374 #[inline]
376 pub fn initdb_binary<T: Into<path::PathBuf>>(&mut self, initdb_binary: T) -> &mut Self {
377 self.initdb_binary = Some(initdb_binary.into());
378 self
379 }
380
381 #[inline]
383 pub fn host(&mut self, host: String) -> &mut Self {
384 self.host = host;
385 self
386 }
387
388 #[inline]
394 pub fn port(&mut self, port: u16) -> &mut Self {
395 self.port = Some(port);
396 self
397 }
398
399 #[inline]
401 pub fn postgres_binary<T: Into<path::PathBuf>>(&mut self, postgres_binary: T) -> &mut Self {
402 self.postgres_binary = Some(postgres_binary.into());
403 self
404 }
405
406 #[inline]
410 pub fn probe_delay(&mut self, probe_delay: Duration) -> &mut Self {
411 self.probe_delay = probe_delay;
412 self
413 }
414
415 #[inline]
417 pub fn psql_binary<T: Into<path::PathBuf>>(&mut self, psql_binary: T) -> &mut Self {
418 self.psql_binary = Some(psql_binary.into());
419 self
420 }
421
422 #[inline]
424 pub fn startup_timeout(&mut self, startup_timeout: Duration) -> &mut Self {
425 self.startup_timeout = startup_timeout;
426 self
427 }
428
429 #[inline]
431 pub fn superuser_pw<T: Into<String>>(&mut self, superuser_pw: T) -> &mut Self {
432 self.superuser_pw = superuser_pw.into();
433 self
434 }
435
436 pub fn start(&self) -> Result<Postgres, Error> {
441 let port = self
442 .port
443 .unwrap_or_else(|| find_unused_port().expect("failed to find an unused port"));
444
445 let postgres_binary = self
446 .postgres_binary
447 .clone()
448 .map(Ok)
449 .unwrap_or_else(|| which::which("postgres").map_err(Error::FindPostgres))?;
450 let initdb_binary = self
451 .initdb_binary
452 .clone()
453 .map(Ok)
454 .unwrap_or_else(|| which::which("initdb").map_err(Error::FindInitdb))?;
455 let psql_binary = self
456 .psql_binary
457 .clone()
458 .map(Ok)
459 .unwrap_or_else(|| which::which("psql").map_err(Error::FindPsql))?;
460
461 let tmp_dir = tempfile::tempdir().map_err(Error::CreateDatabaseDir)?;
462 let data_dir = self
463 .data_dir
464 .clone()
465 .unwrap_or_else(|| tmp_dir.path().join("db"));
466
467 let superuser_pw_file = tmp_dir.path().join("superuser-pw");
468 fs::write(&superuser_pw_file, self.superuser_pw.as_bytes())
469 .map_err(Error::WriteTemporaryPw)?;
470
471 let initdb_status = process::Command::new(initdb_binary)
472 .args([
473 "--no-locale",
475 "--auth=md5",
477 "--encoding=UTF8",
479 "--nosync",
481 "--pgdata",
483 ])
484 .arg(&data_dir)
485 .arg("--pwfile")
486 .arg(&superuser_pw_file)
487 .arg("--username")
488 .arg(&self.superuser)
489 .status()
490 .map_err(Error::RunInitDb)?;
491
492 if !initdb_status.success() {
493 return Err(Error::InitDbFailed(initdb_status));
494 }
495
496 let mut postgres_command = process::Command::new(postgres_binary);
498 postgres_command
499 .arg("-D")
500 .arg(&data_dir)
501 .arg("-p")
502 .arg(port.to_string())
503 .arg("-k")
504 .arg(tmp_dir.path());
505
506 let instance = ProcessGuard::spawn_graceful(&mut postgres_command, Duration::from_secs(5))
507 .map_err(Error::LaunchPostgres)?;
508
509 let socket_addr = format!("{}:{}", self.host, port);
511 let started = Instant::now();
512 loop {
513 match net::TcpStream::connect(socket_addr.as_str()) {
514 Ok(_) => break,
515 Err(_) => {
516 let now = Instant::now();
517
518 if now.duration_since(started) >= self.startup_timeout {
519 return Err(Error::StartupTimeout);
520 }
521
522 thread::sleep(self.probe_delay);
523 }
524 }
525 }
526
527 Ok(Postgres {
528 host: self.host.clone(),
529 port,
530 instance,
531 psql_binary,
532 superuser: self.superuser.clone(),
533 superuser_pw: self.superuser_pw.clone(),
534 tmp_dir,
535 })
536 }
537}
538
539fn generate_random_string() -> String {
541 let raw: [u8; 16] = OsRng.gen();
542 format!("{:x}", hex_fmt::HexFmt(&raw))
543}
544
545fn quote(quote_char: char, unescaped: &str) -> String {
548 let mut result = String::new();
549
550 result.push(quote_char);
551 for c in unescaped.chars() {
552 if c == quote_char {
553 result.push(quote_char);
554 result.push(quote_char);
555 } else {
556 result.push(c);
557 }
558 }
559 result.push(quote_char);
560
561 result
562}
563
564fn escape_ident(unescaped: &str) -> String {
566 quote('"', unescaped)
567}
568
569fn escape_string(unescaped: &str) -> String {
571 quote('\'', unescaped)
572}
573
574#[cfg(test)]
575mod tests {
576 use super::Postgres;
577
578 #[test]
579 fn can_change_superuser_pw() {
580 let pg = Postgres::build()
581 .superuser_pw("helloworld")
582 .start()
583 .expect("could not build postgres database");
584
585 let su = pg.as_superuser();
586 su.create_user("foo", "bar")
587 .expect("could not create normal user");
588
589 assert_eq!(su.password, "helloworld");
591 }
592
593 #[test]
594 fn instances_use_different_port_by_default() {
595 let a = Postgres::build()
596 .start()
597 .expect("could not build postgres database");
598 let b = Postgres::build()
599 .start()
600 .expect("could not build postgres database");
601 let c = Postgres::build()
602 .start()
603 .expect("could not build postgres database");
604
605 assert_ne!(a.port(), b.port());
606 assert_ne!(a.port(), c.port());
607 assert_ne!(b.port(), c.port());
608 }
609
610 #[test]
611 fn ensure_proper_db_reuse_when_using_fixtures() {
612 let db_uri = crate::db_fixture();
613 assert_eq!(
614 &db_uri.as_str()[..51],
615 "postgres://fixture_user_1:fixture_pass_1@127.0.0.1:"
616 );
617
618 let db_uri2 = crate::db_fixture();
620 assert_eq!(
621 &db_uri2.as_str()[..51],
622 "postgres://fixture_user_2:fixture_pass_2@127.0.0.1:"
623 );
624 }
625}