use anyhow::{Context, Result};
use dirs::cache_dir;
use flate2::read::GzDecoder;
use once_cell::sync::OnceCell;
use std::fs::{self, File};
use std::path::{Path, PathBuf};
use std::time::Duration;
use tar::Archive;
use tokio::task::spawn_blocking;
use tokio::{fs as tokio_fs, io::AsyncWriteExt};
use which::which;
const PG_VERSION: &str = "16.3";
const DEFAULT_PG_LINUX_X64_URL: &str =
"https://get.enterprisedb.com/postgresql/postgresql-16.3-1-linux-x64-binaries.tar.gz";
const PG_ARCHIVE_NAME: &str = "postgresql-16.3-1-linux-x64-binaries.tar.gz";
static TOOLS_CELL: OnceCell<PgToolsPaths> = OnceCell::new();
#[derive(Clone, Debug)]
pub struct PgToolsPaths {
pub pg_dump: PathBuf,
pub pg_restore: PathBuf,
}
pub async fn ensure_pg_tools() -> Result<PgToolsPaths> {
use std::fs;
if let Some(cached) = TOOLS_CELL.get() {
return Ok(cached.clone());
}
if let Some(paths) = env_overrides()? {
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
{
let dotenv_path: Option<PathBuf> = std::env::current_dir()
.ok()
.map(|dir| dir.join(".env"))
.filter(|path| path.exists());
if let Some(dotenv_file) = dotenv_path {
let _ = dotenv::from_path(&dotenv_file);
if let Some(paths) = env_overrides()? {
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
}
}
#[cfg(target_os = "windows")]
let candidates = [("pg_dump", "pg_dump.exe"), ("pg_restore", "pg_restore.exe")];
#[cfg(not(target_os = "windows"))]
let candidates = [("pg_dump", "pg_dump"), ("pg_restore", "pg_restore")];
#[cfg(target_os = "windows")]
let which_with_exe = |tool_name: &str, exe_name: &str| -> Option<PathBuf> {
which(exe_name).ok().or_else(|| which(tool_name).ok())
};
#[cfg(not(target_os = "windows"))]
let which_with_exe =
|tool_name: &str, _exe_name: &str| -> Option<PathBuf> { which(tool_name).ok() };
let dump_opt: Option<PathBuf> = which_with_exe(candidates[0].0, candidates[0].1);
let restore_opt: Option<PathBuf> = which_with_exe(candidates[1].0, candidates[1].1);
if let (Some(dump), Some(restore)) = (dump_opt, restore_opt) {
let paths: PgToolsPaths = PgToolsPaths {
pg_dump: dump,
pg_restore: restore,
};
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
ensure_linux_x64()?;
let paths: PgToolsPaths = download_and_extract().await?;
let _ = TOOLS_CELL.set(paths.clone());
Ok(paths)
}
fn env_overrides() -> Result<Option<PgToolsPaths>> {
let dump: Option<PathBuf> = std::env::var("ATHENA_PG_DUMP_PATH").ok().map(PathBuf::from);
let restore: Option<PathBuf> = std::env::var("ATHENA_PG_RESTORE_PATH")
.ok()
.map(PathBuf::from);
if dump.is_none() && restore.is_none() {
return Ok(None);
}
let dump_path: PathBuf = dump.context("ATHENA_PG_DUMP_PATH set but empty")?;
let restore_path: PathBuf = restore.unwrap_or_else(|| {
dump_path
.parent()
.map(|p| p.join("pg_restore"))
.unwrap_or_else(|| dump_path.clone())
});
if !dump_path.is_file() {
anyhow::bail!("pg_dump not found at {}", dump_path.display());
}
if !restore_path.is_file() {
anyhow::bail!("pg_restore not found at {}", restore_path.display());
}
Ok(Some(PgToolsPaths {
pg_dump: dump_path,
pg_restore: restore_path,
}))
}
fn ensure_linux_x64() -> Result<()> {
if cfg!(all(target_os = "linux", target_arch = "x86_64")) {
Ok(())
} else {
anyhow::bail!(
"Automatic pg_dump download is only supported on Linux x86_64; set ATHENA_PG_DUMP_PATH and ATHENA_PG_RESTORE_PATH instead"
)
}
}
async fn download_and_extract() -> Result<PgToolsPaths> {
let cache_root: PathBuf = cache_dir()
.unwrap_or(std::env::temp_dir())
.join("athena")
.join("pg_tools")
.join(format!("{}-linux-x64", PG_VERSION));
let bin_dir: PathBuf = cache_root.join("bin");
let pg_dump_path: PathBuf = bin_dir.join("pg_dump");
let pg_restore_path: PathBuf = bin_dir.join("pg_restore");
if pg_dump_path.is_file() && pg_restore_path.is_file() {
return Ok(PgToolsPaths {
pg_dump: pg_dump_path,
pg_restore: pg_restore_path,
});
}
tokio_fs::create_dir_all(&cache_root)
.await
.context("create cache dir")?;
tokio_fs::create_dir_all(&bin_dir).await?;
let archive_path: PathBuf = cache_root.join(PG_ARCHIVE_NAME);
if !archive_path.is_file() {
let url = std::env::var("ATHENA_PG_TOOLS_URL")
.unwrap_or_else(|_| DEFAULT_PG_LINUX_X64_URL.to_string());
download_archive(&url, &archive_path).await?;
}
extract_binaries(&archive_path, &bin_dir).await?;
Ok(PgToolsPaths {
pg_dump: pg_dump_path,
pg_restore: pg_restore_path,
})
}
async fn download_archive(url: &str, dest: &Path) -> Result<()> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(60))
.build()?;
let resp: reqwest::Response = client
.get(url)
.send()
.await
.with_context(|| format!("downloading {}", url))?;
let status = resp.status();
if !status.is_success() {
anyhow::bail!("download failed ({}): {}", status, url);
}
let bytes = resp.bytes().await?;
let mut file = tokio_fs::File::create(dest).await?;
file.write_all(&bytes).await?;
file.flush().await?;
Ok(())
}
async fn extract_binaries(archive_path: &Path, bin_dir: &Path) -> Result<()> {
let archive_path = archive_path.to_owned();
let bin_dir = bin_dir.to_owned();
spawn_blocking(move || -> Result<()> {
let file = File::open(&archive_path)?;
let decoder = GzDecoder::new(file);
let mut archive = Archive::new(decoder);
let mut extracted = 0usize;
for entry in archive.entries()? {
let mut entry = entry?;
let path = entry.path()?;
let name = path
.file_name()
.and_then(|v| v.to_str())
.unwrap_or_default();
if name == "pg_dump" || name == "pg_restore" {
let dest = bin_dir.join(name);
entry.unpack(&dest)?;
let mut perms = fs::metadata(&dest)?.permissions();
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
perms.set_mode(0o755);
fs::set_permissions(&dest, perms)?;
}
extracted += 1;
}
}
if extracted < 2 {
anyhow::bail!("pg_dump/pg_restore not found in downloaded archive");
}
Ok(())
})
.await?
}