use std::cell::Cell;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, LazyLock};
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use crate::pg_enums::{OperationSystem, PgAcquisitionStatus};
use crate::pg_errors::Error;
use crate::pg_fetch::PgFetchSettings;
use crate::pg_types::PgCommandSync;
use crate::pg_unpack;
use crate::pg_errors::Result;
static ACQUIRED_PG_BINS: LazyLock<Arc<Mutex<HashMap<PathBuf, PgAcquisitionStatus>>>> =
LazyLock::new(|| Arc::new(Mutex::new(HashMap::with_capacity(5))));
const PG_EMBED_CACHE_DIR_NAME: &str = "pg-embed";
const PG_VERSION_FILE_NAME: &str = "PG_VERSION";
pub struct PgAccess {
pub cache_dir: PathBuf,
pub database_dir: PathBuf,
pub pg_ctl_exe: PathBuf,
pub init_db_exe: PathBuf,
pub pw_file_path: PathBuf,
pub zip_file_path: PathBuf,
pg_version_file: PathBuf,
fetch_settings: PgFetchSettings,
}
impl PgAccess {
pub async fn new(
fetch_settings: &PgFetchSettings,
database_dir: &Path,
) -> Result<Self> {
let cache_dir = Self::create_cache_dir_structure(fetch_settings).await?;
Self::create_db_dir_structure(database_dir).await?;
let platform = fetch_settings.platform();
let pg_ctl = cache_dir.join("bin/pg_ctl");
let init_db = cache_dir.join("bin/initdb");
let zip_file_path = cache_dir.join(format!("{}-{}.zip", platform, fetch_settings.version.0));
let mut pw_file = database_dir.to_path_buf();
pw_file.set_extension("pwfile");
let pg_version_file = database_dir.join(PG_VERSION_FILE_NAME);
Ok(PgAccess {
cache_dir,
database_dir: database_dir.to_path_buf(),
pg_ctl_exe: pg_ctl,
init_db_exe: init_db,
pw_file_path: pw_file,
zip_file_path,
pg_version_file,
fetch_settings: fetch_settings.clone(),
})
}
async fn create_cache_dir_structure(fetch_settings: &PgFetchSettings) -> Result<PathBuf> {
let cache_dir = dirs::cache_dir().ok_or(Error::InvalidPgUrl)?;
let os_string = match fetch_settings.operating_system {
OperationSystem::Darwin | OperationSystem::Windows | OperationSystem::Linux => {
fetch_settings.operating_system.to_string()
}
OperationSystem::AlpineLinux => {
format!("arch_{}", fetch_settings.operating_system)
}
};
let pg_path = format!(
"{}/{}/{}/{}",
PG_EMBED_CACHE_DIR_NAME,
os_string,
fetch_settings.architecture,
fetch_settings.version.0
);
let mut cache_pg_embed = cache_dir;
cache_pg_embed.push(pg_path);
tokio::fs::create_dir_all(&cache_pg_embed)
.await
.map_err(|e| Error::DirCreationError(e.to_string()))?;
Ok(cache_pg_embed)
}
async fn create_db_dir_structure(db_dir: &Path) -> Result<()> {
tokio::fs::create_dir_all(db_dir)
.await
.map_err(|e| Error::DirCreationError(e.to_string()))
}
pub async fn maybe_acquire_postgres(&self) -> Result<()> {
let mut lock = ACQUIRED_PG_BINS.lock().await;
if self.pg_executables_cached().await? {
return Ok(());
}
lock.insert(self.cache_dir.clone(), PgAcquisitionStatus::InProgress);
self.fetch_settings
.fetch_postgres_to_file(&self.zip_file_path)
.await?;
log::debug!(
"Unpacking postgres binaries {} {}",
self.zip_file_path.display(),
self.cache_dir.display()
);
pg_unpack::unpack_postgres(&self.zip_file_path, &self.cache_dir).await?;
if let Some(status) = lock.get_mut(&self.cache_dir) {
*status = PgAcquisitionStatus::Finished;
}
Ok(())
}
pub async fn pg_executables_cached(&self) -> Result<bool> {
Self::path_exists(self.init_db_exe.as_path()).await
}
pub async fn db_files_exist(&self) -> Result<bool> {
Ok(self.pg_executables_cached().await?
&& Self::path_exists(self.pg_version_file.as_path()).await?)
}
pub async fn pg_version_file_exists(db_dir: &Path) -> Result<bool> {
let pg_version_file = db_dir.join(PG_VERSION_FILE_NAME);
Self::path_exists(&pg_version_file).await
}
async fn path_exists(file: &Path) -> Result<bool> {
tokio::fs::try_exists(file)
.await
.map_err(|e| Error::ReadFileError(e.to_string()))
}
pub async fn acquisition_status(&self) -> PgAcquisitionStatus {
let lock = ACQUIRED_PG_BINS.lock().await;
let acquisition_status = lock.get(&self.cache_dir);
match acquisition_status {
None => PgAcquisitionStatus::Undefined,
Some(status) => *status,
}
}
pub fn clean(&self) -> Result<()> {
let dir_result = std::fs::remove_dir_all(&self.database_dir)
.map_err(|e| Error::PgCleanUpFailure(e.to_string()));
let file_result = std::fs::remove_file(&self.pw_file_path)
.map_err(|e| Error::PgCleanUpFailure(e.to_string()));
dir_result.and(file_result)
}
pub async fn purge() -> Result<()> {
let mut cache_dir = dirs::cache_dir()
.ok_or_else(|| Error::ReadFileError("cache dir not found".into()))?;
cache_dir.push(PG_EMBED_CACHE_DIR_NAME);
let _ = tokio::fs::remove_dir_all(&cache_dir).await;
Ok(())
}
pub async fn clean_up(database_dir: PathBuf, pw_file: PathBuf) -> Result<()> {
tokio::fs::remove_dir_all(&database_dir)
.await
.map_err(|e| Error::PgCleanUpFailure(e.to_string()))?;
tokio::fs::remove_file(&pw_file)
.await
.map_err(|e| Error::PgCleanUpFailure(e.to_string()))
}
pub async fn create_password_file(&self, password: &[u8]) -> Result<()> {
let mut file = tokio::fs::File::create(self.pw_file_path.as_path())
.await
.map_err(|e| Error::WriteFileError(e.to_string()))?;
file.write_all(password)
.await
.map_err(|e| Error::WriteFileError(e.to_string()))
}
async fn share_extension_dir(cache_dir: &Path) -> PathBuf {
let candidates = [
cache_dir.join("share/postgresql/extension"),
cache_dir.join("share/extension"),
];
for candidate in &candidates {
if tokio::fs::try_exists(candidate).await.unwrap_or(false) {
return candidate.clone();
}
}
candidates[0].clone()
}
pub async fn install_extension(&self, extension_dir: &Path) -> Result<()> {
let lib_dir = self.cache_dir.join("lib");
let share_ext_dir = Self::share_extension_dir(&self.cache_dir).await;
tokio::fs::create_dir_all(&lib_dir)
.await
.map_err(|e| Error::DirCreationError(e.to_string()))?;
tokio::fs::create_dir_all(&share_ext_dir)
.await
.map_err(|e| Error::DirCreationError(e.to_string()))?;
let mut entries = tokio::fs::read_dir(extension_dir)
.await
.map_err(|e| Error::ReadFileError(e.to_string()))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| Error::ReadFileError(e.to_string()))?
{
let file_type = entry
.file_type()
.await
.map_err(|e| Error::ReadFileError(e.to_string()))?;
if !file_type.is_file() {
continue;
}
let path = entry.path();
let file_name = match path.file_name() {
Some(n) => n,
None => continue,
};
let dest_dir = match path.extension().and_then(|e| e.to_str()) {
Some("so") | Some("dylib") | Some("dll") => &lib_dir,
Some("control") | Some("sql") => &share_ext_dir,
_ => continue,
};
tokio::fs::copy(&path, dest_dir.join(file_name))
.await
.map_err(|e| Error::WriteFileError(e.to_string()))?;
}
Ok(())
}
pub fn stop_db_command_sync(&self, database_dir: &Path) -> PgCommandSync {
let mut command = Box::new(Cell::new(
std::process::Command::new(self.pg_ctl_exe.as_os_str()),
));
command.get_mut().arg("stop").arg("-w").arg("-D").arg(database_dir);
command
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pg_fetch::{PgFetchSettings, PG_V17};
#[tokio::test]
async fn test_install_extension() {
let src_dir = tempfile::TempDir::new().unwrap();
let src_path = src_dir.path();
std::fs::write(src_path.join("myvec.so"), b"fake so").unwrap();
std::fs::write(src_path.join("myvec.dylib"), b"fake dylib").unwrap();
std::fs::write(src_path.join("myvec.control"), b"# control").unwrap();
std::fs::write(src_path.join("myvec--1.0.sql"), b"-- sql").unwrap();
std::fs::write(src_path.join("README.txt"), b"readme").unwrap();
let cache_dir = tempfile::TempDir::new().unwrap();
let cache_path = cache_dir.path().to_path_buf();
let pg_access = PgAccess {
cache_dir: cache_path.clone(),
database_dir: cache_path.join("db"),
pg_ctl_exe: cache_path.join("bin/pg_ctl"),
init_db_exe: cache_path.join("bin/initdb"),
pw_file_path: cache_path.join("db.pwfile"),
zip_file_path: cache_path.join("pg.zip"),
pg_version_file: cache_path.join("db/PG_VERSION"),
fetch_settings: PgFetchSettings {
version: PG_V17,
..Default::default()
},
};
pg_access.install_extension(src_path).await.unwrap();
assert!(cache_path.join("lib/myvec.so").exists(), "lib/myvec.so missing");
assert!(cache_path.join("lib/myvec.dylib").exists(), "lib/myvec.dylib missing");
assert!(
cache_path.join("share/postgresql/extension/myvec.control").exists(),
"share/postgresql/extension/myvec.control missing"
);
assert!(
cache_path.join("share/postgresql/extension/myvec--1.0.sql").exists(),
"share/postgresql/extension/myvec--1.0.sql missing"
);
assert!(
!cache_path.join("lib/README.txt").exists(),
"README.txt should not be in lib/"
);
assert!(
!cache_path.join("share/postgresql/extension/README.txt").exists(),
"README.txt should not be in share/postgresql/extension/"
);
}
}