use crate::error::Error::{DatabaseInitializationError, DatabaseStartError, DatabaseStopError};
use crate::error::Result;
use crate::settings::{BOOTSTRAP_DATABASE, BOOTSTRAP_SUPERUSER, Settings};
use postgresql_archive::extract;
#[cfg(not(feature = "bundled"))]
use postgresql_archive::get_archive;
use postgresql_archive::get_version;
use postgresql_archive::{ExactVersion, ExactVersionReq};
#[cfg(feature = "tokio")]
use postgresql_commands::AsyncCommandExecutor;
use postgresql_commands::CommandBuilder;
#[cfg(not(feature = "tokio"))]
use postgresql_commands::CommandExecutor;
use postgresql_commands::initdb::InitDbBuilder;
use postgresql_commands::pg_ctl::Mode::{Start, Stop};
use postgresql_commands::pg_ctl::PgCtlBuilder;
use postgresql_commands::pg_ctl::ShutdownMode::Fast;
use semver::Version;
use sqlx::{PgPool, Row};
use std::fs::{read_dir, remove_dir_all, remove_file};
use std::io::prelude::*;
use std::net::TcpListener;
use std::path::PathBuf;
use tracing::{debug, instrument};
use crate::Error::{CreateDatabaseError, DatabaseExistsError, DropDatabaseError};
const PGDATABASE: &str = "PGDATABASE";
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Status {
NotInstalled,
Installed,
Started,
Stopped,
}
#[derive(Clone, Debug)]
pub struct PostgreSQL {
settings: Settings,
}
impl PostgreSQL {
#[must_use]
pub fn new(settings: Settings) -> Self {
let mut postgresql = PostgreSQL { settings };
if !postgresql.settings.trust_installation_dir
&& let Some(version) = postgresql.settings.version.exact_version()
{
let path = &postgresql.settings.installation_dir;
let version_string = version.to_string();
if !path.ends_with(&version_string) {
postgresql.settings.installation_dir =
postgresql.settings.installation_dir.join(version_string);
}
}
postgresql
}
#[instrument(level = "debug", skip(self))]
pub fn status(&self) -> Status {
if self.is_running() {
Status::Started
} else if self.is_initialized() {
Status::Stopped
} else if self.installed_dir().is_some() {
Status::Installed
} else {
Status::NotInstalled
}
}
#[must_use]
pub fn settings(&self) -> &Settings {
&self.settings
}
fn installed_dir(&self) -> Option<PathBuf> {
if self.settings.trust_installation_dir {
return Some(self.settings.installation_dir.clone());
}
let path = &self.settings.installation_dir;
let maybe_path_version = path
.file_name()
.and_then(|file_name| Version::parse(&file_name.to_string_lossy()).ok());
if let Some(path_version) = maybe_path_version
&& self.settings.version.matches(&path_version)
&& path.exists()
{
return Some(path.clone());
}
let mut versions = read_dir(path)
.ok()?
.filter_map(|entry| {
let Some(entry) = entry.ok() else {
return None;
};
if !entry.file_type().ok()?.is_dir() {
return None;
}
let file_name = entry.file_name();
let version = Version::parse(&file_name.to_string_lossy()).ok()?;
if self.settings.version.matches(&version) {
Some((version, entry.path()))
} else {
None
}
})
.collect::<Vec<_>>();
versions.sort_by(|(a, _), (b, _)| b.cmp(a));
versions.first().map(|(_, path)| path.clone())
}
fn is_initialized(&self) -> bool {
self.settings.data_dir.join("postgresql.conf").exists()
}
fn is_running(&self) -> bool {
let pid_file = self.settings.data_dir.join("postmaster.pid");
pid_file.exists()
}
#[instrument(skip(self))]
pub async fn setup(&mut self) -> Result<()> {
match self.installed_dir() {
Some(installed_dir) => {
self.settings.installation_dir = installed_dir;
}
None => {
self.install().await?;
}
}
if !self.is_initialized() {
self.initialize().await?;
}
Ok(())
}
#[instrument(skip(self))]
async fn install(&mut self) -> Result<()> {
#[cfg(feature = "bundled")]
{
self.settings.version = crate::settings::ARCHIVE_VERSION.clone();
}
debug!(
"Starting installation process for version {}",
self.settings.version
);
if self.settings.version.exact_version().is_none() {
let version = get_version(&self.settings.releases_url, &self.settings.version).await?;
self.settings.version = version.exact_version_req()?;
self.settings.installation_dir =
self.settings.installation_dir.join(version.to_string());
}
if self.settings.installation_dir.exists() {
debug!("Installation directory already exists");
return Ok(());
}
let url = &self.settings.releases_url;
#[cfg(feature = "bundled")]
let bytes = {
debug!("Using bundled installation archive");
crate::settings::ARCHIVE.to_vec()
};
#[cfg(not(feature = "bundled"))]
let bytes = {
let (version, bytes) = get_archive(url, &self.settings.version).await?;
self.settings.version = version.exact_version_req()?;
bytes
};
extract(url, &bytes, &self.settings.installation_dir).await?;
debug!(
"Installed PostgreSQL version {} to {}",
self.settings.version,
self.settings.installation_dir.to_string_lossy()
);
Ok(())
}
#[instrument(skip(self))]
async fn initialize(&mut self) -> Result<()> {
if !self.settings.password_file.exists() {
let mut file = std::fs::File::create(&self.settings.password_file)?;
file.write_all(self.settings.password.as_bytes())?;
}
debug!(
"Initializing database {}",
self.settings.data_dir.to_string_lossy()
);
let initdb = InitDbBuilder::from(&self.settings)
.pgdata(&self.settings.data_dir)
.username(BOOTSTRAP_SUPERUSER)
.auth("password")
.pwfile(&self.settings.password_file)
.encoding("UTF8");
match self.execute_command(initdb).await {
Ok((_stdout, _stderr)) => {
debug!(
"Initialized database {}",
self.settings.data_dir.to_string_lossy()
);
Ok(())
}
Err(error) => Err(DatabaseInitializationError(error.to_string())),
}
}
#[instrument(skip(self))]
pub async fn start(&mut self) -> Result<()> {
if self.settings.port == 0 {
let listener = TcpListener::bind(("0.0.0.0", 0))?;
self.settings.port = listener.local_addr()?.port();
}
#[cfg(unix)]
if let Some(ref socket_dir) = self.settings.socket_dir
&& !socket_dir.exists()
{
std::fs::create_dir_all(socket_dir)?;
}
debug!(
"Starting database {} on port {}{}",
self.settings.data_dir.to_string_lossy(),
self.settings.port,
self.settings
.socket_dir
.as_ref()
.map_or(String::new(), |d| format!(
" with socket dir {}",
d.to_string_lossy()
))
);
let start_log = self.settings.data_dir.join("start.log");
let mut options = Vec::new();
options.push(format!("-F -p {}", self.settings.port));
#[cfg(unix)]
if let Some(ref socket_dir) = self.settings.socket_dir {
options.push(format!("-k {}", socket_dir.to_string_lossy()));
}
for (key, value) in &self.settings.configuration {
options.push(format!("-c {key}={value}"));
}
let pg_ctl = PgCtlBuilder::from(&self.settings)
.env(PGDATABASE, "")
.mode(Start)
.pgdata(&self.settings.data_dir)
.log(start_log)
.options(options.as_slice())
.wait();
match self.execute_command(pg_ctl).await {
Ok((_stdout, _stderr)) => {
debug!(
"Started database {} on port {}{}",
self.settings.data_dir.to_string_lossy(),
self.settings.port,
self.settings
.socket_dir
.as_ref()
.map_or(String::new(), |d| format!(
" with socket dir {}",
d.to_string_lossy()
))
);
Ok(())
}
Err(error) => Err(DatabaseStartError(error.to_string())),
}
}
#[instrument(skip(self))]
pub async fn stop(&self) -> Result<()> {
debug!(
"Stopping database {}",
self.settings.data_dir.to_string_lossy()
);
let pg_ctl = PgCtlBuilder::from(&self.settings)
.mode(Stop)
.pgdata(&self.settings.data_dir)
.shutdown_mode(Fast)
.wait();
match self.execute_command(pg_ctl).await {
Ok((_stdout, _stderr)) => {
debug!(
"Stopped database {}",
self.settings.data_dir.to_string_lossy()
);
Ok(())
}
Err(error) => Err(DatabaseStopError(error.to_string())),
}
}
async fn get_pool(&self) -> Result<PgPool> {
let mut settings = self.settings.clone();
settings.username = BOOTSTRAP_SUPERUSER.to_string();
let database_url = settings.url(BOOTSTRAP_DATABASE);
let pool = PgPool::connect(database_url.as_str()).await?;
Ok(pool)
}
#[instrument(skip(self))]
pub async fn create_database<S>(&self, database_name: S) -> Result<()>
where
S: AsRef<str> + std::fmt::Debug,
{
let database_name = database_name.as_ref();
debug!(
"Creating database {database_name} for {host}:{port}",
host = self.settings.host,
port = self.settings.port
);
let pool = self.get_pool().await?;
sqlx::query(format!("CREATE DATABASE \"{database_name}\"").as_str())
.execute(&pool)
.await
.map_err(|error| CreateDatabaseError(error.to_string()))?;
pool.close().await;
debug!(
"Created database {database_name} for {host}:{port}",
host = self.settings.host,
port = self.settings.port
);
Ok(())
}
#[instrument(skip(self))]
pub async fn database_exists<S>(&self, database_name: S) -> Result<bool>
where
S: AsRef<str> + std::fmt::Debug,
{
let database_name = database_name.as_ref();
debug!(
"Checking if database {database_name} exists for {host}:{port}",
host = self.settings.host,
port = self.settings.port
);
let pool = self.get_pool().await?;
let row = sqlx::query("SELECT COUNT(*) FROM pg_database WHERE datname = $1")
.bind(database_name.to_string())
.fetch_one(&pool)
.await
.map_err(|error| DatabaseExistsError(error.to_string()))?;
let count: i64 = row.get(0);
pool.close().await;
Ok(count == 1)
}
#[instrument(skip(self))]
pub async fn drop_database<S>(&self, database_name: S) -> Result<()>
where
S: AsRef<str> + std::fmt::Debug,
{
let database_name = database_name.as_ref();
debug!(
"Dropping database {database_name} for {host}:{port}",
host = self.settings.host,
port = self.settings.port
);
let pool = self.get_pool().await?;
sqlx::query(format!("DROP DATABASE IF EXISTS \"{database_name}\"").as_str())
.execute(&pool)
.await
.map_err(|error| DropDatabaseError(error.to_string()))?;
pool.close().await;
debug!(
"Dropped database {database_name} for {host}:{port}",
host = self.settings.host,
port = self.settings.port
);
Ok(())
}
#[cfg(not(feature = "tokio"))]
#[instrument(level = "debug", skip(self, command_builder), fields(program = ?command_builder.get_program()))]
async fn execute_command<B: CommandBuilder>(
&self,
command_builder: B,
) -> postgresql_commands::Result<(String, String)> {
let mut command = command_builder.build();
command.execute()
}
#[cfg(feature = "tokio")]
#[instrument(level = "debug", skip(self, command_builder), fields(program = ?command_builder.get_program()))]
async fn execute_command<B: CommandBuilder>(
&self,
command_builder: B,
) -> postgresql_commands::Result<(String, String)> {
let mut command = command_builder.build_tokio();
command.execute(self.settings.timeout).await
}
}
impl Default for PostgreSQL {
fn default() -> Self {
Self::new(Settings::default())
}
}
impl Drop for PostgreSQL {
fn drop(&mut self) {
if self.status() == Status::Started {
let mut pg_ctl = PgCtlBuilder::from(&self.settings)
.mode(Stop)
.pgdata(&self.settings.data_dir)
.shutdown_mode(Fast)
.wait()
.build();
let _ = pg_ctl.output();
}
if self.settings.temporary {
let _ = remove_dir_all(&self.settings.data_dir);
let _ = remove_file(&self.settings.password_file);
if let Some(ref socket_dir) = self.settings.socket_dir {
let _ = remove_dir_all(socket_dir);
}
}
}
}