use crate::error::Error::{DatabaseInitializationError, DatabaseStartError, DatabaseStopError};
use crate::error::Result;
use crate::settings::Settings;
use postgresql_archive::{extract, get_archive};
use postgresql_archive::{get_version, Version};
use postgresql_commands::initdb::InitDbBuilder;
use postgresql_commands::pg_ctl::Mode::{Start, Stop};
use postgresql_commands::pg_ctl::PgCtlBuilder;
use postgresql_commands::pg_ctl::ShutdownMode::Fast;
use postgresql_commands::psql::PsqlBuilder;
#[cfg(feature = "tokio")]
use postgresql_commands::AsyncCommandExecutor;
use postgresql_commands::CommandBuilder;
#[cfg(not(feature = "tokio"))]
use postgresql_commands::CommandExecutor;
use std::fs::{remove_dir_all, remove_file};
use std::io::prelude::*;
use std::net::TcpListener;
#[cfg(feature = "bundled")]
use std::ops::Deref;
#[cfg(feature = "bundled")]
use std::str::FromStr;
use tracing::{debug, instrument};
use crate::Error::{CreateDatabaseError, DatabaseExistsError, DropDatabaseError};
#[cfg(feature = "bundled")]
lazy_static::lazy_static! {
pub(crate) static ref ARCHIVE_VERSION: Version = {
let version_string = include_str!(concat!(std::env!("OUT_DIR"), "/postgresql.version"));
let version = Version::from_str(version_string).unwrap();
debug!("Bundled installation archive version {version}");
version
};
}
#[cfg(feature = "bundled")]
pub(crate) const ARCHIVE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/postgresql.tar.gz"));
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Status {
NotInstalled,
Installed,
Started,
Stopped,
}
#[derive(Clone, Debug)]
pub struct PostgreSQL {
version: Version,
settings: Settings,
}
impl PostgreSQL {
pub fn new(version: Version, settings: Settings) -> Self {
let mut postgresql = PostgreSQL { version, settings };
if version.minor.is_some() && version.release.is_some() {
let path = &postgresql.settings.installation_dir;
let version_string = version.to_string();
if !path.ends_with(&version_string) {
postgresql.settings.installation_dir =
postgresql.settings.installation_dir.join(version_string);
}
}
postgresql
}
pub fn default_version() -> Version {
#[cfg(feature = "bundled")]
{
*ARCHIVE_VERSION
}
#[cfg(not(feature = "bundled"))]
{
postgresql_archive::LATEST
}
}
#[instrument(level = "debug")]
pub fn status(&self) -> Status {
if self.is_running() {
Status::Started
} else if self.is_initialized() {
Status::Stopped
} else if self.is_installed() {
Status::Installed
} else {
Status::NotInstalled
}
}
pub fn version(&self) -> &Version {
&self.version
}
pub fn settings(&self) -> &Settings {
&self.settings
}
fn is_installed(&self) -> bool {
if self.version.minor.is_none() || self.version.release.is_none() {
return false;
}
let path = &self.settings.installation_dir;
path.ends_with(self.version.to_string()) && path.exists()
}
fn is_initialized(&self) -> bool {
self.settings.data_dir.join("postgresql.conf").exists()
}
fn is_running(&self) -> bool {
let pid_file = self.settings.data_dir.join("postmaster.pid");
pid_file.exists()
}
#[instrument]
pub async fn setup(&mut self) -> Result<()> {
if !self.is_installed() {
self.install().await?;
}
if !self.is_initialized() {
self.initialize().await?;
}
Ok(())
}
#[instrument]
async fn install(&mut self) -> Result<()> {
debug!("Starting installation process for version {}", self.version);
if self.version.minor.is_none() || self.version.release.is_none() {
let version = get_version(&self.version).await?;
self.version = version;
self.settings.installation_dir = self
.settings
.installation_dir
.join(self.version.to_string());
}
if self.settings.installation_dir.exists() {
debug!("Installation directory already exists");
return Ok(());
}
#[cfg(feature = "bundled")]
let (version, bytes) = if ARCHIVE_VERSION.deref() == &self.version {
debug!("Using bundled installation archive");
(self.version, bytes::Bytes::copy_from_slice(ARCHIVE))
} else {
get_archive(&self.version).await?
};
#[cfg(not(feature = "bundled"))]
let (version, bytes) = { get_archive(&self.version).await? };
self.version = version;
extract(&bytes, &self.settings.installation_dir).await?;
debug!(
"Installed PostgreSQL version {} to {}",
self.version,
self.settings.installation_dir.to_string_lossy()
);
Ok(())
}
#[instrument]
async fn initialize(&mut self) -> Result<()> {
if !self.settings.password_file.exists() {
let mut file = std::fs::File::create(&self.settings.password_file)?;
file.write_all(self.settings.password.as_bytes())?;
}
debug!(
"Initializing database {}",
self.settings.data_dir.to_string_lossy()
);
let initdb = InitDbBuilder::from(&self.settings)
.pgdata(&self.settings.data_dir)
.auth("password")
.pwfile(&self.settings.password_file)
.encoding("UTF8");
match self.execute_command(initdb).await {
Ok((_stdout, _stderr)) => {
debug!(
"Initialized database {}",
self.settings.data_dir.to_string_lossy()
);
Ok(())
}
Err(error) => Err(DatabaseInitializationError(error.into())),
}
}
#[instrument]
pub async fn start(&mut self) -> Result<()> {
if self.settings.port == 0 {
let listener = TcpListener::bind(("0.0.0.0", 0))?;
self.settings.port = listener.local_addr()?.port();
}
debug!(
"Starting database {} on port {}",
self.settings.data_dir.to_string_lossy(),
self.settings.port
);
let start_log = self.settings.data_dir.join("start.log");
let options = format!("-F -p {}", self.settings.port);
let pg_ctl = PgCtlBuilder::from(&self.settings)
.mode(Start)
.pgdata(&self.settings.data_dir)
.log(start_log)
.options(options)
.wait();
match self.execute_command(pg_ctl).await {
Ok((_stdout, _stderr)) => {
debug!(
"Started database {} on port {}",
self.settings.data_dir.to_string_lossy(),
self.settings.port
);
Ok(())
}
Err(error) => Err(DatabaseStartError(error.into())),
}
}
#[instrument]
pub async fn stop(&self) -> Result<()> {
debug!(
"Stopping database {}",
self.settings.data_dir.to_string_lossy()
);
let pg_ctl = PgCtlBuilder::from(&self.settings)
.mode(Stop)
.pgdata(&self.settings.data_dir)
.shutdown_mode(Fast)
.wait();
match self.execute_command(pg_ctl).await {
Ok((_stdout, _stderr)) => {
debug!(
"Stopped database {}",
self.settings.data_dir.to_string_lossy()
);
Ok(())
}
Err(error) => Err(DatabaseStopError(error.into())),
}
}
#[instrument(skip(database_name))]
pub async fn create_database<S: AsRef<str>>(&self, database_name: S) -> Result<()> {
debug!(
"Creating database {} for {}:{}",
database_name.as_ref(),
self.settings.host,
self.settings.port
);
let psql = PsqlBuilder::from(&self.settings)
.command(format!("CREATE DATABASE \"{}\"", database_name.as_ref()))
.no_psqlrc()
.no_align()
.tuples_only();
match self.execute_command(psql).await {
Ok((_stdout, _stderr)) => {
debug!(
"Created database {} for {}:{}",
database_name.as_ref(),
self.settings.host,
self.settings.port
);
Ok(())
}
Err(error) => Err(CreateDatabaseError(error.into())),
}
}
#[instrument(skip(database_name))]
pub async fn database_exists<S: AsRef<str>>(&self, database_name: S) -> Result<bool> {
debug!(
"Checking if database {} exists for {}:{}",
database_name.as_ref(),
self.settings.host,
self.settings.port
);
let psql = PsqlBuilder::new()
.program_dir(self.settings.binary_dir())
.command(format!(
"SELECT 1 FROM pg_database WHERE datname='{}'",
database_name.as_ref()
))
.host(&self.settings.host)
.port(self.settings.port)
.username(&self.settings.username)
.pg_password(&self.settings.password)
.no_psqlrc()
.no_align()
.tuples_only();
match self.execute_command(psql).await {
Ok((stdout, _stderr)) => match stdout.trim() {
"1" => Ok(true),
_ => Ok(false),
},
Err(error) => Err(DatabaseExistsError(error.into())),
}
}
#[instrument(skip(database_name))]
pub async fn drop_database<S: AsRef<str>>(&self, database_name: S) -> Result<()> {
debug!(
"Dropping database {} for {}:{}",
database_name.as_ref(),
self.settings.host,
self.settings.port
);
let psql = PsqlBuilder::new()
.program_dir(self.settings.binary_dir())
.command(format!(
"DROP DATABASE IF EXISTS \"{}\"",
database_name.as_ref()
))
.host(&self.settings.host)
.port(self.settings.port)
.username(&self.settings.username)
.pg_password(&self.settings.password)
.no_psqlrc()
.no_align()
.tuples_only();
match self.execute_command(psql).await {
Ok((_stdout, _stderr)) => {
debug!(
"Dropped database {} for {}:{}",
database_name.as_ref(),
self.settings.host,
self.settings.port
);
Ok(())
}
Err(error) => Err(DropDatabaseError(error.into())),
}
}
#[cfg(not(feature = "tokio"))]
async fn execute_command<B: CommandBuilder>(
&self,
command_builder: B,
) -> postgresql_commands::Result<(String, String)> {
let mut command = command_builder.build();
command.execute()
}
#[cfg(feature = "tokio")]
#[instrument(level = "debug")]
async fn execute_command<B: CommandBuilder>(
&self,
command_builder: B,
) -> postgresql_commands::Result<(String, String)> {
let mut command = command_builder.build_tokio();
command.execute(self.settings.timeout).await
}
}
impl Default for PostgreSQL {
fn default() -> Self {
let version = PostgreSQL::default_version();
let settings = Settings::default();
Self::new(version, settings)
}
}
impl Drop for PostgreSQL {
fn drop(&mut self) {
if self.status() == Status::Started {
let mut pg_ctl = PgCtlBuilder::from(&self.settings)
.mode(Stop)
.pgdata(&self.settings.data_dir)
.shutdown_mode(Fast)
.wait()
.build();
let _ = pg_ctl.output();
}
if self.settings.temporary {
let _ = remove_dir_all(&self.settings.data_dir);
let _ = remove_file(&self.settings.password_file);
}
}
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(feature = "bundled")]
fn test_archive_version() {
assert!(!super::ARCHIVE_VERSION.to_string().is_empty());
}
}