use std::path::PathBuf;
use std::sync::Arc;
use clap::{Parser, ValueEnum};
use tracing_subscriber::EnvFilter;
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
#[clap(rename_all = "kebab-case")]
enum Mode {
LegacyGrpc,
Local,
Remote,
}
fn resolve_mode(cli_mode: Option<Mode>, env_lookup: impl Fn(&str) -> Option<String>) -> Result<Mode, String> {
if let Some(m) = cli_mode {
return Ok(m);
}
if let Some(raw) = env_lookup("AA_MODE") {
return match raw.to_ascii_lowercase().as_str() {
"legacy-grpc" => Ok(Mode::LegacyGrpc),
"local" => Ok(Mode::Local),
"remote" => Ok(Mode::Remote),
other => Err(format!(
"invalid AA_MODE={other:?} — expected one of: legacy-grpc, local, remote"
)),
};
}
Ok(Mode::LegacyGrpc)
}
#[derive(Parser)]
#[command(name = "aa-gateway", version, about)]
struct Cli {
#[arg(long, value_enum)]
mode: Option<Mode>,
#[arg(long)]
policy: Option<PathBuf>,
#[arg(long, default_value = "127.0.0.1:50051")]
listen: String,
#[arg(long)]
socket: Option<PathBuf>,
#[arg(long)]
audit_dir: Option<PathBuf>,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
.init();
let cli = Cli::parse();
if let Some(ref dir) = cli.audit_dir {
std::env::set_var("AA_AUDIT_DIR", dir);
}
let mode = resolve_mode(cli.mode, |k| std::env::var(k).ok())?;
match mode {
Mode::LegacyGrpc => run_legacy_grpc(cli).await,
Mode::Local => run_local().await,
Mode::Remote => run_remote().await,
}
}
async fn run_local() -> Result<(), Box<dyn std::error::Error>> {
let cfg = aa_core::config::GatewayConfig::load()?;
let handle = aa_gateway::local_mode::start_local(&cfg.local).await?;
aa_gateway::local_mode::run_until_shutdown(handle).await?;
Ok(())
}
async fn run_legacy_grpc(cli: Cli) -> Result<(), Box<dyn std::error::Error>> {
let policy = cli
.policy
.as_ref()
.ok_or("--policy is required in legacy-grpc mode")?
.clone();
tracing::info!(policy = %policy.display(), "loading policy");
let cfg = aa_core::config::GatewayConfig::load()?;
let storage = aa_gateway::storage::open_sqlite_backend(&cfg.local.storage_path).await?;
let registry = Arc::new(aa_gateway::AgentRegistry::new().with_storage(storage.clone()));
let restored = registry.rehydrate_from_storage().await?;
if restored > 0 {
tracing::info!(restored, "rehydrated agents from durable storage");
}
let retention_shutdown = tokio_util::sync::CancellationToken::new();
let _retention_handle = match aa_gateway::storage::spawn_retention_engine(
storage.clone(),
&cfg.storage.retention,
retention_shutdown.clone(),
) {
Ok((_engine, handle)) => {
tracing::info!(
schedule = %cfg.storage.retention.schedule,
hot_days = cfg.storage.retention.hot_days,
warm_days = cfg.storage.retention.warm_days,
"retention engine started"
);
Some(handle)
}
Err(err) => {
tracing::warn!(error = %err, "retention engine disabled — config rejected by validator");
None
}
};
let approval_queue = aa_runtime::approval::ApprovalQueue::new();
let (budget_alert_tx, budget_alert_rx) = tokio::sync::broadcast::channel::<aa_gateway::budget::BudgetAlert>(64);
let _webhook_handle = aa_gateway::events::startup::maybe_spawn_webhook(&approval_queue, budget_alert_rx);
let serve_result = if let Some(socket_path) = &cli.socket {
aa_gateway::server::serve_uds(
&policy,
socket_path,
registry,
approval_queue,
budget_alert_tx,
Some(storage),
)
.await
} else {
aa_gateway::server::serve_tcp(
&policy,
&cli.listen,
registry,
approval_queue,
budget_alert_tx,
Some(storage),
)
.await
};
serve_result
}
async fn run_remote() -> Result<(), Box<dyn std::error::Error>> {
let cfg = aa_core::config::GatewayConfig::load()?;
aa_gateway::remote_mode::start_remote(&cfg.remote).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn env_with(value: &'static str) -> impl Fn(&str) -> Option<String> {
move |k| (k == "AA_MODE").then(|| value.to_string())
}
#[test]
fn cli_flag_overrides_env() {
let resolved = resolve_mode(Some(Mode::Remote), env_with("local")).expect("resolve");
assert_eq!(resolved, Mode::Remote);
}
#[test]
fn falls_back_to_aa_mode_env() {
let resolved = resolve_mode(None, env_with("remote")).expect("resolve");
assert_eq!(resolved, Mode::Remote);
}
#[test]
fn defaults_to_legacy_grpc() {
let resolved = resolve_mode(None, |_| None).expect("resolve");
assert_eq!(resolved, Mode::LegacyGrpc);
}
#[test]
fn rejects_invalid_aa_mode_value() {
let err = resolve_mode(None, env_with("foobar")).expect_err("expected error");
assert!(err.contains("foobar"), "error must echo the invalid value: {err}");
assert!(
err.contains("legacy-grpc") && err.contains("local") && err.contains("remote"),
"error must list valid modes: {err}"
);
}
}