use std::io::Write as _;
use std::sync::Arc;
use agent_sdk::{
AgentConfig, AgentEvent, AgentEventEnvelope, AgentInput, CancellationToken, EventStore,
InMemoryEventStore, LlmProvider, ThreadId, ToolContext, builder,
providers::{
AnthropicProvider, CloudflareAIGatewayProvider, GeminiProvider, OpenAIProvider,
VertexProvider,
},
};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use clap::Args as ClapArgs;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
const DEFAULT_ANTHROPIC_MODEL: &str = "sonnet";
const DEFAULT_OPENAI_MODEL: &str = "gpt-5.4";
const DEFAULT_GEMINI_MODEL: &str = "gemini-3-flash-preview";
const DEFAULT_VERTEX_REGION: &str = "global";
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, clap::ValueEnum)]
pub enum Provider {
#[default]
Anthropic,
Openai,
Gemini,
Vertex,
Cloudflare,
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, clap::ValueEnum)]
pub enum CloudflareUpstream {
#[default]
Anthropic,
Openai,
Gemini,
}
#[derive(ClapArgs, Debug, Clone)]
pub struct ProviderArgs {
#[arg(long, value_enum, default_value_t = Provider::Anthropic)]
pub provider: Provider,
#[arg(long)]
pub model: Option<String>,
#[arg(long, value_enum, default_value_t = CloudflareUpstream::Anthropic)]
pub cf_upstream: CloudflareUpstream,
#[arg(long)]
pub cf_account_id: Option<String>,
#[arg(long)]
pub cf_gateway_id: Option<String>,
#[arg(long)]
pub gcp_project: Option<String>,
#[arg(long)]
pub gcp_region: Option<String>,
}
#[derive(ClapArgs, Debug)]
pub struct RunArgs {
pub prompt: String,
#[arg(long, default_value = "You are a helpful assistant.")]
pub system: String,
#[command(flatten)]
pub provider: ProviderArgs,
}
#[derive(ClapArgs, Debug)]
pub struct ChatArgs {
#[arg(long, default_value = "You are a helpful assistant.")]
pub system: String,
#[command(flatten)]
pub provider: ProviderArgs,
}
fn require_env(name: &str, hint: &str) -> Result<String> {
let value = std::env::var(name).with_context(|| format!("{name} is not set; {hint}"))?;
if value.trim().is_empty() {
bail!("{name} is set but empty; {hint}");
}
Ok(value)
}
fn first_env(names: &[&str]) -> Option<String> {
names.iter().find_map(|name| {
std::env::var(name)
.ok()
.filter(|value| !value.trim().is_empty())
})
}
fn anthropic_provider(api_key: String, model: &str) -> AnthropicProvider {
match model {
"haiku" => AnthropicProvider::haiku(api_key),
"sonnet" => AnthropicProvider::sonnet(api_key),
"opus" => AnthropicProvider::opus(api_key),
"fable" => AnthropicProvider::fable(api_key),
other => AnthropicProvider::new(api_key, other.to_owned()),
}
}
fn resolve_with_env(flag: Option<&str>, env_names: &[&str], missing: &str) -> Result<String> {
flag.map(str::to_owned)
.filter(|value| !value.trim().is_empty())
.or_else(|| first_env(env_names))
.context(missing.to_owned())
}
fn build_cloudflare(args: &ProviderArgs) -> Result<CloudflareAIGatewayProvider> {
let account_id = resolve_with_env(
args.cf_account_id.as_deref(),
&["CLOUDFLARE_ACCOUNT_ID"],
"Cloudflare account id missing; pass --cf-account-id or set CLOUDFLARE_ACCOUNT_ID",
)?;
let gateway_id = resolve_with_env(
args.cf_gateway_id.as_deref(),
&["CLOUDFLARE_GATEWAY_ID"],
"Cloudflare gateway id missing; pass --cf-gateway-id or set CLOUDFLARE_GATEWAY_ID",
)?;
let gateway_token = first_env(&["CLOUDFLARE_GATEWAY_TOKEN", "CLOUDFLARE_AIG_TOKEN"]);
let provider = match args.cf_upstream {
CloudflareUpstream::Anthropic => {
let model = args
.model
.clone()
.unwrap_or_else(|| "claude-sonnet-4-6".to_owned());
let api_key = first_env(&["ANTHROPIC_API_KEY"]).unwrap_or_default();
ensure_cf_auth(&api_key, gateway_token.as_deref(), "ANTHROPIC_API_KEY")?;
let base =
CloudflareAIGatewayProvider::anthropic(api_key, &account_id, &gateway_id, model);
apply_gateway_token(base, gateway_token.as_deref())
}
CloudflareUpstream::Openai => {
let model = args
.model
.clone()
.unwrap_or_else(|| DEFAULT_OPENAI_MODEL.to_owned());
let api_key = first_env(&["OPENAI_API_KEY"]).unwrap_or_default();
ensure_cf_auth(&api_key, gateway_token.as_deref(), "OPENAI_API_KEY")?;
let base =
CloudflareAIGatewayProvider::openai(api_key, &account_id, &gateway_id, model);
apply_gateway_token(base, gateway_token.as_deref())
}
CloudflareUpstream::Gemini => {
let model = args
.model
.clone()
.unwrap_or_else(|| DEFAULT_GEMINI_MODEL.to_owned());
let api_key = first_env(&["GEMINI_API_KEY", "GOOGLE_API_KEY"]).unwrap_or_default();
ensure_cf_auth(
&api_key,
gateway_token.as_deref(),
"GEMINI_API_KEY/GOOGLE_API_KEY",
)?;
let base =
CloudflareAIGatewayProvider::gemini(api_key, &account_id, &gateway_id, model);
apply_gateway_token(base, gateway_token.as_deref())
}
};
Ok(provider)
}
fn ensure_cf_auth(api_key: &str, gateway_token: Option<&str>, key_env: &str) -> Result<()> {
if api_key.trim().is_empty() && gateway_token.is_none() {
bail!(
"Cloudflare requires credentials: set {key_env} for pass-through mode, \
or CLOUDFLARE_GATEWAY_TOKEN for BYOK mode"
);
}
Ok(())
}
fn apply_gateway_token(
provider: CloudflareAIGatewayProvider,
token: Option<&str>,
) -> CloudflareAIGatewayProvider {
match token {
Some(token) => provider.with_gateway_token(token),
None => provider,
}
}
fn build_vertex(args: &ProviderArgs) -> Result<VertexProvider> {
let access_token = require_env(
"VERTEX_ACCESS_TOKEN",
"export an OAuth2 access token (e.g. `gcloud auth print-access-token`)",
)?;
let project_id = resolve_with_env(
args.gcp_project.as_deref(),
&["GOOGLE_CLOUD_PROJECT", "GCP_PROJECT"],
"Vertex project id missing; pass --gcp-project or set GOOGLE_CLOUD_PROJECT",
)?;
let region = args
.gcp_region
.clone()
.or_else(|| first_env(&["VERTEX_REGION", "GOOGLE_CLOUD_REGION"]))
.unwrap_or_else(|| DEFAULT_VERTEX_REGION.to_owned());
let model = args
.model
.clone()
.unwrap_or_else(|| DEFAULT_GEMINI_MODEL.to_owned());
Ok(VertexProvider::new(access_token, project_id, region, model))
}
struct StreamToStdout {
inner: Arc<InMemoryEventStore>,
}
impl StreamToStdout {
fn new() -> Self {
Self {
inner: Arc::new(InMemoryEventStore::new()),
}
}
}
#[async_trait]
impl EventStore for StreamToStdout {
async fn append(
&self,
thread_id: &ThreadId,
turn: usize,
envelope: AgentEventEnvelope,
) -> Result<()> {
match &envelope.event {
AgentEvent::TextDelta { delta, .. } => {
print!("{delta}");
let _ = std::io::stdout().flush();
}
AgentEvent::Error { message, .. } => {
eprintln!("\nerror: {message}");
}
_ => {}
}
self.inner.append(thread_id, turn, envelope).await
}
async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()> {
self.inner.finish_turn(thread_id, turn).await
}
async fn get_turn(
&self,
thread_id: &ThreadId,
turn: usize,
) -> Result<Option<agent_sdk::StoredTurnEvents>> {
self.inner.get_turn(thread_id, turn).await
}
async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<agent_sdk::StoredTurnEvents>> {
self.inner.get_turns(thread_id).await
}
async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
self.inner.clear(thread_id).await
}
}
pub fn run(args: RunArgs) -> Result<()> {
let runtime = tokio::runtime::Runtime::new().context("failed to start async runtime")?;
runtime.block_on(run_async(args))
}
pub fn chat(args: ChatArgs) -> Result<()> {
let runtime = tokio::runtime::Runtime::new().context("failed to start async runtime")?;
runtime.block_on(chat_async(args))
}
async fn run_async(args: RunArgs) -> Result<()> {
let RunArgs {
prompt,
system,
provider,
} = args;
match provider.provider {
Provider::Anthropic => {
let key = require_env(
"ANTHROPIC_API_KEY",
"export your Anthropic API key to run an agent",
)?;
let model = provider.model.as_deref().unwrap_or(DEFAULT_ANTHROPIC_MODEL);
run_with(anthropic_provider(key, model), prompt, system).await
}
Provider::Openai => {
let key = require_env(
"OPENAI_API_KEY",
"export your OpenAI API key to run an agent",
)?;
let model = provider
.model
.clone()
.unwrap_or_else(|| DEFAULT_OPENAI_MODEL.to_owned());
run_with(OpenAIProvider::new(key, model), prompt, system).await
}
Provider::Gemini => {
let key = require_gemini_key()?;
let model = provider
.model
.clone()
.unwrap_or_else(|| DEFAULT_GEMINI_MODEL.to_owned());
run_with(GeminiProvider::new(key, model), prompt, system).await
}
Provider::Vertex => run_with(build_vertex(&provider)?, prompt, system).await,
Provider::Cloudflare => run_with(build_cloudflare(&provider)?, prompt, system).await,
}
}
async fn chat_async(args: ChatArgs) -> Result<()> {
let ChatArgs { system, provider } = args;
match provider.provider {
Provider::Anthropic => {
let key = require_env(
"ANTHROPIC_API_KEY",
"export your Anthropic API key to run an agent",
)?;
let model = provider.model.as_deref().unwrap_or(DEFAULT_ANTHROPIC_MODEL);
chat_with(anthropic_provider(key, model), system).await
}
Provider::Openai => {
let key = require_env(
"OPENAI_API_KEY",
"export your OpenAI API key to run an agent",
)?;
let model = provider
.model
.clone()
.unwrap_or_else(|| DEFAULT_OPENAI_MODEL.to_owned());
chat_with(OpenAIProvider::new(key, model), system).await
}
Provider::Gemini => {
let key = require_gemini_key()?;
let model = provider
.model
.clone()
.unwrap_or_else(|| DEFAULT_GEMINI_MODEL.to_owned());
chat_with(GeminiProvider::new(key, model), system).await
}
Provider::Vertex => chat_with(build_vertex(&provider)?, system).await,
Provider::Cloudflare => chat_with(build_cloudflare(&provider)?, system).await,
}
}
fn require_gemini_key() -> Result<String> {
first_env(&["GEMINI_API_KEY", "GOOGLE_API_KEY"]).context(
"neither GEMINI_API_KEY nor GOOGLE_API_KEY is set; export a Gemini API key to run an agent",
)
}
async fn run_with<P: LlmProvider + 'static>(
provider: P,
prompt: String,
system: String,
) -> Result<()> {
let agent = build_agent(provider, system);
let thread_id = ThreadId::new();
let _ = agent
.run(
thread_id,
AgentInput::Text(prompt),
ToolContext::new(()),
CancellationToken::new(),
)
.await
.context("agent run failed")?;
println!();
Ok(())
}
async fn chat_with<P: LlmProvider + 'static>(provider: P, system: String) -> Result<()> {
let agent = build_agent(provider, system);
let thread_id = ThreadId::new();
let mut stdout = tokio::io::stdout();
stdout
.write_all(b"agent-sdk chat - type a message, or 'exit' / Ctrl-D to quit.\n")
.await?;
stdout.flush().await?;
let mut lines = BufReader::new(tokio::io::stdin()).lines();
loop {
stdout.write_all(b"\nyou> ").await?;
stdout.flush().await?;
let Some(line) = lines.next_line().await? else {
stdout.write_all(b"\n").await?;
break;
};
let prompt = line.trim();
if prompt.is_empty() {
continue;
}
if matches!(prompt, "exit" | "quit") {
break;
}
stdout.write_all(b"\nagent> ").await?;
stdout.flush().await?;
let _ = agent
.run(
thread_id.clone(),
AgentInput::Text(prompt.to_string()),
ToolContext::new(()),
CancellationToken::new(),
)
.await
.context("agent run failed")?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
}
Ok(())
}
type CliAgent<P> = agent_sdk::AgentLoop<
(),
P,
agent_sdk::DefaultHooks,
agent_sdk::InMemoryStore,
agent_sdk::InMemoryStore,
>;
fn build_agent<P: LlmProvider + 'static>(provider: P, system: String) -> CliAgent<P> {
let event_store = Arc::new(StreamToStdout::new());
builder::<()>()
.provider(provider)
.config(AgentConfig {
system_prompt: system,
..Default::default()
})
.event_store(event_store)
.build()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn anthropic_aliases_resolve_to_expected_models() {
assert_eq!(
anthropic_provider("k".to_owned(), "haiku").model(),
"claude-haiku-4-5-20251001"
);
assert_eq!(
anthropic_provider("k".to_owned(), "sonnet").model(),
"claude-sonnet-4-6"
);
assert_eq!(
anthropic_provider("k".to_owned(), "opus").model(),
"claude-opus-4-6"
);
assert_eq!(
anthropic_provider("k".to_owned(), "fable").model(),
"claude-fable-5"
);
}
#[test]
fn anthropic_full_model_id_passes_through() {
assert_eq!(
anthropic_provider("k".to_owned(), "claude-3-5-haiku-20241022").model(),
"claude-3-5-haiku-20241022"
);
}
#[test]
fn resolve_with_env_prefers_explicit_flag() -> Result<()> {
let got = resolve_with_env(Some("from-flag"), &["UNSET_VAR_X"], "missing")?;
assert_eq!(got, "from-flag");
Ok(())
}
#[test]
fn resolve_with_env_ignores_blank_flag_then_errors() {
let result = resolve_with_env(Some(" "), &["AGENT_SDK_DEFINITELY_UNSET_VAR"], "boom");
match result {
Ok(value) => panic!("expected error, got {value:?}"),
Err(err) => assert!(err.to_string().contains("boom"), "unexpected error: {err}"),
}
}
#[test]
fn cloudflare_auth_requires_a_credential() {
match ensure_cf_auth("", None, "ANTHROPIC_API_KEY") {
Ok(()) => panic!("missing credentials should error"),
Err(err) => assert!(
err.to_string().contains("Cloudflare requires credentials"),
"unexpected error: {err}"
),
}
}
#[test]
fn cloudflare_auth_passes_with_gateway_token() -> Result<()> {
ensure_cf_auth("", Some("cf-token"), "ANTHROPIC_API_KEY")?;
Ok(())
}
#[test]
fn cloudflare_auth_passes_with_api_key() -> Result<()> {
ensure_cf_auth("sk-123", None, "ANTHROPIC_API_KEY")?;
Ok(())
}
#[test]
fn provider_default_is_anthropic() {
assert_eq!(Provider::default(), Provider::Anthropic);
}
}