use super::agent::{AgentConfig, AutoContinue, RetryConfig};
use crate::agent_spec::AgentSpec;
use crate::context::CompactionConfig;
use crate::core::{Agent, Prompt, PromptCache, Result};
use crate::events::{AgentMessage, Command};
use crate::mcp::run_mcp_task::McpCommand;
use aether_auth::OAuthCredentialStorage;
use llm::parser::ModelProviderParser;
use llm::types::IsoString;
use llm::{ChatMessage, Context, StreamingModelProvider, ToolDefinition};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::task::JoinHandle;
pub struct AgentHandle {
handle: JoinHandle<()>,
}
impl AgentHandle {
pub fn abort(&self) {
self.handle.abort();
}
pub fn is_finished(&self) -> bool {
self.handle.is_finished()
}
pub async fn await_completion(self) {
let _ = self.handle.await;
}
}
pub struct AgentBuilder {
llm: Arc<dyn StreamingModelProvider>,
prompts: Vec<Prompt>,
tool_definitions: Vec<ToolDefinition>,
initial_messages: Vec<ChatMessage>,
mcp_tx: Option<Sender<McpCommand>>,
channel_capacity: usize,
tool_timeout: Duration,
compaction_config: Option<CompactionConfig>,
max_auto_continues: u32,
retry_config: RetryConfig,
prompt_cache_key: Option<String>,
context_window: Option<u32>,
}
impl AgentBuilder {
pub fn new(llm: Arc<dyn StreamingModelProvider>) -> Self {
Self {
llm,
prompts: Vec::new(),
tool_definitions: Vec::new(),
initial_messages: Vec::new(),
mcp_tx: None,
channel_capacity: 1000,
tool_timeout: Duration::from_mins(20),
compaction_config: Some(CompactionConfig::default()),
max_auto_continues: 3,
retry_config: RetryConfig::default(),
prompt_cache_key: None,
context_window: None,
}
}
pub async fn from_spec(
spec: &AgentSpec,
base_prompts: Vec<Prompt>,
oauth_store: Option<Arc<dyn OAuthCredentialStorage>>,
) -> Result<Self> {
let parser = ModelProviderParser::default().with_provider_connections(spec.provider_connections.clone());
let parser = match oauth_store {
Some(store) => parser.with_codex_provider(store),
None => parser,
};
let (provider, _) = parser.parse(&spec.model).await?;
let mut builder = Self::new(Arc::from(provider)).context_window(spec.context_window);
for prompt in base_prompts {
builder = builder.system_prompt(prompt);
}
for prompt in &spec.prompts {
builder = builder.system_prompt(prompt.clone());
}
Ok(builder)
}
pub fn system_prompt(mut self, prompt: Prompt) -> Self {
self.prompts.push(prompt);
self
}
pub fn tools(mut self, tx: Sender<McpCommand>, tools: Vec<ToolDefinition>) -> Self {
self.tool_definitions = tools;
self.mcp_tx = Some(tx);
self
}
pub fn tool_timeout(mut self, timeout: Duration) -> Self {
self.tool_timeout = timeout;
self
}
pub fn compaction(mut self, config: CompactionConfig) -> Self {
self.compaction_config = Some(config);
self
}
pub fn disable_compaction(mut self) -> Self {
self.compaction_config = None;
self
}
pub fn max_auto_continues(mut self, max: u32) -> Self {
self.max_auto_continues = max;
self
}
pub fn retry(mut self, config: RetryConfig) -> Self {
self.retry_config = config;
self
}
pub fn prompt_cache_key(mut self, key: String) -> Self {
self.prompt_cache_key = Some(key);
self
}
pub fn context_window(mut self, context_window: Option<u32>) -> Self {
self.context_window = context_window;
self
}
pub fn messages(mut self, messages: Vec<ChatMessage>) -> Self {
self.initial_messages = messages;
self
}
pub async fn spawn(self) -> Result<(Sender<Command>, Receiver<AgentMessage>, AgentHandle)> {
let mut prompt_cache = PromptCache::new(self.prompts);
let system_content = prompt_cache.render().await?;
let mut messages = Vec::new();
if !system_content.is_empty() {
messages.push(ChatMessage::System { content: system_content, timestamp: IsoString::now() });
}
messages.extend(self.initial_messages);
let (command_tx, command_rx) = mpsc::channel::<Command>(self.channel_capacity);
let (message_tx, agent_message_rx) = mpsc::channel::<AgentMessage>(self.channel_capacity);
let mut context = Context::new(messages, self.tool_definitions);
context.set_prompt_cache_key(self.prompt_cache_key);
let config = AgentConfig {
llm: self.llm,
context,
mcp_command_tx: self.mcp_tx,
tool_timeout: self.tool_timeout,
compaction_config: self.compaction_config,
auto_continue: AutoContinue::new(self.max_auto_continues),
retry_config: self.retry_config,
context_window: self.context_window,
prompt_cache,
};
let agent = Agent::new(config, command_rx, message_tx);
let agent_handle = tokio::spawn(agent.run());
Ok((command_tx, agent_message_rx, AgentHandle { handle: agent_handle }))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent_spec::{AgentSpecExposure, ToolFilter};
use crate::events::{AgentCommand, UserCommand};
use llm::testing::FakeLlmProvider;
use llm::{LlmResponse, ProviderConnectionOverrides};
#[tokio::test]
async fn test_agent_handle_is_finished() {
let handle = AgentHandle { handle: tokio::spawn(async {}) };
handle.await_completion().await;
}
#[tokio::test]
async fn test_agent_handle_abort() {
let handle = AgentHandle {
handle: tokio::spawn(async {
tokio::time::sleep(Duration::from_mins(1)).await;
}),
};
assert!(!handle.is_finished());
handle.abort();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(handle.is_finished());
}
#[tokio::test]
async fn context_window_override_supplies_unknown_provider_limit() {
let llm = Arc::new(FakeLlmProvider::with_single_response(vec![
LlmResponse::start("msg"),
LlmResponse::usage(100_000, 10),
LlmResponse::done(),
]));
let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
tx.send(Command::UserCommand(UserCommand::Text { content: vec![llm::ContentBlock::text("hello")] }))
.await
.unwrap();
let update = next_context_usage(&mut rx).await;
assert_eq!(update.context_limit, Some(200_000));
assert_eq!(update.usage_ratio, Some(0.5));
assert_eq!(update.input_tokens, 100_000);
handle.abort();
}
#[tokio::test]
async fn context_window_override_beats_provider_limit() {
let llm = Arc::new(
FakeLlmProvider::with_single_response(vec![
LlmResponse::start("msg"),
LlmResponse::usage(100_000, 10),
LlmResponse::done(),
])
.with_context_window(Some(128_000)),
);
let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
tx.send(Command::UserCommand(UserCommand::Text { content: vec![llm::ContentBlock::text("hello")] }))
.await
.unwrap();
let update = next_context_usage(&mut rx).await;
assert_eq!(update.context_limit, Some(200_000));
assert_eq!(update.usage_ratio, Some(0.5));
handle.abort();
}
#[tokio::test]
async fn context_window_override_survives_model_switch() {
let llm = Arc::new(FakeLlmProvider::new(vec![]).with_context_window(Some(128_000)));
let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
tx.send(Command::AgentCommand(AgentCommand::SwitchModel(Box::new(
FakeLlmProvider::new(vec![]).with_display_name("new fake").with_context_window(Some(32_000)),
))))
.await
.unwrap();
let update = next_context_usage(&mut rx).await;
assert_eq!(update.context_limit, Some(200_000));
assert_eq!(update.usage_ratio, Some(0.0));
handle.abort();
}
async fn next_context_usage(rx: &mut Receiver<AgentMessage>) -> ContextUsageUpdate {
loop {
if let AgentMessage::ContextUsageUpdate { usage_ratio, context_limit, input_tokens, .. } =
rx.recv().await.expect("agent should emit context usage")
{
return ContextUsageUpdate { usage_ratio, context_limit, input_tokens };
}
}
}
struct ContextUsageUpdate {
usage_ratio: Option<f64>,
context_limit: Option<u32>,
input_tokens: u32,
}
#[tokio::test]
async fn system_prompt_preserves_add_order() {
let builder = AgentBuilder::new(Arc::new(llm::testing::FakeLlmProvider::new(vec![])))
.system_prompt(Prompt::text("first"))
.system_prompt(Prompt::text("second"))
.system_prompt(Prompt::text("third"));
let rendered = Prompt::build_all(&builder.prompts).await.unwrap();
assert_eq!(rendered, "first\n\nsecond\n\nthird");
}
#[tokio::test]
async fn from_spec_applies_context_window() {
let spec = AgentSpec {
name: "alloy".to_string(),
description: "alloy".to_string(),
model: "ollama:llama3.2,llamacpp:local".to_string(),
reasoning_effort: None,
context_window: Some(200_000),
prompts: vec![],
provider_connections: ProviderConnectionOverrides::default(),
mcp_config_sources: Vec::new(),
exposure: AgentSpecExposure::both(),
tools: ToolFilter::default(),
};
let builder = AgentBuilder::from_spec(&spec, vec![], None).await.unwrap();
assert_eq!(builder.context_window, Some(200_000));
}
#[tokio::test]
async fn from_spec_accepts_alloy_model_specs() {
let spec = AgentSpec {
name: "alloy".to_string(),
description: "alloy".to_string(),
model: "ollama:llama3.2,llamacpp:local".to_string(),
reasoning_effort: None,
context_window: None,
prompts: vec![],
provider_connections: ProviderConnectionOverrides::default(),
mcp_config_sources: Vec::new(),
exposure: AgentSpecExposure::both(),
tools: ToolFilter::default(),
};
let builder = AgentBuilder::from_spec(&spec, vec![], None).await;
assert!(builder.is_ok());
}
}