1use super::agent::{AgentConfig, AutoContinue, RetryConfig};
2use crate::agent_spec::AgentSpec;
3use crate::context::CompactionConfig;
4use crate::core::{Agent, Prompt, PromptCache, Result};
5use crate::events::{AgentMessage, Command};
6use crate::mcp::run_mcp_task::McpCommand;
7use aether_auth::OAuthCredentialStorage;
8use llm::parser::ModelProviderParser;
9use llm::types::IsoString;
10use llm::{ChatMessage, Context, StreamingModelProvider, ToolDefinition};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::mpsc::{self, Receiver, Sender};
14use tokio::task::JoinHandle;
15
16pub struct AgentHandle {
18 handle: JoinHandle<()>,
19}
20
21impl AgentHandle {
22 pub fn abort(&self) {
24 self.handle.abort();
25 }
26
27 pub fn is_finished(&self) -> bool {
29 self.handle.is_finished()
30 }
31
32 pub async fn await_completion(self) {
34 let _ = self.handle.await;
35 }
36}
37
38pub struct AgentBuilder {
39 llm: Arc<dyn StreamingModelProvider>,
40 prompts: Vec<Prompt>,
41 tool_definitions: Vec<ToolDefinition>,
42 initial_messages: Vec<ChatMessage>,
43 mcp_tx: Option<Sender<McpCommand>>,
44 channel_capacity: usize,
45 tool_timeout: Duration,
46 compaction_config: Option<CompactionConfig>,
47 max_auto_continues: u32,
48 retry_config: RetryConfig,
49 prompt_cache_key: Option<String>,
50 context_window: Option<u32>,
51}
52
53impl AgentBuilder {
54 pub fn new(llm: Arc<dyn StreamingModelProvider>) -> Self {
55 Self {
56 llm,
57 prompts: Vec::new(),
58 tool_definitions: Vec::new(),
59 initial_messages: Vec::new(),
60 mcp_tx: None,
61 channel_capacity: 1000,
62 tool_timeout: Duration::from_mins(20),
63 compaction_config: Some(CompactionConfig::default()),
64 max_auto_continues: 3,
65 retry_config: RetryConfig::default(),
66 prompt_cache_key: None,
67 context_window: None,
68 }
69 }
70
71 pub async fn from_spec(
76 spec: &AgentSpec,
77 base_prompts: Vec<Prompt>,
78 oauth_store: Option<Arc<dyn OAuthCredentialStorage>>,
79 ) -> Result<Self> {
80 let parser = ModelProviderParser::default().with_provider_connections(spec.provider_connections.clone());
81 let parser = match oauth_store {
82 Some(store) => parser.with_codex_provider(store),
83 None => parser,
84 };
85 let (provider, _) = parser.parse(&spec.model).await?;
86 let mut builder = Self::new(Arc::from(provider)).context_window(spec.context_window);
87 for prompt in base_prompts {
88 builder = builder.system_prompt(prompt);
89 }
90 for prompt in &spec.prompts {
91 builder = builder.system_prompt(prompt.clone());
92 }
93 Ok(builder)
94 }
95
96 pub fn system_prompt(mut self, prompt: Prompt) -> Self {
100 self.prompts.push(prompt);
101 self
102 }
103
104 pub fn tools(mut self, tx: Sender<McpCommand>, tools: Vec<ToolDefinition>) -> Self {
105 self.tool_definitions = tools;
106 self.mcp_tx = Some(tx);
107 self
108 }
109
110 pub fn tool_timeout(mut self, timeout: Duration) -> Self {
117 self.tool_timeout = timeout;
118 self
119 }
120
121 pub fn compaction(mut self, config: CompactionConfig) -> Self {
142 self.compaction_config = Some(config);
143 self
144 }
145
146 pub fn disable_compaction(mut self) -> Self {
150 self.compaction_config = None;
151 self
152 }
153
154 pub fn max_auto_continues(mut self, max: u32) -> Self {
174 self.max_auto_continues = max;
175 self
176 }
177
178 pub fn retry(mut self, config: RetryConfig) -> Self {
180 self.retry_config = config;
181 self
182 }
183
184 pub fn prompt_cache_key(mut self, key: String) -> Self {
189 self.prompt_cache_key = Some(key);
190 self
191 }
192
193 pub fn context_window(mut self, context_window: Option<u32>) -> Self {
195 self.context_window = context_window;
196 self
197 }
198
199 pub fn messages(mut self, messages: Vec<ChatMessage>) -> Self {
203 self.initial_messages = messages;
204 self
205 }
206
207 pub async fn spawn(self) -> Result<(Sender<Command>, Receiver<AgentMessage>, AgentHandle)> {
208 let mut prompt_cache = PromptCache::new(self.prompts);
209 let system_content = prompt_cache.render().await?;
210
211 let mut messages = Vec::new();
212 if !system_content.is_empty() {
213 messages.push(ChatMessage::System { content: system_content, timestamp: IsoString::now() });
214 }
215
216 messages.extend(self.initial_messages);
217
218 let (command_tx, command_rx) = mpsc::channel::<Command>(self.channel_capacity);
219
220 let (message_tx, agent_message_rx) = mpsc::channel::<AgentMessage>(self.channel_capacity);
221
222 let mut context = Context::new(messages, self.tool_definitions);
223 context.set_prompt_cache_key(self.prompt_cache_key);
224
225 let config = AgentConfig {
226 llm: self.llm,
227 context,
228 mcp_command_tx: self.mcp_tx,
229 tool_timeout: self.tool_timeout,
230 compaction_config: self.compaction_config,
231 auto_continue: AutoContinue::new(self.max_auto_continues),
232 retry_config: self.retry_config,
233 context_window: self.context_window,
234 prompt_cache,
235 };
236
237 let agent = Agent::new(config, command_rx, message_tx);
238
239 let agent_handle = tokio::spawn(agent.run());
240
241 Ok((command_tx, agent_message_rx, AgentHandle { handle: agent_handle }))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::agent_spec::{AgentSpecExposure, ToolFilter};
249 use crate::events::{AgentCommand, UserCommand};
250 use llm::testing::FakeLlmProvider;
251 use llm::{LlmResponse, ProviderConnectionOverrides};
252
253 #[tokio::test]
254 async fn test_agent_handle_is_finished() {
255 let handle = AgentHandle { handle: tokio::spawn(async {}) };
256 handle.await_completion().await;
257 }
258
259 #[tokio::test]
260 async fn test_agent_handle_abort() {
261 let handle = AgentHandle {
262 handle: tokio::spawn(async {
263 tokio::time::sleep(Duration::from_mins(1)).await;
264 }),
265 };
266 assert!(!handle.is_finished());
267 handle.abort();
268 tokio::time::sleep(Duration::from_millis(10)).await;
270 assert!(handle.is_finished());
271 }
272
273 #[tokio::test]
274 async fn context_window_override_supplies_unknown_provider_limit() {
275 let llm = Arc::new(FakeLlmProvider::with_single_response(vec![
276 LlmResponse::start("msg"),
277 LlmResponse::usage(100_000, 10),
278 LlmResponse::done(),
279 ]));
280
281 let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
282 tx.send(Command::UserCommand(UserCommand::Text { content: vec![llm::ContentBlock::text("hello")] }))
283 .await
284 .unwrap();
285
286 let update = next_context_usage(&mut rx).await;
287 assert_eq!(update.context_limit, Some(200_000));
288 assert_eq!(update.usage_ratio, Some(0.5));
289 assert_eq!(update.input_tokens, 100_000);
290 handle.abort();
291 }
292
293 #[tokio::test]
294 async fn context_window_override_beats_provider_limit() {
295 let llm = Arc::new(
296 FakeLlmProvider::with_single_response(vec![
297 LlmResponse::start("msg"),
298 LlmResponse::usage(100_000, 10),
299 LlmResponse::done(),
300 ])
301 .with_context_window(Some(128_000)),
302 );
303
304 let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
305 tx.send(Command::UserCommand(UserCommand::Text { content: vec![llm::ContentBlock::text("hello")] }))
306 .await
307 .unwrap();
308
309 let update = next_context_usage(&mut rx).await;
310 assert_eq!(update.context_limit, Some(200_000));
311 assert_eq!(update.usage_ratio, Some(0.5));
312 handle.abort();
313 }
314
315 #[tokio::test]
316 async fn context_window_override_survives_model_switch() {
317 let llm = Arc::new(FakeLlmProvider::new(vec![]).with_context_window(Some(128_000)));
318 let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
319
320 tx.send(Command::AgentCommand(AgentCommand::SwitchModel(Box::new(
321 FakeLlmProvider::new(vec![]).with_display_name("new fake").with_context_window(Some(32_000)),
322 ))))
323 .await
324 .unwrap();
325
326 let update = next_context_usage(&mut rx).await;
327 assert_eq!(update.context_limit, Some(200_000));
328 assert_eq!(update.usage_ratio, Some(0.0));
329 handle.abort();
330 }
331
332 async fn next_context_usage(rx: &mut Receiver<AgentMessage>) -> ContextUsageUpdate {
333 loop {
334 if let AgentMessage::ContextUsageUpdate { usage_ratio, context_limit, input_tokens, .. } =
335 rx.recv().await.expect("agent should emit context usage")
336 {
337 return ContextUsageUpdate { usage_ratio, context_limit, input_tokens };
338 }
339 }
340 }
341
342 struct ContextUsageUpdate {
343 usage_ratio: Option<f64>,
344 context_limit: Option<u32>,
345 input_tokens: u32,
346 }
347
348 #[tokio::test]
349 async fn system_prompt_preserves_add_order() {
350 let builder = AgentBuilder::new(Arc::new(llm::testing::FakeLlmProvider::new(vec![])))
351 .system_prompt(Prompt::text("first"))
352 .system_prompt(Prompt::text("second"))
353 .system_prompt(Prompt::text("third"));
354
355 let rendered = Prompt::build_all(&builder.prompts).await.unwrap();
356
357 assert_eq!(rendered, "first\n\nsecond\n\nthird");
358 }
359
360 #[tokio::test]
361 async fn from_spec_applies_context_window() {
362 let spec = AgentSpec {
363 name: "alloy".to_string(),
364 description: "alloy".to_string(),
365 model: "ollama:llama3.2,llamacpp:local".to_string(),
366 reasoning_effort: None,
367 context_window: Some(200_000),
368 prompts: vec![],
369 provider_connections: ProviderConnectionOverrides::default(),
370 mcp_config_sources: Vec::new(),
371 exposure: AgentSpecExposure::both(),
372 tools: ToolFilter::default(),
373 };
374
375 let builder = AgentBuilder::from_spec(&spec, vec![], None).await.unwrap();
376
377 assert_eq!(builder.context_window, Some(200_000));
378 }
379
380 #[tokio::test]
381 async fn from_spec_accepts_alloy_model_specs() {
382 let spec = AgentSpec {
383 name: "alloy".to_string(),
384 description: "alloy".to_string(),
385 model: "ollama:llama3.2,llamacpp:local".to_string(),
386 reasoning_effort: None,
387 context_window: None,
388 prompts: vec![],
389 provider_connections: ProviderConnectionOverrides::default(),
390 mcp_config_sources: Vec::new(),
391 exposure: AgentSpecExposure::both(),
392 tools: ToolFilter::default(),
393 };
394
395 let builder = AgentBuilder::from_spec(&spec, vec![], None).await;
396 assert!(builder.is_ok());
397 }
398}