use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use clap::{Parser, Subcommand};
use futures::StreamExt;
use tracing::info;
use crate::agents::BaseAgent;
use crate::runner::Runner;
use crate::services::mem::InMemorySessionService;
use crate::telemetry::{LogFormat, TelemetryConfig};
#[derive(Debug, Parser)]
#[command(name = "adk", version, about = "Agent Development Kit CLI")]
pub struct Cli {
#[arg(long, env = "ADK_LOG")]
pub log: Option<String>,
#[arg(long, default_value = "compact")]
pub log_format: LogFormatArg,
#[command(subcommand)]
pub command: Command,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
pub enum LogFormatArg {
Compact,
Pretty,
Json,
}
impl From<LogFormatArg> for LogFormat {
fn from(v: LogFormatArg) -> Self {
match v {
LogFormatArg::Compact => LogFormat::Compact,
LogFormatArg::Pretty => LogFormat::Pretty,
LogFormatArg::Json => LogFormat::Json,
}
}
}
#[derive(Debug, Subcommand)]
pub enum Command {
Run {
#[arg(long)]
agent: String,
#[arg(long, default_value = "anonymous")]
user: String,
#[arg(long)]
session: Option<String>,
message: String,
},
Web {
#[arg(long, default_value = "127.0.0.1:8000")]
bind: SocketAddr,
#[arg(long, env = "ADK_WEB_TOKEN")]
auth_token: Option<String>,
#[arg(long)]
dangerously_allow_unauthenticated_remote: bool,
#[arg(long = "allow-origins")]
allow_origins: Vec<String>,
},
Eval {
#[arg(long)]
set: std::path::PathBuf,
#[arg(long)]
agent: String,
},
Version,
}
pub struct App {
name: String,
agents: HashMap<String, Arc<dyn BaseAgent>>,
}
impl std::fmt::Debug for App {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("App")
.field("name", &self.name)
.field("agents", &self.agents.keys().collect::<Vec<_>>())
.finish()
}
}
impl App {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
agents: HashMap::new(),
}
}
#[must_use]
pub fn register(mut self, name: impl Into<String>, agent: Arc<dyn BaseAgent>) -> Self {
self.agents.insert(name.into(), agent);
self
}
pub fn run(self) -> crate::error::Result<()> {
let cli = Cli::parse();
crate::telemetry::init(TelemetryConfig {
filter: cli.log,
format: cli.log_format.into(),
..TelemetryConfig::default()
})?;
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| crate::error::Error::other(format!("tokio: {e}")))?;
rt.block_on(self.run_async(cli.command))
}
pub async fn run_async(self, cmd: Command) -> crate::error::Result<()> {
match cmd {
Command::Run {
agent,
user,
session,
message,
} => {
let runner = self.build_runner(&agent)?;
let mut s = runner.run(&user, session.as_deref(), &message).await?;
while let Some(ev) = s.next().await {
let ev = ev?;
if let Some(c) = ev.response.content.as_ref() {
let text = c.text_concat();
if !text.is_empty() {
#[allow(clippy::print_stdout)]
{
println!("{}", text)
}
}
}
}
Ok(())
}
Command::Web {
bind,
auth_token,
dangerously_allow_unauthenticated_remote,
allow_origins,
} => {
let mut runners = HashMap::new();
for (name, agent) in &self.agents {
runners.insert(name.clone(), Arc::new(self.runner_for(agent.clone())?));
}
let runners = Arc::new(runners);
let state = match auth_token {
Some(t) => crate::server::AppState::with_bearer_token(runners, t),
None => crate::server::AppState::unauthenticated(runners),
}
.with_allow_origins(allow_origins);
info!("starting dev server on http://{bind}");
crate::server::serve_with(
bind,
state,
crate::server::ServeOptions {
dangerously_allow_unauthenticated_remote,
},
)
.await
}
Command::Eval { set, agent } => {
let bytes = tokio::fs::read(set).await?;
let set: crate::eval::EvalSet = serde_json::from_slice(&bytes)?;
let agent = self.find_agent(&agent)?;
let runner = crate::eval::EvalRunner::new(
agent,
self.name.clone(),
"eval-user",
vec![
Arc::new(crate::eval::TrajectoryMatch::new(1.0)),
Arc::new(crate::eval::ResponseMatch::new(0.5)),
],
);
let report = runner.run_set(&set).await?;
#[allow(clippy::print_stdout)]
{
println!(
"{}",
serde_json::to_string_pretty(&report).unwrap_or_default()
)
}
Ok(())
}
Command::Version => {
#[allow(clippy::print_stdout)]
{
println!("adk-rs {}", env!("CARGO_PKG_VERSION"))
}
Ok(())
}
}
}
fn find_agent(&self, name: &str) -> crate::error::Result<Arc<dyn BaseAgent>> {
self.agents
.get(name)
.cloned()
.ok_or_else(|| crate::error::Error::not_found(format!("agent {name}")))
}
fn build_runner(&self, agent_name: &str) -> crate::error::Result<Runner> {
let agent = self.find_agent(agent_name)?;
self.runner_for(agent)
}
fn runner_for(&self, agent: Arc<dyn BaseAgent>) -> crate::error::Result<Runner> {
Runner::builder()
.app_name(self.name.clone())
.agent(agent)
.session_service(Arc::new(InMemorySessionService::new()))
.auto_create_session(true)
.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agents::LlmAgent;
use crate::core::Model;
use crate::core::testing::MockModel;
#[tokio::test]
async fn run_command_prints_text() {
let m = Arc::new(MockModel::new("m"));
m.push_text("yo");
let agent: Arc<dyn BaseAgent> = Arc::new(
LlmAgent::builder("greet")
.model(m.clone() as Arc<dyn Model>)
.instruction("greet")
.build()
.unwrap(),
);
let app = App::new("hello").register("greet", agent);
app.run_async(Command::Run {
agent: "greet".into(),
user: "u".into(),
session: None,
message: "hi".into(),
})
.await
.unwrap();
}
}