use futures::{TryFutureExt};
use std::process::{Command, Child};
use crate::fetch;
use crate::errors::PgEmbedError;
#[cfg(any(feature = "rt_tokio", feature = "rt_tokio_migrate"))]
use tokio::io::AsyncWriteExt;
use crate::errors::PgEmbedError::PgCleanUpFailure;
#[cfg(feature = "rt_tokio_migrate")]
use sqlx_tokio::{Postgres};
use std::time::Duration;
#[cfg(any(feature = "rt_tokio", feature = "rt_tokio_migrate"))]
use tokio::time::sleep;
#[cfg(feature = "rt_tokio_migrate")]
use sqlx_tokio::migrate::{Migrator, MigrateDatabase};
use std::path::PathBuf;
pub struct PgSettings {
pub executables_dir: PathBuf,
pub database_dir: PathBuf,
pub port: i16,
pub user: String,
pub password: String,
pub auth_method: PgAuthMethod,
pub persistent: bool,
pub start_timeout: Duration,
pub migration_dir: Option<PathBuf>,
}
pub enum PgAuthMethod {
Plain,
MD5,
ScramSha256,
}
pub struct PgEmbed {
pub pg_settings: PgSettings,
pub fetch_settings: fetch::FetchSettings,
pub process: Option<Child>,
pub db_uri: String,
}
impl Drop for PgEmbed {
fn drop(&mut self) {
let _ = &self.stop_db();
if !&self.pg_settings.persistent {
let _ = &self.clean();
}
}
}
impl PgEmbed {
pub fn new(pg_settings: PgSettings, fetch_settings: fetch::FetchSettings) -> Self {
let password: &str = &pg_settings.password;
let db_uri = format!(
"postgres://{}:{}@localhost:{}",
&pg_settings.user,
&password,
&pg_settings.port
);
PgEmbed {
pg_settings,
fetch_settings,
process: None,
db_uri,
}
}
pub fn clean(&self) -> Result<(), PgEmbedError> {
let exec_dir = self.pg_settings.executables_dir.to_str().unwrap();
let bin_dir = format!("{}/bin", exec_dir);
let lib_dir = format!("{}/lib", exec_dir);
let share_dir = format!("{}/share", exec_dir);
let pw_file = format!("{}/pwfile", exec_dir);
std::fs::remove_dir_all(&self.pg_settings.database_dir).map_err(|e| PgCleanUpFailure(e))?;
std::fs::remove_dir_all(bin_dir).map_err(|e| PgCleanUpFailure(e))?;
std::fs::remove_dir_all(lib_dir).map_err(|e| PgCleanUpFailure(e))?;
std::fs::remove_dir_all(share_dir).map_err(|e| PgCleanUpFailure(e))?;
std::fs::remove_file(pw_file).map_err(|e| PgCleanUpFailure(e))?;
Ok(())
}
pub async fn setup(&self) -> Result<(), PgEmbedError> {
&self.aquire_postgres().await?;
&self.create_password_file().await?;
&self.init_db().await?;
Ok(())
}
pub async fn aquire_postgres(&self) -> Result<(), PgEmbedError> {
let pg_file = fetch::fetch_postgres(&self.fetch_settings, &self.pg_settings.executables_dir).await?;
fetch::unpack_postgres(&pg_file, &self.pg_settings.executables_dir).await
}
pub async fn init_db(&self) -> Result<bool, PgEmbedError> {
let database_path = self.pg_settings.database_dir.as_path();
if !database_path.is_dir() {
let init_db_executable = format!("{}/bin/initdb", &self.pg_settings.executables_dir.to_str().unwrap());
let password_file_arg = format!("--pwfile={}/pwfile", &self.pg_settings.executables_dir.to_str().unwrap());
let auth_host =
match &self.pg_settings.auth_method {
PgAuthMethod::Plain => {
"password"
}
PgAuthMethod::MD5 => {
"md5"
}
PgAuthMethod::ScramSha256 => {
"scram-sha-256"
}
};
Command::new(init_db_executable).args(&[
"-A",
auth_host,
"-U",
&self.pg_settings.user,
"-D",
&self.pg_settings.database_dir.to_str().unwrap(),
&password_file_arg,
]).spawn().map_err(|e| PgEmbedError::PgInitFailure(e))?;
sleep(self.pg_settings.start_timeout).await;
Ok(true)
} else {
Ok(false)
}
}
pub async fn start_db(&mut self) -> Result<(), PgEmbedError> {
let pg_ctl_executable = format!("{}/bin/pg_ctl", &self.pg_settings.executables_dir.to_str().unwrap());
let port_arg = format!("-F -p {}", &self.pg_settings.port.to_string());
let process = Command::new(
pg_ctl_executable,
)
.args(&[
"-o", &port_arg, "start", "-w", "-D", &self.pg_settings.database_dir.to_str().unwrap()
])
.spawn().map_err(|e| PgEmbedError::PgStartFailure(e))?;
self.process = Some(process);
sleep(self.pg_settings.start_timeout).await;
Ok(())
}
pub fn stop_db(&mut self) -> Result<(), PgEmbedError> {
let pg_ctl_executable = format!("{}/bin/pg_ctl", &self.pg_settings.executables_dir.to_str().unwrap());
let mut process = Command::new(
pg_ctl_executable,
)
.args(&[
"stop", "-w", "-D", &self.pg_settings.database_dir.to_str().unwrap(),
])
.spawn().map_err(|e| PgEmbedError::PgStopFailure(e))?;
match process.try_wait() {
Ok(Some(status)) => {
println!("postgresql stopped");
self.process = None;
Ok(())
}
Ok(None) => {
println!("... waiting for postgresql to stop");
let res = process.wait();
println!("result: {:?}", res);
Ok(())
}
Err(e) => Err(PgEmbedError::PgStopFailure(e)),
}
}
pub async fn create_password_file(&self) -> Result<(), PgEmbedError> {
let mut file_path = self.pg_settings.executables_dir.clone();
file_path.push("pwfile");
let mut file: tokio::fs::File = tokio::fs::File::create(&file_path.as_path()).map_err(|e| PgEmbedError::WriteFileError(e)).await?;
let _ = file
.write(&self.pg_settings.password.as_bytes()).map_err(|e| PgEmbedError::WriteFileError(e))
.await?;
Ok(())
}
#[cfg(any(feature = "rt_tokio_migrate", feature = "rt_async_std_migrate", feature = "rt_actix_migrate"))]
pub async fn create_database(&self, db_name: &str) -> Result<(), PgEmbedError> {
Postgres::create_database(&self.full_db_uri(db_name)).await?;
Ok(())
}
#[cfg(any(feature = "rt_tokio_migrate", feature = "rt_async_std_migrate", feature = "rt_actix_migrate"))]
pub async fn drop_database(&self, db_name: &str) -> Result<(), PgEmbedError> {
Postgres::drop_database(&self.full_db_uri(db_name)).await?;
Ok(())
}
#[cfg(any(feature = "rt_tokio_migrate", feature = "rt_async_std_migrate", feature = "rt_actix_migrate"))]
pub async fn database_exists(&self, db_name: &str) -> Result<bool, PgEmbedError> {
let result = Postgres::database_exists(&self.full_db_uri(db_name)).await?;
Ok(result)
}
pub fn full_db_uri(&self, db_name: &str) -> String {
format!("{}/{}", &self.db_uri, db_name)
}
#[cfg(any(feature = "rt_tokio_migrate", feature = "rt_async_std_migrate", feature = "rt_actix_migrate"))]
pub async fn migrate(&self, db_name: &str) -> Result<(), PgEmbedError> {
if let Some(migration_dir) = &self.pg_settings.migration_dir {
let m = Migrator::new(std::path::Path::new(migration_dir)).await?;
let pool = sqlx_tokio::postgres::PgPoolOptions::new().connect(&self.full_db_uri(db_name)).await?;
m.run(&pool).await?;
}
Ok(())
}
}