Skip to main content

aether_core/core/
agent_builder.rs

1use super::agent::{AgentConfig, AutoContinue, RetryConfig};
2use crate::agent_spec::AgentSpec;
3use crate::context::CompactionConfig;
4use crate::core::{Agent, Prompt, PromptCache, Result};
5use crate::events::{AgentMessage, Command};
6use crate::mcp::run_mcp_task::McpCommand;
7use aether_auth::OAuthCredentialStorage;
8use llm::parser::ModelProviderParser;
9use llm::types::IsoString;
10use llm::{ChatMessage, Context, StreamingModelProvider, ToolDefinition};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::mpsc::{self, Receiver, Sender};
14use tokio::task::JoinHandle;
15
16/// Handle for communicating with a running Agent
17pub struct AgentHandle {
18    handle: JoinHandle<()>,
19}
20
21impl AgentHandle {
22    /// Abort the agent task immediately.
23    pub fn abort(&self) {
24        self.handle.abort();
25    }
26
27    /// Returns `true` if the agent task has finished.
28    pub fn is_finished(&self) -> bool {
29        self.handle.is_finished()
30    }
31
32    /// Wait for the agent task to complete.
33    pub async fn await_completion(self) {
34        let _ = self.handle.await;
35    }
36}
37
38pub struct AgentBuilder {
39    llm: Arc<dyn StreamingModelProvider>,
40    prompts: Vec<Prompt>,
41    tool_definitions: Vec<ToolDefinition>,
42    initial_messages: Vec<ChatMessage>,
43    mcp_tx: Option<Sender<McpCommand>>,
44    channel_capacity: usize,
45    tool_timeout: Duration,
46    compaction_config: Option<CompactionConfig>,
47    max_auto_continues: u32,
48    retry_config: RetryConfig,
49    prompt_cache_key: Option<String>,
50    context_window: Option<u32>,
51}
52
53impl AgentBuilder {
54    pub fn new(llm: Arc<dyn StreamingModelProvider>) -> Self {
55        Self {
56            llm,
57            prompts: Vec::new(),
58            tool_definitions: Vec::new(),
59            initial_messages: Vec::new(),
60            mcp_tx: None,
61            channel_capacity: 1000,
62            tool_timeout: Duration::from_mins(20),
63            compaction_config: Some(CompactionConfig::default()),
64            max_auto_continues: 3,
65            retry_config: RetryConfig::default(),
66            prompt_cache_key: None,
67            context_window: None,
68        }
69    }
70
71    /// Create a builder from a resolved `AgentSpec`.
72    ///
73    /// The LLM provider is derived from `spec.model` via `ModelProviderParser`.
74    /// `base_prompts` are prepended before the spec's own prompts.
75    pub async fn from_spec(
76        spec: &AgentSpec,
77        base_prompts: Vec<Prompt>,
78        oauth_store: Option<Arc<dyn OAuthCredentialStorage>>,
79    ) -> Result<Self> {
80        let parser = ModelProviderParser::default().with_provider_connections(spec.provider_connections.clone());
81        let parser = match oauth_store {
82            Some(store) => parser.with_codex_provider(store),
83            None => parser,
84        };
85        let (provider, _) = parser.parse(&spec.model).await?;
86        let mut builder = Self::new(Arc::from(provider)).context_window(spec.context_window);
87        for prompt in base_prompts {
88            builder = builder.system_prompt(prompt);
89        }
90        for prompt in &spec.prompts {
91            builder = builder.system_prompt(prompt.clone());
92        }
93        Ok(builder)
94    }
95
96    /// Add a prompt to the system prompt.
97    ///
98    /// Multiple prompts are concatenated with double newlines.
99    pub fn system_prompt(mut self, prompt: Prompt) -> Self {
100        self.prompts.push(prompt);
101        self
102    }
103
104    pub fn tools(mut self, tx: Sender<McpCommand>, tools: Vec<ToolDefinition>) -> Self {
105        self.tool_definitions = tools;
106        self.mcp_tx = Some(tx);
107        self
108    }
109
110    /// Set the timeout for tool execution
111    ///
112    /// If a tool does not return a result within this duration, it will be marked as failed
113    /// and the agent will continue processing.
114    ///
115    /// Default: 20 minutes
116    pub fn tool_timeout(mut self, timeout: Duration) -> Self {
117        self.tool_timeout = timeout;
118        self
119    }
120
121    /// Configure context compaction settings.
122    ///
123    /// By default, agents automatically compact context when token usage exceeds
124    /// 85% of the context window, preventing overflow during long-running tasks.
125    ///
126    /// # Examples
127    /// ```ignore
128    /// // Custom threshold
129    /// agent(llm).compaction(CompactionConfig::with_threshold(0.9))
130    ///
131    /// // Disable compaction entirely
132    /// agent(llm).compaction(CompactionConfig::disabled())
133    ///
134    /// // Full customization
135    /// agent(llm).compaction(
136    ///     CompactionConfig::with_threshold(0.85)
137    ///         .keep_recent_tool_results(3)
138    ///         .min_messages(20)
139    /// )
140    /// ```
141    pub fn compaction(mut self, config: CompactionConfig) -> Self {
142        self.compaction_config = Some(config);
143        self
144    }
145
146    /// Disable context compaction entirely.
147    ///
148    /// Overflow errors from the model will be surfaced directly to callers.
149    pub fn disable_compaction(mut self) -> Self {
150        self.compaction_config = None;
151        self
152    }
153
154    /// Configure the maximum number of auto-continue attempts.
155    ///
156    /// When the LLM stops without making tool calls, the agent may inject a
157    /// continuation prompt and restart the LLM stream for resumable stop
158    /// reasons (for example, token length limits).
159    ///
160    /// This setting limits how many times the agent will attempt to continue
161    /// before giving up and returning `AgentMessage::Done`.
162    ///
163    /// Default: 3
164    ///
165    /// # Example
166    /// ```ignore
167    /// // Allow up to 5 auto-continue attempts
168    /// agent(llm).max_auto_continues(5)
169    ///
170    /// // Disable auto-continue entirely
171    /// agent(llm).max_auto_continues(0)
172    /// ```
173    pub fn max_auto_continues(mut self, max: u32) -> Self {
174        self.max_auto_continues = max;
175        self
176    }
177
178    /// Configure retry behavior for transient LLM provider failures.
179    pub fn retry(mut self, config: RetryConfig) -> Self {
180        self.retry_config = config;
181        self
182    }
183
184    /// Set a prompt cache key for LLM provider request routing.
185    ///
186    /// This is typically a session ID (UUID) that remains stable across all
187    /// turns within a conversation, improving prompt cache hit rates.
188    pub fn prompt_cache_key(mut self, key: String) -> Self {
189        self.prompt_cache_key = Some(key);
190        self
191    }
192
193    /// Override the effective model context window in tokens.
194    pub fn context_window(mut self, context_window: Option<u32>) -> Self {
195        self.context_window = context_window;
196        self
197    }
198
199    /// Pre-populate the context with conversation history (e.g. from a restored session).
200    ///
201    /// These messages are inserted after the system prompt.
202    pub fn messages(mut self, messages: Vec<ChatMessage>) -> Self {
203        self.initial_messages = messages;
204        self
205    }
206
207    pub async fn spawn(self) -> Result<(Sender<Command>, Receiver<AgentMessage>, AgentHandle)> {
208        let mut prompt_cache = PromptCache::new(self.prompts);
209        let system_content = prompt_cache.render().await?;
210
211        let mut messages = Vec::new();
212        if !system_content.is_empty() {
213            messages.push(ChatMessage::System { content: system_content, timestamp: IsoString::now() });
214        }
215
216        messages.extend(self.initial_messages);
217
218        let (command_tx, command_rx) = mpsc::channel::<Command>(self.channel_capacity);
219
220        let (message_tx, agent_message_rx) = mpsc::channel::<AgentMessage>(self.channel_capacity);
221
222        let mut context = Context::new(messages, self.tool_definitions);
223        context.set_prompt_cache_key(self.prompt_cache_key);
224
225        let config = AgentConfig {
226            llm: self.llm,
227            context,
228            mcp_command_tx: self.mcp_tx,
229            tool_timeout: self.tool_timeout,
230            compaction_config: self.compaction_config,
231            auto_continue: AutoContinue::new(self.max_auto_continues),
232            retry_config: self.retry_config,
233            context_window: self.context_window,
234            prompt_cache,
235        };
236
237        let agent = Agent::new(config, command_rx, message_tx);
238
239        let agent_handle = tokio::spawn(agent.run());
240
241        Ok((command_tx, agent_message_rx, AgentHandle { handle: agent_handle }))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::agent_spec::{AgentSpecExposure, ToolFilter};
249    use crate::events::{AgentCommand, UserCommand};
250    use llm::testing::FakeLlmProvider;
251    use llm::{LlmResponse, ProviderConnectionOverrides};
252
253    #[tokio::test]
254    async fn test_agent_handle_is_finished() {
255        let handle = AgentHandle { handle: tokio::spawn(async {}) };
256        handle.await_completion().await;
257    }
258
259    #[tokio::test]
260    async fn test_agent_handle_abort() {
261        let handle = AgentHandle {
262            handle: tokio::spawn(async {
263                tokio::time::sleep(Duration::from_mins(1)).await;
264            }),
265        };
266        assert!(!handle.is_finished());
267        handle.abort();
268        // Give the runtime a moment to process the abort
269        tokio::time::sleep(Duration::from_millis(10)).await;
270        assert!(handle.is_finished());
271    }
272
273    #[tokio::test]
274    async fn context_window_override_supplies_unknown_provider_limit() {
275        let llm = Arc::new(FakeLlmProvider::with_single_response(vec![
276            LlmResponse::start("msg"),
277            LlmResponse::usage(100_000, 10),
278            LlmResponse::done(),
279        ]));
280
281        let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
282        tx.send(Command::UserCommand(UserCommand::Text { content: vec![llm::ContentBlock::text("hello")] }))
283            .await
284            .unwrap();
285
286        let update = next_context_usage(&mut rx).await;
287        assert_eq!(update.context_limit, Some(200_000));
288        assert_eq!(update.usage_ratio, Some(0.5));
289        assert_eq!(update.input_tokens, 100_000);
290        handle.abort();
291    }
292
293    #[tokio::test]
294    async fn context_window_override_beats_provider_limit() {
295        let llm = Arc::new(
296            FakeLlmProvider::with_single_response(vec![
297                LlmResponse::start("msg"),
298                LlmResponse::usage(100_000, 10),
299                LlmResponse::done(),
300            ])
301            .with_context_window(Some(128_000)),
302        );
303
304        let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
305        tx.send(Command::UserCommand(UserCommand::Text { content: vec![llm::ContentBlock::text("hello")] }))
306            .await
307            .unwrap();
308
309        let update = next_context_usage(&mut rx).await;
310        assert_eq!(update.context_limit, Some(200_000));
311        assert_eq!(update.usage_ratio, Some(0.5));
312        handle.abort();
313    }
314
315    #[tokio::test]
316    async fn context_window_override_survives_model_switch() {
317        let llm = Arc::new(FakeLlmProvider::new(vec![]).with_context_window(Some(128_000)));
318        let (tx, mut rx, handle) = AgentBuilder::new(llm).context_window(Some(200_000)).spawn().await.unwrap();
319
320        tx.send(Command::AgentCommand(AgentCommand::SwitchModel(Box::new(
321            FakeLlmProvider::new(vec![]).with_display_name("new fake").with_context_window(Some(32_000)),
322        ))))
323        .await
324        .unwrap();
325
326        let update = next_context_usage(&mut rx).await;
327        assert_eq!(update.context_limit, Some(200_000));
328        assert_eq!(update.usage_ratio, Some(0.0));
329        handle.abort();
330    }
331
332    async fn next_context_usage(rx: &mut Receiver<AgentMessage>) -> ContextUsageUpdate {
333        loop {
334            if let AgentMessage::ContextUsageUpdate { usage_ratio, context_limit, input_tokens, .. } =
335                rx.recv().await.expect("agent should emit context usage")
336            {
337                return ContextUsageUpdate { usage_ratio, context_limit, input_tokens };
338            }
339        }
340    }
341
342    struct ContextUsageUpdate {
343        usage_ratio: Option<f64>,
344        context_limit: Option<u32>,
345        input_tokens: u32,
346    }
347
348    #[tokio::test]
349    async fn system_prompt_preserves_add_order() {
350        let builder = AgentBuilder::new(Arc::new(llm::testing::FakeLlmProvider::new(vec![])))
351            .system_prompt(Prompt::text("first"))
352            .system_prompt(Prompt::text("second"))
353            .system_prompt(Prompt::text("third"));
354
355        let rendered = Prompt::build_all(&builder.prompts).await.unwrap();
356
357        assert_eq!(rendered, "first\n\nsecond\n\nthird");
358    }
359
360    #[tokio::test]
361    async fn from_spec_applies_context_window() {
362        let spec = AgentSpec {
363            name: "alloy".to_string(),
364            description: "alloy".to_string(),
365            model: "ollama:llama3.2,llamacpp:local".to_string(),
366            reasoning_effort: None,
367            context_window: Some(200_000),
368            prompts: vec![],
369            provider_connections: ProviderConnectionOverrides::default(),
370            mcp_config_sources: Vec::new(),
371            exposure: AgentSpecExposure::both(),
372            tools: ToolFilter::default(),
373        };
374
375        let builder = AgentBuilder::from_spec(&spec, vec![], None).await.unwrap();
376
377        assert_eq!(builder.context_window, Some(200_000));
378    }
379
380    #[tokio::test]
381    async fn from_spec_accepts_alloy_model_specs() {
382        let spec = AgentSpec {
383            name: "alloy".to_string(),
384            description: "alloy".to_string(),
385            model: "ollama:llama3.2,llamacpp:local".to_string(),
386            reasoning_effort: None,
387            context_window: None,
388            prompts: vec![],
389            provider_connections: ProviderConnectionOverrides::default(),
390            mcp_config_sources: Vec::new(),
391            exposure: AgentSpecExposure::both(),
392            tools: ToolFilter::default(),
393        };
394
395        let builder = AgentBuilder::from_spec(&spec, vec![], None).await;
396        assert!(builder.is_ok());
397    }
398}