athena_rs 2.0.2

Database gateway API
Documentation
use anyhow::{Context, Result};
use dirs::cache_dir;
use flate2::read::GzDecoder;
use once_cell::sync::OnceCell;
use std::fs::{self, File};
use std::path::{Path, PathBuf};
use std::time::Duration;
use tar::Archive;
use tokio::task::spawn_blocking;
use tokio::{fs as tokio_fs, io::AsyncWriteExt};
use which::which;

const PG_VERSION: &str = "16.3";
const DEFAULT_PG_LINUX_X64_URL: &str =
    "https://get.enterprisedb.com/postgresql/postgresql-16.3-1-linux-x64-binaries.tar.gz";
const PG_ARCHIVE_NAME: &str = "postgresql-16.3-1-linux-x64-binaries.tar.gz";

static TOOLS_CELL: OnceCell<PgToolsPaths> = OnceCell::new();

/// Paths to pg_dump and pg_restore binaries.
#[derive(Clone, Debug)]
pub struct PgToolsPaths {
    pub pg_dump: PathBuf,
    pub pg_restore: PathBuf,
}

/// Ensure pg_dump and pg_restore are available, downloading a portable bundle if needed.
///
/// Resolution order:
/// 1) `ATHENA_PG_DUMP_PATH` / `ATHENA_PG_RESTORE_PATH`
/// 2) Binaries found on PATH
/// 3) Download Linux x86_64 portable client bundle into cache (~/.cache/athena/pg_tools/<ver>)
pub async fn ensure_pg_tools() -> Result<PgToolsPaths> {
    use std::fs;

    if let Some(cached) = TOOLS_CELL.get() {
        return Ok(cached.clone());
    }

    // 1) Explicit env overrides from current environment.
    if let Some(paths) = env_overrides()? {
        let _ = TOOLS_CELL.set(paths.clone());
        return Ok(paths);
    }

    // 1.5) Try to read from .env if not already set
    {
        // Load .env if present and re-try env_overrides
        let dotenv_path: Option<PathBuf> = std::env::current_dir()
            .ok()
            .map(|dir| dir.join(".env"))
            .filter(|path| path.exists());
        if let Some(dotenv_file) = dotenv_path {
            // Load .env into process env once; ignore errors if file is malformed or missing.
            let _ = dotenv::from_path(&dotenv_file);
            // After loading .env, re-check for explicit env overrides
            if let Some(paths) = env_overrides()? {
                let _ = TOOLS_CELL.set(paths.clone());
                return Ok(paths);
            }
        }
    }

    // 2) PATH lookup.
    #[cfg(target_os = "windows")]
    let candidates = [("pg_dump", "pg_dump.exe"), ("pg_restore", "pg_restore.exe")];
    #[cfg(not(target_os = "windows"))]
    let candidates = [("pg_dump", "pg_dump"), ("pg_restore", "pg_restore")];

    #[cfg(target_os = "windows")]
    let which_with_exe = |tool_name: &str, exe_name: &str| -> Option<PathBuf> {
        // Prefer .exe if found, otherwise fallback to tool_name
        which(exe_name).ok().or_else(|| which(tool_name).ok())
    };
    #[cfg(not(target_os = "windows"))]
    let which_with_exe =
        |tool_name: &str, _exe_name: &str| -> Option<PathBuf> { which(tool_name).ok() };

    let dump_opt: Option<PathBuf> = which_with_exe(candidates[0].0, candidates[0].1);
    let restore_opt: Option<PathBuf> = which_with_exe(candidates[1].0, candidates[1].1);

    if let (Some(dump), Some(restore)) = (dump_opt, restore_opt) {
        let paths: PgToolsPaths = PgToolsPaths {
            pg_dump: dump,
            pg_restore: restore,
        };
        let _ = TOOLS_CELL.set(paths.clone());
        return Ok(paths);
    }

    // 3) Download portable bundle (Linux x86_64 only).
    ensure_linux_x64()?;
    let paths: PgToolsPaths = download_and_extract().await?;
    let _ = TOOLS_CELL.set(paths.clone());
    Ok(paths)
}

