use adk_studio::{
AppState, FileStorage, api_routes, cleanup_stale_sessions, embedded, start_scheduler,
};
use axum::{Router, extract::Path as AxumPath, routing::get};
use std::net::SocketAddr;
use std::path::PathBuf;
use tower_http::{
cors::{AllowOrigin, Any, CorsLayer},
services::{ServeDir, ServeFile},
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Debug, Clone)]
pub struct CliConfig {
pub port: u16,
pub host: [u8; 4],
pub projects_dir: PathBuf,
pub static_dir: Option<PathBuf>,
}
impl Default for CliConfig {
fn default() -> Self {
Self {
port: 3000,
host: [127, 0, 0, 1],
projects_dir: dirs::data_local_dir()
.unwrap_or_default()
.join("adk-studio/projects"),
static_dir: None,
}
}
}
#[derive(Debug)]
pub enum CliAction {
PrintVersion,
PrintHelp,
Run(CliConfig),
}
pub fn parse_args(args: &[String]) -> Result<CliAction, String> {
if args.iter().any(|a| a == "--version" || a == "-V") {
return Ok(CliAction::PrintVersion);
}
if args.iter().any(|a| a == "--help") {
return Ok(CliAction::PrintHelp);
}
let mut config = CliConfig::default();
let mut i = 0;
while i < args.len() {
let arg = &args[i];
match arg.as_str() {
"--port" | "-p" => {
let val = args
.get(i + 1)
.ok_or_else(|| format!("flag '{arg}' requires a value"))?;
config.port = val
.parse::<u16>()
.map_err(|_| format!("invalid port value '{val}'"))?;
i += 2;
}
"--host" | "-h" => {
let val = args
.get(i + 1)
.ok_or_else(|| format!("flag '{arg}' requires a value"))?;
config.host = parse_host(val)?;
i += 2;
}
"--dir" | "-d" => {
let val = args
.get(i + 1)
.ok_or_else(|| format!("flag '{arg}' requires a value"))?;
config.projects_dir = PathBuf::from(val);
i += 2;
}
"--static" | "-s" => {
let val = args
.get(i + 1)
.ok_or_else(|| format!("flag '{arg}' requires a value"))?;
config.static_dir = Some(PathBuf::from(val));
i += 2;
}
other if other.starts_with('-') => {
return Err(format!(
"unknown flag '{other}'\nRun 'adk-studio --help' for usage information."
));
}
_ => {
i += 1;
}
}
}
Ok(CliAction::Run(config))
}
fn parse_host(s: &str) -> Result<[u8; 4], String> {
let parts: Vec<&str> = s.split('.').collect();
if parts.len() != 4 {
return Err(format!("invalid host '{s}': expected format a.b.c.d"));
}
let mut octets = [0u8; 4];
for (i, part) in parts.iter().enumerate() {
octets[i] = part
.parse::<u8>()
.map_err(|_| format!("invalid host '{s}': '{part}' is not a valid octet"))?;
}
Ok(octets)
}
pub fn print_version() {
println!("adk-studio {}", env!("CARGO_PKG_VERSION"));
}
pub fn print_help() {
println!("adk-studio {}", env!("CARGO_PKG_VERSION"));
println!();
println!("Visual development environment for ADK-Rust agents");
println!();
println!("USAGE:");
println!(" adk-studio [OPTIONS]");
println!();
println!("OPTIONS:");
println!(" -p, --port <PORT> Server port [default: 3000]");
println!(" -h, --host <HOST> Server host [default: 127.0.0.1]");
println!(" -d, --dir <DIR> Projects directory [default: user-local]");
println!(
" Use --dir ./.adk-studio/projects for repo-local storage"
);
println!(" -s, --static <DIR> Static files directory");
println!(" -V, --version Print version and exit");
println!(" --help Print this help message and exit");
}
async fn serve_static(AxumPath(path): AxumPath<String>) -> axum::response::Response {
embedded::serve_embedded(path)
}
async fn serve_root() -> axum::response::Response {
embedded::serve_index()
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args: Vec<String> = std::env::args().skip(1).collect();
let config = match parse_args(&args) {
Ok(CliAction::PrintVersion) => {
print_version();
return Ok(());
}
Ok(CliAction::PrintHelp) => {
print_help();
return Ok(());
}
Err(msg) => {
eprintln!("error: {msg}");
eprintln!("Run 'adk-studio --help' for usage information.");
std::process::exit(1);
}
Ok(CliAction::Run(config)) => config,
};
dotenvy::dotenv().ok();
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let port = config.port;
let host = config.host;
let projects_dir = config.projects_dir;
let static_dir = config.static_dir;
let storage = FileStorage::new(projects_dir.clone()).await?;
let state = AppState::new(storage);
let scheduler_state = state.clone();
tokio::spawn(async move {
start_scheduler(scheduler_state).await;
});
tokio::spawn(async {
let cleanup_interval = std::time::Duration::from_secs(600);
let max_session_age = std::time::Duration::from_secs(3600);
loop {
tokio::time::sleep(cleanup_interval).await;
cleanup_stale_sessions(max_session_age).await;
}
});
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::predicate(|origin, _| {
let origin_str = origin.to_str().unwrap_or("");
if let Some(rest) = origin_str
.strip_prefix("http://")
.or_else(|| origin_str.strip_prefix("https://"))
{
const ALLOWED_HOSTS: &[&str] = &["localhost", "127.0.0.1", "[::1]"];
ALLOWED_HOSTS
.iter()
.any(|host| rest == *host || rest.starts_with(&format!("{}:", host)))
} else {
false
}
}))
.allow_methods(Any)
.allow_headers(Any);
let mut app = Router::new()
.nest("/api", api_routes())
.layer(cors)
.with_state(state);
if let Some(dir) = static_dir {
let index = dir.join("index.html");
app = app.fallback_service(ServeDir::new(&dir).fallback(ServeFile::new(index)));
tracing::info!("📂 Serving static files from: {}", dir.display());
} else {
let static_router = Router::new()
.route("/", get(serve_root))
.route("/*path", get(serve_static));
app = app.merge(static_router);
tracing::info!("📦 Serving embedded static files");
}
let addr = SocketAddr::from((host, port));
tracing::info!("🚀 ADK Studio starting on http://{}", addr);
tracing::info!("📁 Projects directory: {}", projects_dir.display());
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}