use std::env;
use std::path::{Path as FsPath, PathBuf};
const B: &str = "\x1b[1m";
const G: &str = "\x1b[32m";
const C: &str = "\x1b[36m";
const Y: &str = "\x1b[33m"; const R: &str = "\x1b[31m"; const Z: &str = "\x1b[0m";
use actix_web::{
HttpRequest, HttpResponse, delete, get, patch, post,
web::{self, Data, Json, Path},
};
use aws_config::BehaviorVersion;
use aws_sdk_s3::Client as S3Client;
use aws_sdk_s3::config::{Credentials, Region};
use aws_smithy_types::timeout::TimeoutConfig;
use chrono::Utc;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::{Connection, FromRow, PgPool, QueryBuilder, Row};
use std::collections::{HashMap, HashSet};
use std::sync::Mutex;
use tokio::process::Command;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::AppState;
use crate::api::auth::authorize_static_admin_key;
use crate::api::response::{
api_success, bad_request, internal_error, not_found, service_unavailable,
};
use crate::parser::resolve_postgres_uri;
use crate::utils::pg_tools::{PgToolsPaths, ensure_pg_tools, resolve_pg_tools_from_dir};
const HEADER_PG_DUMP_PATH: &str = "x-athena-pg-dump-path";
const HEADER_PG_RESTORE_PATH: &str = "x-athena-pg-restore-path";
static SSLMODE_REQUIRE_CACHE: Lazy<Mutex<HashSet<String>>> =
Lazy::new(|| Mutex::new(HashSet::new()));
static BACKUP_JOB_CANCEL_TOKENS: Lazy<Mutex<HashMap<i64, CancellationToken>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
const ERR_BACKUP_CANCELLED: &str = "BACKUP_CANCELLED";
const ERR_RESTORE_CANCELLED: &str = "RESTORE_CANCELLED";
const ERR_S3_DOWNLOAD_FAILED: &str = "S3_DOWNLOAD_FAILED";
const ERR_S3_READ_FAILED: &str = "S3_READ_FAILED";
const S3_DOWNLOAD_MAX_ATTEMPTS: u32 = 3;
const S3_DOWNLOAD_RETRY_BASE_MS: u64 = 500;
const PG_DUMP_PROGRESS_MIN_PCT: i32 = 5;
const PG_DUMP_PROGRESS_MAX_PCT: i32 = 68;
const BACKUP_PROGRESS_ARCHIVING_PCT: i32 = 72;
const BACKUP_PROGRESS_UPLOADING_PCT: i32 = 82;
const BACKUP_PROGRESS_UPLOAD_STORED_PCT: i32 = 94;
fn register_backup_cancel_token(job_id: i64, token: CancellationToken) {
let mut g = BACKUP_JOB_CANCEL_TOKENS.lock().unwrap();
g.insert(job_id, token);
}
fn unregister_backup_cancel_token(job_id: i64) {
let mut g = BACKUP_JOB_CANCEL_TOKENS.lock().unwrap();
g.remove(&job_id);
}
fn trigger_backup_cancel_token(job_id: i64) {
let g = BACKUP_JOB_CANCEL_TOKENS.lock().unwrap();
if let Some(t) = g.get(&job_id) {
t.cancel();
}
}
struct BackupJobCancelGuard(i64);
impl Drop for BackupJobCancelGuard {
fn drop(&mut self) {
unregister_backup_cancel_token(self.0);
}
}
async fn is_backup_job_cancelled(pool: &PgPool, job_id: i64) -> bool {
match sqlx::query_scalar::<_, String>("SELECT status::text FROM backup_jobs WHERE id = $1")
.bind(job_id)
.fetch_optional(pool)
.await
{
Ok(Some(s)) => s == "cancelled",
_ => false,
}
}
fn kill_pid_best_effort(pid: Option<u32>) {
let Some(pid) = pid else {
return;
};
#[cfg(unix)]
{
let _ = std::process::Command::new("kill")
.args(["-TERM", &pid.to_string()])
.status();
}
#[cfg(windows)]
{
let _ = std::process::Command::new("taskkill")
.args(["/PID", &pid.to_string(), "/F"])
.status();
}
}
async fn command_output_cancellable(
cmd: &mut Command,
cancel: Option<&CancellationToken>,
pg_dump_progress: Option<(&PgDumpProgressTracker, &FsPath)>,
) -> Result<std::process::Output, String> {
if cancel.is_none() && pg_dump_progress.is_none() {
return cmd.output().await.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
"pg_dump binary could not be resolved (PATH/env/download). Set ATHENA_PG_DUMP_PATH to override."
.to_string()
} else {
format!("Failed to start pg_dump: {e}")
}
});
}
let child = cmd.spawn().map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
"pg_dump binary could not be resolved (PATH/env/download). Set ATHENA_PG_DUMP_PATH to override."
.to_string()
} else {
format!("Failed to start pg_dump: {e}")
}
})?;
let pid = child.id();
let out_fut = child.wait_with_output();
tokio::pin!(out_fut);
let mut tick = tokio::time::interval(std::time::Duration::from_secs(1));
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let _ = tick.tick().await;
loop {
match cancel {
Some(c) => {
tokio::select! {
res = &mut out_fut => {
return res.map_err(|e| format!("Failed to wait for pg_dump: {e}"));
}
_ = c.cancelled() => {
kill_pid_best_effort(pid);
return Err(ERR_BACKUP_CANCELLED.to_string());
}
_ = tick.tick(), if pg_dump_progress.is_some() => {
if let Some((tracker, dump_path)) = pg_dump_progress {
tracker.report(dump_path).await;
}
}
}
}
None => {
tokio::select! {
res = &mut out_fut => {
return res.map_err(|e| format!("Failed to wait for pg_dump: {e}"));
}
_ = tick.tick(), if pg_dump_progress.is_some() => {
if let Some((tracker, dump_path)) = pg_dump_progress {
tracker.report(dump_path).await;
}
}
}
}
}
}
}
async fn command_status_cancellable(
cmd: &mut Command,
cancel: Option<&CancellationToken>,
not_found_msg: &'static str,
) -> Result<std::process::ExitStatus, String> {
match cancel {
None => cmd.status().await.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
not_found_msg.to_string()
} else {
format!("Failed to start command: {e}")
}
}),
Some(c) => {
let mut child = cmd.spawn().map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
not_found_msg.to_string()
} else {
format!("Failed to start command: {e}")
}
})?;
let pid = child.id();
let wait_fut = child.wait();
tokio::select! {
res = wait_fut => res.map_err(|e| format!("Failed to wait for process: {e}")),
_ = c.cancelled() => {
kill_pid_best_effort(pid);
Err(ERR_RESTORE_CANCELLED.to_string())
}
}
}
}
}
fn s3_retry_delay(attempt: u32) -> std::time::Duration {
let factor = 1_u64 << attempt.saturating_sub(1).min(10);
std::time::Duration::from_millis(S3_DOWNLOAD_RETRY_BASE_MS.saturating_mul(factor))
}
async fn download_s3_object_with_retry(
s3_client: &S3Client,
bucket: &str,
key: &str,
max_attempts: u32,
) -> Result<web::Bytes, String> {
let attempts = max_attempts.max(1);
for attempt in 1..=attempts {
if attempt > 1 {
tracing::warn!(
"Retrying S3 object download for key='{}' (attempt {}/{})",
key,
attempt,
attempts
);
}
let resp: aws_sdk_s3::operation::get_object::GetObjectOutput =
match s3_client.get_object().bucket(bucket).key(key).send().await {
Ok(resp) => resp,
Err(e) => {
tracing::error!(
error = ?e,
"S3 get_object failed for key='{}' on attempt {}/{}",
key,
attempt,
attempts
);
if attempt < attempts {
tokio::time::sleep(s3_retry_delay(attempt)).await;
continue;
}
return Err(format!("{ERR_S3_DOWNLOAD_FAILED}: {e}"));
}
};
match resp.body.collect().await {
Ok(body) => return Ok(body.into_bytes()),
Err(e) => {
tracing::error!(
error = ?e,
"S3 body stream read failed for key='{}' on attempt {}/{}",
key,
attempt,
attempts
);
if attempt < attempts {
tokio::time::sleep(s3_retry_delay(attempt)).await;
continue;
}
return Err(format!(
"{ERR_S3_READ_FAILED}: {e}. Retry the restore once; if persistent, validate endpoint/network path from this host."
));
}
}
}
Err(format!(
"{ERR_S3_READ_FAILED}: exhausted retries unexpectedly for key '{key}'"
))
}
fn s3_env(key: &str) -> Option<String> {
env::var(key).ok().filter(|v| !v.trim().is_empty())
}
struct S3Config {
bucket: String,
region: String,
prefix: String,
access_key: Option<String>,
secret_key: Option<String>,
endpoint: Option<String>,
}
impl S3Config {
fn from_env() -> Option<Self> {
let bucket: String = s3_env("ATHENA_BACKUP_S3_BUCKET")?;
Some(Self {
bucket,
region: s3_env("ATHENA_BACKUP_S3_REGION").unwrap_or_else(|| "us-east-1".to_string()),
prefix: s3_env("ATHENA_BACKUP_S3_PREFIX").unwrap_or_else(|| "backups".to_string()),
access_key: s3_env("ATHENA_BACKUP_S3_ACCESS_KEY"),
secret_key: s3_env("ATHENA_BACKUP_S3_SECRET_KEY"),
endpoint: s3_env("ATHENA_BACKUP_S3_ENDPOINT"),
})
}
}
async fn build_s3_client(cfg: &S3Config) -> S3Client {
tracing::info!(
"Building S3 client with config: bucket={}, region={}, prefix={:?}, endpoint={:?}, access_key_set={}",
cfg.bucket,
cfg.region,
cfg.prefix,
cfg.endpoint,
cfg.access_key.is_some()
);
let region: Region = Region::new(cfg.region.clone());
let aws_config: aws_config::SdkConfig = if let (Some(ak), Some(sk)) =
(&cfg.access_key, &cfg.secret_key)
{
tracing::info!("Using provided AWS credentials for S3 client.");
let creds: Credentials = Credentials::new(ak, sk, None, None, "athena-env");
let mut builder: aws_config::ConfigLoader = aws_config::defaults(BehaviorVersion::latest())
.region(region)
.credentials_provider(creds);
if let Some(ep) = &cfg.endpoint {
tracing::info!("S3 client will use custom endpoint: {ep}");
builder = builder.endpoint_url(ep);
}
builder.load().await
} else {
tracing::info!("Using default AWS credentials/config for S3 client.");
let mut builder: aws_config::ConfigLoader =
aws_config::defaults(BehaviorVersion::latest()).region(region);
if let Some(ep) = &cfg.endpoint {
tracing::info!("S3 client will use custom endpoint: {ep}");
builder = builder.endpoint_url(ep);
}
builder.load().await
};
let mut s3_cfg_builder: aws_sdk_s3::config::Builder =
aws_sdk_s3::config::Builder::from(&aws_config);
if cfg.endpoint.is_some() {
tracing::info!("Forcing path-style for S3 client.");
s3_cfg_builder = s3_cfg_builder.force_path_style(true);
}
let timeout_cfg = TimeoutConfig::builder()
.read_timeout(std::time::Duration::from_secs(7200))
.build();
s3_cfg_builder = s3_cfg_builder.timeout_config(timeout_cfg);
tracing::info!("{}✓{} S3 client built.", G, Z);
S3Client::from_conf(s3_cfg_builder.build())
}
async fn latest_client_backup_size_bytes(
s3_client: &S3Client,
cfg: &S3Config,
client_name: &str,
) -> Option<i64> {
let prefix = format!("{}/{}/", cfg.prefix.trim_end_matches('/'), client_name);
let mut continuation: Option<String> = None;
let mut newest: Option<(i64, i64)> = None;
loop {
let mut request = s3_client
.list_objects_v2()
.bucket(&cfg.bucket)
.prefix(&prefix)
.max_keys(1000);
if let Some(token) = continuation.as_deref() {
request = request.continuation_token(token);
}
let response = match request.send().await {
Ok(response) => response,
Err(err) => {
tracing::warn!(
client_name = %client_name,
prefix = %prefix,
error = %err,
"Failed to list previous backups for progress estimation; using fallback estimator"
);
return newest.map(|(_, size)| size);
}
};
for object in response.contents() {
let Some(key) = object.key() else {
continue;
};
let size = object.size().unwrap_or(0);
let modified_secs = object
.last_modified()
.map(|timestamp| timestamp.secs())
.unwrap_or(i64::MIN);
maybe_record_backup_object(&mut newest, key, modified_secs, size);
}
if !response.is_truncated().unwrap_or(false) {
break;
}
continuation = response.next_continuation_token().map(ToString::to_string);
if continuation.is_none() {
break;
}
}
newest.map(|(_, size)| size)
}
fn update_latest_backup_candidate(
newest: &mut Option<(i64, i64)>,
modified_secs: i64,
size_bytes: i64,
) {
match *newest {
Some((known_secs, known_size)) => {
if modified_secs > known_secs
|| (modified_secs == known_secs && size_bytes >= known_size)
{
*newest = Some((modified_secs, size_bytes));
}
}
None => *newest = Some((modified_secs, size_bytes)),
}
}
fn maybe_record_backup_object(
newest: &mut Option<(i64, i64)>,
key: &str,
modified_secs: i64,
size_bytes: i64,
) {
if !key.ends_with(".tar.gz") || size_bytes <= 0 {
return;
}
update_latest_backup_candidate(newest, modified_secs, size_bytes);
}
#[derive(Debug, Deserialize)]
struct CreateBackupRequest {
#[serde(default)]
client_name: Option<String>,
#[serde(default)]
pg_uri: Option<String>,
#[serde(default)]
label: Option<String>,
#[serde(default)]
timeout_seconds: Option<i32>,
#[serde(default)]
recovery_strategy: Option<BackupRecoveryStrategy>,
}
#[derive(Debug, Deserialize, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum BackupRecoveryStrategy {
None,
RepairMissingRoleOids,
}
impl Default for BackupRecoveryStrategy {
fn default() -> Self {
Self::None
}
}
#[derive(Debug, Deserialize)]
struct RestoreBackupRequest {
#[serde(default)]
client_name: Option<String>,
#[serde(default)]
pg_uri: Option<String>,
#[serde(default)]
timeout_seconds: Option<i32>,
}
#[derive(Debug, Deserialize)]
struct CreateScheduleRequest {
client_name: String,
#[serde(default)]
pg_uri: Option<String>,
frequency: String,
#[serde(default = "default_time")]
time: String,
#[serde(default)]
day_of_week: Option<i32>,
#[serde(default)]
day_of_month: Option<i32>,
#[serde(default)]
label: Option<String>,
#[serde(default = "default_timeout")]
timeout_seconds: i32,
}
fn default_time() -> String {
"02:00".to_string()
}
fn default_timeout() -> i32 {
3600
}
fn clamp_timeout(seconds: i32) -> i32 {
let min: i32 = 60;
let max: i32 = 86_400;
if seconds < min {
min
} else if seconds > max {
max
} else {
seconds
}
}
#[derive(Debug, Deserialize)]
struct UpdateScheduleRequest {
#[serde(default)]
enabled: Option<bool>,
#[serde(default)]
frequency: Option<String>,
#[serde(default)]
time: Option<String>,
#[serde(default)]
day_of_week: Option<Option<i32>>,
#[serde(default)]
day_of_month: Option<Option<i32>>,
#[serde(default)]
label: Option<Option<String>>,
#[serde(default)]
timeout_seconds: Option<i32>,
}
#[derive(Debug, Serialize, FromRow)]
struct BackupScheduleRow {
id: i64,
client_name: String,
pg_uri: Option<String>,
frequency: String,
cron_expression: String,
time_of_day: chrono::NaiveTime,
day_of_week: Option<i32>,
day_of_month: Option<i32>,
label: Option<String>,
enabled: bool,
timeout_seconds: i32,
last_run_at: Option<chrono::DateTime<Utc>>,
last_job_id: Option<i64>,
next_run_at: Option<chrono::DateTime<Utc>>,
created_at: chrono::DateTime<Utc>,
updated_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Serialize)]
struct BackupObject {
id: String,
key: String,
client_name: String,
label: Option<String>,
size_bytes: i64,
created_at: String,
}
#[derive(Debug, Serialize, FromRow)]
struct BackupJobSummary {
id: i64,
job_type: String,
client_name: Option<String>,
status: String,
progress_pct: Option<i32>,
progress_stage: Option<String>,
s3_bucket: Option<String>,
s3_key: Option<String>,
label: Option<String>,
size_bytes: Option<i64>,
error_message: Option<String>,
started_at: chrono::DateTime<Utc>,
updated_at: chrono::DateTime<Utc>,
completed_at: Option<chrono::DateTime<Utc>>,
}
#[derive(Debug, Serialize, FromRow)]
struct BackupJobLog {
id: i64,
job_id: i64,
level: String,
message: String,
created_at: chrono::DateTime<Utc>,
}
fn resolve_pg_uri(state: &AppState, client_name: &str) -> Result<String, HttpResponse> {
tracing::info!("Resolving Postgres URI for client_name={}", client_name);
let registered: crate::drivers::postgresql::sqlx_driver::RegisteredClient = state
.pg_registry
.registered_client(client_name)
.ok_or_else(|| {
tracing::info!("Client '{}' not found in pg_registry.", client_name);
bad_request(
"Unknown client",
format!("No Postgres client named '{}' is registered.", client_name),
)
})?;
let uri: String = registered
.config_uri_template
.as_deref()
.map(resolve_postgres_uri)
.or(registered.pg_uri)
.ok_or_else(|| {
tracing::info!("No usable Postgres URI found for client '{}'", client_name);
bad_request(
"Client URI unavailable",
format!("No Postgres URI is available for client '{}'.", client_name),
)
})?;
tracing::info!("{}Resolved URI{} for {}: {}", C, Z, client_name, &uri);
Ok(uri)
}
fn extract_pg_password(pg_uri: &str) -> (String, Option<String>) {
tracing::info!("Extracting password from Postgres URI for command safety.");
let prefix = if pg_uri.starts_with("postgresql://") {
"postgresql://"
} else if pg_uri.starts_with("postgres://") {
"postgres://"
} else {
tracing::info!("Nonstandard Postgres URI, not extracting password.");
return (pg_uri.to_string(), None);
};
let after_scheme = &pg_uri[prefix.len()..];
if let Some(at_pos) = after_scheme.rfind('@') {
let userinfo: &str = &after_scheme[..at_pos];
let after_at: &str = &after_scheme[at_pos..];
if let Some(colon_pos) = userinfo.find(':') {
let user: &str = &userinfo[..colon_pos];
let password: String = userinfo[colon_pos + 1..].to_string();
let sanitized: String = format!("{}{}{}", prefix, user, after_at);
tracing::info!("Password extracted from URI (hidden); sanitized URI prepared.");
return (sanitized, Some(password));
}
}
tracing::info!("No password found in Postgres URI.");
(pg_uri.to_string(), None)
}
fn postgres_uri_target_is_local_pg_tools(pg_uri: &str) -> bool {
let after_scheme = if pg_uri.starts_with("postgresql://") {
&pg_uri["postgresql://".len()..]
} else if pg_uri.starts_with("postgres://") {
&pg_uri["postgres://".len()..]
} else {
return false;
};
let host_part = if let Some(at_pos) = after_scheme.rfind('@') {
&after_scheme[at_pos + 1..]
} else {
after_scheme
};
if host_part.is_empty() || host_part.starts_with('/') {
return true;
}
let host_segment = host_part
.split(|c| c == '/' || c == '?')
.next()
.unwrap_or(host_part);
let host_only = if host_segment.starts_with('[') {
let end = host_segment.find(']').unwrap_or(host_segment.len());
&host_segment[1..end]
} else {
host_segment
.split_once(':')
.map(|(h, _)| h)
.unwrap_or(host_segment)
};
let h = host_only.to_lowercase();
h == "localhost" || h == "127.0.0.1" || h == "::1"
}
fn pg_uri_set_sslmode(pg_uri: &str, sslmode_value: &str) -> String {
let (base, query_opt) = match pg_uri.split_once('?') {
Some((b, q)) => (b, Some(q)),
None => (pg_uri, None),
};
let mut parts: Vec<String> = Vec::new();
if let Some(q) = query_opt {
for seg in q.split('&') {
if seg.is_empty() {
continue;
}
let skip = seg
.split_once('=')
.map(|(k, _)| k.eq_ignore_ascii_case("sslmode"))
.unwrap_or(false);
if !skip {
parts.push(seg.to_string());
}
}
}
parts.push(format!("sslmode={sslmode_value}"));
format!("{}?{}", base, parts.join("&"))
}
fn sslmode_is_present(pg_uri: &str) -> bool {
pg_uri.contains("sslmode=")
}
fn with_sslmode_require(pg_uri: &str) -> String {
if sslmode_is_present(pg_uri) {
return pg_uri.to_string();
}
if pg_uri.contains('?') {
format!("{pg_uri}&sslmode=require")
} else {
format!("{pg_uri}?sslmode=require")
}
}
fn ssl_retry_dedup_enabled() -> bool {
matches!(
std::env::var("ATHENA_PG_SSLMODE_DEDUP").ok().as_deref(),
Some("1") | Some("true") | Some("TRUE") | Some("yes") | Some("YES")
)
}
fn sslmode_cache_key(pg_uri: &str) -> String {
let (safe, _pw) = extract_pg_password(pg_uri);
safe
}
fn sslmode_cached_require(pg_uri: &str) -> bool {
if postgres_uri_target_is_local_pg_tools(pg_uri) {
return false;
}
if !ssl_retry_dedup_enabled() {
return false;
}
let key = sslmode_cache_key(pg_uri);
SSLMODE_REQUIRE_CACHE
.lock()
.ok()
.is_some_and(|set| set.contains(&key))
}
fn sslmode_cache_mark_require(pg_uri: &str) {
if !ssl_retry_dedup_enabled() {
return;
}
let key: String = sslmode_cache_key(pg_uri);
if let Ok(mut set) = SSLMODE_REQUIRE_CACHE.lock() {
set.insert(key);
}
}
fn pg_uri_for_backup_tools(pg_uri: &str) -> String {
if postgres_uri_target_is_local_pg_tools(pg_uri) {
return pg_uri_set_sslmode(pg_uri, "disable");
}
if sslmode_cached_require(pg_uri) && !sslmode_is_present(pg_uri) {
with_sslmode_require(pg_uri)
} else {
pg_uri.to_string()
}
}
fn looks_like_ssl_required_error(text: &str) -> bool {
let t = text.to_lowercase();
(t.contains("ssl off") && t.contains("no pg_hba.conf entry"))
|| (t.contains("sslmode") && t.contains("require"))
|| (t.contains("ssl") && t.contains("required"))
}
fn summarize_pg_dump_failure(stderr: &str, stdout: &str, code: i32) -> String {
let mut missing_extension_warnings: usize = 0usize;
let mut other_warnings: Vec<String> = Vec::new();
let mut errors: Vec<String> = Vec::new();
let mut orphan_role_oid: Option<String> = None;
for raw in stderr.lines() {
let line = raw.trim();
if line.is_empty() {
continue;
}
if line.starts_with("pg_dump: warning: could not find referenced extension") {
missing_extension_warnings += 1;
continue;
}
if line.starts_with("pg_dump: warning:") {
if !other_warnings.iter().any(|w| w == line) {
other_warnings.push(line.to_string());
}
continue;
}
if line.starts_with("pg_dump: error:") {
if orphan_role_oid.is_none()
&& line.contains("role with OID")
&& line.contains("does not exist")
{
let oid_fragment: Option<String> = line
.split("role with OID")
.nth(1)
.and_then(|rest| rest.split("does not exist").next())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned);
orphan_role_oid = oid_fragment;
}
errors.push(line.to_string());
}
}
let mut parts: Vec<String> = vec![format!("pg_dump exited with code {code}.")];
if missing_extension_warnings > 0 {
parts.push(format!(
"{missing_extension_warnings} warning(s) about missing referenced extension metadata were suppressed."
));
}
if !other_warnings.is_empty() {
parts.push(format!("Warnings: {}", other_warnings.join(" | ")));
}
if !errors.is_empty() {
parts.push(format!("Errors: {}", errors.join(" | ")));
} else if !stderr.trim().is_empty() {
parts.push(stderr.trim().to_string());
} else if !stdout.trim().is_empty() {
parts.push(stdout.trim().to_string());
}
if let Some(oid) = orphan_role_oid {
parts.push(format!(
"Detected orphaned role reference (OID {oid}). The source database contains objects that still reference a dropped role; repair ownership/policy references in the source DB, then retry backup/clone."
));
}
parts.join(" ")
}
fn parse_missing_role_oids(stderr: &str) -> Vec<i64> {
let mut oids: Vec<i64> = Vec::new();
for raw in stderr.lines() {
let line: &str = raw.trim();
if !(line.contains("role with OID") && line.contains("does not exist")) {
continue;
}
let maybe_oid: Option<i64> = line
.split("role with OID")
.nth(1)
.and_then(|rest| rest.split("does not exist").next())
.map(str::trim)
.and_then(|v| v.parse::<i64>().ok());
if let Some(oid) = maybe_oid
&& !oids.contains(&oid)
{
oids.push(oid);
}
}
oids
}
async fn recover_missing_role_oids(pg_uri: &str, missing_oids: &[i64]) -> Result<(), String> {
if missing_oids.is_empty() {
return Ok(());
}
let mut conn: sqlx::PgConnection = sqlx::PgConnection::connect(pg_uri)
.await
.map_err(|e| format!("Recovery connection failed: {e}"))?;
let target_owner_oid: i64 =
sqlx::query_scalar("SELECT oid::bigint FROM pg_roles WHERE rolname = current_user")
.fetch_one(&mut conn)
.await
.map_err(|e| format!("Recovery could not resolve current_user role OID: {e}"))?;
for missing_oid in missing_oids {
let statements: [&str; 15] = [
"UPDATE pg_catalog.pg_class SET relowner = $1::oid WHERE relowner = $2::oid",
"UPDATE pg_catalog.pg_namespace SET nspowner = $1::oid WHERE nspowner = $2::oid",
"UPDATE pg_catalog.pg_proc SET proowner = $1::oid WHERE proowner = $2::oid",
"UPDATE pg_catalog.pg_type SET typowner = $1::oid WHERE typowner = $2::oid",
"UPDATE pg_catalog.pg_database SET datdba = $1::oid WHERE datname = current_database() AND datdba = $2::oid",
"UPDATE pg_catalog.pg_foreign_data_wrapper SET fdwowner = $1::oid WHERE fdwowner = $2::oid",
"UPDATE pg_catalog.pg_foreign_server SET srvowner = $1::oid WHERE srvowner = $2::oid",
"UPDATE pg_catalog.pg_extension SET extowner = $1::oid WHERE extowner = $2::oid",
"UPDATE pg_catalog.pg_event_trigger SET evtowner = $1::oid WHERE evtowner = $2::oid",
"UPDATE pg_catalog.pg_publication SET pubowner = $1::oid WHERE pubowner = $2::oid",
"UPDATE pg_catalog.pg_subscription SET subowner = $1::oid WHERE subowner = $2::oid",
"UPDATE pg_catalog.pg_conversion SET conowner = $1::oid WHERE conowner = $2::oid",
"UPDATE pg_catalog.pg_operator SET oprowner = $1::oid WHERE oprowner = $2::oid",
"UPDATE pg_catalog.pg_opclass SET opcowner = $1::oid WHERE opcowner = $2::oid",
"UPDATE pg_catalog.pg_opfamily SET opfowner = $1::oid WHERE opfowner = $2::oid",
];
for statement in statements {
if let Err(err) = sqlx::query(statement)
.bind(target_owner_oid)
.bind(*missing_oid)
.execute(&mut conn)
.await
{
let mut ignore = false;
if let sqlx::Error::Database(db_err) = &err
&& let Some(code) = db_err.code()
{
if code == "42P01" || code == "42703" || code == "42501" {
ignore = true;
}
}
if !ignore {
return Err(format!(
"Recovery failed while repairing role OID {} with statement '{}': {}",
missing_oid, statement, err
));
}
}
}
let _ = sqlx::query(
"UPDATE pg_catalog.pg_policy SET polroles = array_remove(polroles, $1::oid) WHERE $1::oid = ANY(polroles)",
)
.bind(*missing_oid)
.execute(&mut conn)
.await;
}
Ok(())
}
fn header_path_hint(req: &HttpRequest, header_name: &str) -> Option<String> {
req.headers()
.get(header_name)
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
}
async fn resolve_pg_tools_with_overrides(
server_major: Option<u32>,
dump_override: Option<&str>,
restore_override: Option<&str>,
) -> Result<PgToolsPaths, String> {
let mut pg_tools: PgToolsPaths = if dump_override.is_none() && restore_override.is_none() {
if let Some(major) = server_major {
if let Some(paths) = resolve_pg_tools_from_dir(major) {
paths
} else {
ensure_pg_tools()
.await
.map_err(|e| format!("PostgreSQL tools resolution failed: {e}"))?
}
} else {
ensure_pg_tools()
.await
.map_err(|e| format!("PostgreSQL tools resolution failed: {e}"))?
}
} else {
ensure_pg_tools()
.await
.map_err(|e| format!("PostgreSQL tools resolution failed: {e}"))?
};
if let Some(path) = dump_override {
let p = PathBuf::from(path);
if !p.is_file() {
return Err(format!(
"Configured pg_dump path does not exist or is not a file: {}",
p.display()
));
}
pg_tools.pg_dump = p;
}
if let Some(path) = restore_override {
let p = PathBuf::from(path);
if !p.is_file() {
return Err(format!(
"Configured pg_restore path does not exist or is not a file: {}",
p.display()
));
}
pg_tools.pg_restore = p;
}
Ok(pg_tools)
}
async fn postgres_server_major(pg_uri: &str) -> Result<u32, String> {
let primary = pg_uri_for_backup_tools(pg_uri);
let mut conn = match sqlx::PgConnection::connect(&primary).await {
Ok(c) => c,
Err(first_err) => {
if postgres_uri_target_is_local_pg_tools(pg_uri) {
return Err(format!(
"Failed to connect to Postgres to detect version: {first_err}"
));
}
if !sslmode_is_present(pg_uri) {
let retry = with_sslmode_require(pg_uri);
if let Ok(c) = sqlx::PgConnection::connect(&retry).await {
sslmode_cache_mark_require(pg_uri);
c
} else {
return Err(format!(
"Failed to connect to Postgres to detect version: {first_err}"
));
}
} else {
return Err(format!(
"Failed to connect to Postgres to detect version: {first_err}"
));
}
}
};
let version_num: String = sqlx::query_scalar("SELECT current_setting('server_version_num')")
.fetch_one(&mut conn)
.await
.map_err(|e| format!("Failed to read Postgres server version: {e}"))?;
let parsed: u32 = version_num
.trim()
.parse()
.map_err(|_| format!("Invalid server_version_num returned by Postgres: '{version_num}'"))?;
Ok(parsed / 10_000)
}
fn estimate_pg_dump_progress_pct(
current_dump_bytes: i64,
latest_s3_backup_size_bytes: Option<i64>,
) -> i32 {
let current = current_dump_bytes.max(0) as f64;
let span = (PG_DUMP_PROGRESS_MAX_PCT - PG_DUMP_PROGRESS_MIN_PCT) as f64;
let normalized = match latest_s3_backup_size_bytes.filter(|size| *size > 0) {
Some(reference_size) => {
let expected_upper = (reference_size as f64 * 1.2).max(1.0);
if current <= expected_upper {
0.92 * (current / expected_upper).powf(0.6)
} else {
let overflow = (current - expected_upper) / expected_upper;
0.92 + (1.0 - (-2.0 * overflow).exp()) * 0.08
}
}
None => {
let fallback_scale_bytes = 256.0 * 1024.0 * 1024.0;
0.9 * (1.0 - (-current / fallback_scale_bytes).exp())
}
};
let bounded = normalized.clamp(0.0, 1.0);
(PG_DUMP_PROGRESS_MIN_PCT as f64 + bounded * span).round() as i32
}
fn path_size_bytes_blocking(path: &FsPath) -> i64 {
if !path.exists() {
return 0;
}
let mut total: u128 = 0;
let mut stack = vec![path.to_path_buf()];
while let Some(current) = stack.pop() {
let metadata = match std::fs::symlink_metadata(¤t) {
Ok(md) => md,
Err(_) => continue,
};
if metadata.file_type().is_symlink() {
continue;
}
if metadata.is_file() {
total = total.saturating_add(metadata.len() as u128);
continue;
}
if metadata.is_dir() {
let entries = match std::fs::read_dir(¤t) {
Ok(entries) => entries,
Err(_) => continue,
};
for entry in entries.flatten() {
stack.push(entry.path());
}
}
}
total.min(i64::MAX as u128) as i64
}
async fn path_size_bytes(path: &FsPath) -> i64 {
let path = path.to_path_buf();
tokio::task::spawn_blocking(move || path_size_bytes_blocking(&path))
.await
.unwrap_or(0)
}
#[derive(Clone)]
struct PgDumpProgressTracker {
job: JobLogger,
latest_s3_size_bytes: Option<i64>,
last_reported_pct: std::sync::Arc<Mutex<i32>>,
}
impl PgDumpProgressTracker {
fn new(job: JobLogger, latest_s3_size_bytes: Option<i64>) -> Self {
Self {
job,
latest_s3_size_bytes,
last_reported_pct: std::sync::Arc::new(Mutex::new(PG_DUMP_PROGRESS_MIN_PCT)),
}
}
async fn report(&self, dump_path: &FsPath) {
let current_size = path_size_bytes(dump_path).await;
let estimated = estimate_pg_dump_progress_pct(current_size, self.latest_s3_size_bytes);
let should_emit = {
let mut g = self.last_reported_pct.lock().unwrap();
if estimated > *g {
*g = estimated;
true
} else {
false
}
};
if should_emit {
self.job
.progress(
None,
Some("pg_dump"),
Some(estimated),
None,
None,
Some(current_size),
None,
None,
)
.await;
}
}
}
async fn run_pg_dump(
pg_uri: &str,
pg_dump_override: Option<&str>,
pg_restore_override: Option<&str>,
recovery_strategy: BackupRecoveryStrategy,
cancel: Option<&CancellationToken>,
pg_dump_progress: Option<&PgDumpProgressTracker>,
) -> Result<PathBuf, String> {
let tmp_root: PathBuf = env::temp_dir().join(format!("athena_backup_{}", Uuid::new_v4()));
let dump_dir: PathBuf = tmp_root.join("dump");
let archive_path: PathBuf = tmp_root.join("backup.tar.gz");
tracing::info!("Ensuring PostgreSQL tools (pg_dump, etc.) are available.");
let effective_pg_uri = pg_uri_for_backup_tools(pg_uri);
let server_major = postgres_server_major(&effective_pg_uri).await.ok();
let pg_tools: PgToolsPaths =
resolve_pg_tools_with_overrides(server_major, pg_dump_override, pg_restore_override)
.await
.map_err(|e| format!("pg_dump resolution failed: {e}"))?;
tracing::info!("Creating dump directory at {:?}", &dump_dir);
tokio::fs::create_dir_all(&dump_dir)
.await
.map_err(|e| format!("Could not create temp directory: {e}"))?;
let (pg_uri_safe, pg_password) = extract_pg_password(&effective_pg_uri);
tracing::info!(
"Invoking pg_dump for dump_dir={:?}, sanitized_uri={:?}",
&dump_dir,
&pg_uri_safe
);
let mut cmd: Command = Command::new(&pg_tools.pg_dump);
let pg_password_for_first_dump = pg_password.clone();
if let Some(pass) = pg_password_for_first_dump {
cmd.env("PGPASSWORD", pass);
}
cmd.args(["--format=directory", "--no-owner", "--no-acl", "--file"])
.arg(&dump_dir)
.arg(&pg_uri_safe)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
let output = match command_output_cancellable(
&mut cmd,
cancel,
pg_dump_progress.map(|tracker| (tracker, dump_dir.as_path())),
)
.await
{
Ok(o) => o,
Err(e) => {
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(e);
}
};
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let code = output.status.code().unwrap_or(-1);
tracing::info!("pg_dump failed with exit code {}", code);
if !stderr.is_empty() {
tracing::info!("pg_dump stderr: {}", stderr.trim());
}
if !stdout.is_empty() {
tracing::info!("pg_dump stdout: {}", stdout.trim());
}
if recovery_strategy == BackupRecoveryStrategy::RepairMissingRoleOids {
let missing_oids = parse_missing_role_oids(&stderr);
if !missing_oids.is_empty() {
tracing::warn!(
"pg_dump detected missing role OIDs {:?}; attempting autonomous recovery",
missing_oids
);
match recover_missing_role_oids(&effective_pg_uri, &missing_oids).await {
Ok(_) => {
tracing::info!(
"Autonomous recovery completed; retrying pg_dump after role OID repair"
);
let mut retry_after_recovery: Command = Command::new(&pg_tools.pg_dump);
if let Some(pass) = pg_password.as_ref() {
retry_after_recovery.env("PGPASSWORD", pass);
}
retry_after_recovery
.args(["--format=directory", "--no-owner", "--no-acl", "--file"])
.arg(&dump_dir)
.arg(&pg_uri_safe)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
let retry_output = match command_output_cancellable(
&mut retry_after_recovery,
cancel,
pg_dump_progress.map(|tracker| (tracker, dump_dir.as_path())),
)
.await
{
Ok(o) => o,
Err(e) => {
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(e);
}
};
if retry_output.status.success() {
tracing::info!("pg_dump retry succeeded after autonomous recovery");
} else {
let retry_stderr =
String::from_utf8_lossy(&retry_output.stderr).to_string();
let retry_stdout =
String::from_utf8_lossy(&retry_output.stdout).to_string();
let retry_code = retry_output.status.code().unwrap_or(-1);
tracing::info!(
"pg_dump retry after recovery failed with exit code {}",
retry_code
);
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(summarize_pg_dump_failure(
&retry_stderr,
&retry_stdout,
retry_code,
));
}
}
Err(recovery_err) => {
tracing::warn!(
"Autonomous recovery failed for missing role OID scenario: {}",
recovery_err
);
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(format!(
"{} Recovery strategy error: {}",
summarize_pg_dump_failure(&stderr, &stdout, code),
recovery_err
));
}
}
tracing::info!("pg_dump recovered via recovery_strategy; continuing backup flow");
} else {
tracing::info!(
"recovery_strategy=repair_missing_role_oids enabled, but no missing role OID error detected"
);
}
}
if recovery_strategy == BackupRecoveryStrategy::RepairMissingRoleOids
&& !parse_missing_role_oids(&stderr).is_empty()
{
} else {
if !postgres_uri_target_is_local_pg_tools(pg_uri)
&& !sslmode_is_present(pg_uri)
&& !sslmode_is_present(&effective_pg_uri)
&& looks_like_ssl_required_error(&stderr)
{
tracing::warn!("pg_dump failed; retrying once with sslmode=require");
sslmode_cache_mark_require(pg_uri);
let retry_uri = with_sslmode_require(pg_uri);
let (retry_safe, retry_pw) = extract_pg_password(&retry_uri);
let mut retry_cmd: Command = Command::new(&pg_tools.pg_dump);
if let Some(pass) = retry_pw {
retry_cmd.env("PGPASSWORD", pass);
}
retry_cmd
.args(["--format=directory", "--no-owner", "--no-acl", "--file"])
.arg(&dump_dir)
.arg(&retry_safe)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
let retry_out = match command_output_cancellable(
&mut retry_cmd,
cancel,
pg_dump_progress.map(|tracker| (tracker, dump_dir.as_path())),
)
.await
{
Ok(o) => o,
Err(e) => {
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(e);
}
};
if retry_out.status.success() {
tracing::info!("pg_dump retry succeeded with sslmode=require");
} else {
let retry_stderr: std::borrow::Cow<'_, str> =
String::from_utf8_lossy(&retry_out.stderr);
let retry_stdout: std::borrow::Cow<'_, str> =
String::from_utf8_lossy(&retry_out.stdout);
let retry_code = retry_out.status.code().unwrap_or(-1);
tracing::info!("pg_dump retry failed with exit code {}", retry_code);
if !retry_stderr.is_empty() {
tracing::info!("pg_dump retry stderr: {}", retry_stderr.trim());
}
if !retry_stdout.is_empty() {
tracing::info!("pg_dump retry stdout: {}", retry_stdout.trim());
}
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(summarize_pg_dump_failure(
&retry_stderr,
&retry_stdout,
retry_code,
));
}
} else {
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
let detail = if stderr.is_empty() && stdout.is_empty() {
format!(
"pg_dump exited with code {}. No output from pg_dump. Ensure pg_dump major version matches the Postgres server (e.g. pg_dump 17 for server 17). Check server logs for more.",
code
)
} else {
summarize_pg_dump_failure(&stderr, &stdout, code)
};
return Err(detail);
}
}
}
if cancel.is_some_and(|c| c.is_cancelled()) {
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(ERR_BACKUP_CANCELLED.to_string());
}
tracing::info!("pg_dump completed successfully, archiving result.");
let dump_dir_clone: PathBuf = dump_dir.clone();
let archive_path_clone: PathBuf = archive_path.clone();
tokio::task::spawn_blocking(move || -> Result<(), String> {
let file = std::fs::File::create(&archive_path_clone)
.map_err(|e| format!("Cannot create archive: {e}"))?;
let enc = flate2::write::GzEncoder::new(file, flate2::Compression::default());
let mut builder = tar::Builder::new(enc);
builder
.append_dir_all("dump", &dump_dir_clone)
.map_err(|e| format!("Cannot archive dump directory: {e}"))?;
builder
.finish()
.map_err(|e| format!("Cannot finalize archive: {e}"))?;
Ok(())
})
.await
.map_err(|e| format!("Archive task panicked: {e}"))??;
tracing::info!(
"pg_dump directory archived, removing uncompressed dump dir {:?}",
dump_dir
);
let _ = tokio::fs::remove_dir_all(&dump_dir).await;
tracing::info!("Backup archive produced at {:?}", archive_path);
Ok(archive_path)
}
#[derive(Clone)]
struct JobLogger {
pool: PgPool,
job_id: i64,
}
impl JobLogger {
async fn progress(
&self,
status: Option<&str>,
stage: Option<&str>,
progress_pct: Option<i32>,
s3_bucket: Option<&str>,
s3_key: Option<&str>,
size_bytes: Option<i64>,
error_message: Option<&str>,
completed_at: Option<chrono::DateTime<Utc>>,
) {
if let Err(err) = update_job_progress(
&self.pool,
self.job_id,
status,
stage,
progress_pct,
s3_bucket,
s3_key,
size_bytes,
error_message,
completed_at,
)
.await
{
tracing::warn!(job_id = self.job_id, error = %err, "Failed to update backup job progress");
}
}
async fn log(&self, level: &str, message: &str) {
if let Err(err) =
sqlx::query("INSERT INTO backup_job_logs (job_id, level, message) VALUES ($1, $2, $3)")
.bind(self.job_id)
.bind(level)
.bind(message)
.execute(&self.pool)
.await
{
tracing::warn!(job_id = self.job_id, error = %err, "Failed to insert backup job log");
}
}
}
async fn logging_pool(state: &AppState) -> Result<PgPool, HttpResponse> {
let Some(client_name) = state.logging_client_name.as_ref() else {
return Err(service_unavailable(
"Logging unavailable",
"No athena_logging client is configured.",
));
};
state.pg_registry.get_pool(client_name).ok_or_else(|| {
service_unavailable(
"Logging unavailable",
format!("Logging client '{}' is not connected.", client_name),
)
})
}
async fn create_backup_job(
pool: &PgPool,
job_type: &str,
client_name: &str,
initial_stage: &str,
label: Option<&str>,
) -> Result<i64, sqlx::Error> {
let row: sqlx::postgres::PgRow = sqlx::query(
r#"
INSERT INTO backup_jobs (job_type, client_name, status, progress_stage, label, started_at, updated_at)
VALUES ($1, $2, 'running', $3, $4, now(), now())
RETURNING id
"#,
)
.bind(job_type)
.bind(client_name)
.bind(initial_stage)
.bind(label)
.fetch_one(pool)
.await?;
Ok(row.get::<i64, _>("id"))
}
async fn update_job_progress(
pool: &PgPool,
job_id: i64,
status: Option<&str>,
stage: Option<&str>,
progress_pct: Option<i32>,
s3_bucket: Option<&str>,
s3_key: Option<&str>,
size_bytes: Option<i64>,
error_message: Option<&str>,
completed_at: Option<chrono::DateTime<Utc>>,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
UPDATE backup_jobs
SET status = COALESCE($2, status),
progress_stage = COALESCE($3, progress_stage),
progress_pct = COALESCE($4, progress_pct),
s3_bucket = COALESCE($5, s3_bucket),
s3_key = COALESCE($6, s3_key),
size_bytes = COALESCE($7, size_bytes),
error_message = COALESCE($8, error_message),
completed_at = COALESCE($9, completed_at),
updated_at = now()
WHERE id = $1
"#,
)
.bind(job_id)
.bind(status)
.bind(stage)
.bind(progress_pct)
.bind(s3_bucket)
.bind(s3_key)
.bind(size_bytes)
.bind(error_message)
.bind(completed_at)
.execute(pool)
.await?;
Ok(())
}
async fn run_pg_restore(
s3_client: &S3Client,
bucket: &str,
key: &str,
pg_uri: &str,
pg_dump_override: Option<&str>,
pg_restore_override: Option<&str>,
job: Option<JobLogger>,
cancel: Option<&CancellationToken>,
) -> Result<(), String> {
tracing::info!("Starting pg_restore from S3 bucket={} key={}", bucket, key);
if let Some(logger) = &job {
logger
.progress(
Some("running"),
Some("downloading"),
Some(10),
Some(bucket),
Some(key),
None,
None,
None,
)
.await;
logger
.log(
"info",
"Starting restore: downloading backup from object storage",
)
.await;
}
let bytes: web::Bytes =
download_s3_object_with_retry(s3_client, bucket, key, S3_DOWNLOAD_MAX_ATTEMPTS).await?;
let size_bytes: i64 = bytes.len() as i64;
if let Some(logger) = &job {
logger
.progress(
None,
Some("writing"),
Some(25),
Some(bucket),
Some(key),
Some(size_bytes),
None,
None,
)
.await;
}
let tmp_root: PathBuf = env::temp_dir().join(format!("athena_restore_{}", Uuid::new_v4()));
tracing::info!("Creating temp directory for restore: {:?}", &tmp_root);
tokio::fs::create_dir_all(&tmp_root)
.await
.map_err(|e| format!("Could not create temp dir: {e}"))?;
let archive_path: PathBuf = tmp_root.join("backup.tar.gz");
let restore_dir: PathBuf = tmp_root.join("dump");
tracing::info!("Writing downloaded archive to {:?}", &archive_path);
tokio::fs::write(&archive_path, &bytes)
.await
.map_err(|e| format!("Could not write archive: {e}"))?;
if let Some(logger) = &job {
logger
.progress(
None,
Some("extracting"),
Some(50),
Some(bucket),
Some(key),
Some(size_bytes),
None,
None,
)
.await;
logger
.log("info", "Downloaded backup, extracting archive")
.await;
}
tracing::info!("Extracting backup archive for restore.");
let archive_path_clone: PathBuf = archive_path.clone();
let restore_dir_clone: PathBuf = restore_dir.clone();
tokio::task::spawn_blocking(move || -> Result<(), String> {
let file = std::fs::File::open(&archive_path_clone)
.map_err(|e| format!("Cannot open archive: {e}"))?;
let dec = flate2::read::GzDecoder::new(file);
let mut archive: tar::Archive<flate2::read::GzDecoder<std::fs::File>> =
tar::Archive::new(dec);
archive
.unpack(&restore_dir_clone)
.map_err(|e| format!("Cannot extract archive: {e}"))?;
Ok(())
})
.await
.map_err(|e| format!("Extract task panicked: {e}"))??;
let dump_dir = resolve_restore_dump_dir(&restore_dir).await?;
tracing::info!("Restore will use backup contents at {:?}", dump_dir);
let effective_pg_uri = pg_uri_for_backup_tools(pg_uri);
let (pg_uri_safe, pg_password) = extract_pg_password(&effective_pg_uri);
let server_major = postgres_server_major(&effective_pg_uri).await.ok();
let pg_tools: PgToolsPaths =
resolve_pg_tools_with_overrides(server_major, pg_dump_override, pg_restore_override)
.await
.map_err(|e| {
tracing::info!("pg_restore tool resolution failed: {e}");
format!("pg_restore resolution failed: {e}")
})?;
tracing::info!(
"Launching pg_restore for database restore, using pg_restore at {:?}",
&pg_tools.pg_restore
);
if let Some(logger) = &job {
logger
.progress(
None,
Some("pg_restore"),
Some(80),
Some(bucket),
Some(key),
Some(size_bytes),
None,
None,
)
.await;
logger
.log("info", "Running pg_restore against target database")
.await;
}
let mut cmd: Command = Command::new(&pg_tools.pg_restore);
if let Some(pass) = pg_password {
cmd.env("PGPASSWORD", pass);
}
const PG_RESTORE_NOT_FOUND: &str =
"pg_restore binary not found in PATH — ensure PostgreSQL client tools are installed";
cmd.args(["--format=directory", "--clean", "--if-exists", "--dbname"])
.arg(&pg_uri_safe)
.arg(&dump_dir);
let status = match command_status_cancellable(&mut cmd, cancel, PG_RESTORE_NOT_FOUND).await {
Ok(s) => s,
Err(e) => {
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(e);
}
};
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
tracing::info!("Cleanup: removed temp restore directory {:?}", &tmp_root);
if !status.success() {
tracing::info!("pg_restore finished with error status: {:?}", status);
if !postgres_uri_target_is_local_pg_tools(pg_uri)
&& !sslmode_is_present(pg_uri)
&& !sslmode_is_present(&effective_pg_uri)
{
tracing::warn!("pg_restore failed; retrying once with sslmode=require");
sslmode_cache_mark_require(pg_uri);
let retry_uri = with_sslmode_require(pg_uri);
let (retry_safe, retry_pw) = extract_pg_password(&retry_uri);
let mut retry_cmd: Command = Command::new(&pg_tools.pg_restore);
if let Some(pass) = retry_pw {
retry_cmd.env("PGPASSWORD", pass);
}
retry_cmd
.args(["--format=directory", "--clean", "--if-exists", "--dbname"])
.arg(&retry_safe)
.arg(&dump_dir);
let retry_status = match command_status_cancellable(
&mut retry_cmd,
cancel,
PG_RESTORE_NOT_FOUND,
)
.await
{
Ok(s) => s,
Err(e) => {
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(e);
}
};
if retry_status.success() {
tracing::info!("pg_restore retry succeeded with sslmode=require");
return Ok(());
}
}
return Err(format!("pg_restore exited with status {status}"));
}
tracing::info!("pg_restore completed successfully!");
Ok(())
}
async fn resolve_restore_dump_dir(restore_dir: &FsPath) -> Result<PathBuf, String> {
let root_toc = restore_dir.join("toc.dat");
if tokio::fs::metadata(&root_toc).await.is_ok() {
return Ok(restore_dir.to_path_buf());
}
let nested_dump = restore_dir.join("dump");
if tokio::fs::metadata(nested_dump.join("toc.dat"))
.await
.is_ok()
{
return Ok(nested_dump);
}
let mut read_dir = tokio::fs::read_dir(restore_dir)
.await
.map_err(|e| format!("Cannot inspect extracted restore directory: {e}"))?;
while let Some(entry) = read_dir
.next_entry()
.await
.map_err(|e| format!("Cannot inspect extracted restore directory entries: {e}"))?
{
let file_type = entry
.file_type()
.await
.map_err(|e| format!("Cannot inspect extracted restore directory entry type: {e}"))?;
if !file_type.is_dir() {
continue;
}
let candidate = entry.path();
if tokio::fs::metadata(candidate.join("toc.dat")).await.is_ok() {
return Ok(candidate);
}
}
Err(format!(
"Could not locate pg_dump directory format in {:?}; expected toc.dat in root, dump/, or a direct child directory",
restore_dir
))
}
async fn upload_to_s3(
s3_client: &S3Client,
cfg: &S3Config,
local_path: &PathBuf,
client_name: &str,
label: Option<&str>,
) -> Result<(String, i64), String> {
let backup_id: String = Uuid::new_v4().to_string();
let key: String = format!("{}/{}/{}.tar.gz", cfg.prefix, client_name, backup_id);
tracing::info!("Reading local backup archive from {:?}", local_path);
let data: Vec<u8> = tokio::fs::read(local_path)
.await
.map_err(|e| format!("Cannot read archive file: {e}"))?;
let size_bytes: i64 = data.len() as i64;
tracing::info!(
"Putting object to S3: bucket='{}', key='{}', client_name='{}', label={:?}",
&cfg.bucket,
&key,
client_name,
label
);
let mut req: aws_sdk_s3::operation::put_object::builders::PutObjectFluentBuilder = s3_client
.put_object()
.bucket(&cfg.bucket)
.key(&key)
.body(data.into())
.content_type("application/gzip")
.metadata("client_name", client_name)
.metadata("backup_id", &backup_id)
.metadata("created_at", Utc::now().to_rfc3339());
if let Some(lbl) = label {
req = req.metadata("label", lbl);
}
req.send().await.map_err(|e| {
tracing::info!("S3 upload to key '{}' failed: {e}", key, e = &e);
format!("S3 upload failed: {e}")
})?;
tracing::info!("S3 backup upload complete for key {}", key);
Ok((key, size_bytes))
}
#[post("/admin/backups")]
pub async fn admin_create_backup(
req: HttpRequest,
state: Data<AppState>,
body: Json<CreateBackupRequest>,
) -> HttpResponse {
tracing::info!(
"{}admin_create_backup{} called. client_name: {:?}, pg_uri_provided: {}, label: {:?}, recovery_strategy: {:?}",
B,
Z,
body.client_name,
body.pg_uri.is_some(),
body.label,
body.recovery_strategy
);
if let Err(resp) = authorize_static_admin_key(&req) {
tracing::info!("{}Authorization failed{} for admin_create_backup", R, Z);
return resp;
}
let logging_pool = match logging_pool(&state).await {
Ok(pool) => pool,
Err(resp) => return resp,
};
let Some(s3_cfg) = S3Config::from_env() else {
tracing::info!("S3Config could not be constructed from env for admin_create_backup");
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let (pg_uri, effective_client_name) = if let Some(uri) = &body.pg_uri {
(
uri.clone(),
body.client_name
.clone()
.unwrap_or_else(|| "custom".to_string()),
)
} else if let Some(cn) = &body.client_name {
match resolve_pg_uri(&state, cn) {
Ok(uri) => (uri, cn.clone()),
Err(resp) => {
tracing::info!("Could not resolve pg_uri for client_name='{}'", cn);
return resp;
}
}
} else {
return bad_request("Missing target", "Provide either client_name or pg_uri.");
};
let timeout_secs = clamp_timeout(body.timeout_seconds.unwrap_or(3600));
let recovery_strategy = body
.recovery_strategy
.unwrap_or(BackupRecoveryStrategy::None);
let pg_dump_override = header_path_hint(&req, HEADER_PG_DUMP_PATH);
let pg_restore_override = header_path_hint(&req, HEADER_PG_RESTORE_PATH);
let job_id = match create_backup_job(
&logging_pool,
"backup",
&effective_client_name,
"pg_dump",
body.label.as_deref(),
)
.await
{
Ok(id) => id,
Err(err) => {
tracing::warn!("{}Failed to create backup_job row{}: {}", Y, Z, err);
return internal_error("Logging unavailable", "Could not create backup job record");
}
};
let job_logger = JobLogger {
pool: logging_pool.clone(),
job_id,
};
let s3_client = build_s3_client(&s3_cfg).await;
let latest_backup_size =
latest_client_backup_size_bytes(&s3_client, &s3_cfg, &effective_client_name).await;
if let Some(size) = latest_backup_size {
job_logger
.log(
"info",
&format!("Using latest S3 backup size ({size} bytes) as pg_dump progress baseline"),
)
.await;
} else {
job_logger
.log(
"info",
"No prior S3 backup size baseline found; using fallback pg_dump estimator",
)
.await;
}
let pg_dump_progress = PgDumpProgressTracker::new(job_logger.clone(), latest_backup_size);
let cancel = CancellationToken::new();
register_backup_cancel_token(job_id, cancel.clone());
let _cancel_guard = BackupJobCancelGuard(job_id);
job_logger
.log("info", "Starting backup job and running pg_dump")
.await;
job_logger
.progress(
None,
Some("pg_dump"),
Some(PG_DUMP_PROGRESS_MIN_PCT),
None,
None,
None,
None,
None,
)
.await;
tracing::info!(
"Running pg_dump for client_name='{}'",
effective_client_name
);
let archive_path = match tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs as u64),
run_pg_dump(
&pg_uri,
pg_dump_override.as_deref(),
pg_restore_override.as_deref(),
recovery_strategy,
Some(&cancel),
Some(&pg_dump_progress),
),
)
.await
{
Err(_) => {
let msg = format!("pg_dump timed out after {}s", timeout_secs);
tracing::info!("{}", msg);
job_logger
.progress(
Some("failed"),
Some("pg_dump"),
Some(0),
None,
None,
None,
Some(&msg),
Some(Utc::now()),
)
.await;
job_logger.log("error", &msg).await;
return internal_error("pg_dump timed out", msg);
}
Ok(result) => match result {
Ok(p) => {
tracing::info!("pg_dump completed. Archive at {:?}", p);
job_logger
.progress(
None,
Some("archiving"),
Some(BACKUP_PROGRESS_ARCHIVING_PCT),
None,
None,
None,
None,
None,
)
.await;
job_logger
.log("info", "pg_dump completed, archiving dump")
.await;
p
}
Err(err) if err == ERR_BACKUP_CANCELLED => {
job_logger
.progress(
Some("cancelled"),
Some("pg_dump"),
Some(0),
None,
None,
None,
Some("Cancelled by operator"),
Some(Utc::now()),
)
.await;
job_logger.log("info", "Backup cancelled").await;
return api_success(
"Backup cancelled",
json!({ "job_id": job_id, "status": "cancelled" }),
);
}
Err(err) => {
tracing::info!("pg_dump failed: {}", err);
job_logger
.progress(
Some("failed"),
Some("pg_dump"),
Some(0),
None,
None,
None,
Some(&err),
Some(Utc::now()),
)
.await;
job_logger
.log("error", &format!("pg_dump failed: {}", err))
.await;
return internal_error("pg_dump failed", err);
}
},
};
if is_backup_job_cancelled(&logging_pool, job_id).await {
if let Some(parent) = archive_path.parent() {
let _ = tokio::fs::remove_dir_all(parent).await;
}
job_logger
.progress(
Some("cancelled"),
Some("cancelled"),
Some(0),
None,
None,
None,
Some("Cancelled before upload"),
Some(Utc::now()),
)
.await;
job_logger
.log("info", "Backup cancelled before upload")
.await;
return api_success(
"Backup cancelled",
json!({ "job_id": job_id, "status": "cancelled" }),
);
}
tracing::info!("Uploading archive to S3...");
job_logger
.progress(
None,
Some("uploading"),
Some(BACKUP_PROGRESS_UPLOADING_PCT),
Some(&s3_cfg.bucket),
None,
None,
None,
None,
)
.await;
let (key, size_bytes) = match upload_to_s3(
&s3_client,
&s3_cfg,
&archive_path,
&effective_client_name,
body.label.as_deref(),
)
.await
{
Ok((key, size_bytes)) => {
tracing::info!("S3 upload succeeded. backup key: {}", key);
job_logger
.progress(
None,
Some("uploading"),
Some(BACKUP_PROGRESS_UPLOAD_STORED_PCT),
Some(&s3_cfg.bucket),
Some(&key),
Some(size_bytes),
None,
None,
)
.await;
(key, size_bytes)
}
Err(err) => {
tracing::info!("S3 upload failed: {}", err);
let _ = tokio::fs::remove_file(&archive_path).await;
job_logger
.progress(
Some("failed"),
Some("uploading"),
Some(BACKUP_PROGRESS_UPLOADING_PCT),
Some(&s3_cfg.bucket),
None,
None,
Some(&err),
Some(Utc::now()),
)
.await;
job_logger
.log("error", &format!("S3 upload failed: {}", err))
.await;
return internal_error("S3 upload failed", err);
}
};
if let Some(parent) = archive_path.parent() {
tracing::info!("Cleaning up archive directory: {:?}", parent);
let _ = tokio::fs::remove_dir_all(parent).await;
}
tracing::info!("admin_create_backup successful for key {}", key);
job_logger
.progress(
Some("completed"),
Some("completed"),
Some(100),
Some(&s3_cfg.bucket),
Some(&key),
Some(size_bytes),
None,
Some(Utc::now()),
)
.await;
job_logger
.log("info", "Backup completed and stored in object storage")
.await;
api_success(
"Backup created",
json!({
"job_id": job_id,
"key": key,
"client_name": effective_client_name,
"label": body.label,
}),
)
}
#[get("/admin/backups/config")]
pub async fn admin_backup_config(req: HttpRequest) -> HttpResponse {
tracing::info!("admin_backup_config called");
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let Some(cfg) = S3Config::from_env() else {
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
api_success(
"Backup storage configuration",
json!({
"bucket": cfg.bucket,
"region": cfg.region,
"prefix": cfg.prefix,
"endpoint": cfg.endpoint,
}),
)
}
#[get("/admin/backups")]
pub async fn admin_list_backups(
req: HttpRequest,
_state: Data<AppState>,
query: web::Query<std::collections::HashMap<String, String>>,
) -> HttpResponse {
tracing::info!(
"{}admin_list_backups{} called with query: {:?}",
B,
Z,
query
);
if let Err(resp) = authorize_static_admin_key(&req) {
tracing::info!("Authorization failed for admin_list_backups");
return resp;
}
let Some(s3_cfg) = S3Config::from_env() else {
tracing::info!("S3Config could not be constructed from env for admin_list_backups");
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let s3_client = build_s3_client(&s3_cfg).await;
let filter_client = query.get("client_name").cloned();
let prefix = match &filter_client {
Some(cn) => format!("{}/{}/", s3_cfg.prefix, cn),
None => format!("{}/", s3_cfg.prefix),
};
tracing::info!("Listing S3 backups with prefix: {}", prefix);
let resp = match s3_client
.list_objects_v2()
.bucket(&s3_cfg.bucket)
.prefix(&prefix)
.send()
.await
{
Ok(r) => r,
Err(err) => {
tracing::info!("Failed to list S3 objects: {}", err);
return internal_error("Failed to list S3 objects", err.to_string());
}
};
let mut backups: Vec<BackupObject> = Vec::new();
for obj in resp.contents() {
let key = obj.key().unwrap_or_default().to_string();
let parts: Vec<&str> = key.split('/').collect();
let (client_name, id) = if parts.len() >= 3 {
let cn = parts[parts.len() - 2].to_string();
let id = parts
.last()
.and_then(|s| s.strip_suffix(".tar.gz"))
.unwrap_or_else(|| parts.last().copied().unwrap_or(""))
.to_string();
(cn, id)
} else {
tracing::warn!(key = %key, "S3 backup key does not match expected format <prefix>/<client_name>/<id>.tar.gz");
(
"unknown".to_string(),
parts
.last()
.and_then(|s| s.strip_suffix(".tar.gz"))
.unwrap_or_else(|| parts.last().copied().unwrap_or(&key))
.to_string(),
)
};
tracing::info!("Fetching S3 object metadata for label for key {}", key);
let label = match s3_client
.head_object()
.bucket(&s3_cfg.bucket)
.key(&key)
.send()
.await
{
Ok(head) => head.metadata().and_then(|m| m.get("label")).cloned(),
Err(_) => None,
};
let size_bytes = obj.size().unwrap_or(0);
let created_at = obj
.last_modified()
.map(|t| t.to_string())
.unwrap_or_default();
backups.push(BackupObject {
id,
key,
client_name,
label,
size_bytes,
created_at,
});
}
backups.sort_by(|a, b| b.created_at.cmp(&a.created_at));
tracing::info!("Returning {} backup(s) from S3 list.", backups.len());
api_success("Listed backups", json!({ "backups": backups }))
}
#[get("/admin/backups/jobs")]
pub async fn admin_list_backup_jobs(
req: HttpRequest,
state: Data<AppState>,
query: web::Query<std::collections::HashMap<String, String>>,
) -> HttpResponse {
tracing::info!("admin_list_backup_jobs called with query: {:?}", query);
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let pool = match logging_pool(&state).await {
Ok(pool) => pool,
Err(resp) => return resp,
};
let limit: i64 = query
.get("limit")
.and_then(|v| v.parse::<i64>().ok())
.filter(|v| *v > 0 && *v <= 500)
.unwrap_or(50);
let client_filter = query.get("client_name").cloned();
let status_filter = query.get("status").cloned();
let job_type_filter = query.get("job_type").cloned();
let mut qb = QueryBuilder::new(
r#"
SELECT id, job_type, client_name, status, progress_pct, progress_stage, s3_bucket, s3_key, label, size_bytes, error_message, started_at, updated_at, completed_at
FROM backup_jobs
WHERE 1=1
"#,
);
if let Some(c) = client_filter {
qb.push(" AND client_name = ").push_bind(c);
}
if let Some(s) = status_filter {
qb.push(" AND status = ").push_bind(s);
}
if let Some(j) = job_type_filter {
qb.push(" AND job_type = ").push_bind(j);
}
qb.push(" ORDER BY started_at DESC LIMIT ").push_bind(limit);
let jobs: Vec<BackupJobSummary> = match qb
.build_query_as::<BackupJobSummary>()
.fetch_all(&pool)
.await
{
Ok(rows) => rows,
Err(err) => {
tracing::warn!("Failed to list backup jobs: {}", err);
return internal_error("Failed to list backup jobs", err.to_string());
}
};
api_success("Listed backup jobs", json!({ "jobs": jobs }))
}
#[get("/admin/backups/jobs/{id}")]
pub async fn admin_get_backup_job(
req: HttpRequest,
state: Data<AppState>,
job_id: Path<i64>,
) -> HttpResponse {
tracing::info!("admin_get_backup_job called: job_id={}", job_id);
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let pool = match logging_pool(&state).await {
Ok(pool) => pool,
Err(resp) => return resp,
};
let job = match sqlx::query_as::<_, BackupJobSummary>(
r#"
SELECT id, job_type, client_name, status, progress_pct, progress_stage, s3_bucket, s3_key, label, size_bytes, error_message, started_at, updated_at, completed_at
FROM backup_jobs
WHERE id = $1
"#,
)
.bind(*job_id)
.fetch_optional(&pool)
.await
{
Ok(Some(row)) => row,
Ok(None) => {
return HttpResponse::NotFound().json(json!({
"status": "error",
"message": format!("Backup job {} not found", *job_id)
}));
}
Err(err) => {
tracing::warn!("Failed to fetch backup job {}: {}", *job_id, err);
return internal_error("Failed to fetch backup job", err.to_string());
}
};
let logs: Vec<BackupJobLog> = match sqlx::query_as::<_, BackupJobLog>(
r#"
SELECT id, job_id, level, message, created_at
FROM backup_job_logs
WHERE job_id = $1
ORDER BY created_at DESC
LIMIT 100
"#,
)
.bind(*job_id)
.fetch_all(&pool)
.await
{
Ok(rows) => rows,
Err(err) => {
tracing::warn!("Failed to fetch backup job logs for {}: {}", *job_id, err);
return internal_error("Failed to fetch backup job logs", err.to_string());
}
};
api_success("Backup job", json!({ "job": job, "logs": logs }))
}
#[post("/admin/backups/jobs/{id}/cancel")]
pub async fn admin_cancel_backup_job(
req: HttpRequest,
state: Data<AppState>,
job_id: Path<i64>,
) -> HttpResponse {
tracing::info!("admin_cancel_backup_job called: job_id={}", job_id);
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let pool = match logging_pool(&state).await {
Ok(pool) => pool,
Err(resp) => return resp,
};
let status: Option<String> =
match sqlx::query_scalar("SELECT status::text FROM backup_jobs WHERE id = $1")
.bind(*job_id)
.fetch_optional(&pool)
.await
{
Ok(s) => s,
Err(err) => {
tracing::warn!("Failed to fetch backup job {}: {}", *job_id, err);
return internal_error("Failed to cancel backup job", err.to_string());
}
};
let Some(cur) = status else {
return not_found(
"Backup job not found",
format!("Backup job {} not found", *job_id),
);
};
if cur != "running" && cur != "pending" {
return bad_request(
"Job not active",
"Only running or pending jobs can be cancelled.",
);
}
let updated = match sqlx::query(
r#"
UPDATE backup_jobs
SET status = 'cancelled',
progress_stage = 'cancelled',
progress_pct = 0,
error_message = $2,
completed_at = now(),
updated_at = now()
WHERE id = $1 AND status IN ('running', 'pending')
"#,
)
.bind(*job_id)
.bind("Cancelled by operator")
.execute(&pool)
.await
{
Ok(r) => r.rows_affected(),
Err(err) => {
tracing::warn!("Failed to cancel backup job {}: {}", *job_id, err);
return internal_error("Failed to cancel backup job", err.to_string());
}
};
if updated == 0 {
return bad_request(
"Job not active",
"Job was not running or could not be cancelled.",
);
}
trigger_backup_cancel_token(*job_id);
if let Err(err) =
sqlx::query("INSERT INTO backup_job_logs (job_id, level, message) VALUES ($1, $2, $3)")
.bind(*job_id)
.bind("info")
.bind("Cancellation requested")
.execute(&pool)
.await
{
tracing::warn!("Failed to log cancel for job {}: {}", *job_id, err);
}
api_success(
"Backup job cancelled",
json!({ "job_id": *job_id, "status": "cancelled" }),
)
}
#[delete("/admin/backups/jobs/{id}")]
pub async fn admin_delete_backup_job(
req: HttpRequest,
state: Data<AppState>,
job_id: Path<i64>,
) -> HttpResponse {
tracing::info!("admin_delete_backup_job called: job_id={}", job_id);
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let pool = match logging_pool(&state).await {
Ok(pool) => pool,
Err(resp) => return resp,
};
let status: Option<String> = match sqlx::query_scalar(
r#"
SELECT status
FROM backup_jobs
WHERE id = $1
"#,
)
.bind(*job_id)
.fetch_optional(&pool)
.await
{
Ok(row) => row,
Err(err) => {
tracing::warn!("Failed to fetch backup job {}: {}", *job_id, err);
return internal_error("Failed to delete backup job", err.to_string());
}
};
let Some(status) = status else {
return not_found(
"Backup job not found",
format!("Backup job {} not found", *job_id),
);
};
if matches!(status.as_str(), "running" | "pending") {
return bad_request(
"Backup job still running",
"Only completed, failed, or cancelled jobs can be deleted.",
);
}
if let Err(err) = sqlx::query(
r#"
UPDATE backup_schedules
SET last_job_id = NULL
WHERE last_job_id = $1
"#,
)
.bind(*job_id)
.execute(&pool)
.await
{
tracing::warn!(
"Failed to clear schedule last_job_id for {}: {}",
*job_id,
err
);
return internal_error("Failed to delete backup job", err.to_string());
}
match sqlx::query("DELETE FROM backup_jobs WHERE id = $1")
.bind(*job_id)
.execute(&pool)
.await
{
Ok(_) => {}
Err(err) => {
tracing::warn!("Failed to delete backup job {}: {}", *job_id, err);
return internal_error("Failed to delete backup job", err.to_string());
}
}
api_success("Backup job deleted", json!({ "job_id": *job_id }))
}
#[post("/admin/backups/{key:.*}/restore")]
pub async fn admin_restore_backup(
req: HttpRequest,
state: Data<AppState>,
key_param: Path<String>,
body: Json<RestoreBackupRequest>,
) -> HttpResponse {
tracing::info!(
"admin_restore_backup called, client_name={:?}, pg_uri_provided={}, key_param={}",
body.client_name,
body.pg_uri.is_some(),
key_param
);
if let Err(resp) = authorize_static_admin_key(&req) {
tracing::info!("Authorization failed for admin_restore_backup");
return resp;
}
let logging_pool = match logging_pool(&state).await {
Ok(pool) => pool,
Err(resp) => return resp,
};
let Some(s3_cfg) = S3Config::from_env() else {
tracing::info!("S3Config could not be constructed from env for admin_restore_backup");
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let (pg_uri, effective_client_name) = if let Some(uri) = &body.pg_uri {
(
uri.clone(),
body.client_name
.clone()
.unwrap_or_else(|| "custom".to_string()),
)
} else if let Some(cn) = &body.client_name {
match resolve_pg_uri(&state, cn) {
Ok(uri) => (uri, cn.clone()),
Err(resp) => {
tracing::info!("Could not resolve pg_uri for client_name='{}'", cn);
return resp;
}
}
} else {
return bad_request("Missing target", "Provide either client_name or pg_uri.");
};
let timeout_secs: i32 = clamp_timeout(body.timeout_seconds.unwrap_or(3600));
let pg_dump_override: Option<String> = header_path_hint(&req, HEADER_PG_DUMP_PATH);
let pg_restore_override: Option<String> = header_path_hint(&req, HEADER_PG_RESTORE_PATH);
let key: String = key_param.into_inner();
if key.is_empty() {
tracing::info!("No backup key provided to admin_restore_backup");
return bad_request(
"Missing backup key",
"Provide the S3 object key as the path segment.",
);
}
let s3_client: S3Client = build_s3_client(&s3_cfg).await;
let job_id: i64 = match create_backup_job(
&logging_pool,
"restore",
&effective_client_name,
"downloading",
None,
)
.await
{
Ok(id) => id,
Err(err) => {
tracing::warn!("Failed to create restore job row: {}", err);
return internal_error("Logging unavailable", "Could not create restore job record");
}
};
let job_logger = JobLogger {
pool: logging_pool.clone(),
job_id,
};
let cancel: CancellationToken = CancellationToken::new();
register_backup_cancel_token(job_id, cancel.clone());
let _cancel_guard: BackupJobCancelGuard = BackupJobCancelGuard(job_id);
job_logger
.log("info", "Starting restore job, downloading backup")
.await;
tracing::info!(
"Calling run_pg_restore for key='{}', client_name='{}'",
key,
effective_client_name
);
let restore_result = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs as u64),
run_pg_restore(
&s3_client,
&s3_cfg.bucket,
&key,
&pg_uri,
pg_dump_override.as_deref(),
pg_restore_override.as_deref(),
Some(job_logger.clone()),
Some(&cancel),
),
)
.await;
match restore_result {
Err(_) => {
let msg = format!("pg_restore timed out after {}s", timeout_secs);
tracing::info!("{}", msg);
job_logger
.progress(
Some("failed"),
Some("pg_restore"),
Some(90),
Some(&s3_cfg.bucket),
Some(&key),
None,
Some(&msg),
Some(Utc::now()),
)
.await;
job_logger.log("error", &msg).await;
internal_error("pg_restore timed out", msg)
}
Ok(Ok(())) => {
tracing::info!("Restore succeeded for key='{}'", key);
job_logger
.progress(
Some("completed"),
Some("completed"),
Some(100),
Some(&s3_cfg.bucket),
Some(&key),
None,
None,
Some(Utc::now()),
)
.await;
job_logger
.log("info", "Restore completed successfully")
.await;
api_success(
"Restore completed",
json!({ "key": key, "client_name": effective_client_name, "job_id": job_id }),
)
}
Ok(Err(err)) if err == ERR_RESTORE_CANCELLED => {
job_logger
.progress(
Some("cancelled"),
Some("pg_restore"),
Some(0),
Some(&s3_cfg.bucket),
Some(&key),
None,
Some("Cancelled by operator"),
Some(Utc::now()),
)
.await;
job_logger.log("info", "Restore cancelled").await;
api_success(
"Restore cancelled",
json!({
"key": key,
"client_name": effective_client_name,
"job_id": job_id,
"status": "cancelled"
}),
)
}
Ok(Err(err)) => {
tracing::info!("Restore failed for key='{}': {}", key, err);
let progress_stage =
if err.starts_with(ERR_S3_DOWNLOAD_FAILED) || err.starts_with(ERR_S3_READ_FAILED) {
"downloading"
} else {
"pg_restore"
};
job_logger
.progress(
Some("failed"),
Some(progress_stage),
Some(90),
Some(&s3_cfg.bucket),
Some(&key),
None,
Some(&err),
Some(Utc::now()),
)
.await;
job_logger
.log("error", &format!("Restore failed: {}", err))
.await;
internal_error("pg_restore failed", err)
}
}
}
#[get("/admin/backups/{key:.*}/download")]
pub async fn admin_download_backup(
req: HttpRequest,
_state: Data<AppState>,
key_param: Path<String>,
) -> HttpResponse {
tracing::info!("admin_download_backup called: key_param={}", key_param);
if let Err(resp) = authorize_static_admin_key(&req) {
tracing::info!("Authorization failed for admin_download_backup");
return resp;
}
let Some(s3_cfg) = S3Config::from_env() else {
tracing::info!("S3Config could not be constructed from env for admin_download_backup");
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let key = key_param.into_inner();
if key.is_empty() {
tracing::info!("No backup key provided to admin_download_backup");
return bad_request(
"Missing backup key",
"Provide the S3 object key as the path segment.",
);
}
let s3_client = build_s3_client(&s3_cfg).await;
tracing::info!("Requesting backup archive from S3 for key='{}'", key);
let bytes = match download_s3_object_with_retry(
&s3_client,
&s3_cfg.bucket,
&key,
S3_DOWNLOAD_MAX_ATTEMPTS,
)
.await
{
Ok(b) => b,
Err(err) => {
if err.starts_with(ERR_S3_DOWNLOAD_FAILED) {
tracing::info!("S3 download failed for key '{}': {}", key, err);
return internal_error("S3 download failed", err.to_string());
}
tracing::info!("S3 read failed for key '{}': {}", key, err);
return internal_error("S3 read failed", err.to_string());
}
};
let filename = key
.rsplit('/')
.next()
.unwrap_or("backup.tar.gz")
.to_string();
tracing::info!(
"Download backup returning S3 object for key='{}', filename='{}'",
key,
filename
);
HttpResponse::Ok()
.content_type("application/gzip")
.insert_header((
"Content-Disposition",
format!("attachment; filename=\"{}\"", filename),
))
.body(bytes.to_vec())
}
#[delete("/admin/backups/{key:.*}")]
pub async fn admin_delete_backup(
req: HttpRequest,
_state: Data<AppState>,
key_param: Path<String>,
) -> HttpResponse {
tracing::info!("admin_delete_backup called: key_param={}", key_param);
if let Err(resp) = authorize_static_admin_key(&req) {
tracing::info!("Authorization failed for admin_delete_backup");
return resp;
}
let Some(s3_cfg) = S3Config::from_env() else {
tracing::info!("S3Config could not be constructed from env for admin_delete_backup");
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let key = key_param.into_inner();
if key.is_empty() {
tracing::info!("No backup key provided to admin_delete_backup");
return bad_request(
"Missing backup key",
"Provide the S3 object key as the path segment.",
);
}
let s3_client: S3Client = build_s3_client(&s3_cfg).await;
tracing::info!("Deleting S3 object for key='{}'", key);
match s3_client
.delete_object()
.bucket(&s3_cfg.bucket)
.key(&key)
.send()
.await
{
Ok(_) => {
tracing::info!("Successfully deleted S3 backup for key '{}'", key);
api_success("Backup deleted", json!({ "key": key }))
}
Err(err) => {
tracing::info!("S3 delete failed for key '{}': {}", key, err);
internal_error("S3 delete failed", err.to_string())
}
}
}
fn build_cron_expression(
frequency: &str,
time: &str,
day_of_week: Option<i32>,
day_of_month: Option<i32>,
) -> Result<String, String> {
let parts: Vec<&str> = time.split(':').collect();
let (hour, minute) = if parts.len() == 2 {
(
parts[0].parse::<u8>().map_err(|_| "Invalid hour")?,
parts[1].parse::<u8>().map_err(|_| "Invalid minute")?,
)
} else {
(2, 0)
};
match frequency {
"hourly" => Ok(format!("{minute} * * * *")),
"daily" => Ok(format!("{minute} {hour} * * *")),
"weekly" => Ok(format!("{minute} {hour} * * {}", day_of_week.unwrap_or(1))),
"monthly" => Ok(format!("{minute} {hour} {} * *", day_of_month.unwrap_or(1))),
_ => Err(format!("Invalid frequency: {frequency}")),
}
}
fn compute_next_run(
frequency: &str,
time: &str,
day_of_week: Option<i32>,
day_of_month: Option<i32>,
) -> chrono::DateTime<Utc> {
let now = Utc::now();
let parts: Vec<&str> = time.split(':').collect();
let (hour, minute) = if parts.len() == 2 {
(
parts[0].parse::<u32>().unwrap_or(2),
parts[1].parse::<u32>().unwrap_or(0),
)
} else {
(2, 0)
};
use chrono::{Datelike, Duration, Timelike};
let mut candidate = now
.with_hour(hour)
.unwrap_or(now)
.with_minute(minute)
.unwrap_or(now)
.with_second(0)
.unwrap_or(now);
if candidate <= now {
candidate = candidate + Duration::days(1);
}
match frequency {
"weekly" => {
let target_dow = day_of_week.unwrap_or(1) as u32;
while candidate.weekday().num_days_from_sunday() != target_dow {
candidate = candidate + Duration::days(1);
}
}
"monthly" => {
let target_day = day_of_month.unwrap_or(1) as u32;
while candidate.day() != target_day {
candidate = candidate + Duration::days(1);
}
}
_ => {}
}
candidate
}
#[get("/admin/backups/schedules")]
pub async fn admin_list_schedules(req: HttpRequest, state: Data<AppState>) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let pool = match logging_pool(&state).await {
Ok(p) => p,
Err(resp) => return resp,
};
let schedules: Vec<BackupScheduleRow> = match sqlx::query_as::<_, BackupScheduleRow>(
"SELECT id, client_name, pg_uri, frequency, cron_expression, time_of_day, day_of_week, day_of_month, label, enabled, timeout_seconds, last_run_at, last_job_id, next_run_at, created_at, updated_at FROM backup_schedules ORDER BY created_at DESC"
).fetch_all(&pool).await {
Ok(rows) => rows,
Err(err) => return internal_error("Failed to list schedules", err.to_string()),
};
api_success("Listed backup schedules", json!({ "schedules": schedules }))
}
#[post("/admin/backups/schedules")]
pub async fn admin_create_schedule(
req: HttpRequest,
state: Data<AppState>,
body: Json<CreateScheduleRequest>,
) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let pool = match logging_pool(&state).await {
Ok(p) => p,
Err(resp) => return resp,
};
let valid_frequencies = ["hourly", "daily", "weekly", "monthly"];
if !valid_frequencies.contains(&body.frequency.as_str()) {
return bad_request(
"Invalid frequency",
"Must be hourly, daily, weekly, or monthly.",
);
}
let cron_expr = match build_cron_expression(
&body.frequency,
&body.time,
body.day_of_week,
body.day_of_month,
) {
Ok(c) => c,
Err(e) => return bad_request("Invalid schedule", e),
};
let next_run = compute_next_run(
&body.frequency,
&body.time,
body.day_of_week,
body.day_of_month,
);
let time_of_day = chrono::NaiveTime::parse_from_str(&format!("{}:00", body.time), "%H:%M:%S")
.unwrap_or_else(|_| chrono::NaiveTime::from_hms_opt(2, 0, 0).unwrap());
let row = match sqlx::query_as::<_, BackupScheduleRow>(
r#"INSERT INTO backup_schedules (client_name, pg_uri, frequency, cron_expression, time_of_day, day_of_week, day_of_month, label, enabled, timeout_seconds, next_run_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10)
RETURNING id, client_name, pg_uri, frequency, cron_expression, time_of_day, day_of_week, day_of_month, label, enabled, timeout_seconds, last_run_at, last_job_id, next_run_at, created_at, updated_at"#
)
.bind(&body.client_name)
.bind(&body.pg_uri)
.bind(&body.frequency)
.bind(&cron_expr)
.bind(time_of_day)
.bind(body.day_of_week)
.bind(body.day_of_month)
.bind(&body.label)
.bind(body.timeout_seconds)
.bind(next_run)
.fetch_one(&pool)
.await
{
Ok(r) => r,
Err(err) => return internal_error("Failed to create schedule", err.to_string()),
};
api_success("Created backup schedule", json!({ "schedule": row }))
}
#[patch("/admin/backups/schedules/{id}")]
pub async fn admin_update_schedule(
req: HttpRequest,
state: Data<AppState>,
schedule_id: Path<i64>,
body: Json<UpdateScheduleRequest>,
) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let pool = match logging_pool(&state).await {
Ok(p) => p,
Err(resp) => return resp,
};
let existing = match sqlx::query_as::<_, BackupScheduleRow>(
"SELECT id, client_name, pg_uri, frequency, cron_expression, time_of_day, day_of_week, day_of_month, label, enabled, timeout_seconds, last_run_at, last_job_id, next_run_at, created_at, updated_at FROM backup_schedules WHERE id = $1"
).bind(*schedule_id).fetch_optional(&pool).await {
Ok(Some(r)) => r,
Ok(None) => return HttpResponse::NotFound().json(json!({"status":"error","message":"Schedule not found"})),
Err(err) => return internal_error("Failed to fetch schedule", err.to_string()),
};
let freq = body.frequency.as_deref().unwrap_or(&existing.frequency);
let existing_time_str = existing.time_of_day.format("%H:%M").to_string();
let time_str = body.time.as_deref().unwrap_or(&existing_time_str);
let dow = match &body.day_of_week {
Some(v) => *v,
None => existing.day_of_week,
};
let dom = match &body.day_of_month {
Some(v) => *v,
None => existing.day_of_month,
};
let lbl = match &body.label {
Some(v) => v.clone(),
None => existing.label.clone(),
};
let enabled = body.enabled.unwrap_or(existing.enabled);
let timeout = body.timeout_seconds.unwrap_or(existing.timeout_seconds);
let cron_expr = match build_cron_expression(freq, time_str, dow, dom) {
Ok(c) => c,
Err(e) => return bad_request("Invalid schedule", e),
};
let next_run = if enabled {
Some(compute_next_run(freq, time_str, dow, dom))
} else {
None
};
let time_of_day = chrono::NaiveTime::parse_from_str(&format!("{}:00", time_str), "%H:%M:%S")
.unwrap_or(existing.time_of_day);
let row = match sqlx::query_as::<_, BackupScheduleRow>(
r#"UPDATE backup_schedules SET frequency = $2, cron_expression = $3, time_of_day = $4, day_of_week = $5, day_of_month = $6, label = $7, enabled = $8, timeout_seconds = $9, next_run_at = $10, updated_at = now()
WHERE id = $1
RETURNING id, client_name, pg_uri, frequency, cron_expression, time_of_day, day_of_week, day_of_month, label, enabled, timeout_seconds, last_run_at, last_job_id, next_run_at, created_at, updated_at"#
)
.bind(*schedule_id)
.bind(freq)
.bind(&cron_expr)
.bind(time_of_day)
.bind(dow)
.bind(dom)
.bind(&lbl)
.bind(enabled)
.bind(timeout)
.bind(next_run)
.fetch_one(&pool)
.await
{
Ok(r) => r,
Err(err) => return internal_error("Failed to update schedule", err.to_string()),
};
api_success("Updated backup schedule", json!({ "schedule": row }))
}
#[delete("/admin/backups/schedules/{id}")]
pub async fn admin_delete_schedule(
req: HttpRequest,
state: Data<AppState>,
schedule_id: Path<i64>,
) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let pool = match logging_pool(&state).await {
Ok(p) => p,
Err(resp) => return resp,
};
match sqlx::query("DELETE FROM backup_schedules WHERE id = $1")
.bind(*schedule_id)
.execute(&pool)
.await
{
Ok(r) if r.rows_affected() == 0 => {
HttpResponse::NotFound().json(json!({"status":"error","message":"Schedule not found"}))
}
Ok(_) => api_success("Deleted backup schedule", json!({"id": *schedule_id})),
Err(err) => internal_error("Failed to delete schedule", err.to_string()),
}
}
pub fn services(cfg: &mut web::ServiceConfig) {
tracing::debug!("Registering backup service routes.");
cfg.service(admin_create_backup)
.service(admin_backup_config)
.service(admin_list_backups)
.service(admin_list_backup_jobs)
.service(admin_get_backup_job)
.service(admin_cancel_backup_job)
.service(admin_delete_backup_job)
.service(admin_list_schedules)
.service(admin_create_schedule)
.service(admin_update_schedule)
.service(admin_delete_schedule)
.service(admin_download_backup)
.service(admin_restore_backup)
.service(admin_delete_backup);
}
#[cfg(test)]
mod tests {
use super::{
PG_DUMP_PROGRESS_MAX_PCT, PG_DUMP_PROGRESS_MIN_PCT, S3_DOWNLOAD_RETRY_BASE_MS,
estimate_pg_dump_progress_pct, looks_like_ssl_required_error, maybe_record_backup_object,
pg_uri_for_backup_tools, pg_uri_set_sslmode, postgres_uri_target_is_local_pg_tools,
resolve_restore_dump_dir, s3_retry_delay, sslmode_cache_mark_require,
sslmode_cached_require, sslmode_is_present, update_latest_backup_candidate,
with_sslmode_require,
};
use std::path::PathBuf;
use uuid::Uuid;
#[test]
fn postgres_uri_target_is_local_pg_tools_detects_loopback_and_sockets() {
assert!(postgres_uri_target_is_local_pg_tools(
"postgres://user@127.0.0.1:46035/db"
));
assert!(postgres_uri_target_is_local_pg_tools(
"postgres://user:pass@localhost:5432/db"
));
assert!(postgres_uri_target_is_local_pg_tools(
"postgres://user@[::1]:5432/db"
));
assert!(postgres_uri_target_is_local_pg_tools(
"postgresql:///dbname"
));
assert!(postgres_uri_target_is_local_pg_tools(
"postgres://user@/var/run/postgresql/.s.PGSQL.5432/db"
));
assert!(!postgres_uri_target_is_local_pg_tools(
"postgres://user@db.example.com:5432/db"
));
}
#[test]
fn pg_uri_for_backup_tools_forces_disable_on_loopback() {
assert_eq!(
pg_uri_for_backup_tools("postgres://u@127.0.0.1:46035/db?sslmode=require"),
"postgres://u@127.0.0.1:46035/db?sslmode=disable"
);
assert_eq!(
pg_uri_set_sslmode(
"postgres://u@127.0.0.1/db?connect_timeout=3&sslmode=require",
"disable"
),
"postgres://u@127.0.0.1/db?connect_timeout=3&sslmode=disable"
);
}
#[test]
fn sslmode_is_present_detects_query_param() {
assert!(!sslmode_is_present("postgres://user@host/db"));
assert!(sslmode_is_present(
"postgres://user@host/db?sslmode=require"
));
assert!(sslmode_is_present(
"postgres://user@host/db?connect_timeout=5&sslmode=require"
));
}
#[test]
fn with_sslmode_require_appends_correctly() {
assert_eq!(
with_sslmode_require("postgres://user@host/db"),
"postgres://user@host/db?sslmode=require"
);
assert_eq!(
with_sslmode_require("postgres://user@host/db?connect_timeout=5"),
"postgres://user@host/db?connect_timeout=5&sslmode=require"
);
assert_eq!(
with_sslmode_require("postgres://user@host/db?sslmode=disable"),
"postgres://user@host/db?sslmode=disable"
);
}
#[test]
fn looks_like_ssl_required_error_matches_common_patterns() {
assert!(looks_like_ssl_required_error(
"FATAL: no pg_hba.conf entry for host \"1.2.3.4\", user \"postgres\", database \"railway\", SSL off"
));
assert!(looks_like_ssl_required_error(
"SSL connection is required. Please set sslmode=require"
));
assert!(!looks_like_ssl_required_error(
"password authentication failed"
));
}
#[test]
fn sslmode_dedup_cache_is_opt_in() {
let key = format!(
"postgres://user:{}@host/db",
Uuid::new_v4().to_string().replace('-', "")
);
unsafe {
std::env::remove_var("ATHENA_PG_SSLMODE_DEDUP");
}
sslmode_cache_mark_require(&key);
assert!(!sslmode_cached_require(&key));
unsafe {
std::env::set_var("ATHENA_PG_SSLMODE_DEDUP", "1");
}
assert!(!sslmode_cached_require(&key));
sslmode_cache_mark_require(&key);
assert!(sslmode_cached_require(&key));
let local = "postgres://u@127.0.0.1:5432/db";
sslmode_cache_mark_require(local);
assert!(!sslmode_cached_require(local));
unsafe {
std::env::remove_var("ATHENA_PG_SSLMODE_DEDUP");
}
}
#[test]
fn s3_retry_delay_uses_exponential_backoff() {
assert_eq!(
s3_retry_delay(1),
std::time::Duration::from_millis(S3_DOWNLOAD_RETRY_BASE_MS)
);
assert_eq!(
s3_retry_delay(2),
std::time::Duration::from_millis(S3_DOWNLOAD_RETRY_BASE_MS * 2)
);
assert_eq!(
s3_retry_delay(3),
std::time::Duration::from_millis(S3_DOWNLOAD_RETRY_BASE_MS * 4)
);
}
#[test]
fn pg_dump_progress_starts_at_minimum_when_empty() {
let pct = estimate_pg_dump_progress_pct(0, Some(10_000_000));
assert_eq!(pct, PG_DUMP_PROGRESS_MIN_PCT);
}
#[test]
fn pg_dump_progress_with_reference_moves_high_near_expected_size() {
let pct = estimate_pg_dump_progress_pct(1_000_000, Some(1_000_000));
assert!(pct > 50, "expected >50, got {pct}");
assert!(
pct < PG_DUMP_PROGRESS_MAX_PCT,
"expected below max, got {pct}"
);
}
#[test]
fn pg_dump_progress_with_reference_caps_at_max() {
let pct = estimate_pg_dump_progress_pct(50_000_000, Some(1_000_000));
assert_eq!(pct, PG_DUMP_PROGRESS_MAX_PCT);
}
#[test]
fn pg_dump_progress_without_reference_is_monotonic() {
let small = estimate_pg_dump_progress_pct(8 * 1024 * 1024, None);
let medium = estimate_pg_dump_progress_pct(64 * 1024 * 1024, None);
let large = estimate_pg_dump_progress_pct(512 * 1024 * 1024, None);
assert!(
small >= PG_DUMP_PROGRESS_MIN_PCT,
"small below min: {small}"
);
assert!(
small < medium,
"small ({small}) should be < medium ({medium})"
);
assert!(
medium < large,
"medium ({medium}) should be < large ({large})"
);
assert!(
large < PG_DUMP_PROGRESS_MAX_PCT,
"fallback estimator should leave headroom, got {large}"
);
}
#[test]
fn pg_dump_progress_with_reference_is_monotonic() {
let checkpoints = [
estimate_pg_dump_progress_pct(0, Some(1_000_000)),
estimate_pg_dump_progress_pct(200_000, Some(1_000_000)),
estimate_pg_dump_progress_pct(800_000, Some(1_000_000)),
estimate_pg_dump_progress_pct(1_000_000, Some(1_000_000)),
estimate_pg_dump_progress_pct(1_200_000, Some(1_000_000)),
estimate_pg_dump_progress_pct(2_000_000, Some(1_000_000)),
];
for pair in checkpoints.windows(2) {
assert!(
pair[0] <= pair[1],
"progress should be monotonic but saw {} -> {}",
pair[0],
pair[1]
);
}
assert!(checkpoints[5] <= PG_DUMP_PROGRESS_MAX_PCT);
}
#[test]
fn latest_backup_candidate_prefers_newer_timestamp_and_larger_tie_break() {
let mut newest: Option<(i64, i64)> = None;
update_latest_backup_candidate(&mut newest, 100, 50);
assert_eq!(newest, Some((100, 50)));
update_latest_backup_candidate(&mut newest, 99, 999);
assert_eq!(newest, Some((100, 50)));
update_latest_backup_candidate(&mut newest, 101, 40);
assert_eq!(newest, Some((101, 40)));
update_latest_backup_candidate(&mut newest, 101, 55);
assert_eq!(newest, Some((101, 55)));
}
#[test]
fn mock_s3_paginated_selection_prefers_latest_timestamp() {
let mut newest: Option<(i64, i64)> = None;
let page_one = [
("backups/athena_logging/first.tar.gz", 1_000_i64, 120_i64),
("backups/athena_logging/skip.txt", 1_100_i64, 9_999_i64),
];
let page_two = [("backups/athena_logging/second.tar.gz", 2_000_i64, 240_i64)];
for (key, modified_secs, size) in page_one.into_iter().chain(page_two) {
maybe_record_backup_object(&mut newest, key, modified_secs, size);
}
assert_eq!(newest, Some((2_000, 240)));
}
#[test]
fn mock_s3_paginated_selection_breaks_same_timestamp_with_size() {
let mut newest: Option<(i64, i64)> = None;
let page_one = [
(
"backups/athena_logging/same-ts-small.tar.gz",
5_000_i64,
320_i64,
),
("backups/athena_logging/zero-size.tar.gz", 5_000_i64, 0_i64),
];
let page_two = [(
"backups/athena_logging/same-ts-large.tar.gz",
5_000_i64,
512_i64,
)];
for (key, modified_secs, size) in page_one.into_iter().chain(page_two) {
maybe_record_backup_object(&mut newest, key, modified_secs, size);
}
assert_eq!(newest, Some((5_000, 512)));
}
#[tokio::test]
async fn resolve_restore_dump_dir_accepts_root_layout() {
let root: PathBuf =
std::env::temp_dir().join(format!("athena_restore_test_{}", Uuid::new_v4()));
tokio::fs::create_dir_all(&root).await.unwrap();
tokio::fs::write(root.join("toc.dat"), b"test")
.await
.unwrap();
let resolved = resolve_restore_dump_dir(&root).await.unwrap();
assert_eq!(resolved, root);
let _ = tokio::fs::remove_dir_all(&resolved).await;
}
#[tokio::test]
async fn resolve_restore_dump_dir_accepts_nested_dump_layout() {
let root: PathBuf =
std::env::temp_dir().join(format!("athena_restore_test_{}", Uuid::new_v4()));
let nested = root.join("dump");
tokio::fs::create_dir_all(&nested).await.unwrap();
tokio::fs::write(nested.join("toc.dat"), b"test")
.await
.unwrap();
let resolved = resolve_restore_dump_dir(&root).await.unwrap();
assert_eq!(resolved, nested);
let _ = tokio::fs::remove_dir_all(&root).await;
}
}