1#![doc = include_str!("../README.md")]
2
3use std::{
4 env, fs, io, net,
5 net::TcpListener,
6 path, process,
7 sync::{Arc, Mutex, Weak},
8 thread,
9 time::{Duration, Instant},
10};
11
12use process_guard::ProcessGuard;
13use rand::{rngs::OsRng, Rng};
14use thiserror::Error;
15use url::Url;
16
17#[derive(Debug)]
22pub enum DbUrl {
23 Local {
25 _arc: Arc<Postgres>,
27 url: Url,
29 },
30 External {
32 url: Url,
34 superuser_url: Url,
36 },
37}
38
39impl DbUrl {
40 pub fn as_str(&self) -> &str {
42 match self {
43 DbUrl::Local { url, .. } => url.as_str(),
44 DbUrl::External { url, .. } => url.as_str(),
45 }
46 }
47
48 pub fn as_url(&self) -> &Url {
50 match self {
51 DbUrl::Local { url, .. } => url,
52 DbUrl::External { url, .. } => url,
53 }
54 }
55}
56
57impl AsRef<str> for DbUrl {
58 fn as_ref(&self) -> &str {
59 self.as_str()
60 }
61}
62
63impl Drop for DbUrl {
64 fn drop(&mut self) {
65 if let DbUrl::External { url, superuser_url } = self {
66 let db_name = url.path().trim_start_matches('/');
68 let db_user = url.username();
69
70 let psql_binary = which::which("psql").unwrap_or_else(|_| "psql".into());
72
73 let run_cleanup_sql = |sql: &str| {
75 let username = superuser_url.username();
76 let password = superuser_url.password().unwrap_or_default();
77 let host = superuser_url.host_str().unwrap_or("localhost");
78 let port = superuser_url.port().unwrap_or(5432);
79
80 let _ = process::Command::new(&psql_binary)
81 .arg("-h")
82 .arg(host)
83 .arg("-p")
84 .arg(port.to_string())
85 .arg("-U")
86 .arg(username)
87 .arg("-d")
88 .arg("postgres")
89 .arg("-c")
90 .arg(sql)
91 .env("PGPASSWORD", password)
92 .output();
93 };
94
95 run_cleanup_sql(&format!(
97 "DROP DATABASE IF EXISTS {};",
98 escape_ident(db_name)
99 ));
100
101 run_cleanup_sql(&format!("DROP ROLE IF EXISTS {};", escape_ident(db_user)));
103 }
104 }
105}
106
107fn parse_external_test_url() -> Result<Option<Url>, Error> {
112 match env::var("PGDB_TESTS_URL") {
113 Ok(url_str) => {
114 let url = Url::parse(&url_str)
115 .map_err(|e| Error::InvalidExternalUrl(ExternalUrlError::ParseError(e)))?;
116
117 if url.scheme() != "postgres" {
119 return Err(Error::InvalidExternalUrl(ExternalUrlError::InvalidScheme));
120 }
121
122 if url.host_str().is_none() {
123 return Err(Error::InvalidExternalUrl(ExternalUrlError::MissingHost));
124 }
125
126 if url.username().is_empty() {
127 return Err(Error::InvalidExternalUrl(ExternalUrlError::MissingUsername));
128 }
129
130 Ok(Some(url))
131 }
132 Err(_) => Ok(None),
133 }
134}
135
136pub fn run_psql_command(superuser_url: &Url, database: &str, sql: &str) -> Result<(), Error> {
138 let psql_binary = which::which("psql").unwrap_or_else(|_| "psql".into());
139 let username = superuser_url.username();
140 let password = superuser_url.password().unwrap_or_default();
141 let host = superuser_url.host_str().expect("URL must have a host");
142 let port = superuser_url.port().unwrap_or(5432);
143
144 let status = process::Command::new(&psql_binary)
145 .arg("-h")
146 .arg(host)
147 .arg("-p")
148 .arg(port.to_string())
149 .arg("-U")
150 .arg(username)
151 .arg("-d")
152 .arg(database)
153 .arg("-c")
154 .arg(sql)
155 .env("PGPASSWORD", password)
156 .status()
157 .map_err(Error::RunPsql)?;
158
159 if !status.success() {
160 return Err(Error::PsqlFailed(status));
161 }
162
163 Ok(())
164}
165
166pub fn create_user_and_database(
168 superuser_url: &Url,
169 db_name: &str,
170 db_user: &str,
171 db_pw: &str,
172) -> Result<(), Error> {
173 run_psql_command(
175 superuser_url,
176 "postgres",
177 &format!(
178 "CREATE ROLE {} LOGIN ENCRYPTED PASSWORD {};",
179 escape_ident(db_user),
180 escape_string(db_pw)
181 ),
182 )?;
183
184 run_psql_command(
186 superuser_url,
187 "postgres",
188 &format!(
189 "CREATE DATABASE {} OWNER {};",
190 escape_ident(db_name),
191 escape_ident(db_user)
192 ),
193 )?;
194
195 Ok(())
196}
197
198fn create_fixture_db(superuser_url: &Url) -> Result<Url, Error> {
200 let random_id = generate_random_string();
202 let db_name = format!("fixture_db_{}", random_id);
203 let db_user = format!("fixture_user_{}", random_id);
204 let db_pw = format!("fixture_pass_{}", random_id);
205
206 create_user_and_database(superuser_url, &db_name, &db_user, &db_pw)?;
208
209 let mut url = superuser_url.clone();
211 url.set_username(&db_user).expect("Failed to set username");
212 url.set_password(Some(&db_pw))
213 .expect("Failed to set password");
214 url.set_path(&db_name);
215
216 Ok(url)
217}
218
219pub fn db_fixture() -> DbUrl {
235 if let Some(external_url) = parse_external_test_url().expect("invalid PGDB_TESTS_URL") {
237 let url = create_fixture_db(&external_url).expect("failed to create external fixture DB");
238 return DbUrl::External {
239 url,
240 superuser_url: external_url,
241 };
242 }
243
244 static DB: Mutex<Weak<Postgres>> = Mutex::new(Weak::new());
245
246 let pg = {
247 let mut guard = DB.lock().expect("lock poisoned");
248 if let Some(arc) = guard.upgrade() {
249 arc
251 } else {
252 let arc = Arc::new(
253 Postgres::build()
254 .start()
255 .expect("failed to start global postgres DB"),
256 );
257 *guard = Arc::downgrade(&arc);
258 arc
259 }
260 };
261
262 let url = create_fixture_db(pg.superuser_url()).expect("failed to create local fixture DB");
264 DbUrl::Local { _arc: pg, url }
265}
266
267fn find_unused_port() -> io::Result<u16> {
273 let listener = TcpListener::bind("127.0.0.1:0")?;
274 let port = listener.local_addr()?.port();
275 Ok(port)
276}
277
278#[derive(Debug)]
283pub struct Postgres {
284 superuser_url: Url,
286 #[allow(dead_code)] instance: ProcessGuard,
289 psql_binary: path::PathBuf,
291 #[allow(dead_code)] tmp_dir: tempfile::TempDir,
294}
295
296#[derive(Debug)]
300pub struct PostgresClient<'a> {
301 instance: &'a Postgres,
302 client_url: Url,
304}
305
306#[derive(Debug)]
310pub struct PostgresBuilder {
311 data_dir: Option<path::PathBuf>,
313 port: Option<u16>,
317 host: String,
319 superuser: String,
321 superuser_pw: String,
323 postgres_binary: Option<path::PathBuf>,
325 initdb_binary: Option<path::PathBuf>,
327 psql_binary: Option<path::PathBuf>,
329 probe_delay: Duration,
331 startup_timeout: Duration,
333}
334
335#[derive(Debug, Error)]
337pub enum ExternalUrlError {
338 #[error("invalid URL: {0}")]
340 ParseError(#[source] url::ParseError),
341 #[error("must use postgres:// scheme")]
343 InvalidScheme,
344 #[error("must include a host")]
346 MissingHost,
347 #[error("must include a username")]
349 MissingUsername,
350}
351
352#[derive(Debug, Error)]
354pub enum Error {
355 #[error("could not find `postgres` binary")]
356 FindPostgres(which::Error),
357 #[error("could not find `initdb` binary")]
359 FindInitdb(which::Error),
360 #[error("could not find `psql` binary")]
362 FindPsql(which::Error),
363 #[error("could not create temporary directory for database")]
365 CreateDatabaseDir(io::Error),
366 #[error("error writing temporary password")]
368 WriteTemporaryPw(io::Error),
369 #[error("failed to run `initdb`")]
371 RunInitDb(io::Error),
372 #[error("`initdb` exited with status {}", 0)]
374 InitDbFailed(process::ExitStatus),
375 #[error("failed to launch `postgres`")]
377 LaunchPostgres(io::Error),
378 #[error("timeout probing tcp socket")]
380 StartupTimeout,
381 #[error("failed to run `psql`")]
383 RunPsql(io::Error),
384 #[error("`psql` exited with status {}", 0)]
386 PsqlFailed(process::ExitStatus),
387 #[error("invalid PGDB_TESTS_URL")]
389 InvalidExternalUrl(#[source] ExternalUrlError),
390}
391
392impl Postgres {
393 #[inline]
395 pub fn build() -> PostgresBuilder {
396 PostgresBuilder {
397 data_dir: None,
398 port: None,
399 host: "127.0.0.1".to_string(),
400 superuser: "postgres".to_string(),
401 superuser_pw: generate_random_string(),
402 postgres_binary: None,
403 initdb_binary: None,
404 psql_binary: None,
405 probe_delay: Duration::from_millis(100),
406 startup_timeout: Duration::from_secs(10),
407 }
408 }
409
410 #[inline]
412 pub fn as_superuser(&self) -> PostgresClient<'_> {
413 PostgresClient {
414 instance: self,
415 client_url: self.superuser_url.clone(),
416 }
417 }
418
419 #[inline]
421 pub fn as_user(&self, username: &str, password: &str) -> PostgresClient<'_> {
422 let mut client_url = self.superuser_url.clone();
423 client_url
424 .set_username(username)
425 .expect("Failed to set username");
426 client_url
427 .set_password(Some(password))
428 .expect("Failed to set password");
429 PostgresClient {
430 instance: self,
431 client_url,
432 }
433 }
434
435 pub fn superuser_url(&self) -> &Url {
437 &self.superuser_url
438 }
439}
440
441impl<'a> PostgresClient<'a> {
442 pub fn psql(&self, database: &str) -> process::Command {
447 let mut cmd = process::Command::new(&self.instance.psql_binary);
448
449 let username = self.client_url.username();
450 let password = self.client_url.password().unwrap_or_default();
451
452 let host = self
453 .client_url
454 .host_str()
455 .expect("Client URL must have a host");
456 let port = self.client_url.port().expect("Client URL must have a port");
457
458 cmd.arg("-h")
459 .arg(host)
460 .arg("-p")
461 .arg(port.to_string())
462 .arg("-U")
463 .arg(username)
464 .arg("-d")
465 .arg(database)
466 .env("PGPASSWORD", password);
467
468 cmd
469 }
470
471 pub fn load_sql<P: AsRef<path::Path>>(&self, database: &str, filename: P) -> Result<(), Error> {
473 let status = self
474 .psql(database)
475 .arg("-f")
476 .arg(filename.as_ref())
477 .status()
478 .map_err(Error::RunPsql)?;
479
480 if !status.success() {
481 return Err(Error::PsqlFailed(status));
482 }
483
484 Ok(())
485 }
486
487 pub fn run_sql(&self, database: &str, sql: &str) -> Result<(), Error> {
489 let status = self
490 .psql(database)
491 .arg("-c")
492 .arg(sql)
493 .status()
494 .map_err(Error::RunPsql)?;
495
496 if !status.success() {
497 return Err(Error::PsqlFailed(status));
498 }
499
500 Ok(())
501 }
502
503 #[inline]
507 pub fn create_database(&self, database: &str, owner: &str) -> Result<(), Error> {
508 self.run_sql(
509 "postgres",
510 &format!(
511 "CREATE DATABASE {} OWNER {};",
512 escape_ident(database),
513 escape_ident(owner)
514 ),
515 )
516 }
517
518 #[inline]
522 pub fn create_user(&self, username: &str, password: &str) -> Result<(), Error> {
523 self.run_sql(
524 "postgres",
525 &format!(
526 "CREATE ROLE {} LOGIN ENCRYPTED PASSWORD {};",
527 escape_ident(username),
528 escape_string(password)
529 ),
530 )
531 }
532
533 #[inline]
535 pub fn instance(&self) -> &Postgres {
536 self.instance
537 }
538
539 pub fn url(&self, database: &str) -> Url {
541 let mut url = self.client_url.clone();
542 url.set_path(database);
543 url
544 }
545
546 pub fn client_url(&self) -> &Url {
548 &self.client_url
549 }
550}
551
552impl PostgresBuilder {
553 #[inline]
557 pub fn data_dir<T: Into<path::PathBuf>>(&mut self, data_dir: T) -> &mut Self {
558 self.data_dir = Some(data_dir.into());
559 self
560 }
561
562 #[inline]
564 pub fn initdb_binary<T: Into<path::PathBuf>>(&mut self, initdb_binary: T) -> &mut Self {
565 self.initdb_binary = Some(initdb_binary.into());
566 self
567 }
568
569 #[inline]
571 pub fn host(&mut self, host: String) -> &mut Self {
572 self.host = host;
573 self
574 }
575
576 #[inline]
582 pub fn port(&mut self, port: u16) -> &mut Self {
583 self.port = Some(port);
584 self
585 }
586
587 #[inline]
589 pub fn postgres_binary<T: Into<path::PathBuf>>(&mut self, postgres_binary: T) -> &mut Self {
590 self.postgres_binary = Some(postgres_binary.into());
591 self
592 }
593
594 #[inline]
598 pub fn probe_delay(&mut self, probe_delay: Duration) -> &mut Self {
599 self.probe_delay = probe_delay;
600 self
601 }
602
603 #[inline]
605 pub fn psql_binary<T: Into<path::PathBuf>>(&mut self, psql_binary: T) -> &mut Self {
606 self.psql_binary = Some(psql_binary.into());
607 self
608 }
609
610 #[inline]
612 pub fn startup_timeout(&mut self, startup_timeout: Duration) -> &mut Self {
613 self.startup_timeout = startup_timeout;
614 self
615 }
616
617 #[inline]
619 pub fn superuser_pw<T: Into<String>>(&mut self, superuser_pw: T) -> &mut Self {
620 self.superuser_pw = superuser_pw.into();
621 self
622 }
623
624 pub fn start(&self) -> Result<Postgres, Error> {
629 let port = self
630 .port
631 .unwrap_or_else(|| find_unused_port().expect("failed to find an unused port"));
632
633 let postgres_binary = self
634 .postgres_binary
635 .clone()
636 .map(Ok)
637 .unwrap_or_else(|| which::which("postgres").map_err(Error::FindPostgres))?;
638 let initdb_binary = self
639 .initdb_binary
640 .clone()
641 .map(Ok)
642 .unwrap_or_else(|| which::which("initdb").map_err(Error::FindInitdb))?;
643 let psql_binary = self
644 .psql_binary
645 .clone()
646 .map(Ok)
647 .unwrap_or_else(|| which::which("psql").map_err(Error::FindPsql))?;
648
649 let tmp_dir = tempfile::tempdir().map_err(Error::CreateDatabaseDir)?;
650 let data_dir = self
651 .data_dir
652 .clone()
653 .unwrap_or_else(|| tmp_dir.path().join("db"));
654
655 let superuser_pw_file = tmp_dir.path().join("superuser-pw");
656 fs::write(&superuser_pw_file, self.superuser_pw.as_bytes())
657 .map_err(Error::WriteTemporaryPw)?;
658
659 let initdb_status = process::Command::new(initdb_binary)
660 .args([
661 "--no-locale",
663 "--auth=md5",
665 "--encoding=UTF8",
667 "--nosync",
669 "--pgdata",
671 ])
672 .arg(&data_dir)
673 .arg("--pwfile")
674 .arg(&superuser_pw_file)
675 .arg("--username")
676 .arg(&self.superuser)
677 .status()
678 .map_err(Error::RunInitDb)?;
679
680 if !initdb_status.success() {
681 return Err(Error::InitDbFailed(initdb_status));
682 }
683
684 let mut postgres_command = process::Command::new(postgres_binary);
686 postgres_command
687 .arg("-D")
688 .arg(&data_dir)
689 .arg("-p")
690 .arg(port.to_string())
691 .arg("-k")
692 .arg(tmp_dir.path());
693
694 let instance = ProcessGuard::spawn_graceful(&mut postgres_command, Duration::from_secs(5))
695 .map_err(Error::LaunchPostgres)?;
696
697 let socket_addr = format!("{}:{}", self.host, port);
699 let started = Instant::now();
700 loop {
701 match net::TcpStream::connect(socket_addr.as_str()) {
702 Ok(_) => break,
703 Err(_) => {
704 let now = Instant::now();
705
706 if now.duration_since(started) >= self.startup_timeout {
707 return Err(Error::StartupTimeout);
708 }
709
710 thread::sleep(self.probe_delay);
711 }
712 }
713 }
714
715 let superuser_url = Url::parse(&format!(
716 "postgres://{}:{}@{}:{}",
717 self.superuser, self.superuser_pw, self.host, port
718 ))
719 .expect("Failed to construct base URL");
720
721 Ok(Postgres {
722 superuser_url,
723 instance,
724 psql_binary,
725 tmp_dir,
726 })
727 }
728}
729
730fn generate_random_string() -> String {
732 let raw: [u8; 16] = OsRng.gen();
733 format!("{:x}", hex_fmt::HexFmt(&raw))
734}
735
736fn quote(quote_char: char, unescaped: &str) -> String {
739 let mut result = String::new();
740
741 result.push(quote_char);
742 for c in unescaped.chars() {
743 if c == quote_char {
744 result.push(quote_char);
745 result.push(quote_char);
746 } else {
747 result.push(c);
748 }
749 }
750 result.push(quote_char);
751
752 result
753}
754
755fn escape_ident(unescaped: &str) -> String {
757 quote('"', unescaped)
758}
759
760fn escape_string(unescaped: &str) -> String {
762 quote('\'', unescaped)
763}
764
765#[cfg(test)]
766mod tests {
767 use super::Postgres;
768
769 #[test]
770 fn can_change_superuser_pw() {
771 let pg = Postgres::build()
772 .superuser_pw("helloworld")
773 .start()
774 .expect("could not build postgres database");
775
776 let su = pg.as_superuser();
777 su.create_user("foo", "bar")
778 .expect("could not create normal user");
779
780 assert_eq!(su.client_url().password(), Some("helloworld"));
782 }
783
784 #[test]
785 fn instances_use_different_port_by_default() {
786 let a = Postgres::build()
787 .start()
788 .expect("could not build postgres database");
789 let b = Postgres::build()
790 .start()
791 .expect("could not build postgres database");
792 let c = Postgres::build()
793 .start()
794 .expect("could not build postgres database");
795
796 assert_ne!(
797 a.superuser_url().port().expect("URL must have a port"),
798 b.superuser_url().port().expect("URL must have a port")
799 );
800 assert_ne!(
801 a.superuser_url().port().expect("URL must have a port"),
802 c.superuser_url().port().expect("URL must have a port")
803 );
804 assert_ne!(
805 b.superuser_url().port().expect("URL must have a port"),
806 c.superuser_url().port().expect("URL must have a port")
807 );
808 }
809
810 #[test]
811 fn ensure_proper_db_reuse_when_using_fixtures() {
812 let db_url = crate::db_fixture();
813 let db_url2 = crate::db_fixture();
814
815 match (&db_url, &db_url2) {
816 (crate::DbUrl::Local { .. }, crate::DbUrl::Local { .. }) => {
817 assert!(db_url.as_str().contains("fixture_user_"));
819 assert!(db_url.as_str().contains("fixture_pass_"));
820 assert!(db_url.as_str().contains("fixture_db_"));
821
822 assert!(db_url2.as_str().contains("fixture_user_"));
823 assert!(db_url2.as_str().contains("fixture_pass_"));
824 assert!(db_url2.as_str().contains("fixture_db_"));
825
826 assert_ne!(db_url.as_str(), db_url2.as_str());
828 }
829 (crate::DbUrl::External { .. }, crate::DbUrl::External { .. }) => {
830 assert!(db_url.as_str().contains("fixture_user_"));
832 assert!(db_url.as_str().contains("fixture_pass_"));
833 assert!(db_url.as_str().contains("fixture_db_"));
834
835 assert!(db_url2.as_str().contains("fixture_user_"));
836 assert!(db_url2.as_str().contains("fixture_pass_"));
837 assert!(db_url2.as_str().contains("fixture_db_"));
838
839 assert_ne!(db_url.as_str(), db_url2.as_str());
841
842 assert_eq!(db_url.as_url().host_str(), db_url2.as_url().host_str());
844 assert_eq!(db_url.as_url().port(), db_url2.as_url().port());
845 }
846 _ => panic!("Inconsistent DbUrl types returned from db_fixture"),
847 }
848 }
849
850 #[test]
851 fn external_db_cleanup_on_drop() {
852 if crate::parse_external_test_url().unwrap().is_none() {
854 return;
855 }
856
857 let superuser_url = crate::parse_external_test_url().unwrap().unwrap();
858 let psql_binary = which::which("psql").unwrap_or_else(|_| "psql".into());
859
860 let (db_name, db_user) = {
862 let db_url = crate::db_fixture();
863
864 match &db_url {
866 crate::DbUrl::External { url, .. } => {
867 let db_name = url.path().trim_start_matches('/').to_string();
868 let db_user = url.username().to_string();
869 (db_name, db_user)
870 }
871 _ => panic!("Expected external database"),
872 }
873 }; std::thread::sleep(std::time::Duration::from_millis(100));
877
878 let check_db_exists = |name: &str| -> bool {
880 let username = superuser_url.username();
881 let password = superuser_url.password().unwrap_or_default();
882 let host = superuser_url.host_str().unwrap();
883 let port = superuser_url.port().unwrap_or(5432);
884
885 let output = std::process::Command::new(&psql_binary)
886 .arg("-h")
887 .arg(host)
888 .arg("-p")
889 .arg(port.to_string())
890 .arg("-U")
891 .arg(username)
892 .arg("-d")
893 .arg("postgres")
894 .arg("-t")
895 .arg("-c")
896 .arg(format!(
897 "SELECT 1 FROM pg_database WHERE datname = '{}'",
898 name
899 ))
900 .env("PGPASSWORD", password)
901 .output()
902 .expect("Failed to check database existence");
903
904 String::from_utf8_lossy(&output.stdout).trim() == "1"
905 };
906
907 let check_user_exists = |name: &str| -> bool {
909 let username = superuser_url.username();
910 let password = superuser_url.password().unwrap_or_default();
911 let host = superuser_url.host_str().unwrap();
912 let port = superuser_url.port().unwrap_or(5432);
913
914 let output = std::process::Command::new(&psql_binary)
915 .arg("-h")
916 .arg(host)
917 .arg("-p")
918 .arg(port.to_string())
919 .arg("-U")
920 .arg(username)
921 .arg("-d")
922 .arg("postgres")
923 .arg("-t")
924 .arg("-c")
925 .arg(format!("SELECT 1 FROM pg_roles WHERE rolname = '{}'", name))
926 .env("PGPASSWORD", password)
927 .output()
928 .expect("Failed to check user existence");
929
930 String::from_utf8_lossy(&output.stdout).trim() == "1"
931 };
932
933 assert!(
935 !check_db_exists(&db_name),
936 "Database should have been dropped"
937 );
938 assert!(
939 !check_user_exists(&db_user),
940 "User should have been dropped"
941 );
942 }
943}