fn env_overrides() -> Result<Option<PgToolsPaths>> {
    let dump: Option<PathBuf> = std::env::var("ATHENA_PG_DUMP_PATH").ok().map(PathBuf::from);
    let restore: Option<PathBuf> = std::env::var("ATHENA_PG_RESTORE_PATH")
        .ok()
        .map(PathBuf::from);

    if dump.is_none() && restore.is_none() {
        return Ok(None);
    }

    let dump_path: PathBuf = dump.context("ATHENA_PG_DUMP_PATH set but empty")?;
    let restore_path: PathBuf = restore.unwrap_or_else(|| {
        dump_path
            .parent()
            .map(|p| p.join("pg_restore"))
            .unwrap_or_else(|| dump_path.clone())
    });

    if !dump_path.is_file() {
        anyhow::bail!("pg_dump not found at {}", dump_path.display());
    }
    if !restore_path.is_file() {
        anyhow::bail!("pg_restore not found at {}", restore_path.display());
    }

    Ok(Some(PgToolsPaths {
        pg_dump: dump_path,
        pg_restore: restore_path,
    }))
}

fn ensure_linux_x64() -> Result<()> {
    if cfg!(all(target_os = "linux", target_arch = "x86_64")) {
        Ok(())
    } else {
        anyhow::bail!(
            "Automatic pg_dump download is only supported on Linux x86_64; set ATHENA_PG_DUMP_PATH and ATHENA_PG_RESTORE_PATH instead"
        )
    }
}

async fn download_and_extract() -> Result<PgToolsPaths> {
    let cache_root: PathBuf = cache_dir()
        .unwrap_or(std::env::temp_dir())
        .join("athena")
        .join("pg_tools")
        .join(format!("{}-linux-x64", PG_VERSION));
    let bin_dir: PathBuf = cache_root.join("bin");
    let pg_dump_path: PathBuf = bin_dir.join("pg_dump");
    let pg_restore_path: PathBuf = bin_dir.join("pg_restore");

    if pg_dump_path.is_file() && pg_restore_path.is_file() {
        return Ok(PgToolsPaths {
            pg_dump: pg_dump_path,
            pg_restore: pg_restore_path,
        });
    }

    tokio_fs::create_dir_all(&cache_root)
        .await
        .context("create cache dir")?;
    tokio_fs::create_dir_all(&bin_dir).await?;

    let archive_path: PathBuf = cache_root.join(PG_ARCHIVE_NAME);
    if !archive_path.is_file() {
        let url = std::env::var("ATHENA_PG_TOOLS_URL")
            .unwrap_or_else(|_| DEFAULT_PG_LINUX_X64_URL.to_string());
        download_archive(&url, &archive_path).await?;
    }

    extract_binaries(&archive_path, &bin_dir).await?;

    Ok(PgToolsPaths {
        pg_dump: pg_dump_path,
        pg_restore: pg_restore_path,
    })
}

async fn download_archive(url: &str, dest: &Path) -> Result<()> {
    let client = reqwest::Client::builder()
        .timeout(Duration::from_secs(60))
        .build()?;

    let resp: reqwest::Response = client
        .get(url)
        .send()
        .await
        .with_context(|| format!("downloading {}", url))?;
    let status = resp.status();
    if !status.is_success() {
        anyhow::bail!("download failed ({}): {}", status, url);
    }

    let bytes = resp.bytes().await?;
    let mut file = tokio_fs::File::create(dest).await?;
    file.write_all(&bytes).await?;
    file.flush().await?;
    Ok(())
}

async fn extract_binaries(archive_path: &Path, bin_dir: &Path) -> Result<()> {
    let archive_path = archive_path.to_owned();
    let bin_dir = bin_dir.to_owned();

    spawn_blocking(move || -> Result<()> {
        let file = File::open(&archive_path)?;
        let decoder = GzDecoder::new(file);
        let mut archive = Archive::new(decoder);

        let mut extracted = 0usize;
        for entry in archive.entries()? {
            let mut entry = entry?;
            let path = entry.path()?;
            let name = path
                .file_name()
                .and_then(|v| v.to_str())
                .unwrap_or_default();
            if name == "pg_dump" || name == "pg_restore" {
                let dest = bin_dir.join(name);
                entry.unpack(&dest)?;
                let mut perms = fs::metadata(&dest)?.permissions();
                #[cfg(unix)]
                {
                    use std::os::unix::fs::PermissionsExt;
                    perms.set_mode(0o755);
                    fs::set_permissions(&dest, perms)?;
                }
                extracted += 1;
            }
        }

        if extracted < 2 {
            anyhow::bail!("pg_dump/pg_restore not found in downloaded archive");
        }

        Ok(())
    })
    .await?
}