aether_core/core/
agent_builder.rs1use super::agent::{AgentConfig, AutoContinue, RetryConfig};
2use crate::agent_spec::AgentSpec;
3use crate::context::CompactionConfig;
4use crate::core::{Agent, Prompt, Result};
5use crate::events::{AgentMessage, UserMessage};
6use crate::mcp::run_mcp_task::McpCommand;
7use aether_auth::OAuthCredentialStorage;
8use llm::parser::ModelProviderParser;
9use llm::types::IsoString;
10use llm::{ChatMessage, Context, StreamingModelProvider, ToolDefinition};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::mpsc::{self, Receiver, Sender};
14use tokio::task::JoinHandle;
15
16pub struct AgentHandle {
18 handle: JoinHandle<()>,
19}
20
21impl AgentHandle {
22 pub fn abort(&self) {
24 self.handle.abort();
25 }
26
27 pub fn is_finished(&self) -> bool {
29 self.handle.is_finished()
30 }
31
32 pub async fn await_completion(self) {
34 let _ = self.handle.await;
35 }
36}
37
38pub struct AgentBuilder {
39 llm: Arc<dyn StreamingModelProvider>,
40 prompts: Vec<Prompt>,
41 tool_definitions: Vec<ToolDefinition>,
42 initial_messages: Vec<ChatMessage>,
43 mcp_tx: Option<Sender<McpCommand>>,
44 channel_capacity: usize,
45 tool_timeout: Duration,
46 compaction_config: Option<CompactionConfig>,
47 max_auto_continues: u32,
48 retry_config: RetryConfig,
49 prompt_cache_key: Option<String>,
50}
51
52impl AgentBuilder {
53 pub fn new(llm: Arc<dyn StreamingModelProvider>) -> Self {
54 Self {
55 llm,
56 prompts: Vec::new(),
57 tool_definitions: Vec::new(),
58 initial_messages: Vec::new(),
59 mcp_tx: None,
60 channel_capacity: 1000,
61 tool_timeout: Duration::from_mins(20),
62 compaction_config: Some(CompactionConfig::default()),
63 max_auto_continues: 3,
64 retry_config: RetryConfig::default(),
65 prompt_cache_key: None,
66 }
67 }
68
69 pub async fn from_spec(
74 spec: &AgentSpec,
75 base_prompts: Vec<Prompt>,
76 oauth_store: Option<Arc<dyn OAuthCredentialStorage>>,
77 ) -> Result<Self> {
78 let parser = ModelProviderParser::default();
79 let parser = match oauth_store {
80 Some(store) => parser.with_codex_provider(store),
81 None => parser,
82 };
83 let (provider, _) = parser.parse(&spec.model).await?;
84 let mut builder = Self::new(Arc::from(provider));
85 for prompt in base_prompts {
86 builder = builder.system_prompt(prompt);
87 }
88 for prompt in &spec.prompts {
89 builder = builder.system_prompt(prompt.clone());
90 }
91 Ok(builder)
92 }
93
94 pub fn system_prompt(mut self, prompt: Prompt) -> Self {
98 self.prompts.push(prompt);
99 self
100 }
101
102 pub fn tools(mut self, tx: Sender<McpCommand>, tools: Vec<ToolDefinition>) -> Self {
103 self.tool_definitions = tools;
104 self.mcp_tx = Some(tx);
105 self
106 }
107
108 pub fn tool_timeout(mut self, timeout: Duration) -> Self {
115 self.tool_timeout = timeout;
116 self
117 }
118
119 pub fn compaction(mut self, config: CompactionConfig) -> Self {
140 self.compaction_config = Some(config);
141 self
142 }
143
144 pub fn disable_compaction(mut self) -> Self {
148 self.compaction_config = None;
149 self
150 }
151
152 pub fn max_auto_continues(mut self, max: u32) -> Self {
172 self.max_auto_continues = max;
173 self
174 }
175
176 pub fn retry(mut self, config: RetryConfig) -> Self {
178 self.retry_config = config;
179 self
180 }
181
182 pub fn prompt_cache_key(mut self, key: String) -> Self {
187 self.prompt_cache_key = Some(key);
188 self
189 }
190
191 pub fn messages(mut self, messages: Vec<ChatMessage>) -> Self {
195 self.initial_messages = messages;
196 self
197 }
198
199 pub async fn spawn(self) -> Result<(Sender<UserMessage>, Receiver<AgentMessage>, AgentHandle)> {
200 let mut messages = Vec::new();
201
202 if !self.prompts.is_empty() {
203 let system_content = Prompt::build_all(&self.prompts).await?;
204 if !system_content.is_empty() {
205 messages.push(ChatMessage::System { content: system_content, timestamp: IsoString::now() });
206 }
207 }
208
209 messages.extend(self.initial_messages);
210
211 let (user_message_tx, user_message_rx) = mpsc::channel::<UserMessage>(self.channel_capacity);
212
213 let (message_tx, agent_message_rx) = mpsc::channel::<AgentMessage>(self.channel_capacity);
214
215 let mut context = Context::new(messages, self.tool_definitions);
216 context.set_prompt_cache_key(self.prompt_cache_key);
217
218 let config = AgentConfig {
219 llm: self.llm,
220 context,
221 mcp_command_tx: self.mcp_tx,
222 tool_timeout: self.tool_timeout,
223 compaction_config: self.compaction_config,
224 auto_continue: AutoContinue::new(self.max_auto_continues),
225 retry_config: self.retry_config,
226 };
227
228 let agent = Agent::new(config, user_message_rx, message_tx);
229
230 let agent_handle = tokio::spawn(agent.run());
231
232 Ok((user_message_tx, agent_message_rx, AgentHandle { handle: agent_handle }))
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::agent_spec::{AgentSpecExposure, ToolFilter};
240
241 #[tokio::test]
242 async fn test_agent_handle_is_finished() {
243 let handle = AgentHandle { handle: tokio::spawn(async {}) };
244 handle.await_completion().await;
245 }
246
247 #[tokio::test]
248 async fn test_agent_handle_abort() {
249 let handle = AgentHandle {
250 handle: tokio::spawn(async {
251 tokio::time::sleep(Duration::from_mins(1)).await;
252 }),
253 };
254 assert!(!handle.is_finished());
255 handle.abort();
256 tokio::time::sleep(Duration::from_millis(10)).await;
258 assert!(handle.is_finished());
259 }
260
261 #[tokio::test]
262 async fn system_prompt_preserves_add_order() {
263 let builder = AgentBuilder::new(Arc::new(llm::testing::FakeLlmProvider::new(vec![])))
264 .system_prompt(Prompt::text("first"))
265 .system_prompt(Prompt::text("second"))
266 .system_prompt(Prompt::text("third"));
267
268 let rendered = Prompt::build_all(&builder.prompts).await.unwrap();
269
270 assert_eq!(rendered, "first\n\nsecond\n\nthird");
271 }
272
273 #[tokio::test]
274 async fn from_spec_accepts_alloy_model_specs() {
275 let spec = AgentSpec {
276 name: "alloy".to_string(),
277 description: "alloy".to_string(),
278 model: "ollama:llama3.2,llamacpp:local".to_string(),
279 reasoning_effort: None,
280 prompts: vec![],
281 mcp_config_sources: Vec::new(),
282 exposure: AgentSpecExposure::both(),
283 tools: ToolFilter::default(),
284 };
285
286 let builder = AgentBuilder::from_spec(&spec, vec![], None).await;
287 assert!(builder.is_ok());
288 }
289}