use actix_web::dev::ServiceResponse;
use actix_web::http::StatusCode as ActixStatusCode;
use actix_web::web::Bytes;
use actix_web::{
App,
body::to_bytes,
test::{TestRequest, call_service, init_service},
};
use anyhow::{Context, Result, anyhow, bail};
use clap::{Args, Parser, Subcommand};
use colored::Colorize;
use reqwest::{Client as HttpClient, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::fs::{self, File};
use std::io::Write;
use std::path::{Path, PathBuf};
#[cfg(feature = "cdc")]
use std::time::Duration;
use crate::api::pipelines::run_pipeline;
use crate::bootstrap::Bootstrap;
#[cfg(feature = "cdc")]
use crate::cdc::postgres::sequin::{
AuditLogger, backfill_from_csv, load_table_configs, stream_events,
};
#[cfg(feature = "cdc")]
use crate::client::{AthenaClient, backend::BackendType};
#[cfg(feature = "cdc")]
use crate::config::Config;
use crate::provisioning::{ProvisioningError, run_provision_sql};
const DEFAULT_CLIENTS_PATH: &str = "config/clients.json";
const DEFAULT_FETCH_URL: &str = "https://athena-db.com/gateway/fetch";
#[derive(Parser)]
#[command(name = "athena_rs")]
#[command(author, version, about = "Athena API server + CLI helpers")]
pub struct AthenaCli {
#[arg(
long = "config",
alias = "config-path",
global = true,
value_name = "PATH",
help = "Override the default configuration file path (default: config.yaml)"
)]
pub config_path: Option<PathBuf>,
#[arg(
long,
global = true,
value_name = "PATH",
default_value = "config/pipelines.yaml"
)]
pub pipelines_path: PathBuf,
#[arg(long, global = true, value_name = "PORT")]
pub port: Option<u16>,
#[arg(long, global = true)]
pub api_only: bool,
#[cfg(feature = "cdc")]
#[arg(long, global = true)]
pub cdc_only: bool,
#[command(subcommand)]
pub command: Option<Command>,
}
#[derive(Subcommand)]
pub enum Command {
Server,
Pipeline(PipelineArgs),
Clients {
#[command(subcommand)]
command: ClientCommand,
},
Fetch(FetchArgs),
#[cfg(feature = "cdc")]
Cdc {
#[command(subcommand)]
command: CdcCommand,
},
Provision(ProvisionArgs),
Diag,
Version,
}
#[derive(Args)]
pub struct ProvisionArgs {
#[arg(long, value_name = "URI")]
pub uri: String,
#[arg(long, value_name = "NAME")]
pub client_name: Option<String>,
#[arg(long, default_value_t = false)]
pub register: bool,
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
pub clients_path: PathBuf,
}
#[derive(Args)]
pub struct PipelineArgs {
#[arg(short, long, value_name = "CLIENT")]
pub client: String,
#[arg(long, conflicts_with = "payload_file")]
pub payload_json: Option<String>,
#[arg(long, value_name = "PATH", conflicts_with = "payload_json")]
pub payload_file: Option<PathBuf>,
#[arg(long, value_name = "NAME")]
pub pipeline: Option<String>,
}
#[derive(Subcommand)]
pub enum ClientCommand {
List {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
path: PathBuf,
},
Add {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
path: PathBuf,
name: String,
uri: String,
},
Remove {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
path: PathBuf,
name: String,
},
SetDefault {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
path: PathBuf,
name: String,
},
}
#[derive(Args)]
pub struct FetchArgs {
#[arg(long, value_name = "PATH", default_value = DEFAULT_CLIENTS_PATH)]
pub clients_path: PathBuf,
#[arg(long)]
pub client: Option<String>,
#[arg(long, value_name = "URL", default_value = DEFAULT_FETCH_URL)]
pub url: String,
#[arg(long, conflicts_with = "body_file")]
pub body_json: Option<String>,
#[arg(long, value_name = "PATH", conflicts_with = "body_json")]
pub body_file: Option<PathBuf>,
}
#[cfg(feature = "cdc")]
#[derive(Subcommand)]
pub enum CdcCommand {
Backfill(CdcBackfillArgs),
Stream(CdcStreamArgs),
}
#[cfg(feature = "cdc")]
#[derive(Args)]
pub struct CdcBackfillArgs {
#[arg(long, value_name = "PATH", default_value = "src/cdc/sequin_events.csv")]
pub csv_path: PathBuf,
#[arg(long, value_name = "PATH")]
pub table_config: Option<PathBuf>,
#[arg(long, value_name = "PATH", default_value = "cdc/state.json")]
pub state_file: PathBuf,
#[arg(long)]
pub backend: Option<String>,
#[arg(long, value_name = "URL")]
pub url: String,
#[arg(long, value_name = "KEY")]
pub key: Option<String>,
#[arg(long)]
pub dry_run: bool,
#[arg(long)]
pub audit_backend: Option<String>,
#[arg(long, value_name = "URL")]
pub audit_url: Option<String>,
#[arg(long, value_name = "KEY")]
pub audit_key: Option<String>,
}
#[cfg(feature = "cdc")]
#[derive(Args)]
pub struct CdcStreamArgs {
#[arg(long, value_name = "PATH")]
pub table_config: Option<PathBuf>,
#[arg(long, value_name = "PATH", default_value = "cdc/state.json")]
pub state_file: PathBuf,
#[arg(long)]
pub backend: Option<String>,
#[arg(long, value_name = "URL")]
pub url: String,
#[arg(long, value_name = "KEY")]
pub key: Option<String>,
#[arg(long)]
pub dry_run: bool,
#[arg(long)]
pub audit_backend: Option<String>,
#[arg(long, value_name = "URL")]
pub audit_url: Option<String>,
#[arg(long, value_name = "KEY")]
pub audit_key: Option<String>,
#[arg(long, default_value = "public.sequin_events")]
pub sequin_table: String,
#[arg(long, default_value = "512")]
pub batch_size: usize,
#[arg(long, default_value = "5")]
pub poll_interval_secs: u64,
}
#[derive(Debug, Deserialize, Serialize, Default)]
struct ClientStore {
default: Option<String>,
clients: HashMap<String, String>,
}
impl ClientStore {
fn load(path: &Path) -> Result<Self> {
if path.exists() {
let bytes = fs::read_to_string(path).context("reading clients file")?;
let store: ClientStore =
serde_json::from_str(&bytes).context("parsing clients file as JSON")?;
Ok(store)
} else {
Ok(Self::default())
}
}
fn persist(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).context("creating clients directory")?;
}
let mut file: File = File::create(path).context("creating clients file")?;
serde_json::to_writer_pretty(&mut file, &self).context("serializing clients file")?;
file.write_all(b"\n")
.context("writing newline to clients file")?;
Ok(())
}
}
pub async fn run_pipeline_command(bootstrap: &Bootstrap, args: PipelineArgs) -> Result<()> {
let mut body = match parse_json_input(args.payload_json, args.payload_file)? {
Some(value) => value,
None => json!({}),
};
if !body.is_object() {
bail!("pipeline payload must be a JSON object");
}
if let Some(name) = args.pipeline.as_ref() {
body["pipeline"] = json!(name);
}
if body.as_object().is_none_or(|map| map.is_empty()) && args.pipeline.is_none() {
bail!("provide either --payload-json/--payload-file or --pipeline");
}
let service = init_service(
App::new()
.app_data(bootstrap.app_state.clone())
.service(run_pipeline),
)
.await;
let request = TestRequest::post()
.uri("/pipelines")
.insert_header(("X-Athena-Client", args.client.as_str()))
.set_json(&body)
.to_request();
let response: ServiceResponse = call_service(&service, request).await;
let status: ActixStatusCode = response.status();
let body_bytes: Bytes = to_bytes(response.into_body())
.await
.map_err(|err| anyhow!("reading pipeline response body: {}", err))?;
let body_text = String::from_utf8_lossy(&body_bytes);
if status.is_success() {
if let Ok(parsed) = serde_json::from_slice::<Value>(&body_bytes) {
println!("{}", serde_json::to_string_pretty(&parsed)?);
} else {
println!("{}", body_text);
}
} else {
eprintln!("pipeline request failed ({})", status);
eprintln!("{}", body_text);
}
Ok(())
}
pub fn run_clients_command(command: ClientCommand) -> Result<()> {
match command {
ClientCommand::List { path } => {
let store: ClientStore = ClientStore::load(&path)?;
println!("{}", "Clients".bold().underline());
println!(
"{} {}",
"Clients file:".bright_cyan(),
path.display().to_string().bright_white()
);
match store.default.as_deref() {
Some(default) => println!(
"{} {}",
"Default client:".bright_cyan(),
default.bright_green().bold()
),
None => println!("{} {}", "Default client:".bright_cyan(), "none".dimmed()),
}
let mut names: Vec<&String> = store.clients.keys().collect();
names.sort();
for name in names {
if let Some(uri) = store.clients.get(name) {
let name_display = if store.default.as_deref() == Some(name.as_str()) {
format!("{} (default)", name)
} else {
name.to_string()
};
println!(
" {} {} {}",
name_display.yellow().bold(),
"->".dimmed(),
uri.cyan()
);
}
}
Ok(())
}
ClientCommand::Add { path, name, uri } => {
let mut store: ClientStore = ClientStore::load(&path)?;
store.clients.insert(name.clone(), uri.clone());
if store.default.is_none() {
store.default = Some(name.clone());
}
store.persist(&path)?;
println!(
"{} '{}' {} {} in {}",
"Saved client".bright_green().bold(),
name.bright_yellow(),
"->".dimmed(),
uri.cyan(),
path.display().to_string().bright_cyan()
);
Ok(())
}
ClientCommand::Remove { path, name } => {
let mut store: ClientStore = ClientStore::load(&path)?;
if store.clients.remove(&name).is_none() {
bail!("client '{}' not found", name);
}
if store.default.as_deref() == Some(name.as_str()) {
store.default = None;
}
store.persist(&path)?;
println!(
"{} {} {} {}",
"Removed client".bright_red().bold(),
format!("'{}'", name).bright_yellow(),
"from".white(),
path.display().to_string().bright_cyan()
);
Ok(())
}
ClientCommand::SetDefault { path, name } => {
let mut store: ClientStore = ClientStore::load(&path)?;
if !store.clients.contains_key(&name) {
bail!("client '{}' is not defined", name);
}
store.default = Some(name.clone());
store.persist(&path)?;
println!(
"{} {} {} {} {}",
"Default client set to".bright_green().bold(),
format!("'{}'", name).bright_yellow(),
"in".white(),
path.display().to_string().bright_cyan(),
"✅".bright_green()
);
Ok(())
}
}
}
pub async fn run_fetch_command(args: FetchArgs) -> Result<()> {
let body: Value =
parse_json_input(args.body_json.clone(), args.body_file.clone())?.context("empty body")?;
let store: ClientStore = ClientStore::load(&args.clients_path)?;
let client_name: String = if let Some(name) = args.client {
if !store.clients.contains_key(&name) {
bail!(
"client '{}' not found in {}",
name,
args.clients_path.display()
);
}
name
} else {
store
.default
.clone()
.ok_or_else(|| anyhow!("no client provided and no default set"))?
};
let http: HttpClient = HttpClient::new();
let resp: Response = http
.post(&args.url)
.header("Content-Type", "application/json")
.header("X-Athena-Client", client_name.clone())
.json(&body)
.send()
.await
.context("failed to execute fetch request")?;
let status: StatusCode = resp.status();
let text: String = resp.text().await.context("reading fetch response")?;
if status.is_success() {
if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
println!("{}", serde_json::to_string_pretty(&parsed)?);
} else {
println!("{}", text);
}
} else {
eprintln!("fetch failed ({})", status);
eprintln!("{}", text);
}
Ok(())
}
#[cfg(feature = "cdc")]
pub async fn run_cdc_command(config: &Config, command: CdcCommand) -> Result<()> {
match command {
CdcCommand::Backfill(args) => run_cdc_backfill(config, args).await,
CdcCommand::Stream(args) => run_cdc_stream(config, args).await,
}
}
#[cfg(feature = "cdc")]
async fn run_cdc_backfill(config: &Config, args: CdcBackfillArgs) -> Result<()> {
let configs = load_table_configs(args.table_config.as_deref())?;
let client = build_cdc_client(args.backend.as_deref(), &args.url, args.key.as_deref()).await?;
let audit_logger = build_audit_logger(
config,
args.audit_backend.as_deref(),
args.audit_url.as_deref(),
args.audit_key.as_deref(),
)
.await?;
let state = backfill_from_csv(
&client,
&args.csv_path,
&configs,
args.state_file.as_path(),
args.dry_run,
audit_logger.as_ref(),
)
.await?;
println!(
"Backfill finished through seq {}",
state.last_seq.unwrap_or(0)
);
Ok(())
}
#[cfg(feature = "cdc")]
async fn run_cdc_stream(config: &Config, args: CdcStreamArgs) -> Result<()> {
let configs = load_table_configs(args.table_config.as_deref())?;
let client = build_cdc_client(args.backend.as_deref(), &args.url, args.key.as_deref()).await?;
let audit_logger = build_audit_logger(
config,
args.audit_backend.as_deref(),
args.audit_url.as_deref(),
args.audit_key.as_deref(),
)
.await?;
let interval = Duration::from_secs(args.poll_interval_secs.max(1));
let limit = args.batch_size.max(1);
println!(
"Starting CDC stream against {} (batch {}, interval {}s)...",
args.sequin_table,
limit,
interval.as_secs()
);
stream_events(
&client,
&configs,
&args.sequin_table,
limit,
interval,
args.state_file.as_path(),
args.dry_run,
audit_logger.as_ref(),
)
.await
}
#[cfg(feature = "cdc")]
async fn build_cdc_client(
backend: Option<&str>,
url: &str,
key: Option<&str>,
) -> Result<AthenaClient> {
let backend = parse_backend_type(backend);
build_client(backend, url, key).await
}
#[cfg(feature = "cdc")]
async fn build_client(backend: BackendType, url: &str, key: Option<&str>) -> Result<AthenaClient> {
let mut builder = AthenaClient::builder().backend(backend).url(url);
if let Some(key) = key {
builder = builder.key(key);
}
AthenaClient::build(builder)
.await
.map_err(|err| anyhow!("failed to build Athena client: {}", err))
}
#[cfg(feature = "cdc")]
async fn build_audit_logger(
config: &Config,
backend: Option<&str>,
url: Option<&str>,
key: Option<&str>,
) -> Result<Option<AuditLogger>> {
let fallback_url = config.get_postgres_uri("athena_logging").cloned();
let target_url = url.map(|value| value.to_string()).or(fallback_url);
if let Some(url) = target_url {
let backend = backend
.map(|value| parse_backend_type(Some(value)))
.unwrap_or(BackendType::PostgreSQL);
let client = build_client(backend, &url, key).await?;
Ok(Some(AuditLogger::new(client, "cdc", "cdc")))
} else {
Ok(None)
}
}
#[cfg(feature = "cdc")]
fn parse_backend_type(name: Option<&str>) -> BackendType {
match name.unwrap_or("native").to_lowercase().as_str() {
"supabase" => BackendType::Supabase,
"postgrest" => BackendType::Postgrest,
"scylla" => BackendType::Scylla,
"neon" => BackendType::Neon,
"postgresql" | "postgres" => BackendType::PostgreSQL,
_ => BackendType::Native,
}
}
pub async fn run_provision_command(args: ProvisionArgs) -> Result<()> {
println!(
"{} {}",
"Provisioning".bright_cyan().bold(),
args.uri
.split('@')
.last()
.unwrap_or(&args.uri)
.bright_white()
);
let total = run_provision_sql(&args.uri)
.await
.map_err(|err| match err {
ProvisioningError::InvalidInput(message)
| ProvisioningError::Conflict(message)
| ProvisioningError::Unavailable(message)
| ProvisioningError::Execution(message) => anyhow!(message),
})?;
println!(
"{} Applied {} statements.",
"✔".bright_green().bold(),
total
);
if args.register {
let client_name = args
.client_name
.clone()
.unwrap_or_else(|| "provisioned".to_string());
let mut store: ClientStore = if args.clients_path.exists() {
ClientStore::load(&args.clients_path)?
} else {
ClientStore {
default: Some(client_name.clone()),
clients: HashMap::new(),
}
};
store.clients.insert(client_name.clone(), args.uri.clone());
if store.default.is_none() {
store.default = Some(client_name.clone());
}
store.persist(&args.clients_path)?;
println!(
"{} Registered client '{}' in {}",
"✔".bright_green().bold(),
client_name.bright_white(),
args.clients_path.display()
);
} else if let Some(name) = args.client_name.as_ref() {
println!(
"{} To register this client run: athena_rs provision --uri <URI> --client-name {} --register",
"ℹ".bright_cyan(),
name
);
}
Ok(())
}
pub fn run_diag_command() -> Result<()> {
let info: Value = json!({
"os": std::env::consts::OS,
"arch": std::env::consts::ARCH,
"family": std::env::consts::FAMILY,
"cpu_count": num_cpus::get(),
"rust_version": env!("CARGO_PKG_VERSION"),
"hostname": std::env::var("HOSTNAME").unwrap_or_default(),
"locale": std::env::var("LANG").unwrap_or_default(),
});
println!("{}", serde_json::to_string_pretty(&info)?);
Ok(())
}
pub fn run_version_command() {
println!("{}", env!("CARGO_PKG_VERSION"));
}
fn parse_json_input(json: Option<String>, file: Option<PathBuf>) -> Result<Option<Value>> {
if let Some(text) = json {
let value: Value = serde_json::from_str(&text).context("parsing inline JSON payload")?;
return Ok(Some(value));
}
if let Some(path) = file {
let text: String = fs::read_to_string(&path).context("reading payload file")?;
let value: Value = serde_json::from_str(&text).context("parsing payload file")?;
return Ok(Some(value));
}
Ok(None)
}