use adk_gateway::channel;
use adk_gateway::config;
use adk_gateway::config_encryption;
use adk_gateway::gateway;
use adk_gateway::knowledge_graph;
use adk_gateway::mcp;
use adk_gateway::pairing;
use adk_gateway::rag;
use adk_gateway::telemetry;
use clap::{Parser, Subcommand};
use std::path::PathBuf;
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(
name = "adk-gateway",
version,
about = "Multi-channel AI gateway for adk-rust agents"
)]
struct Cli {
#[arg(short, long, global = true)]
config: Option<PathBuf>,
#[arg(short, long, global = true)]
verbose: bool,
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
Gateway {
#[arg(short, long)]
port: Option<u16>,
#[arg(long)]
force: bool,
},
ConfigValidate,
ConfigShow,
ChannelsStatus {
#[arg(long)]
probe: bool,
},
#[command(subcommand)]
Memory(MemoryCommands),
#[command(subcommand)]
Rag(RagCommands),
#[command(subcommand)]
Pairing(PairingCommands),
#[command(subcommand)]
Mcp(McpCommands),
ConfigEncrypt {
#[arg(long)]
key_file: PathBuf,
},
}
#[derive(Subcommand)]
enum MemoryCommands {
Search {
query: String,
#[arg(long)]
user_id: String,
},
DeleteUser {
user_id: String,
},
}
#[derive(Subcommand)]
enum RagCommands {
Ingest {
path: PathBuf,
},
Search {
query: String,
#[arg(long, default_value = "5")]
top_k: usize,
},
}
#[derive(Subcommand)]
enum PairingCommands {
GenerateCode,
}
#[derive(Subcommand)]
enum McpCommands {
Add {
#[arg(long, required_unless_present = "json")]
name: Option<String>,
#[arg(long)]
command: Option<String>,
#[arg(long)]
args: Vec<String>,
#[arg(long)]
env: Vec<String>,
#[arg(long)]
url: Option<String>,
#[arg(long)]
disabled: bool,
#[arg(long, conflicts_with_all = ["command", "url", "args", "env"])]
json: Option<String>,
},
Remove {
name: String,
},
List,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
let config_path = cli.config.unwrap_or_else(config::default_config_path);
let cfg = config::load_config(&config_path)?;
let filter = if cli.verbose {
EnvFilter::new("adk_gateway=debug,tower_http=debug")
} else {
EnvFilter::new("adk_gateway=info,tower_http=info")
};
let telemetry_setup = telemetry::TelemetrySetup::from_config(&cfg.telemetry);
telemetry_setup.init(filter);
tracing::info!(telemetry = %telemetry_setup.describe(), "telemetry initialized");
match cli.command {
None | Some(Commands::Gateway { .. }) => {
let port = match &cli.command {
Some(Commands::Gateway { port, .. }) => port.unwrap_or(cfg.gateway.port),
_ => cfg.gateway.port,
};
let force = matches!(&cli.command, Some(Commands::Gateway { force: true, .. }));
if force {
kill_port_holder(port);
}
tracing::info!(port, "starting adk-gateway");
gateway::run(cfg, port, config_path).await?;
}
Some(Commands::ConfigValidate) => {
tracing::info!("configuration is valid");
println!("{}", serde_json::to_string_pretty(&cfg)?);
}
Some(Commands::ConfigShow) => {
let mut display_cfg = serde_json::to_value(&cfg)?;
if let Some(channels) = display_cfg.get_mut("channels") {
if let Some(tg) = channels.get_mut("telegram") {
if tg.get("bot_token").is_some() {
tg["bot_token"] = serde_json::Value::String("***REDACTED***".into());
}
}
if let Some(slack) = channels.get_mut("slack") {
for key in &["bot_token", "app_token", "signing_secret"] {
if slack.get(*key).is_some() {
slack[*key] = serde_json::Value::String("***REDACTED***".into());
}
}
}
}
if let Some(hooks) = display_cfg.get_mut("hooks") {
if hooks.get("token").is_some() {
hooks["token"] = serde_json::Value::String("***REDACTED***".into());
}
}
if let Some(auth) = display_cfg.get_mut("auth") {
if let Some(sso) = auth.get_mut("sso") {
if sso.get("client_secret").is_some() {
sso["client_secret"] = serde_json::Value::String("***REDACTED***".into());
}
}
}
println!("{}", serde_json::to_string_pretty(&display_cfg)?);
}
Some(Commands::ChannelsStatus { probe }) => {
channel::print_status(&cfg.channels, probe).await;
}
Some(Commands::Memory(mem_cmd)) => {
let kg = knowledge_graph::KnowledgeGraph::new();
match mem_cmd {
MemoryCommands::Search { query, user_id } => {
let results = kg.search_nodes(&user_id, &query);
println!("{}", serde_json::to_string_pretty(&results)?);
}
MemoryCommands::DeleteUser { user_id } => {
let deleted = kg.delete_user_graph(&user_id);
if deleted {
println!("Deleted knowledge graph for user '{user_id}'");
} else {
println!("No knowledge graph found for user '{user_id}'");
}
}
}
}
Some(Commands::Rag(rag_cmd)) => match rag_cmd {
RagCommands::Ingest { path } => {
let rag_config = cfg
.rag
.ok_or_else(|| anyhow::anyhow!("no RAG configuration found in config file"))?;
let pipeline = rag::RagPipelineBuilder::build(&rag_config)?;
let count = pipeline.ingest(&path)?;
println!("Ingested {count} chunks from {}", path.display());
}
RagCommands::Search { query, top_k } => {
let rag_config = cfg
.rag
.ok_or_else(|| anyhow::anyhow!("no RAG configuration found in config file"))?;
let pipeline = rag::RagPipelineBuilder::build(&rag_config)?;
let results = pipeline.search(&query, top_k);
println!("{}", serde_json::to_string_pretty(&results)?);
}
},
Some(Commands::Pairing(pairing_cmd)) => match pairing_cmd {
PairingCommands::GenerateCode => {
let service = pairing::DmPairingService::new();
let code = service.generate_code();
println!("{code}");
}
},
Some(Commands::Mcp(mcp_cmd)) => {
handle_mcp_command(mcp_cmd, &config_path, &cfg)?;
}
Some(Commands::ConfigEncrypt { key_file }) => {
match config_encryption::encrypt_config_file(&config_path, &key_file) {
Ok(count) => {
println!(
"Encrypted {count} sensitive field(s) in {}",
config_path.display()
);
}
Err(e) => {
eprintln!("Error: {e}");
std::process::exit(1);
}
}
}
}
Ok(())
}
fn handle_mcp_command(
cmd: McpCommands,
config_path: &std::path::Path,
cfg: &config::GatewayConfig,
) -> anyhow::Result<()> {
match cmd {
McpCommands::Add {
name,
command,
args,
env,
url,
disabled,
json,
} => {
let mut updated_cfg = cfg.clone();
if let Some(json_str) = json {
let parsed: serde_json::Value = serde_json::from_str(&json_str)
.map_err(|e| anyhow::anyhow!("invalid JSON: {e}"))?;
let obj = parsed
.as_object()
.ok_or_else(|| {
anyhow::anyhow!(
"JSON must be an object mapping server names to configs"
)
})?;
for (server_name, server_val) in obj {
let server_obj = server_val.as_object().ok_or_else(|| {
anyhow::anyhow!("config for '{server_name}' must be an object")
})?;
let transport = if let Some(cmd_val) = server_obj.get("command") {
let cmd_str = cmd_val
.as_str()
.ok_or_else(|| anyhow::anyhow!("'command' must be a string"))?
.to_string();
let args_val: Vec<String> = server_obj
.get("args")
.and_then(|a| a.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let env_map: std::collections::HashMap<String, String> = server_obj
.get("env")
.and_then(|e| e.as_object())
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| {
v.as_str().map(|s| (k.clone(), s.to_string()))
})
.collect()
})
.unwrap_or_default();
mcp::McpTransport::Stdio {
command: cmd_str,
args: args_val,
env: env_map,
}
} else if let Some(url_val) = server_obj.get("url") {
let url_str = url_val
.as_str()
.ok_or_else(|| anyhow::anyhow!("'url' must be a string"))?
.to_string();
mcp::McpTransport::Sse { url: url_str }
} else {
anyhow::bail!("server '{server_name}' must have 'command' or 'url'");
};
let is_disabled = server_obj
.get("disabled")
.and_then(|d| d.as_bool())
.unwrap_or(false);
let new_server = mcp::McpServerConfig {
server_id: server_name.clone(),
transport,
auth: None,
enabled: !is_disabled,
};
updated_cfg
.mcp_servers
.retain(|s| s.server_id != *server_name);
updated_cfg.mcp_servers.push(new_server);
println!("Added MCP server '{server_name}'");
}
} else {
let name = name.expect("name is required when --json is not used");
let transport = if let Some(cmd_str) = command {
let env_map: std::collections::HashMap<String, String> = env
.iter()
.filter_map(|e| {
let parts: Vec<&str> = e.splitn(2, '=').collect();
if parts.len() == 2 {
Some((parts[0].to_string(), parts[1].to_string()))
} else {
None
}
})
.collect();
mcp::McpTransport::Stdio {
command: cmd_str,
args,
env: env_map,
}
} else if let Some(url_str) = url {
mcp::McpTransport::Sse { url: url_str }
} else {
anyhow::bail!("either --command or --url must be provided");
};
let new_server = mcp::McpServerConfig {
server_id: name.clone(),
transport,
auth: None,
enabled: !disabled,
};
updated_cfg.mcp_servers.retain(|s| s.server_id != name);
updated_cfg.mcp_servers.push(new_server);
println!("Added MCP server '{name}'");
}
let output = serde_json::to_string_pretty(&updated_cfg)?;
std::fs::write(config_path, &output)?;
}
McpCommands::Remove { name } => {
let mut updated_cfg = cfg.clone();
let before = updated_cfg.mcp_servers.len();
updated_cfg.mcp_servers.retain(|s| s.server_id != name);
if updated_cfg.mcp_servers.len() == before {
println!("MCP server '{name}' not found");
} else {
let output = serde_json::to_string_pretty(&updated_cfg)?;
std::fs::write(config_path, &output)?;
println!("Removed MCP server '{name}'");
}
}
McpCommands::List => {
if cfg.mcp_servers.is_empty() {
println!("No MCP servers configured");
} else {
println!("{:<20} {:<10} {:<10}", "SERVER ID", "TRANSPORT", "ENABLED");
println!("{}", "-".repeat(42));
for server in &cfg.mcp_servers {
let transport_type = match &server.transport {
mcp::McpTransport::Stdio { .. } => "stdio",
mcp::McpTransport::Sse { .. } => "sse",
};
let enabled = if server.enabled { "yes" } else { "no" };
println!(
"{:<20} {:<10} {:<10}",
server.server_id, transport_type, enabled
);
}
}
}
}
Ok(())
}
fn kill_port_holder(port: u16) {
let output = std::process::Command::new("lsof")
.args(["-ti", &format!(":{}", port)])
.output();
if let Ok(out) = output {
let pids = String::from_utf8_lossy(&out.stdout);
for pid_str in pids.split_whitespace() {
if pid_str.parse::<u32>().is_ok() {
tracing::info!(pid = pid_str, port, "killing existing process on port");
let _ = std::process::Command::new("kill")
.args(["-9", pid_str])
.output();
}
}
if !pids.trim().is_empty() {
std::thread::sleep(std::time::Duration::from_millis(500));
}
}
}