use std::env;
use std::path::PathBuf;
use actix_web::{
HttpRequest, HttpResponse, delete, get, post,
web::{self, Data, Json, Path},
};
use aws_config::BehaviorVersion;
use aws_sdk_s3::Client as S3Client;
use aws_sdk_s3::config::{Credentials, Region};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::process::Command;
use uuid::Uuid;
use crate::AppState;
use crate::api::auth::authorize_static_admin_key;
use crate::api::response::{api_success, bad_request, internal_error, service_unavailable};
use crate::parser::resolve_postgres_uri;
use crate::utils::pg_tools::{PgToolsPaths, ensure_pg_tools};
fn s3_env(key: &str) -> Option<String> {
env::var(key).ok().filter(|v| !v.trim().is_empty())
}
struct S3Config {
bucket: String,
region: String,
prefix: String,
access_key: Option<String>,
secret_key: Option<String>,
endpoint: Option<String>,
}
impl S3Config {
fn from_env() -> Option<Self> {
let bucket = s3_env("ATHENA_BACKUP_S3_BUCKET")?;
Some(Self {
bucket,
region: s3_env("ATHENA_BACKUP_S3_REGION").unwrap_or_else(|| "us-east-1".to_string()),
prefix: s3_env("ATHENA_BACKUP_S3_PREFIX").unwrap_or_else(|| "backups".to_string()),
access_key: s3_env("ATHENA_BACKUP_S3_ACCESS_KEY"),
secret_key: s3_env("ATHENA_BACKUP_S3_SECRET_KEY"),
endpoint: s3_env("ATHENA_BACKUP_S3_ENDPOINT"),
})
}
}
async fn build_s3_client(cfg: &S3Config) -> S3Client {
let region: Region = Region::new(cfg.region.clone());
let aws_config: aws_config::SdkConfig = if let (Some(ak), Some(sk)) =
(&cfg.access_key, &cfg.secret_key)
{
let creds: Credentials = Credentials::new(ak, sk, None, None, "athena-env");
let mut builder: aws_config::ConfigLoader = aws_config::defaults(BehaviorVersion::latest())
.region(region)
.credentials_provider(creds);
if let Some(ep) = &cfg.endpoint {
builder = builder.endpoint_url(ep);
}
builder.load().await
} else {
let mut builder: aws_config::ConfigLoader =
aws_config::defaults(BehaviorVersion::latest()).region(region);
if let Some(ep) = &cfg.endpoint {
builder = builder.endpoint_url(ep);
}
builder.load().await
};
let mut s3_cfg_builder: aws_sdk_s3::config::Builder =
aws_sdk_s3::config::Builder::from(&aws_config);
if cfg.endpoint.is_some() {
s3_cfg_builder = s3_cfg_builder.force_path_style(true);
}
S3Client::from_conf(s3_cfg_builder.build())
}
#[derive(Debug, Deserialize)]
struct CreateBackupRequest {
client_name: String,
#[serde(default)]
label: Option<String>,
}
#[derive(Debug, Deserialize)]
struct RestoreBackupRequest {
client_name: String,
}
#[derive(Debug, Serialize)]
struct BackupObject {
id: String,
key: String,
client_name: String,
label: Option<String>,
size_bytes: i64,
created_at: String,
}
fn resolve_pg_uri(state: &AppState, client_name: &str) -> Result<String, HttpResponse> {
let registered: crate::drivers::postgresql::sqlx_driver::RegisteredClient = state
.pg_registry
.registered_client(client_name)
.ok_or_else(|| {
bad_request(
"Unknown client",
format!("No Postgres client named '{}' is registered.", client_name),
)
})?;
let uri = registered
.config_uri_template
.as_deref()
.map(resolve_postgres_uri)
.or(registered.pg_uri)
.ok_or_else(|| {
bad_request(
"Client URI unavailable",
format!("No Postgres URI is available for client '{}'.", client_name),
)
})?;
Ok(uri)
}
fn extract_pg_password(pg_uri: &str) -> (String, Option<String>) {
let prefix = if pg_uri.starts_with("postgresql://") {
"postgresql://"
} else if pg_uri.starts_with("postgres://") {
"postgres://"
} else {
return (pg_uri.to_string(), None);
};
let after_scheme = &pg_uri[prefix.len()..];
if let Some(at_pos) = after_scheme.rfind('@') {
let userinfo = &after_scheme[..at_pos];
let after_at = &after_scheme[at_pos..];
if let Some(colon_pos) = userinfo.find(':') {
let user = &userinfo[..colon_pos];
let password = userinfo[colon_pos + 1..].to_string();
let sanitized = format!("{}{}{}", prefix, user, after_at);
return (sanitized, Some(password));
}
}
(pg_uri.to_string(), None)
}
async fn run_pg_dump(pg_uri: &str) -> Result<PathBuf, String> {
let tmp_root = env::temp_dir().join(format!("athena_backup_{}", Uuid::new_v4()));
let dump_dir = tmp_root.join("dump");
let archive_path = tmp_root.join("backup.tar.gz");
let pg_tools: PgToolsPaths = ensure_pg_tools()
.await
.map_err(|e| format!("pg_dump resolution failed: {e}"))?;
tokio::fs::create_dir_all(&dump_dir)
.await
.map_err(|e| format!("Could not create temp directory: {e}"))?;
let (pg_uri_safe, pg_password) = extract_pg_password(pg_uri);
let mut cmd = Command::new(&pg_tools.pg_dump);
if let Some(pass) = pg_password {
cmd.env("PGPASSWORD", pass);
}
let status = cmd
.args(["--format=directory", "--file"])
.arg(&dump_dir)
.arg(&pg_uri_safe)
.status()
.await
.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
"pg_dump binary could not be resolved (PATH/env/download). Set ATHENA_PG_DUMP_PATH to override."
.to_string()
} else {
format!("Failed to start pg_dump: {e}")
}
})?;
if !status.success() {
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
return Err(format!("pg_dump exited with status {status}"));
}
let dump_dir_clone = dump_dir.clone();
let archive_path_clone = archive_path.clone();
tokio::task::spawn_blocking(move || -> Result<(), String> {
let file = std::fs::File::create(&archive_path_clone)
.map_err(|e| format!("Cannot create archive: {e}"))?;
let enc = flate2::write::GzEncoder::new(file, flate2::Compression::default());
let mut builder = tar::Builder::new(enc);
builder
.append_dir_all("dump", &dump_dir_clone)
.map_err(|e| format!("Cannot archive dump directory: {e}"))?;
builder
.finish()
.map_err(|e| format!("Cannot finalize archive: {e}"))?;
Ok(())
})
.await
.map_err(|e| format!("Archive task panicked: {e}"))??;
let _ = tokio::fs::remove_dir_all(&dump_dir).await;
Ok(archive_path)
}
async fn run_pg_restore(
s3_client: &S3Client,
bucket: &str,
key: &str,
pg_uri: &str,
) -> Result<(), String> {
let resp: aws_sdk_s3::operation::get_object::GetObjectOutput = s3_client
.get_object()
.bucket(bucket)
.key(key)
.send()
.await
.map_err(|e| format!("S3 download failed: {e}"))?;
let bytes: web::Bytes = resp
.body
.collect()
.await
.map_err(|e| format!("S3 read failed: {e}"))?
.into_bytes();
let tmp_root: PathBuf = env::temp_dir().join(format!("athena_restore_{}", Uuid::new_v4()));
tokio::fs::create_dir_all(&tmp_root)
.await
.map_err(|e| format!("Could not create temp dir: {e}"))?;
let archive_path: PathBuf = tmp_root.join("backup.tar.gz");
let restore_dir: PathBuf = tmp_root.join("dump");
tokio::fs::write(&archive_path, &bytes)
.await
.map_err(|e| format!("Could not write archive: {e}"))?;
let archive_path_clone: PathBuf = archive_path.clone();
let restore_dir_clone: PathBuf = restore_dir.clone();
tokio::task::spawn_blocking(move || -> Result<(), String> {
let file = std::fs::File::open(&archive_path_clone)
.map_err(|e| format!("Cannot open archive: {e}"))?;
let dec = flate2::read::GzDecoder::new(file);
let mut archive: tar::Archive<flate2::read::GzDecoder<std::fs::File>> =
tar::Archive::new(dec);
archive
.unpack(&restore_dir_clone)
.map_err(|e| format!("Cannot extract archive: {e}"))?;
Ok(())
})
.await
.map_err(|e| format!("Extract task panicked: {e}"))??;
let inner_dump_dir: PathBuf = restore_dir.join("dump");
let (pg_uri_safe, pg_password) = extract_pg_password(pg_uri);
let pg_tools: PgToolsPaths = ensure_pg_tools()
.await
.map_err(|e| format!("pg_restore resolution failed: {e}"))?;
let mut cmd: Command = Command::new(&pg_tools.pg_restore);
if let Some(pass) = pg_password {
cmd.env("PGPASSWORD", pass);
}
let status: std::process::ExitStatus = cmd
.args(["--format=directory", "--clean", "--if-exists", "--dbname"])
.arg(&pg_uri_safe)
.arg(&inner_dump_dir)
.status()
.await
.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
"pg_restore binary not found in PATH — ensure PostgreSQL client tools are installed"
.to_string()
} else {
format!("Failed to start pg_restore: {e}")
}
})?;
let _ = tokio::fs::remove_dir_all(&tmp_root).await;
if !status.success() {
return Err(format!("pg_restore exited with status {status}"));
}
Ok(())
}
async fn upload_to_s3(
s3_client: &S3Client,
cfg: &S3Config,
local_path: &PathBuf,
client_name: &str,
label: Option<&str>,
) -> Result<String, String> {
let backup_id: String = Uuid::new_v4().to_string();
let key: String = format!("{}/{}/{}.tar.gz", cfg.prefix, client_name, backup_id);
let data: Vec<u8> = tokio::fs::read(local_path)
.await
.map_err(|e| format!("Cannot read archive file: {e}"))?;
let mut req: aws_sdk_s3::operation::put_object::builders::PutObjectFluentBuilder = s3_client
.put_object()
.bucket(&cfg.bucket)
.key(&key)
.body(data.into())
.content_type("application/gzip")
.metadata("client_name", client_name)
.metadata("backup_id", &backup_id)
.metadata("created_at", Utc::now().to_rfc3339());
if let Some(lbl) = label {
req = req.metadata("label", lbl);
}
req.send()
.await
.map_err(|e| format!("S3 upload failed: {e}"))?;
Ok(key)
}
#[post("/admin/backups")]
pub async fn admin_create_backup(
req: HttpRequest,
state: Data<AppState>,
body: Json<CreateBackupRequest>,
) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let Some(s3_cfg) = S3Config::from_env() else {
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let pg_uri = match resolve_pg_uri(&state, &body.client_name) {
Ok(uri) => uri,
Err(resp) => return resp,
};
let archive_path = match run_pg_dump(&pg_uri).await {
Ok(p) => p,
Err(err) => return internal_error("pg_dump failed", err),
};
let s3_client = build_s3_client(&s3_cfg).await;
let key = match upload_to_s3(
&s3_client,
&s3_cfg,
&archive_path,
&body.client_name,
body.label.as_deref(),
)
.await
{
Ok(k) => k,
Err(err) => {
let _ = tokio::fs::remove_file(&archive_path).await;
return internal_error("S3 upload failed", err);
}
};
if let Some(parent) = archive_path.parent() {
let _ = tokio::fs::remove_dir_all(parent).await;
}
api_success(
"Backup created",
json!({
"key": key,
"client_name": body.client_name,
"label": body.label,
}),
)
}
#[get("/admin/backups")]
pub async fn admin_list_backups(
req: HttpRequest,
_state: Data<AppState>,
query: web::Query<std::collections::HashMap<String, String>>,
) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let Some(s3_cfg) = S3Config::from_env() else {
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let s3_client = build_s3_client(&s3_cfg).await;
let filter_client = query.get("client_name").cloned();
let prefix = match &filter_client {
Some(cn) => format!("{}/{}/", s3_cfg.prefix, cn),
None => format!("{}/", s3_cfg.prefix),
};
let resp = match s3_client
.list_objects_v2()
.bucket(&s3_cfg.bucket)
.prefix(&prefix)
.send()
.await
{
Ok(r) => r,
Err(err) => return internal_error("Failed to list S3 objects", err.to_string()),
};
let mut backups: Vec<BackupObject> = Vec::new();
for obj in resp.contents() {
let key = obj.key().unwrap_or_default().to_string();
let parts: Vec<&str> = key.split('/').collect();
let (client_name, id) = if parts.len() >= 3 {
let cn = parts[parts.len() - 2].to_string();
let id = parts
.last()
.and_then(|s| s.strip_suffix(".tar.gz"))
.unwrap_or_else(|| parts.last().copied().unwrap_or(""))
.to_string();
(cn, id)
} else {
tracing::warn!(key = %key, "S3 backup key does not match expected format <prefix>/<client_name>/<id>.tar.gz");
(
"unknown".to_string(),
parts
.last()
.and_then(|s| s.strip_suffix(".tar.gz"))
.unwrap_or_else(|| parts.last().copied().unwrap_or(&key))
.to_string(),
)
};
let label = match s3_client
.head_object()
.bucket(&s3_cfg.bucket)
.key(&key)
.send()
.await
{
Ok(head) => head.metadata().and_then(|m| m.get("label")).cloned(),
Err(_) => None,
};
let size_bytes = obj.size().unwrap_or(0);
let created_at = obj
.last_modified()
.map(|t| t.to_string())
.unwrap_or_default();
backups.push(BackupObject {
id,
key,
client_name,
label,
size_bytes,
created_at,
});
}
backups.sort_by(|a, b| b.created_at.cmp(&a.created_at));
api_success("Listed backups", json!({ "backups": backups }))
}
#[post("/admin/backups/{key:.*}/restore")]
pub async fn admin_restore_backup(
req: HttpRequest,
state: Data<AppState>,
key_param: Path<String>,
body: Json<RestoreBackupRequest>,
) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let Some(s3_cfg) = S3Config::from_env() else {
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let pg_uri = match resolve_pg_uri(&state, &body.client_name) {
Ok(uri) => uri,
Err(resp) => return resp,
};
let key = key_param.into_inner();
if key.is_empty() {
return bad_request(
"Missing backup key",
"Provide the S3 object key as the path segment.",
);
}
let s3_client = build_s3_client(&s3_cfg).await;
match run_pg_restore(&s3_client, &s3_cfg.bucket, &key, &pg_uri).await {
Ok(()) => api_success(
"Restore completed",
json!({ "key": key, "client_name": body.client_name }),
),
Err(err) => internal_error("pg_restore failed", err),
}
}
#[get("/admin/backups/{key:.*}/download")]
pub async fn admin_download_backup(
req: HttpRequest,
_state: Data<AppState>,
key_param: Path<String>,
) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let Some(s3_cfg) = S3Config::from_env() else {
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let key = key_param.into_inner();
if key.is_empty() {
return bad_request(
"Missing backup key",
"Provide the S3 object key as the path segment.",
);
}
let s3_client = build_s3_client(&s3_cfg).await;
let resp = match s3_client
.get_object()
.bucket(&s3_cfg.bucket)
.key(&key)
.send()
.await
{
Ok(r) => r,
Err(err) => return internal_error("S3 download failed", err.to_string()),
};
let bytes = match resp.body.collect().await {
Ok(b) => b.into_bytes(),
Err(err) => return internal_error("S3 read failed", err.to_string()),
};
let filename = key
.rsplit('/')
.next()
.unwrap_or("backup.tar.gz")
.to_string();
HttpResponse::Ok()
.content_type("application/gzip")
.insert_header((
"Content-Disposition",
format!("attachment; filename=\"{}\"", filename),
))
.body(bytes.to_vec())
}
#[delete("/admin/backups/{key:.*}")]
pub async fn admin_delete_backup(
req: HttpRequest,
_state: Data<AppState>,
key_param: Path<String>,
) -> HttpResponse {
if let Err(resp) = authorize_static_admin_key(&req) {
return resp;
}
let Some(s3_cfg) = S3Config::from_env() else {
return service_unavailable(
"S3 not configured",
"Set ATHENA_BACKUP_S3_BUCKET and related environment variables to enable backups.",
);
};
let key = key_param.into_inner();
if key.is_empty() {
return bad_request(
"Missing backup key",
"Provide the S3 object key as the path segment.",
);
}
let s3_client: S3Client = build_s3_client(&s3_cfg).await;
match s3_client
.delete_object()
.bucket(&s3_cfg.bucket)
.key(&key)
.send()
.await
{
Ok(_) => api_success("Backup deleted", json!({ "key": key })),
Err(err) => internal_error("S3 delete failed", err.to_string()),
}
}
pub fn services(cfg: &mut web::ServiceConfig) {
cfg.service(admin_create_backup)
.service(admin_list_backups)
.service(admin_download_backup)
.service(admin_restore_backup)
.service(admin_delete_backup);
}