Skip to main content

agent_io/agent/
service.rs

1//! Agent service - core execution loop
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use futures::{Stream, StreamExt};
7use tokio::sync::RwLock;
8use tracing::debug;
9
10use crate::agent::{
11    AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
12};
13use crate::llm::{
14    AssistantMessage, BaseChatModel, ChatCompletion, Message, ToolDefinition, ToolMessage,
15};
16use crate::memory::{MemoryManager, MemoryType};
17use crate::tools::Tool;
18use crate::{Error, Result};
19
20use super::builder::AgentBuilder;
21use super::config::{AgentConfig, EphemeralConfig, build_ephemeral_config};
22
23/// Agent - the main orchestrator for LLM interactions
24pub struct Agent {
25    /// LLM provider
26    llm: Arc<dyn BaseChatModel>,
27    /// Available tools
28    tools: Vec<Arc<dyn Tool>>,
29    /// Configuration
30    config: AgentConfig,
31    /// Message history
32    history: Arc<RwLock<Vec<Message>>>,
33    /// Usage tracking
34    usage: Arc<RwLock<UsageSummary>>,
35    /// Ephemeral config per tool name
36    ephemeral_config: HashMap<String, EphemeralConfig>,
37    /// Memory manager (optional)
38    memory: Option<Arc<RwLock<MemoryManager>>>,
39}
40
41impl Agent {
42    /// Create a new agent
43    pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
44        let ephemeral_config = build_ephemeral_config(&tools);
45
46        Self {
47            llm,
48            tools,
49            config: AgentConfig::default(),
50            history: Arc::new(RwLock::new(Vec::new())),
51            usage: Arc::new(RwLock::new(UsageSummary::new())),
52            ephemeral_config,
53            memory: None,
54        }
55    }
56
57    /// Create an agent builder
58    pub fn builder() -> AgentBuilder {
59        AgentBuilder::default()
60    }
61
62    /// Set configuration
63    pub fn with_config(mut self, config: AgentConfig) -> Self {
64        self.config = config;
65        self
66    }
67
68    /// Create agent with all components (used by builder)
69    pub(super) fn new_with_config(
70        llm: Arc<dyn BaseChatModel>,
71        tools: Vec<Arc<dyn Tool>>,
72        config: AgentConfig,
73        ephemeral_config: HashMap<String, EphemeralConfig>,
74        memory: Option<Arc<RwLock<MemoryManager>>>,
75    ) -> Self {
76        Self {
77            llm,
78            tools,
79            config,
80            history: Arc::new(RwLock::new(Vec::new())),
81            usage: Arc::new(RwLock::new(UsageSummary::new())),
82            ephemeral_config,
83            memory,
84        }
85    }
86
87    /// Query the agent synchronously (returns final response)
88    pub async fn query(&self, message: impl Into<String>) -> Result<String> {
89        // Add user message to history
90        {
91            let mut history = self.history.write().await;
92            history.push(Message::user(message.into()));
93        }
94
95        // Execute and collect result
96        let stream = self.execute_loop();
97        futures::pin_mut!(stream);
98
99        while let Some(event) = stream.next().await {
100            if let AgentEvent::FinalResponse(response) = event {
101                return Ok(response.content);
102            }
103        }
104
105        Err(Error::Agent("No final response received".into()))
106    }
107
108    /// Query with memory context
109    pub async fn query_with_memory(&self, message: impl Into<String>) -> Result<String> {
110        let message = message.into();
111        let context = self.recall_memory_context(&message).await?;
112
113        {
114            let mut history = self.history.write().await;
115            if let Some(context_message) = self.memory_context_message(context) {
116                history.push(context_message);
117            }
118            history.push(Message::user(message.clone()));
119        }
120
121        let stream = self.execute_loop();
122        futures::pin_mut!(stream);
123
124        while let Some(event) = stream.next().await {
125            if let AgentEvent::FinalResponse(response) = event {
126                self.remember_short_term(&message).await?;
127                return Ok(response.content);
128            }
129        }
130
131        Err(Error::Agent("No final response received".into()))
132    }
133
134    /// Query the agent with streaming events
135    pub async fn query_stream<'a, M: Into<String> + 'a>(
136        &'a self,
137        message: M,
138    ) -> Result<impl Stream<Item = AgentEvent> + 'a> {
139        // Add user message to history
140        {
141            let mut history = self.history.write().await;
142            history.push(Message::user(message.into()));
143        }
144
145        Ok(self.execute_loop())
146    }
147
148    async fn recall_memory_context(&self, message: &str) -> Result<String> {
149        if let Some(memory) = &self.memory {
150            let mem = memory.read().await;
151            mem.recall_context(message).await
152        } else {
153            Ok(String::new())
154        }
155    }
156
157    fn memory_context_message(&self, context: String) -> Option<Message> {
158        if context.is_empty() {
159            None
160        } else {
161            Some(Message::developer(format!(
162                "Relevant memory context:\n{}",
163                context
164            )))
165        }
166    }
167
168    async fn remember_short_term(&self, message: &str) -> Result<()> {
169        if let Some(memory) = &self.memory {
170            let mut mem = memory.write().await;
171            mem.remember(message, MemoryType::ShortTerm).await?;
172        }
173        Ok(())
174    }
175
176    async fn build_request_messages(&self) -> Vec<Message> {
177        let history = self.history.read().await;
178        let mut messages =
179            Vec::with_capacity(history.len() + usize::from(self.config.system_prompt.is_some()));
180        if let Some(ref prompt) = self.config.system_prompt {
181            messages.push(Message::system(prompt));
182        }
183        messages.extend(history.iter().cloned());
184        messages
185    }
186
187    /// Main execution loop
188    fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
189        async_stream::stream! {
190            let mut step = 0;
191
192            loop {
193                if step >= self.config.max_iterations {
194                    yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
195                    break;
196                }
197
198                yield AgentEvent::StepStart(StepStartEvent::new(step));
199
200                // Destroy ephemeral messages from previous iteration
201                {
202                    let mut h = self.history.write().await;
203                    Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
204                }
205
206                let full_messages = self.build_request_messages().await;
207
208                // Build tool definitions
209                let tool_defs: Vec<ToolDefinition> = self.tools.iter()
210                    .map(|t| t.definition())
211                    .collect();
212                let tool_defs = if tool_defs.is_empty() { None } else { Some(tool_defs) };
213                let tool_choice = self.config.tool_choice.clone();
214
215                // Call LLM with retry
216                let completion = match Self::call_llm_with_retry(
217                    self.llm.as_ref(),
218                    &full_messages,
219                    tool_defs.as_deref(),
220                    Some(&tool_choice),
221                ).await {
222                    Ok(c) => c,
223                    Err(e) => {
224                        yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
225                        break;
226                    }
227                };
228
229                // Track usage
230                if let Some(ref u) = completion.usage {
231                    let mut us = self.usage.write().await;
232                    us.add_usage(self.llm.model(), u);
233                }
234
235                // Yield thinking content
236                if let Some(ref thinking) = completion.thinking {
237                    yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
238                }
239
240                // Yield text content
241                if let Some(ref content) = completion.content {
242                    yield AgentEvent::Text(crate::agent::TextEvent::new(content));
243                }
244
245                // Handle tool calls
246                if completion.has_tool_calls() {
247                    // Add assistant message to history
248                    {
249                        let mut h = self.history.write().await;
250                        h.push(Message::Assistant(AssistantMessage {
251                            role: "assistant".to_string(),
252                            content: completion.content.clone(),
253                            thinking: completion.thinking.clone(),
254                            redacted_thinking: None,
255                            tool_calls: completion.tool_calls.clone(),
256                            refusal: None,
257                        }));
258                    }
259
260                    // Execute tools
261                    for tool_call in &completion.tool_calls {
262                        yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
263
264                        // Find tool
265                        let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
266
267                        let result = if let Some(t) = tool {
268                            let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
269                                .unwrap_or(serde_json::json!({}));
270                            t.execute(args).await
271                        } else {
272                            Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
273                        };
274
275                        match result {
276                            Ok(tool_result) => {
277                                yield AgentEvent::ToolResult(
278                                    crate::agent::ToolResultEvent::new(
279                                        &tool_call.id,
280                                        &tool_call.function.name,
281                                        &tool_result.content,
282                                        step,
283                                    ).with_ephemeral(tool_result.ephemeral)
284                                );
285
286                                // Add tool result to history with ephemeral metadata
287                                {
288                                    let mut h = self.history.write().await;
289                                    let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
290                                    msg.tool_name = Some(tool_call.function.name.clone());
291                                    msg.ephemeral = tool_result.ephemeral;
292                                    h.push(Message::Tool(msg));
293                                }
294                            }
295                            Err(e) => {
296                                yield AgentEvent::Error(ErrorEvent::new(format!(
297                                    "Tool execution failed: {}",
298                                    e
299                                )));
300                            }
301                        }
302                    }
303
304                    step += 1;
305                    yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
306                    continue;
307                }
308
309                // No tool calls - we're done
310                let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
311                    .with_steps(step);
312
313                yield AgentEvent::FinalResponse(final_response);
314                yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
315                break;
316            }
317        }
318    }
319
320    /// Call LLM with exponential backoff retry
321    async fn call_llm_with_retry(
322        llm: &dyn BaseChatModel,
323        messages: &[Message],
324        tools: Option<&[ToolDefinition]>,
325        tool_choice: Option<&crate::llm::ToolChoice>,
326    ) -> Result<ChatCompletion> {
327        let max_retries = 10;
328        let mut delay = std::time::Duration::from_millis(500);
329
330        for attempt in 0..=max_retries {
331            let request_messages = messages.to_vec();
332            let request_tools = tools.map(|value| value.to_vec());
333            let request_tool_choice = tool_choice.cloned();
334
335            match llm
336                .invoke(request_messages, request_tools, request_tool_choice)
337                .await
338            {
339                Ok(completion) => return Ok(completion),
340                Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
341                    tracing::warn!(
342                        "Rate limit or empty response, retrying in {:?} (attempt {}/{})",
343                        delay,
344                        attempt + 1,
345                        max_retries
346                    );
347                    tokio::time::sleep(delay).await;
348                    delay *= 2;
349                }
350                Err(e) => return Err(Error::Llm(e)),
351            }
352        }
353
354        Err(Error::Agent("Max retries exceeded".into()))
355    }
356
357    /// Get usage summary
358    pub async fn get_usage(&self) -> UsageSummary {
359        self.usage.read().await.clone()
360    }
361
362    /// Destroy old ephemeral message content, keeping the last N per tool.
363    fn destroy_ephemeral_messages(
364        history: &mut [Message],
365        ephemeral_config: &HashMap<String, EphemeralConfig>,
366    ) {
367        // First pass: collect indices of ephemeral messages by tool name
368        let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
369
370        for (idx, msg) in history.iter().enumerate() {
371            let tool_msg = match msg {
372                Message::Tool(t) => t,
373                _ => continue,
374            };
375
376            if !tool_msg.ephemeral || tool_msg.destroyed {
377                continue;
378            }
379
380            let tool_name = match &tool_msg.tool_name {
381                Some(name) => name.clone(),
382                None => continue,
383            };
384
385            ephemeral_indices_by_tool
386                .entry(tool_name)
387                .or_default()
388                .push(idx);
389        }
390
391        // Collect all indices to destroy
392        let mut indices_to_destroy: Vec<usize> = Vec::new();
393
394        for (tool_name, indices) in ephemeral_indices_by_tool {
395            let keep_count = ephemeral_config
396                .get(&tool_name)
397                .map(|c| c.keep_count)
398                .unwrap_or(1);
399
400            // Destroy messages beyond the keep limit (older ones first)
401            let to_destroy = if keep_count > 0 && indices.len() > keep_count {
402                &indices[..indices.len() - keep_count]
403            } else {
404                &indices[..]
405            };
406
407            indices_to_destroy.extend(to_destroy.iter().copied());
408        }
409
410        // Second pass: destroy the messages
411        for idx in indices_to_destroy {
412            if let Message::Tool(tool_msg) = &mut history[idx] {
413                debug!("Destroying ephemeral message at index {}", idx);
414                tool_msg.destroy();
415            }
416        }
417    }
418
419    /// Clear history
420    pub async fn clear_history(&self) {
421        let mut history = self.history.write().await;
422        history.clear();
423    }
424
425    /// Load history
426    pub async fn load_history(&self, messages: Vec<Message>) {
427        let mut history = self.history.write().await;
428        *history = messages;
429    }
430
431    /// Get current history
432    pub async fn get_history(&self) -> Vec<Message> {
433        self.history.read().await.clone()
434    }
435
436    /// Check if memory is enabled
437    pub fn has_memory(&self) -> bool {
438        self.memory.is_some()
439    }
440
441    /// Get memory manager reference
442    pub fn get_memory(&self) -> Option<&Arc<RwLock<MemoryManager>>> {
443        self.memory.as_ref()
444    }
445}