use anyhow::{Context, Result};
use dirs::cache_dir;
use flate2::read::GzDecoder;
use once_cell::sync::OnceCell;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::time::Duration;
use tar::Archive;
use tokio::process::Command;
use tokio::task::spawn_blocking;
use tokio::{fs as tokio_fs, io::AsyncWriteExt};
use which::which;
const PG_VERSION: &str = "17.7";
const DEFAULT_PG_LINUX_X64_URL: &str =
"https://get.enterprisedb.com/postgresql/postgresql-17.7-1-linux-x64-binaries.tar.gz";
const PG_ARCHIVE_NAME: &str = "postgresql-17.7-1-linux-x64-binaries.tar.gz";
static TOOLS_CELL: OnceCell<PgToolsPaths> = OnceCell::new();
#[derive(Clone, Debug)]
pub struct PgToolsPaths {
pub pg_dump: PathBuf,
pub pg_restore: PathBuf,
}
fn tool_filename(base: &str) -> String {
#[cfg(target_os = "windows")]
{
match base {
"pg_dump" => "pg_dump.exe".to_string(),
"pg_restore" => "pg_restore.exe".to_string(),
other => other.to_string(),
}
}
#[cfg(not(target_os = "windows"))]
{
base.to_string()
}
}
fn parse_major_version(version: &str) -> Option<u32> {
version.split('.').next()?.parse().ok()
}
fn parse_version_token(token: &str) -> Option<u32> {
let trimmed: &str = token.trim_start_matches(|ch: char| !ch.is_ascii_digit());
let numeric: String = trimmed
.chars()
.take_while(|ch| ch.is_ascii_digit() || *ch == '.')
.collect();
if numeric.is_empty() {
None
} else {
parse_major_version(&numeric)
}
}
fn parse_pg_dump_version_major(output: &str) -> Option<u32> {
output.split_whitespace().find_map(parse_version_token)
}
async fn pg_dump_major_version_internal(path: &Path) -> Option<u32> {
let output = match Command::new(path).arg("--version").output().await {
Ok(output) => output,
Err(err) => {
tracing::warn!(
"Failed to execute pg_dump at {} (path may not exist or is not executable): {}",
path.display(),
err
);
return None;
}
};
let stdout: std::borrow::Cow<'_, str> = String::from_utf8_lossy(&output.stdout);
let stderr: std::borrow::Cow<'_, str> = String::from_utf8_lossy(&output.stderr);
let text: std::borrow::Cow<'_, str> = if stdout.trim().is_empty() {
stderr
} else {
stdout
};
let parsed = parse_pg_dump_version_major(&text);
if parsed.is_none() {
tracing::warn!(
"Unable to parse pg_dump version output from {}: {}",
path.display(),
text.trim()
);
}
parsed
}
pub fn required_pg_dump_major() -> u32 {
parse_major_version(PG_VERSION).unwrap_or(17)
}
pub async fn pg_dump_major_version(path: &Path) -> Option<u32> {
pg_dump_major_version_internal(path).await
}
pub fn resolve_pg_tools_from_dir(server_major: u32) -> Option<PgToolsPaths> {
let root: String = std::env::var("ATHENA_PG_TOOLS_DIR").ok()?;
if root.trim().is_empty() {
return None;
}
let bin_dir: PathBuf = PathBuf::from(root)
.join(server_major.to_string())
.join("bin");
let dump: PathBuf = bin_dir.join(tool_filename("pg_dump"));
let restore: PathBuf = bin_dir.join(tool_filename("pg_restore"));
if dump.is_file() && restore.is_file() {
Some(PgToolsPaths {
pg_dump: dump,
pg_restore: restore,
})
} else {
None
}
}
#[allow(dead_code)]
async fn maybe_download_newer_pg_tools(dump_path: &Path) -> Option<PgToolsPaths> {
let required_major: u32 = parse_major_version(PG_VERSION)?;
let installed_major: u32 = pg_dump_major_version_internal(dump_path).await?;
if installed_major >= required_major {
return None;
}
tracing::info!(
"pg_dump on PATH is major version {} which is older than required major {} (PG_VERSION {}). Attempting to download pg tools.",
installed_major,
required_major,
PG_VERSION
);
if !allow_pg_tools_download() {
tracing::warn!(
"Automatic pg tools download disabled; using existing pg_dump major version {}.",
installed_major
);
return None;
}
if let Err(err) = ensure_linux_x64() {
tracing::warn!(
"Automatic pg tools download not available: {}. Using existing pg_dump major version {}.",
err,
installed_major
);
return None;
}
match download_and_extract().await {
Ok(paths) => Some(paths),
Err(err) => {
tracing::warn!(
"Failed to download newer pg tools: {}. Using existing pg_dump major version {}.",
err,
installed_major
);
None
}
}
}
pub async fn ensure_pg_tools() -> Result<PgToolsPaths> {
if let Some(cached) = TOOLS_CELL.get() {
return Ok(cached.clone());
}
if let Some(paths) = env_overrides()? {
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
{
let dotenv_path: Option<PathBuf> = std::env::current_dir()
.ok()
.map(|dir| dir.join(".env"))
.filter(|path| path.exists());
if let Some(dotenv_file) = dotenv_path {
let _ = dotenv::from_path(&dotenv_file);
if let Some(paths) = env_overrides()? {
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
}
}
#[cfg(target_os = "windows")]
{
let candidates = [("pg_dump", "pg_dump.exe"), ("pg_restore", "pg_restore.exe")];
let which_with_exe = |tool_name: &str, exe_name: &str| -> Option<PathBuf> {
which(exe_name).ok().or_else(|| which(tool_name).ok())
};
let dump_opt: Option<PathBuf> = which_with_exe(candidates[0].0, candidates[0].1);
let restore_opt: Option<PathBuf> = which_with_exe(candidates[1].0, candidates[1].1);
if let (Some(dump), Some(restore)) = (dump_opt.clone(), restore_opt.clone()) {
let paths: PgToolsPaths = PgToolsPaths {
pg_dump: dump,
pg_restore: restore,
};
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
let dump_path = PathBuf::from(r"C:\Program Files\PostgreSQL\18\bin\pg_dump.exe");
let restore_path = PathBuf::from(r"C:\Program Files\PostgreSQL\18\bin\pg_restore.exe");
if dump_path.is_file() && restore_path.is_file() {
let paths = PgToolsPaths {
pg_dump: dump_path,
pg_restore: restore_path,
};
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
}
#[cfg(not(target_os = "windows"))]
{
let candidates = [("pg_dump", "pg_dump"), ("pg_restore", "pg_restore")];
let which_with_exe =
|tool_name: &str, _exe_name: &str| -> Option<PathBuf> { which(tool_name).ok() };
let mut dump_opt: Option<PathBuf> = which_with_exe(candidates[0].0, candidates[0].1);
let mut restore_opt: Option<PathBuf> = which_with_exe(candidates[1].0, candidates[1].1);
if let (Some(dump), Some(restore)) = (dump_opt.clone(), restore_opt.clone()) {
if let Some(paths) = maybe_download_newer_pg_tools(&dump).await {
if TOOLS_CELL.set(paths.clone()).is_err() {
if let Some(cached) = TOOLS_CELL.get() {
return Ok(cached.clone());
}
tracing::warn!(
"pg tool cache already initialized but no cached value found; using freshly downloaded tools."
);
}
return Ok(paths);
}
let paths: PgToolsPaths = PgToolsPaths {
pg_dump: dump,
pg_restore: restore,
};
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
#[cfg(target_os = "linux")]
{
maybe_install_pg_tools_with_apt().await;
dump_opt = which_with_exe(candidates[0].0, candidates[0].1);
restore_opt = which_with_exe(candidates[1].0, candidates[1].1);
if let (Some(dump), Some(restore)) = (dump_opt, restore_opt) {
let paths: PgToolsPaths = PgToolsPaths {
pg_dump: dump,
pg_restore: restore,
};
let _ = TOOLS_CELL.set(paths.clone());
return Ok(paths);
}
}
}
if !allow_pg_tools_download() {
anyhow::bail!(
"pg_dump/pg_restore not found on PATH and automatic download is disabled. \
Install PostgreSQL client tools (e.g. apt install postgresql-client postgresql-common), \
set ATHENA_PG_DUMP_PATH and ATHENA_PG_RESTORE_PATH, or use the official Athena Docker image which has them pre-installed."
);
}
ensure_linux_x64()?;
let paths: PgToolsPaths = download_and_extract().await?;
let _ = TOOLS_CELL.set(paths.clone());
Ok(paths)
}
fn allow_pg_tools_download() -> bool {
match std::env::var("ATHENA_PG_TOOLS_ALLOW_DOWNLOAD") {
Ok(v) => !matches!(v.as_str(), "0" | "false" | "FALSE" | "False" | "no"),
Err(_) => true,
}
}
fn env_overrides() -> Result<Option<PgToolsPaths>> {
let dump: Option<PathBuf> = std::env::var("ATHENA_PG_DUMP_PATH").ok().map(PathBuf::from);
let restore: Option<PathBuf> = std::env::var("ATHENA_PG_RESTORE_PATH")
.ok()
.map(PathBuf::from);
if dump.is_none() && restore.is_none() {
return Ok(None);
}
let dump_path: PathBuf = dump.context("ATHENA_PG_DUMP_PATH set but empty")?;
let restore_path: PathBuf = restore.unwrap_or_else(|| {
dump_path
.parent()
.map(|p| p.join("pg_restore"))
.unwrap_or_else(|| dump_path.clone())
});
if !dump_path.is_file() {
anyhow::bail!("pg_dump not found at {}", dump_path.display());
}
if !restore_path.is_file() {
anyhow::bail!("pg_restore not found at {}", restore_path.display());
}
Ok(Some(PgToolsPaths {
pg_dump: dump_path,
pg_restore: restore_path,
}))
}
fn ensure_linux_x64() -> Result<()> {
if cfg!(all(target_os = "linux", target_arch = "x86_64")) {
Ok(())
} else {
anyhow::bail!(
"Automatic pg_dump download is only supported on Linux x86_64; set ATHENA_PG_DUMP_PATH and ATHENA_PG_RESTORE_PATH instead"
)
}
}
async fn download_and_extract() -> Result<PgToolsPaths> {
let cache_root: PathBuf = cache_dir()
.unwrap_or(std::env::temp_dir())
.join("athena")
.join("pg_tools")
.join(format!("{}-linux-x64", PG_VERSION));
let bin_dir: PathBuf = cache_root.join("bin");
let pg_dump_path: PathBuf = bin_dir.join("pg_dump");
let pg_restore_path: PathBuf = bin_dir.join("pg_restore");
if pg_dump_path.is_file() && pg_restore_path.is_file() {
return Ok(PgToolsPaths {
pg_dump: pg_dump_path,
pg_restore: pg_restore_path,
});
}
tokio_fs::create_dir_all(&cache_root)
.await
.context("create cache dir")?;
tokio_fs::create_dir_all(&bin_dir).await?;
let archive_path: PathBuf = cache_root.join(PG_ARCHIVE_NAME);
if !archive_path.is_file() {
let url = std::env::var("ATHENA_PG_TOOLS_URL")
.unwrap_or_else(|_| DEFAULT_PG_LINUX_X64_URL.to_string());
download_archive(&url, &archive_path).await?;
}
extract_binaries(&archive_path, &bin_dir).await?;
Ok(PgToolsPaths {
pg_dump: pg_dump_path,
pg_restore: pg_restore_path,
})
}
async fn download_archive(url: &str, dest: &Path) -> Result<()> {
let client: reqwest::Client = reqwest::Client::builder()
.timeout(Duration::from_secs(60))
.build()?;
let resp: reqwest::Response = client
.get(url)
.send()
.await
.with_context(|| format!("downloading {}", url))?;
let status = resp.status();
if !status.is_success() {
let hint = if status.as_u16() == 403 {
" (403 Forbidden — use an image with PostgreSQL client pre-installed, or set ATHENA_PG_DUMP_PATH/ATHENA_PG_RESTORE_PATH)"
} else {
""
};
anyhow::bail!("download failed ({}){}: {}", status, hint, url);
}
let bytes = resp.bytes().await?;
let mut file = tokio_fs::File::create(dest).await?;
file.write_all(&bytes).await?;
file.flush().await?;
Ok(())
}
#[cfg(target_os = "linux")]
async fn maybe_install_pg_tools_with_apt() {
use tokio::process::Command;
if let Ok(val) = std::env::var("ATHENA_AUTO_INSTALL_PG_TOOLS") {
if matches!(val.as_str(), "0" | "false" | "FALSE" | "False") {
tracing::info!(
"ATHENA_AUTO_INSTALL_PG_TOOLS disabled; skipping automatic pg tools install via apt-get."
);
return;
}
}
if which("apt-get").is_err() {
tracing::info!("apt-get not found on PATH; skipping automatic pg tools install.");
return;
}
tracing::info!(
"Attempting to install PostgreSQL client tools via apt-get (postgresql-client, postgresql-common)."
);
match Command::new("apt-get").arg("update").status().await {
Ok(status) if status.success() => {
tracing::info!("apt-get update succeeded before installing PostgreSQL client tools.");
}
Ok(status) => {
tracing::warn!(
?status,
"apt-get update failed; skipping automatic pg tools install."
);
return;
}
Err(err) => {
tracing::warn!(error = %err, "Failed to invoke apt-get update; skipping automatic pg tools install.");
return;
}
}
match Command::new("apt-get")
.args(["install", "-y", "postgresql-client", "postgresql-common"])
.status()
.await
{
Ok(status) if status.success() => {
tracing::info!(
"apt-get install postgresql-client postgresql-common succeeded; pg_dump/pg_restore should now be on PATH."
);
}
Ok(status) => {
tracing::warn!(
?status,
"apt-get install postgresql-client postgresql-common failed; pg tools may still be missing."
);
}
Err(err) => {
tracing::warn!(error = %err, "Failed to invoke apt-get install; pg tools may still be missing.");
}
}
}
async fn extract_binaries(archive_path: &Path, bin_dir: &Path) -> Result<()> {
let archive_path = archive_path.to_owned();
let bin_dir = bin_dir.to_owned();
spawn_blocking(move || -> Result<()> {
let file: File = File::open(&archive_path)?;
let decoder: GzDecoder<File> = GzDecoder::new(file);
let mut archive: Archive<GzDecoder<File>> = Archive::new(decoder);
let mut extracted: usize = 0usize;
for entry in archive.entries()? {
let mut entry: tar::Entry<'_, GzDecoder<File>> = entry?;
let path = entry.path()?;
let name: &str = path
.file_name()
.and_then(|v| v.to_str())
.unwrap_or_default();
if name == "pg_dump" || name == "pg_restore" {
let dest = bin_dir.join(name);
entry.unpack(&dest)?;
#[cfg(unix)]
{
let mut perms: std::fs::Permissions = std::fs::metadata(&dest)?.permissions();
use std::os::unix::fs::PermissionsExt;
perms.set_mode(0o755);
std::fs::set_permissions(&dest, perms)?;
}
extracted += 1;
}
}
if extracted < 2 {
anyhow::bail!("pg_dump/pg_restore not found in downloaded archive");
}
Ok(())
})
.await?
}
#[cfg(test)]
mod tests {
use std::path::Path;
use super::{
maybe_download_newer_pg_tools, parse_major_version, parse_pg_dump_version_major,
parse_version_token, pg_dump_major_version,
};
#[test]
fn parse_pg_dump_version_major_from_standard_output() {
let output = "pg_dump (PostgreSQL) 17.7";
assert_eq!(parse_pg_dump_version_major(output), Some(17));
}
#[test]
fn parse_pg_dump_version_major_from_ubuntu_output() {
let output = "pg_dump (PostgreSQL) 16.13 (Ubuntu 16.13-0ubuntu0.24.04.1)";
assert_eq!(parse_pg_dump_version_major(output), Some(16));
}
#[test]
fn parse_pg_dump_version_major_from_token_suffix() {
let output = "pg_dump (PostgreSQL) 16.13-0ubuntu0.24.04.1";
assert_eq!(parse_pg_dump_version_major(output), Some(16));
}
#[test]
fn parse_major_version_handles_invalid() {
assert_eq!(parse_major_version(""), None);
assert_eq!(parse_major_version("abc"), None);
assert_eq!(parse_major_version("9"), Some(9));
assert_eq!(parse_major_version("17.7"), Some(17));
}
#[test]
fn parse_version_token_handles_prefix_and_suffix() {
assert_eq!(parse_version_token("16.13-0ubuntu0.24.04.1"), Some(16));
assert_eq!(parse_version_token("v17.7"), Some(17));
assert_eq!(parse_version_token("PostgreSQL17.7"), Some(17));
assert_eq!(parse_version_token("postgresql"), None);
}
#[tokio::test]
async fn maybe_download_returns_none_when_version_check_fails() {
let missing_path = Path::new("/nonexistent/pg_dump");
assert!(maybe_download_newer_pg_tools(missing_path).await.is_none());
}
#[tokio::test]
async fn pg_dump_major_version_parses_pg_dump_when_available() {
let Ok(pg_dump_path) = which::which("pg_dump") else {
return;
};
let parsed = pg_dump_major_version(&pg_dump_path).await;
assert!(parsed.is_some());
}
#[test]
fn resolve_pg_tools_from_dir_uses_expected_layout() {
let tmp =
std::env::temp_dir().join(format!("athena_pg_tools_test_{}", uuid::Uuid::new_v4()));
let bin_dir = tmp.join("17").join("bin");
std::fs::create_dir_all(&bin_dir).expect("create test bin dir");
let dump = bin_dir.join(super::tool_filename("pg_dump"));
let restore = bin_dir.join(super::tool_filename("pg_restore"));
std::fs::write(&dump, b"").expect("create dump file");
std::fs::write(&restore, b"").expect("create restore file");
unsafe {
std::env::set_var("ATHENA_PG_TOOLS_DIR", &tmp);
}
let resolved = super::resolve_pg_tools_from_dir(17).expect("resolved tools");
assert_eq!(resolved.pg_dump, dump);
assert_eq!(resolved.pg_restore, restore);
let _ = std::fs::remove_dir_all(&tmp);
}
}