use std::io::BufRead;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use log::{error, info};
use tokio::sync::Mutex;
#[cfg(feature = "rt_tokio_migrate")]
use sqlx::migrate::{MigrateDatabase, Migrator};
#[cfg(feature = "rt_tokio_migrate")]
use sqlx::postgres::PgPoolOptions;
#[cfg(feature = "rt_tokio_migrate")]
use sqlx::Postgres;
use crate::command_executor::AsyncCommand;
use crate::pg_access::PgAccess;
use crate::pg_commands::PgCommand;
use crate::pg_enums::{PgAuthMethod, PgServerStatus};
use crate::pg_errors::Error;
use crate::pg_errors::Result;
use crate::pg_fetch;
pub struct PgSettings {
pub database_dir: PathBuf,
pub port: u16,
pub user: String,
pub password: String,
pub auth_method: PgAuthMethod,
pub persistent: bool,
pub timeout: Option<Duration>,
pub migration_dir: Option<PathBuf>,
}
pub struct PgEmbed {
pub pg_settings: PgSettings,
pub fetch_settings: pg_fetch::PgFetchSettings,
pub db_uri: String,
pub server_status: Arc<Mutex<PgServerStatus>>,
pub shutting_down: bool,
pub pg_access: PgAccess,
}
impl Drop for PgEmbed {
fn drop(&mut self) {
if !self.shutting_down {
if let Err(e) = self.stop_db_sync() {
log::warn!("pg_ctl stop failed during drop: {e}");
}
}
if !self.pg_settings.persistent {
if let Err(e) = self.pg_access.clean() {
log::warn!("cleanup failed during drop: {e}");
}
}
}
}
impl PgEmbed {
pub async fn new(
pg_settings: PgSettings,
fetch_settings: pg_fetch::PgFetchSettings,
) -> Result<Self> {
let db_uri = format!(
"postgres://{}:{}@localhost:{}",
&pg_settings.user, &pg_settings.password, pg_settings.port
);
let pg_access = PgAccess::new(&fetch_settings, &pg_settings.database_dir).await?;
Ok(PgEmbed {
pg_settings,
fetch_settings,
db_uri,
server_status: Arc::new(Mutex::new(PgServerStatus::Uninitialized)),
shutting_down: false,
pg_access,
})
}
pub async fn setup(&mut self) -> Result<()> {
self.pg_access.maybe_acquire_postgres().await?;
self.pg_access
.create_password_file(self.pg_settings.password.as_bytes())
.await?;
if self.pg_access.db_files_exist().await? {
let mut server_status = self.server_status.lock().await;
*server_status = PgServerStatus::Initialized;
} else {
self.init_db().await?;
}
Ok(())
}
pub async fn install_extension(&self, extension_dir: &Path) -> Result<()> {
self.pg_access.install_extension(extension_dir).await
}
pub async fn init_db(&mut self) -> Result<()> {
{
let mut server_status = self.server_status.lock().await;
*server_status = PgServerStatus::Initializing;
}
let mut executor = PgCommand::init_db_executor(
&self.pg_access.init_db_exe,
&self.pg_access.database_dir,
&self.pg_access.pw_file_path,
&self.pg_settings.user,
&self.pg_settings.auth_method,
)?;
let exit_status = executor.execute(self.pg_settings.timeout).await?;
let mut server_status = self.server_status.lock().await;
*server_status = exit_status;
Ok(())
}
pub async fn start_db(&mut self) -> Result<()> {
{
let mut server_status = self.server_status.lock().await;
*server_status = PgServerStatus::Starting;
}
self.shutting_down = false;
let mut executor = PgCommand::start_db_executor(
&self.pg_access.pg_ctl_exe,
&self.pg_access.database_dir,
&self.pg_settings.port,
)?;
let exit_status = executor.execute(self.pg_settings.timeout).await?;
let mut server_status = self.server_status.lock().await;
*server_status = exit_status;
Ok(())
}
pub async fn stop_db(&mut self) -> Result<()> {
{
let mut server_status = self.server_status.lock().await;
*server_status = PgServerStatus::Stopping;
}
self.shutting_down = true;
let mut executor =
PgCommand::stop_db_executor(&self.pg_access.pg_ctl_exe, &self.pg_access.database_dir)?;
let exit_status = executor.execute(self.pg_settings.timeout).await?;
let mut server_status = self.server_status.lock().await;
*server_status = exit_status;
Ok(())
}
pub fn stop_db_sync(&mut self) -> Result<()> {
self.shutting_down = true;
let mut stop_db_command = self
.pg_access
.stop_db_command_sync(&self.pg_settings.database_dir);
let process = stop_db_command
.get_mut()
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| Error::PgError(e.to_string(), "".to_string()))?;
self.handle_process_io_sync(process)
}
pub fn handle_process_io_sync(&self, mut process: std::process::Child) -> Result<()> {
if let Some(stdout) = process.stdout.take() {
std::io::BufReader::new(stdout)
.lines()
.for_each(|line| {
if let Ok(l) = line {
info!("{}", l);
}
});
}
if let Some(stderr) = process.stderr.take() {
std::io::BufReader::new(stderr)
.lines()
.for_each(|line| {
if let Ok(l) = line {
error!("{}", l);
}
});
}
Ok(())
}
#[cfg(feature = "rt_tokio_migrate")]
pub async fn create_database(&self, db_name: &str) -> Result<()> {
Postgres::create_database(&self.full_db_uri(db_name))
.await
.map_err(|e| Error::PgTaskJoinError(e.to_string()))?;
Ok(())
}
#[cfg(feature = "rt_tokio_migrate")]
pub async fn drop_database(&self, db_name: &str) -> Result<()> {
Postgres::drop_database(&self.full_db_uri(db_name))
.await
.map_err(|e| Error::PgTaskJoinError(e.to_string()))?;
Ok(())
}
#[cfg(feature = "rt_tokio_migrate")]
pub async fn database_exists(&self, db_name: &str) -> Result<bool> {
Postgres::database_exists(&self.full_db_uri(db_name))
.await
.map_err(|e| Error::PgTaskJoinError(e.to_string()))
}
pub fn full_db_uri(&self, db_name: &str) -> String {
format!("{}/{}", &self.db_uri, db_name)
}
#[cfg(feature = "rt_tokio_migrate")]
pub async fn migrate(&self, db_name: &str) -> Result<()> {
if let Some(migration_dir) = &self.pg_settings.migration_dir {
let m = Migrator::new(migration_dir.as_path())
.await
.map_err(|e| Error::MigrationError(e.to_string()))?;
let pool = PgPoolOptions::new()
.connect(&self.full_db_uri(db_name))
.await
.map_err(|e| Error::SqlQueryError(e.to_string()))?;
m.run(&pool)
.await
.map_err(|e| Error::MigrationError(e.to_string()))?;
}
Ok(())
}
}