use actix_web::dev::ServiceResponse;
use actix_web::http::StatusCode as ActixStatusCode;
use actix_web::web::Bytes;
use actix_web::{
App,
body::to_bytes,
test::{TestRequest, call_service, init_service},
};
use anyhow::{Context, Result, anyhow, bail};
use clap::{Args, Parser, Subcommand};
use colored::Colorize;
use reqwest::{Client as HttpClient, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::{BTreeSet, HashMap};
use std::fs::{self, File, OpenOptions};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::process::{Command as ProcessCommand, Stdio};
#[cfg(feature = "cdc")]
use std::time::Duration;
#[cfg(unix)]
use std::{
os::unix::fs::{MetadataExt, PermissionsExt},
time::{SystemTime, UNIX_EPOCH},
};
use crate::api::pipelines::run_pipeline;
use crate::bootstrap::Bootstrap;
#[cfg(feature = "cdc")]
use crate::cdc::postgres::sequin::{
AuditLogger, backfill_from_csv, load_table_configs, stream_events,
};
#[cfg(feature = "cdc")]
use crate::client::{AthenaClient, backend::BackendType};
#[cfg(feature = "cdc")]
use crate::config::Config;
use crate::provisioning::{ProvisioningError, run_provision_sql};
const DEFAULT_CLIENTS_PATH: &str = "config/clients.json";
const DEFAULT_FETCH_URL: &str = "https://athena-db.com/gateway/fetch";
#[derive(Parser)]
#[command(name = "athena_rs")]
#[command(author, version, about = "Athena API server + CLI helpers")]
pub struct AthenaCli {
#[arg(
long = "config",
alias = "config-path",
global = true,
value_name = "PATH",
env = "ATHENA_CONFIG_PATH",
help = "Override the default configuration file path (default: config.yaml). May be set via ATHENA_CONFIG_PATH."
)]
pub config_path: Option<PathBuf>,
#[arg(
long,
global = true,
value_name = "PATH",
default_value = "config/pipelines.yaml"
)]
pub pipelines_path: PathBuf,
#[arg(long, global = true, value_name = "PORT")]
pub port: Option<u16>,
#[arg(long = "max-workers", global = true, value_name = "COUNT")]
pub max_workers: Option<usize>,
#[arg(long, global = true)]
pub api_only: bool,
#[cfg(feature = "cdc")]
#[arg(long, global = true)]
pub cdc_only: bool,
#[command(subcommand)]
pub command: Option<Command>,
}
#[derive(Subcommand)]
pub enum Command {
Server,
Pipeline(PipelineArgs),
Clients {
#[command(subcommand)]
command: ClientCommand,
},
Fetch(FetchArgs),
Pools(PoolsArgs),
#[cfg(feature = "cdc")]
Cdc {
#[command(subcommand)]
command: CdcCommand,
},
Provision(ProvisionArgs),
Diag,
Install(InstallArgs),
Version,
}
#[derive(Args, Debug, Clone)]
pub struct InstallArgs {
#[arg(long, value_name = "PATH", default_value = "/opt/athena/pg_tools")]
pub pg_tools_dir: PathBuf,
#[arg(long, value_delimiter = ',', default_value = "14,15,16,17,18")]
pub pg_majors: Vec<u16>,
#[arg(long, value_name = "UNIT", default_value = "athena.service")]
pub service: String,
#[arg(long, default_value_t = false)]
pub restart: bool,
#[arg(long, default_value_t = false)]
pub disable_pm2_packagecloud: bool,
#[arg(long, default_value_t = false)]
pub dry_run: bool,
}
#[derive(Args)]
pub struct ProvisionArgs {
#[arg(long, value_name = "URI")]
pub uri: String,
#[arg(long, value_name = "NAME")]
pub client_name: Option<String>,
#[arg(long, default_value_t = false)]
pub register: bool,
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
pub clients_path: PathBuf,
}
#[derive(Args)]
pub struct PipelineArgs {
#[arg(short, long, value_name = "CLIENT")]
pub client: String,
#[arg(long, conflicts_with = "payload_file")]
pub payload_json: Option<String>,
#[arg(long, value_name = "PATH", conflicts_with = "payload_json")]
pub payload_file: Option<PathBuf>,
#[arg(long, value_name = "NAME")]
pub pipeline: Option<String>,
}
#[derive(Subcommand)]
pub enum ClientCommand {
List {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
path: PathBuf,
},
Add {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
path: PathBuf,
name: String,
uri: String,
},
Remove {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
path: PathBuf,
name: String,
},
SetDefault {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
path: PathBuf,
name: String,
},
}
#[derive(Args)]
pub struct FetchArgs {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
pub clients_path: PathBuf,
#[arg(long)]
pub client: Option<String>,
#[arg(long, value_name = "URL", default_value = DEFAULT_FETCH_URL)]
pub url: String,
#[arg(long, conflicts_with = "body_file")]
pub body_json: Option<String>,
#[arg(long, value_name = "PATH", conflicts_with = "body_json")]
pub body_file: Option<PathBuf>,
}
#[derive(Args)]
pub struct PoolsArgs {
#[arg(long, value_name = "URL", default_value = "http://localhost:4052")]
pub url: String,
#[arg(long, value_name = "KEY")]
pub admin_key: Option<String>,
}
#[cfg(feature = "cdc")]
#[derive(Subcommand)]
pub enum CdcCommand {
Backfill(CdcBackfillArgs),
Stream(CdcStreamArgs),
}
#[cfg(feature = "cdc")]
#[derive(Args)]
pub struct CdcBackfillArgs {
#[arg(long, value_name = "PATH", default_value = "src/cdc/sequin_events.csv")]
pub csv_path: PathBuf,
#[arg(long, value_name = "PATH")]
pub table_config: Option<PathBuf>,
#[arg(long, value_name = "PATH", default_value = "cdc/state.json")]
pub state_file: PathBuf,
#[arg(long)]
pub backend: Option<String>,
#[arg(long, value_name = "URL")]
pub url: String,
#[arg(long, value_name = "KEY")]
pub key: Option<String>,
#[arg(long)]
pub dry_run: bool,
#[arg(long)]
pub audit_backend: Option<String>,
#[arg(long, value_name = "URL")]
pub audit_url: Option<String>,
#[arg(long, value_name = "KEY")]
pub audit_key: Option<String>,
}
#[cfg(feature = "cdc")]
#[derive(Args)]
pub struct CdcStreamArgs {
#[arg(long, value_name = "PATH")]
pub table_config: Option<PathBuf>,
#[arg(long, value_name = "PATH", default_value = "cdc/state.json")]
pub state_file: PathBuf,
#[arg(long)]
pub backend: Option<String>,
#[arg(long, value_name = "URL")]
pub url: String,
#[arg(long, value_name = "KEY")]
pub key: Option<String>,
#[arg(long)]
pub dry_run: bool,
#[arg(long)]
pub audit_backend: Option<String>,
#[arg(long, value_name = "URL")]
pub audit_url: Option<String>,
#[arg(long, value_name = "KEY")]
pub audit_key: Option<String>,
#[arg(long, default_value = "public.sequin_events")]
pub sequin_table: String,
#[arg(long, default_value = "512")]
pub batch_size: usize,
#[arg(long, default_value = "5")]
pub poll_interval_secs: u64,
}
#[derive(Debug, Deserialize, Serialize, Default)]
struct ClientStore {
default: Option<String>,
clients: HashMap<String, String>,
}
impl ClientStore {
fn load(path: &Path) -> Result<Self> {
if path.exists() {
let bytes: String = fs::read_to_string(path).context("reading clients file")?;
let store: ClientStore =
serde_json::from_str(&bytes).context("parsing clients file as JSON")?;
Ok(store)
} else {
Ok(Self::default())
}
}
fn persist(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).context("creating clients directory")?;
}
let mut file: File = File::create(path).context("creating clients file")?;
serde_json::to_writer_pretty(&mut file, &self).context("serializing clients file")?;
file.write_all(b"\n")
.context("writing newline to clients file")?;
Ok(())
}
}
pub async fn run_pipeline_command(bootstrap: &Bootstrap, args: PipelineArgs) -> Result<()> {
let mut body: Value = match parse_json_input(args.payload_json, args.payload_file)? {
Some(value) => value,
None => json!({}),
};
if !body.is_object() {
bail!("pipeline payload must be a JSON object");
}
if let Some(name) = args.pipeline.as_ref() {
body["pipeline"] = json!(name);
}
if body.as_object().is_none_or(|map| map.is_empty()) && args.pipeline.is_none() {
bail!("provide either --payload-json/--payload-file or --pipeline");
}
let service = init_service(
App::new()
.app_data(bootstrap.app_state.clone())
.service(run_pipeline),
)
.await;
let request = TestRequest::post()
.uri("/pipelines")
.insert_header(("X-Athena-Client", args.client.as_str()))
.set_json(&body)
.to_request();
let response: ServiceResponse = call_service(&service, request).await;
let status: ActixStatusCode = response.status();
let body_bytes: Bytes = to_bytes(response.into_body())
.await
.map_err(|err| anyhow!("reading pipeline response body: {}", err))?;
let body_text = String::from_utf8_lossy(&body_bytes);
if status.is_success() {
if let Ok(parsed) = serde_json::from_slice::<Value>(&body_bytes) {
println!("{}", serde_json::to_string_pretty(&parsed)?);
} else {
println!("{}", body_text);
}
} else {
eprintln!("pipeline request failed ({})", status);
eprintln!("{}", body_text);
}
Ok(())
}
pub fn run_clients_command(command: ClientCommand) -> Result<()> {
match command {
ClientCommand::List { path } => {
let store: ClientStore = ClientStore::load(&path)?;
println!("{}", "Clients".bold().underline());
println!(
"{} {}",
"Clients file:".bright_cyan(),
path.display().to_string().bright_white()
);
match store.default.as_deref() {
Some(default) => println!(
"{} {}",
"Default client:".bright_cyan(),
default.bright_green().bold()
),
None => println!("{} {}", "Default client:".bright_cyan(), "none".dimmed()),
}
let mut names: Vec<&String> = store.clients.keys().collect();
names.sort();
for name in names {
if let Some(uri) = store.clients.get(name) {
let name_display = if store.default.as_deref() == Some(name.as_str()) {
format!("{} (default)", name)
} else {
name.to_string()
};
println!(
" {} {} {}",
name_display.yellow().bold(),
"->".dimmed(),
uri.cyan()
);
}
}
Ok(())
}
ClientCommand::Add { path, name, uri } => {
let mut store: ClientStore = ClientStore::load(&path)?;
store.clients.insert(name.clone(), uri.clone());
if store.default.is_none() {
store.default = Some(name.clone());
}
store.persist(&path)?;
println!(
"{} '{}' {} {} in {}",
"Saved client".bright_green().bold(),
name.bright_yellow(),
"->".dimmed(),
uri.cyan(),
path.display().to_string().bright_cyan()
);
Ok(())
}
ClientCommand::Remove { path, name } => {
let mut store: ClientStore = ClientStore::load(&path)?;
if store.clients.remove(&name).is_none() {
bail!("client '{}' not found", name);
}
if store.default.as_deref() == Some(name.as_str()) {
store.default = None;
}
store.persist(&path)?;
println!(
"{} {} {} {}",
"Removed client".bright_red().bold(),
format!("'{}'", name).bright_yellow(),
"from".white(),
path.display().to_string().bright_cyan()
);
Ok(())
}
ClientCommand::SetDefault { path, name } => {
let mut store: ClientStore = ClientStore::load(&path)?;
if !store.clients.contains_key(&name) {
bail!("client '{}' is not defined", name);
}
store.default = Some(name.clone());
store.persist(&path)?;
println!(
"{} {} {} {} {}",
"Default client set to".bright_green().bold(),
format!("'{}'", name).bright_yellow(),
"in".white(),
path.display().to_string().bright_cyan(),
"✅".bright_green()
);
Ok(())
}
}
}
pub async fn run_fetch_command(args: FetchArgs) -> Result<()> {
let body: Value =
parse_json_input(args.body_json.clone(), args.body_file.clone())?.context("empty body")?;
let store: ClientStore = ClientStore::load(&args.clients_path)?;
let client_name: String = if let Some(name) = args.client {
if !store.clients.contains_key(&name) {
bail!(
"client '{}' not found in {}",
name,
args.clients_path.display()
);
}
name
} else {
store
.default
.clone()
.ok_or_else(|| anyhow!("no client provided and no default set"))?
};
let http: HttpClient = HttpClient::new();
let resp: Response = http
.post(&args.url)
.header("Content-Type", "application/json")
.header("X-Athena-Client", client_name.clone())
.json(&body)
.send()
.await
.context("failed to execute fetch request")?;
let status: StatusCode = resp.status();
let text: String = resp.text().await.context("reading fetch response")?;
if status.is_success() {
if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
println!("{}", serde_json::to_string_pretty(&parsed)?);
} else {
println!("{}", text);
}
} else {
eprintln!("fetch failed ({})", status);
eprintln!("{}", text);
}
Ok(())
}
pub async fn run_pools_command(args: PoolsArgs) -> Result<()> {
let admin_key = args
.admin_key
.clone()
.or_else(|| std::env::var("ATHENA_ADMIN_KEY").ok())
.ok_or_else(|| anyhow!("missing admin key: pass --admin-key or set ATHENA_ADMIN_KEY"))?;
let base_url = args.url.trim_end_matches('/');
let endpoint = format!("{base_url}/admin/pools");
let http = HttpClient::new();
let resp = http
.get(&endpoint)
.header("X-Athena-Admin-Key", admin_key)
.send()
.await
.context("failed to call /admin/pools")?;
let status = resp.status();
let text = resp.text().await.context("reading /admin/pools response")?;
if status.is_success() {
if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
println!("{}", serde_json::to_string_pretty(&parsed)?);
} else {
println!("{}", text);
}
Ok(())
} else {
bail!("/admin/pools failed ({}): {}", status, text);
}
}
#[cfg(feature = "cdc")]
pub async fn run_cdc_command(config: &Config, command: CdcCommand) -> Result<()> {
match command {
CdcCommand::Backfill(args) => run_cdc_backfill(config, args).await,
CdcCommand::Stream(args) => run_cdc_stream(config, args).await,
}
}
#[cfg(feature = "cdc")]
async fn run_cdc_backfill(config: &Config, args: CdcBackfillArgs) -> Result<()> {
let configs: HashMap<String, crate::cdc::postgres::CdcTableConfig> =
load_table_configs(args.table_config.as_deref())?;
let client: AthenaClient =
build_cdc_client(args.backend.as_deref(), &args.url, args.key.as_deref()).await?;
let audit_logger = build_audit_logger(
config,
args.audit_backend.as_deref(),
args.audit_url.as_deref(),
args.audit_key.as_deref(),
)
.await?;
let state = backfill_from_csv(
&client,
&args.csv_path,
&configs,
args.state_file.as_path(),
args.dry_run,
audit_logger.as_ref(),
)
.await?;
println!(
"Backfill finished through seq {}",
state.last_seq.unwrap_or(0)
);
Ok(())
}
#[cfg(feature = "cdc")]
async fn run_cdc_stream(config: &Config, args: CdcStreamArgs) -> Result<()> {
let configs: HashMap<String, crate::cdc::postgres::CdcTableConfig> =
load_table_configs(args.table_config.as_deref())?;
let client: AthenaClient =
build_cdc_client(args.backend.as_deref(), &args.url, args.key.as_deref()).await?;
let audit_logger = build_audit_logger(
config,
args.audit_backend.as_deref(),
args.audit_url.as_deref(),
args.audit_key.as_deref(),
)
.await?;
let interval = Duration::from_secs(args.poll_interval_secs.max(1));
let limit = args.batch_size.max(1);
println!(
"Starting CDC stream against {} (batch {}, interval {}s)...",
args.sequin_table,
limit,
interval.as_secs()
);
stream_events(
&client,
&configs,
&args.sequin_table,
limit,
interval,
args.state_file.as_path(),
args.dry_run,
audit_logger.as_ref(),
)
.await
}
#[cfg(feature = "cdc")]
async fn build_cdc_client(
backend: Option<&str>,
url: &str,
key: Option<&str>,
) -> Result<AthenaClient> {
let backend = parse_backend_type(backend);
build_client(backend, url, key).await
}
#[cfg(feature = "cdc")]
async fn build_client(backend: BackendType, url: &str, key: Option<&str>) -> Result<AthenaClient> {
let mut builder = AthenaClient::builder().backend(backend).url(url);
if let Some(key) = key {
builder = builder.key(key);
}
AthenaClient::build(builder)
.await
.map_err(|err| anyhow!("failed to build Athena client: {}", err))
}
#[cfg(feature = "cdc")]
async fn build_audit_logger(
config: &Config,
backend: Option<&str>,
url: Option<&str>,
key: Option<&str>,
) -> Result<Option<AuditLogger>> {
let fallback_url = config.get_postgres_uri("athena_logging").cloned();
let target_url = url.map(|value| value.to_string()).or(fallback_url);
if let Some(url) = target_url {
let backend = backend
.map(|value| parse_backend_type(Some(value)))
.unwrap_or(BackendType::PostgreSQL);
let client = build_client(backend, &url, key).await?;
Ok(Some(AuditLogger::new(client, "cdc", "cdc")))
} else {
Ok(None)
}
}
#[cfg(feature = "cdc")]
fn parse_backend_type(name: Option<&str>) -> BackendType {
match name.unwrap_or("native").to_lowercase().as_str() {
"supabase" => BackendType::Supabase,
"postgrest" => BackendType::Postgrest,
"scylla" => BackendType::Scylla,
"neon" => BackendType::Neon,
"postgresql" | "postgres" => BackendType::PostgreSQL,
_ => BackendType::Native,
}
}
pub async fn run_provision_command(args: ProvisionArgs) -> Result<()> {
println!(
"{} {}",
"Provisioning".bright_cyan().bold(),
args.uri
.split('@')
.last()
.unwrap_or(&args.uri)
.bright_white()
);
let total = run_provision_sql(&args.uri)
.await
.map_err(|err| match err {
ProvisioningError::InvalidInput(message)
| ProvisioningError::Conflict(message)
| ProvisioningError::Unavailable(message)
| ProvisioningError::Execution(message) => anyhow!(message),
})?;
println!(
"{} Applied {} statements.",
"✔".bright_green().bold(),
total
);
if args.register {
let client_name = args
.client_name
.clone()
.unwrap_or_else(|| "provisioned".to_string());
let mut store: ClientStore = if args.clients_path.exists() {
ClientStore::load(&args.clients_path)?
} else {
ClientStore {
default: Some(client_name.clone()),
clients: HashMap::new(),
}
};
store.clients.insert(client_name.clone(), args.uri.clone());
if store.default.is_none() {
store.default = Some(client_name.clone());
}
store.persist(&args.clients_path)?;
println!(
"{} Registered client '{}' in {}",
"✔".bright_green().bold(),
client_name.bright_white(),
args.clients_path.display()
);
} else if let Some(name) = args.client_name.as_ref() {
println!(
"{} To register this client run: athena_rs provision --uri <URI> --client-name {} --register",
"ℹ".bright_cyan(),
name
);
}
Ok(())
}
pub fn run_diag_command() -> Result<()> {
let home_dir = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.unwrap_or_default();
let repo_expected_path = if home_dir.is_empty() {
PathBuf::from("github/xylex-group/athena")
} else {
PathBuf::from(home_dir)
.join("github")
.join("xylex-group")
.join("athena")
};
let mut folder_checks = vec![
path_check("/var/log/athena", Path::new("/var/log/athena")),
path_check("/etc/athena", Path::new("/etc/athena")),
path_check("$HOME/github/xylex-group/athena", &repo_expected_path),
];
folder_checks.extend((14..=18).map(|major| {
path_check(
&format!("/usr/lib/postgresql/{major}/bin"),
Path::new(&format!("/usr/lib/postgresql/{major}/bin")),
)
}));
let mem = memory_summary();
let cpu = cpu_summary();
let required_tools = required_tool_checks();
let pg_tools = postgres_tool_matrix();
let ports = listening_ports_summary();
let service_status = systemd_service_status("athena.service");
let info: Value = json!({
"meta": {
"athena_version": env!("CARGO_PKG_VERSION"),
"os": std::env::consts::OS,
"arch": std::env::consts::ARCH,
"family": std::env::consts::FAMILY,
"hostname": resolve_hostname(),
"locale": std::env::var("LANG").unwrap_or_default(),
"cpu_logical_count": num_cpus::get(),
},
"resources": {
"memory": mem,
"cpu": cpu,
},
"paths": folder_checks,
"ports": ports,
"service": service_status,
"tools": required_tools,
"postgres": pg_tools,
"env": {
"ATHENA_PG_TOOLS_DIR": std::env::var("ATHENA_PG_TOOLS_DIR").ok(),
"ATHENA_PG_DUMP_PATH": std::env::var("ATHENA_PG_DUMP_PATH").ok(),
"ATHENA_PG_RESTORE_PATH": std::env::var("ATHENA_PG_RESTORE_PATH").ok(),
}
});
println!("{}", serde_json::to_string_pretty(&info)?);
Ok(())
}
pub fn run_install_command(args: InstallArgs) -> Result<()> {
if std::env::consts::FAMILY != "unix" {
bail!("install command currently supports Linux/Unix hosts only");
}
if !args.dry_run && !is_running_as_root() {
bail!("install requires root privileges (run with sudo or use --dry-run)");
}
let mut actions: Vec<Value> = Vec::new();
let majors: Vec<u16> = if args.pg_majors.is_empty() {
vec![14, 15, 16, 17, 18]
} else {
args.pg_majors.clone()
};
if args.disable_pm2_packagecloud {
let mut disabled = Vec::new();
let sources_dir = Path::new("/etc/apt/sources.list.d");
if sources_dir.exists() {
for entry in fs::read_dir(sources_dir).context("reading apt sources list directory")? {
let entry = entry?;
let p = entry.path();
let name = p
.file_name()
.and_then(|n| n.to_str())
.unwrap_or_default()
.to_lowercase();
if name.contains("pm2") {
let disabled_path = PathBuf::from(format!("{}.disabled", p.to_string_lossy()));
if args.dry_run {
disabled.push(format!(
"would move {} -> {}",
p.display(),
disabled_path.display()
));
} else {
fs::rename(&p, &disabled_path)
.with_context(|| format!("disabling pm2 source {}", p.display()))?;
disabled.push(format!(
"moved {} -> {}",
p.display(),
disabled_path.display()
));
}
}
}
}
actions.push(json!({"step": "disable_pm2_packagecloud", "details": disabled}));
}
run_command(
args.dry_run,
"apt-get",
&["update"],
&mut actions,
"apt_update_initial",
)?;
run_command(
args.dry_run,
"apt-get",
&[
"install",
"-y",
"--no-install-recommends",
"ca-certificates",
"curl",
"gnupg",
"lsb-release",
],
&mut actions,
"apt_install_base",
)?;
ensure_dir(
args.dry_run,
Path::new("/etc/apt/keyrings"),
&mut actions,
"create_keyrings",
)?;
run_command(
args.dry_run,
"curl",
&[
"-fsSL",
"https://www.postgresql.org/media/keys/ACCC4CF8.asc",
"-o",
"/tmp/athena-pgdg-key.asc",
],
&mut actions,
"download_pgdg_key",
)?;
run_command(
args.dry_run,
"gpg",
&[
"--dearmor",
"-o",
"/etc/apt/keyrings/postgresql.gpg",
"/tmp/athena-pgdg-key.asc",
],
&mut actions,
"install_pgdg_key",
)?;
let codename = os_release_codename().unwrap_or_else(|| "bookworm".to_string());
let pgdg_line = format!(
"deb [signed-by=/etc/apt/keyrings/postgresql.gpg] http://apt.postgresql.org/pub/repos/apt {}-pgdg main\n",
codename
);
write_file(
args.dry_run,
Path::new("/etc/apt/sources.list.d/pgdg.list"),
&pgdg_line,
&mut actions,
"write_pgdg_repo",
)?;
run_command(
args.dry_run,
"apt-get",
&["update"],
&mut actions,
"apt_update_after_pgdg",
)?;
let mut package_args: Vec<String> = vec![
"install".to_string(),
"-y".to_string(),
"--no-install-recommends".to_string(),
"postgresql-common".to_string(),
"s3cmd".to_string(),
"nginx".to_string(),
"docker.io".to_string(),
"curl".to_string(),
"lsb-release".to_string(),
];
for major in &majors {
package_args.push(format!("postgresql-client-{major}"));
}
let package_arg_refs: Vec<&str> = package_args.iter().map(String::as_str).collect();
run_command(
args.dry_run,
"apt-get",
&package_arg_refs,
&mut actions,
"apt_install_required_cli_and_pg_clients",
)?;
ensure_dir(
args.dry_run,
&args.pg_tools_dir,
&mut actions,
"create_pg_tools_root",
)?;
for major in &majors {
let base = PathBuf::from(format!("/usr/lib/postgresql/{major}/bin"));
let dump = base.join("pg_dump");
let restore = base.join("pg_restore");
if !args.dry_run {
if !dump.exists() || !restore.exists() {
bail!(
"missing pg_dump/pg_restore for major {} under {}",
major,
base.display()
);
}
}
let target_bin = args.pg_tools_dir.join(major.to_string()).join("bin");
ensure_dir(
args.dry_run,
&target_bin,
&mut actions,
&format!("create_pg_tools_bin_dir_{major}"),
)?;
symlink_force(
args.dry_run,
&dump,
&target_bin.join("pg_dump"),
&mut actions,
&format!("link_pg_dump_{major}"),
)?;
symlink_force(
args.dry_run,
&restore,
&target_bin.join("pg_restore"),
&mut actions,
&format!("link_pg_restore_{major}"),
)?;
}
let dropin_dir = Path::new("/etc/systemd/system").join(format!("{}.d", args.service));
ensure_dir(
args.dry_run,
&dropin_dir,
&mut actions,
"create_systemd_dropin_dir",
)?;
let dropin_path = dropin_dir.join("pg-tools.conf");
let dropin_contents = format!(
"[Service]\nEnvironment=\"ATHENA_PG_TOOLS_DIR={}\"\nEnvironment=\"ATHENA_PG_TOOLS_ALLOW_DOWNLOAD=0\"\n",
args.pg_tools_dir.display()
);
write_file(
args.dry_run,
&dropin_path,
&dropin_contents,
&mut actions,
"write_systemd_dropin",
)?;
run_command(
args.dry_run,
"systemctl",
&["daemon-reload"],
&mut actions,
"systemd_daemon_reload",
)?;
if args.restart {
run_command(
args.dry_run,
"systemctl",
&["restart", &args.service],
&mut actions,
"systemd_restart_service",
)?;
}
println!(
"{}",
serde_json::to_string_pretty(&json!({
"status": "ok",
"dry_run": args.dry_run,
"service": args.service,
"pg_majors": majors,
"pg_tools_dir": args.pg_tools_dir,
"actions": actions,
}))?
);
Ok(())
}
fn is_running_as_root() -> bool {
match ProcessCommand::new("id").arg("-u").output() {
Ok(out) if out.status.success() => String::from_utf8_lossy(&out.stdout).trim() == "0",
_ => false,
}
}
fn run_command(
dry_run: bool,
program: &str,
args: &[&str],
actions: &mut Vec<Value>,
step: &str,
) -> Result<()> {
if dry_run {
actions.push(json!({"step": step, "command": format!("{} {}", program, args.join(" ")), "result": "skipped (dry-run)"}));
return Ok(());
}
let output = ProcessCommand::new(program)
.args(args)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.with_context(|| format!("running command: {} {}", program, args.join(" ")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
bail!(
"step '{}' failed: {} {}: {}",
step,
program,
args.join(" "),
stderr
);
}
actions.push(
json!({"step": step, "command": format!("{} {}", program, args.join(" ")), "result": "ok"}),
);
Ok(())
}
fn ensure_dir(dry_run: bool, path: &Path, actions: &mut Vec<Value>, step: &str) -> Result<()> {
if dry_run {
actions.push(json!({"step": step, "path": path, "result": "skipped (dry-run)"}));
return Ok(());
}
fs::create_dir_all(path).with_context(|| format!("creating directory {}", path.display()))?;
actions.push(json!({"step": step, "path": path, "result": "ok"}));
Ok(())
}
fn write_file(
dry_run: bool,
path: &Path,
contents: &str,
actions: &mut Vec<Value>,
step: &str,
) -> Result<()> {
if dry_run {
actions.push(json!({"step": step, "path": path, "result": "skipped (dry-run)"}));
return Ok(());
}
fs::write(path, contents).with_context(|| format!("writing {}", path.display()))?;
actions.push(json!({"step": step, "path": path, "result": "ok"}));
Ok(())
}
#[cfg(unix)]
fn symlink_force(
dry_run: bool,
source: &Path,
target: &Path,
actions: &mut Vec<Value>,
step: &str,
) -> Result<()> {
if dry_run {
actions.push(json!({"step": step, "source": source, "target": target, "result": "skipped (dry-run)"}));
return Ok(());
}
match fs::remove_file(target) {
Ok(_) => {}
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {}
Err(err) => return Err(err).with_context(|| format!("removing {}", target.display())),
}
std::os::unix::fs::symlink(source, target)
.with_context(|| format!("symlinking {} -> {}", target.display(), source.display()))?;
actions.push(json!({"step": step, "source": source, "target": target, "result": "ok"}));
Ok(())
}
#[cfg(not(unix))]
fn symlink_force(
_dry_run: bool,
_source: &Path,
_target: &Path,
_actions: &mut Vec<Value>,
_step: &str,
) -> Result<()> {
bail!("symlink setup is only supported on unix hosts")
}
fn os_release_codename() -> Option<String> {
let content = fs::read_to_string("/etc/os-release").ok()?;
for line in content.lines() {
if let Some(value) = line.strip_prefix("VERSION_CODENAME=") {
return Some(value.trim_matches('"').to_string());
}
}
None
}
fn resolve_hostname() -> String {
std::env::var("HOSTNAME")
.or_else(|_| {
ProcessCommand::new("hostname")
.output()
.ok()
.filter(|o| o.status.success())
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
.ok_or_else(|| std::env::VarError::NotPresent)
})
.unwrap_or_default()
}
fn path_check(label: &str, path: &Path) -> Value {
let exists = path.exists();
let meta = fs::metadata(path).ok();
let is_dir = meta.as_ref().is_some_and(|m| m.is_dir());
let is_file = meta.as_ref().is_some_and(|m| m.is_file());
let readable = if !exists {
false
} else if is_dir {
fs::read_dir(path).is_ok()
} else {
File::open(path).is_ok()
};
let writable = if !exists {
false
} else if is_dir {
can_write_dir(path)
} else {
OpenOptions::new().append(true).open(path).is_ok()
};
#[cfg(unix)]
let perms = meta
.as_ref()
.map(|m| format!("{:o}", m.permissions().mode() & 0o7777));
#[cfg(not(unix))]
let perms: Option<String> = None;
#[cfg(unix)]
let uid = meta.as_ref().map(MetadataExt::uid);
#[cfg(unix)]
let gid = meta.as_ref().map(MetadataExt::gid);
#[cfg(not(unix))]
let uid: Option<u32> = None;
#[cfg(not(unix))]
let gid: Option<u32> = None;
json!({
"label": label,
"path": path,
"exists": exists,
"is_dir": is_dir,
"is_file": is_file,
"readable": readable,
"writable": writable,
"permissions_octal": perms,
"uid": uid,
"gid": gid,
})
}
fn can_write_dir(path: &Path) -> bool {
#[cfg(unix)]
{
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or_default();
let probe = path.join(format!(".athena_diag_write_probe_{}", ts));
match OpenOptions::new().create_new(true).write(true).open(&probe) {
Ok(_) => {
let _ = fs::remove_file(&probe);
true
}
Err(_) => false,
}
}
#[cfg(not(unix))]
{
let _ = path;
false
}
}
fn memory_summary() -> Value {
if std::env::consts::OS == "linux" {
if let Ok(meminfo) = fs::read_to_string("/proc/meminfo") {
let mut total_kb: u64 = 0;
let mut avail_kb: u64 = 0;
for line in meminfo.lines() {
if let Some(value) = line.strip_prefix("MemTotal:") {
total_kb = value
.split_whitespace()
.next()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
} else if let Some(value) = line.strip_prefix("MemAvailable:") {
avail_kb = value
.split_whitespace()
.next()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
}
}
let used_kb = total_kb.saturating_sub(avail_kb);
return json!({
"total_bytes": total_kb * 1024,
"available_bytes": avail_kb * 1024,
"used_bytes": used_kb * 1024,
});
}
}
json!({"error": "memory metrics unavailable"})
}
fn cpu_summary() -> Value {
if std::env::consts::OS == "linux" {
if let Some((usage, window_ms)) = linux_cpu_usage_percent() {
return json!({
"usage_percent": usage,
"sample_window_ms": window_ms,
});
}
}
json!({"error": "cpu metrics unavailable"})
}
fn linux_cpu_usage_percent() -> Option<(f64, u64)> {
let first = read_proc_stat_totals()?;
let window_ms = 250_u64;
std::thread::sleep(std::time::Duration::from_millis(window_ms));
let second = read_proc_stat_totals()?;
let total_delta = second.0.saturating_sub(first.0);
let idle_delta = second.1.saturating_sub(first.1);
if total_delta == 0 {
return None;
}
let usage = ((total_delta - idle_delta) as f64 / total_delta as f64) * 100.0;
Some((usage, window_ms))
}
fn read_proc_stat_totals() -> Option<(u64, u64)> {
let content = fs::read_to_string("/proc/stat").ok()?;
let cpu_line = content.lines().find(|l| l.starts_with("cpu "))?;
let fields: Vec<u64> = cpu_line
.split_whitespace()
.skip(1)
.filter_map(|v| v.parse::<u64>().ok())
.collect();
if fields.len() < 5 {
return None;
}
let idle = fields[3].saturating_add(*fields.get(4).unwrap_or(&0));
let total = fields.iter().copied().sum();
Some((total, idle))
}
fn required_tool_checks() -> Value {
let tools = [
"s3cmd",
"nginx",
"curl",
"lsb_release",
"pg_dump",
"pg_restore",
"docker",
"systemctl",
"apt-get",
"gpg",
"ss",
"netstat",
];
let checks: Vec<Value> = tools.iter().map(|tool| tool_check(tool)).collect();
let docker_compose = match ProcessCommand::new("docker")
.args(["compose", "version"])
.output()
{
Ok(out) => json!({
"name": "docker compose",
"installed": out.status.success(),
"version": String::from_utf8_lossy(&out.stdout).trim(),
"error": String::from_utf8_lossy(&out.stderr).trim(),
}),
Err(err) => json!({
"name": "docker compose",
"installed": false,
"error": err.to_string(),
}),
};
json!({
"commands": checks,
"docker_compose": docker_compose,
})
}
fn tool_check(name: &str) -> Value {
match which::which(name) {
Ok(path) => {
let version = command_first_line(name, &["--version"])
.or_else(|| command_first_line(name, &["-V"]));
json!({
"name": name,
"installed": true,
"path": path,
"version": version,
})
}
Err(_) => json!({
"name": name,
"installed": false,
}),
}
}
fn command_first_line(program: &str, args: &[&str]) -> Option<String> {
let output = ProcessCommand::new(program).args(args).output().ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let first = stdout
.lines()
.find(|l| !l.trim().is_empty())
.or_else(|| stderr.lines().find(|l| !l.trim().is_empty()))?
.trim()
.to_string();
Some(first)
}
fn postgres_tool_matrix() -> Value {
let mut matrix = Vec::new();
for major in 14..=18 {
let base = PathBuf::from(format!("/usr/lib/postgresql/{major}/bin"));
let dump = base.join("pg_dump");
let restore = base.join("pg_restore");
matrix.push(json!({
"major": major,
"bin_dir": base,
"bin_dir_exists": base.is_dir(),
"pg_dump_exists": dump.exists(),
"pg_restore_exists": restore.exists(),
"pg_dump_version": if dump.exists() { command_first_line(dump.to_string_lossy().as_ref(), &["--version"]) } else { None },
"pg_restore_version": if restore.exists() { command_first_line(restore.to_string_lossy().as_ref(), &["--version"]) } else { None },
}));
}
json!({
"pg_dump_on_path": command_first_line("pg_dump", &["--version"]),
"pg_restore_on_path": command_first_line("pg_restore", &["--version"]),
"major_matrix": matrix,
})
}
fn listening_ports_summary() -> Value {
let candidates: [(&str, &[&str]); 2] = [("ss", &["-lntupH"]), ("netstat", &["-lntup"])];
for (program, args) in candidates {
if let Ok(output) = ProcessCommand::new(program).args(args).output()
&& output.status.success()
{
let text = String::from_utf8_lossy(&output.stdout);
let ports = extract_port_numbers(&text);
return json!({
"source": format!("{} {}", program, args.join(" ")),
"listening_ports": ports,
"count": ports.len(),
});
}
}
json!({"error": "unable to query listening ports (ss/netstat unavailable)"})
}
fn extract_port_numbers(text: &str) -> Vec<u16> {
let regex = regex::Regex::new(r":(\d{1,5})(?:\s|$)").ok();
let mut ports = BTreeSet::new();
if let Some(re) = regex {
for captures in re.captures_iter(text) {
if let Some(raw) = captures.get(1)
&& let Ok(port) = raw.as_str().parse::<u16>()
{
ports.insert(port);
}
}
}
ports.into_iter().collect()
}
fn systemd_service_status(service: &str) -> Value {
let is_active = command_first_line("systemctl", &["is-active", service]);
let is_enabled = command_first_line("systemctl", &["is-enabled", service]);
let status_line =
command_first_line("systemctl", &["status", service, "--no-pager", "--lines=0"]);
json!({
"service": service,
"is_active": is_active,
"is_enabled": is_enabled,
"status_summary": status_line,
})
}
pub fn run_version_command() {
println!("{}", env!("CARGO_PKG_VERSION"));
}
fn parse_json_input(json: Option<String>, file: Option<PathBuf>) -> Result<Option<Value>> {
if let Some(text) = json {
let value: Value = serde_json::from_str(&text).context("parsing inline JSON payload")?;
return Ok(Some(value));
}
if let Some(path) = file {
let text: String = fs::read_to_string(&path).context("reading payload file")?;
let value: Value = serde_json::from_str(&text).context("parsing payload file")?;
return Ok(Some(value));
}
Ok(None)
}