#![doc = include_str!("../README.md")]
mod db_instance;
mod error;
use std::{
env, fs, io,
net::TcpListener,
path, process, thread,
time::{Duration, Instant},
};
pub use db_instance::{db_fixture, DbInstance};
pub use error::{Error, ExternalUrlError};
use process_guard::ProcessGuard;
use rand::{rngs::OsRng, Rng};
use url::Url;
pub fn run_psql_command(superuser_url: &Url, database: &str, sql: &str) -> Result<(), Error> {
let psql_binary = which::which("psql").unwrap_or_else(|_| "psql".into());
let username = superuser_url.username();
let password = superuser_url.password().unwrap_or_default();
let host = superuser_url.host_str().expect("URL must have a host");
let port = superuser_url.port().unwrap_or(5432);
let status = process::Command::new(&psql_binary)
.arg("-h")
.arg(host)
.arg("-p")
.arg(port.to_string())
.arg("-U")
.arg(username)
.arg("-d")
.arg(database)
.arg("-c")
.arg(sql)
.env("PGPASSWORD", password)
.status()
.map_err(Error::RunPsql)?;
if !status.success() {
return Err(Error::PsqlFailed(status));
}
Ok(())
}
pub fn create_user_and_database(
superuser_url: &Url,
db_name: &str,
db_user: &str,
db_pw: &str,
) -> Result<(), Error> {
run_psql_command(
superuser_url,
"postgres",
&format!(
"CREATE ROLE {} LOGIN ENCRYPTED PASSWORD {};",
escape_ident(db_user),
escape_string(db_pw)
),
)?;
run_psql_command(
superuser_url,
"postgres",
&format!(
"CREATE DATABASE {} OWNER {};",
escape_ident(db_name),
escape_ident(db_user)
),
)?;
Ok(())
}
fn create_fixture_db(superuser_url: &Url) -> Result<Url, Error> {
let random_id = generate_random_string();
let db_name = format!("fixture_db_{}", random_id);
let db_user = format!("fixture_user_{}", random_id);
let db_pw = format!("fixture_pass_{}", random_id);
create_user_and_database(superuser_url, &db_name, &db_user, &db_pw)?;
let mut url = superuser_url.clone();
url.set_username(&db_user).expect("Failed to set username");
url.set_password(Some(&db_pw))
.expect("Failed to set password");
url.set_path(&db_name);
Ok(url)
}
fn find_unused_port() -> io::Result<u16> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
Ok(port)
}
#[derive(Debug)]
pub struct Postgres {
superuser_url: Url,
#[allow(dead_code)] instance: ProcessGuard,
psql_binary: path::PathBuf,
#[allow(dead_code)] tmp_dir: tempfile::TempDir,
}
#[derive(Debug)]
pub struct PostgresClient<'a> {
instance: &'a Postgres,
client_url: Url,
}
#[derive(Debug)]
pub struct PostgresBuilder {
data_dir: Option<path::PathBuf>,
port: Option<u16>,
host: String,
superuser: String,
superuser_pw: String,
postgres_binary: Option<path::PathBuf>,
initdb_binary: Option<path::PathBuf>,
pg_isready_binary: Option<path::PathBuf>,
psql_binary: Option<path::PathBuf>,
probe_delay: Duration,
startup_timeout: Duration,
}
impl Postgres {
#[inline]
pub fn build() -> PostgresBuilder {
PostgresBuilder {
data_dir: None,
port: None,
host: "127.0.0.1".to_string(),
superuser: "postgres".to_string(),
superuser_pw: generate_random_string(),
postgres_binary: None,
initdb_binary: None,
pg_isready_binary: None,
psql_binary: None,
probe_delay: Duration::from_millis(100),
startup_timeout: Duration::from_secs(10),
}
}
#[inline]
pub fn as_superuser(&self) -> PostgresClient<'_> {
PostgresClient {
instance: self,
client_url: self.superuser_url.clone(),
}
}
#[inline]
pub fn as_user(&self, username: &str, password: &str) -> PostgresClient<'_> {
let mut client_url = self.superuser_url.clone();
client_url
.set_username(username)
.expect("Failed to set username");
client_url
.set_password(Some(password))
.expect("Failed to set password");
PostgresClient {
instance: self,
client_url,
}
}
pub fn superuser_url(&self) -> &Url {
&self.superuser_url
}
}
impl<'a> PostgresClient<'a> {
pub fn psql(&self, database: &str) -> process::Command {
let mut cmd = process::Command::new(&self.instance.psql_binary);
let username = self.client_url.username();
let password = self.client_url.password().unwrap_or_default();
let host = self
.client_url
.host_str()
.expect("Client URL must have a host");
let port = self.client_url.port().expect("Client URL must have a port");
cmd.arg("-h")
.arg(host)
.arg("-p")
.arg(port.to_string())
.arg("-U")
.arg(username)
.arg("-d")
.arg(database)
.env("PGPASSWORD", password);
cmd
}
pub fn load_sql<P: AsRef<path::Path>>(&self, database: &str, filename: P) -> Result<(), Error> {
let status = self
.psql(database)
.arg("-f")
.arg(filename.as_ref())
.status()
.map_err(Error::RunPsql)?;
if !status.success() {
return Err(Error::PsqlFailed(status));
}
Ok(())
}
pub fn run_sql(&self, database: &str, sql: &str) -> Result<(), Error> {
let status = self
.psql(database)
.arg("-c")
.arg(sql)
.status()
.map_err(Error::RunPsql)?;
if !status.success() {
return Err(Error::PsqlFailed(status));
}
Ok(())
}
#[inline]
pub fn create_database(&self, database: &str, owner: &str) -> Result<(), Error> {
self.run_sql(
"postgres",
&format!(
"CREATE DATABASE {} OWNER {};",
escape_ident(database),
escape_ident(owner)
),
)
}
#[inline]
pub fn create_user(&self, username: &str, password: &str) -> Result<(), Error> {
self.run_sql(
"postgres",
&format!(
"CREATE ROLE {} LOGIN ENCRYPTED PASSWORD {};",
escape_ident(username),
escape_string(password)
),
)
}
#[inline]
pub fn instance(&self) -> &Postgres {
self.instance
}
pub fn url(&self, database: &str) -> Url {
let mut url = self.client_url.clone();
url.set_path(database);
url
}
pub fn client_url(&self) -> &Url {
&self.client_url
}
}
impl PostgresBuilder {
#[inline]
pub fn data_dir<T: Into<path::PathBuf>>(&mut self, data_dir: T) -> &mut Self {
self.data_dir = Some(data_dir.into());
self
}
#[inline]
pub fn initdb_binary<T: Into<path::PathBuf>>(&mut self, initdb_binary: T) -> &mut Self {
self.initdb_binary = Some(initdb_binary.into());
self
}
#[inline]
pub fn pg_isready_binary<T: Into<path::PathBuf>>(&mut self, pg_isready_binary: T) -> &mut Self {
self.pg_isready_binary = Some(pg_isready_binary.into());
self
}
#[inline]
pub fn host(&mut self, host: String) -> &mut Self {
self.host = host;
self
}
#[inline]
pub fn port(&mut self, port: u16) -> &mut Self {
self.port = Some(port);
self
}
#[inline]
pub fn postgres_binary<T: Into<path::PathBuf>>(&mut self, postgres_binary: T) -> &mut Self {
self.postgres_binary = Some(postgres_binary.into());
self
}
#[inline]
pub fn probe_delay(&mut self, probe_delay: Duration) -> &mut Self {
self.probe_delay = probe_delay;
self
}
#[inline]
pub fn psql_binary<T: Into<path::PathBuf>>(&mut self, psql_binary: T) -> &mut Self {
self.psql_binary = Some(psql_binary.into());
self
}
#[inline]
pub fn startup_timeout(&mut self, startup_timeout: Duration) -> &mut Self {
self.startup_timeout = startup_timeout;
self
}
#[inline]
pub fn superuser_pw<T: Into<String>>(&mut self, superuser_pw: T) -> &mut Self {
self.superuser_pw = superuser_pw.into();
self
}
pub fn start(&self) -> Result<Postgres, Error> {
let port = self
.port
.unwrap_or_else(|| find_unused_port().expect("failed to find an unused port"));
let postgres_binary = self
.postgres_binary
.clone()
.map(Ok)
.unwrap_or_else(|| which::which("postgres").map_err(Error::FindPostgres))?;
let initdb_binary = self
.initdb_binary
.clone()
.map(Ok)
.unwrap_or_else(|| which::which("initdb").map_err(Error::FindInitdb))?;
let pg_isready_binary = self
.pg_isready_binary
.clone()
.map(Ok)
.unwrap_or_else(|| which::which("pg_isready").map_err(Error::FindPgIsready))?;
let psql_binary = self
.psql_binary
.clone()
.map(Ok)
.unwrap_or_else(|| which::which("psql").map_err(Error::FindPsql))?;
let tmp_dir = tempfile::tempdir().map_err(Error::CreateDatabaseDir)?;
let data_dir = self
.data_dir
.clone()
.unwrap_or_else(|| tmp_dir.path().join("db"));
let superuser_pw_file = tmp_dir.path().join("superuser-pw");
fs::write(&superuser_pw_file, self.superuser_pw.as_bytes())
.map_err(Error::WriteTemporaryPw)?;
let initdb_status = process::Command::new(initdb_binary)
.args([
"--no-locale",
"--auth=md5",
"--encoding=UTF8",
"--nosync",
"--pgdata",
])
.arg(&data_dir)
.arg("--pwfile")
.arg(&superuser_pw_file)
.arg("--username")
.arg(&self.superuser)
.status()
.map_err(Error::RunInitDb)?;
if !initdb_status.success() {
return Err(Error::InitDbFailed(initdb_status));
}
let mut postgres_command = process::Command::new(postgres_binary);
postgres_command
.arg("-D")
.arg(&data_dir)
.arg("-p")
.arg(port.to_string())
.arg("-k")
.arg(tmp_dir.path());
let instance = ProcessGuard::spawn_graceful(&mut postgres_command, Duration::from_secs(5))
.map_err(Error::LaunchPostgres)?;
let started = Instant::now();
loop {
let status = process::Command::new(&pg_isready_binary)
.arg("-h")
.arg(&self.host)
.arg("-p")
.arg(port.to_string())
.stdout(process::Stdio::null())
.stderr(process::Stdio::null())
.status();
match status {
Ok(exit_status) if exit_status.success() => break,
_ => {
if started.elapsed() >= self.startup_timeout {
return Err(Error::StartupTimeout);
}
thread::sleep(self.probe_delay);
}
}
}
let superuser_url = Url::parse(&format!(
"postgres://{}:{}@{}:{}",
self.superuser, self.superuser_pw, self.host, port
))
.expect("Failed to construct base URL");
Ok(Postgres {
superuser_url,
instance,
psql_binary,
tmp_dir,
})
}
}
fn generate_random_string() -> String {
let raw: [u8; 16] = OsRng.gen();
format!("{:x}", hex_fmt::HexFmt(&raw))
}
fn quote(quote_char: char, unescaped: &str) -> String {
let mut result = String::new();
result.push(quote_char);
for c in unescaped.chars() {
if c == quote_char {
result.push(quote_char);
result.push(quote_char);
} else {
result.push(c);
}
}
result.push(quote_char);
result
}
fn escape_ident(unescaped: &str) -> String {
quote('"', unescaped)
}
fn escape_string(unescaped: &str) -> String {
quote('\'', unescaped)
}
pub fn parse_external_test_url() -> Result<Option<Url>, Error> {
match env::var("PGDB_TESTS_URL") {
Ok(url_str) => {
let url = Url::parse(&url_str)
.map_err(|e| Error::InvalidExternalUrl(ExternalUrlError::ParseError(e)))?;
if url.scheme() != "postgres" {
return Err(Error::InvalidExternalUrl(ExternalUrlError::InvalidScheme));
}
if url.host_str().is_none() {
return Err(Error::InvalidExternalUrl(ExternalUrlError::MissingHost));
}
if url.username().is_empty() {
return Err(Error::InvalidExternalUrl(ExternalUrlError::MissingUsername));
}
Ok(Some(url))
}
Err(_) => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::Postgres;
#[test]
fn can_change_superuser_pw() {
let pg = Postgres::build()
.superuser_pw("helloworld")
.start()
.expect("could not build postgres database");
let su = pg.as_superuser();
su.create_user("foo", "bar")
.expect("could not create normal user");
assert_eq!(su.client_url().password(), Some("helloworld"));
}
#[test]
fn instances_use_different_port_by_default() {
let a = Postgres::build()
.start()
.expect("could not build postgres database");
let b = Postgres::build()
.start()
.expect("could not build postgres database");
let c = Postgres::build()
.start()
.expect("could not build postgres database");
assert_ne!(
a.superuser_url().port().expect("URL must have a port"),
b.superuser_url().port().expect("URL must have a port")
);
assert_ne!(
a.superuser_url().port().expect("URL must have a port"),
c.superuser_url().port().expect("URL must have a port")
);
assert_ne!(
b.superuser_url().port().expect("URL must have a port"),
c.superuser_url().port().expect("URL must have a port")
);
}
#[test]
fn ensure_proper_db_reuse_when_using_fixtures() {
let db_url = crate::db_fixture();
let db_url2 = crate::db_fixture();
match (&db_url, &db_url2) {
(crate::DbInstance::Local { .. }, crate::DbInstance::Local { .. }) => {
assert!(db_url.as_str().contains("fixture_user_"));
assert!(db_url.as_str().contains("fixture_pass_"));
assert!(db_url.as_str().contains("fixture_db_"));
assert!(db_url2.as_str().contains("fixture_user_"));
assert!(db_url2.as_str().contains("fixture_pass_"));
assert!(db_url2.as_str().contains("fixture_db_"));
assert_ne!(db_url.as_str(), db_url2.as_str());
}
(crate::DbInstance::External { .. }, crate::DbInstance::External { .. }) => {
assert!(db_url.as_str().contains("fixture_user_"));
assert!(db_url.as_str().contains("fixture_pass_"));
assert!(db_url.as_str().contains("fixture_db_"));
assert!(db_url2.as_str().contains("fixture_user_"));
assert!(db_url2.as_str().contains("fixture_pass_"));
assert!(db_url2.as_str().contains("fixture_db_"));
assert_ne!(db_url.as_str(), db_url2.as_str());
assert_eq!(db_url.as_url().host_str(), db_url2.as_url().host_str());
assert_eq!(db_url.as_url().port(), db_url2.as_url().port());
}
_ => panic!("Inconsistent DbUrl types returned from db_fixture"),
}
}
#[test]
fn external_db_cleanup_on_drop() {
if crate::parse_external_test_url().unwrap().is_none() {
return;
}
let superuser_url = crate::parse_external_test_url().unwrap().unwrap();
let psql_binary = which::which("psql").unwrap_or_else(|_| "psql".into());
let (db_name, db_user) = {
let db_url = crate::db_fixture();
match &db_url {
crate::DbInstance::External { url, .. } => {
let db_name = url.path().trim_start_matches('/').to_string();
let db_user = url.username().to_string();
(db_name, db_user)
}
_ => panic!("Expected external database"),
}
};
std::thread::sleep(std::time::Duration::from_millis(100));
let check_db_exists = |name: &str| -> bool {
let username = superuser_url.username();
let password = superuser_url.password().unwrap_or_default();
let host = superuser_url.host_str().unwrap();
let port = superuser_url.port().unwrap_or(5432);
let output = std::process::Command::new(&psql_binary)
.arg("-h")
.arg(host)
.arg("-p")
.arg(port.to_string())
.arg("-U")
.arg(username)
.arg("-d")
.arg("postgres")
.arg("-t")
.arg("-c")
.arg(format!(
"SELECT 1 FROM pg_database WHERE datname = '{}'",
name
))
.env("PGPASSWORD", password)
.output()
.expect("Failed to check database existence");
String::from_utf8_lossy(&output.stdout).trim() == "1"
};
let check_user_exists = |name: &str| -> bool {
let username = superuser_url.username();
let password = superuser_url.password().unwrap_or_default();
let host = superuser_url.host_str().unwrap();
let port = superuser_url.port().unwrap_or(5432);
let output = std::process::Command::new(&psql_binary)
.arg("-h")
.arg(host)
.arg("-p")
.arg(port.to_string())
.arg("-U")
.arg(username)
.arg("-d")
.arg("postgres")
.arg("-t")
.arg("-c")
.arg(format!("SELECT 1 FROM pg_roles WHERE rolname = '{}'", name))
.env("PGPASSWORD", password)
.output()
.expect("Failed to check user existence");
String::from_utf8_lossy(&output.stdout).trim() == "1"
};
assert!(
!check_db_exists(&db_name),
"Database should have been dropped"
);
assert!(
!check_user_exists(&db_user),
"User should have been dropped"
);
}
}