Skip to main content

agent_io/agent/
service.rs

1//! Agent service - core execution loop
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use futures::{Stream, StreamExt};
7use tokio::sync::RwLock;
8use tracing::debug;
9
10use crate::agent::{
11    AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
12};
13use crate::llm::{
14    AssistantMessage, BaseChatModel, ChatCompletion, Message, ToolDefinition, ToolMessage,
15};
16use crate::memory::{MemoryManager, MemoryType};
17use crate::tools::Tool;
18use crate::{Error, Result};
19
20use super::builder::AgentBuilder;
21use super::config::{AgentConfig, EphemeralConfig};
22
23/// Agent - the main orchestrator for LLM interactions
24pub struct Agent {
25    /// LLM provider
26    llm: Arc<dyn BaseChatModel>,
27    /// Available tools
28    tools: Vec<Arc<dyn Tool>>,
29    /// Configuration
30    config: AgentConfig,
31    /// Message history
32    history: Arc<RwLock<Vec<Message>>>,
33    /// Usage tracking
34    usage: Arc<RwLock<UsageSummary>>,
35    /// Ephemeral config per tool name
36    ephemeral_config: HashMap<String, EphemeralConfig>,
37    /// Memory manager (optional)
38    memory: Option<Arc<RwLock<MemoryManager>>>,
39}
40
41impl Agent {
42    /// Create a new agent
43    pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
44        // Build ephemeral config from tools
45        let ephemeral_config = tools
46            .iter()
47            .filter_map(|t| {
48                let cfg = t.ephemeral();
49                if cfg != crate::tools::EphemeralConfig::None {
50                    let keep_count = match cfg {
51                        crate::tools::EphemeralConfig::Single => 1,
52                        crate::tools::EphemeralConfig::Count(n) => n,
53                        crate::tools::EphemeralConfig::None => 0,
54                    };
55                    Some((t.name().to_string(), EphemeralConfig { keep_count }))
56                } else {
57                    None
58                }
59            })
60            .collect();
61
62        Self {
63            llm,
64            tools,
65            config: AgentConfig::default(),
66            history: Arc::new(RwLock::new(Vec::new())),
67            usage: Arc::new(RwLock::new(UsageSummary::new())),
68            ephemeral_config,
69            memory: None,
70        }
71    }
72
73    /// Create an agent builder
74    pub fn builder() -> AgentBuilder {
75        AgentBuilder::default()
76    }
77
78    /// Set configuration
79    pub fn with_config(mut self, config: AgentConfig) -> Self {
80        self.config = config;
81        self
82    }
83
84    /// Create agent with all components (used by builder)
85    pub(super) fn new_with_config(
86        llm: Arc<dyn BaseChatModel>,
87        tools: Vec<Arc<dyn Tool>>,
88        config: AgentConfig,
89        ephemeral_config: HashMap<String, EphemeralConfig>,
90        memory: Option<Arc<RwLock<MemoryManager>>>,
91    ) -> Self {
92        Self {
93            llm,
94            tools,
95            config,
96            history: Arc::new(RwLock::new(Vec::new())),
97            usage: Arc::new(RwLock::new(UsageSummary::new())),
98            ephemeral_config,
99            memory,
100        }
101    }
102
103    /// Query the agent synchronously (returns final response)
104    pub async fn query(&self, message: impl Into<String>) -> Result<String> {
105        // Add user message to history
106        {
107            let mut history = self.history.write().await;
108            history.push(Message::user(message.into()));
109        }
110
111        // Execute and collect result
112        let stream = self.execute_loop();
113        futures::pin_mut!(stream);
114
115        while let Some(event) = stream.next().await {
116            if let AgentEvent::FinalResponse(response) = event {
117                return Ok(response.content);
118            }
119        }
120
121        Err(Error::Agent("No final response received".into()))
122    }
123
124    /// Query with memory context
125    pub async fn query_with_memory(&self, message: impl Into<String>) -> Result<String> {
126        let message = message.into();
127
128        // Recall relevant memories
129        let context = if let Some(memory) = &self.memory {
130            let mem = memory.read().await;
131            mem.recall_context(&message).await?
132        } else {
133            String::new()
134        };
135
136        // Build enhanced prompt with memory context
137        let enhanced_message = if context.is_empty() {
138            message.clone()
139        } else {
140            format!(
141                "Relevant context from memory:\n{}\n\nUser query: {}",
142                context, message
143            )
144        };
145
146        // Execute query
147        let result = self.query(enhanced_message).await?;
148
149        // Store this interaction in memory
150        if let Some(memory) = &self.memory {
151            let mut mem = memory.write().await;
152            mem.remember(&message, MemoryType::ShortTerm).await?;
153        }
154
155        Ok(result)
156    }
157
158    /// Query the agent with streaming events
159    pub async fn query_stream<'a, M: Into<String> + 'a>(
160        &'a self,
161        message: M,
162    ) -> Result<impl Stream<Item = AgentEvent> + 'a> {
163        // Add user message to history
164        {
165            let mut history = self.history.write().await;
166            history.push(Message::user(message.into()));
167        }
168
169        Ok(self.execute_loop())
170    }
171
172    /// Main execution loop
173    fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
174        async_stream::stream! {
175            let mut step = 0;
176
177            loop {
178                if step >= self.config.max_iterations {
179                    yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
180                    break;
181                }
182
183                yield AgentEvent::StepStart(StepStartEvent::new(step));
184
185                // Destroy ephemeral messages from previous iteration
186                {
187                    let mut h = self.history.write().await;
188                    Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
189                }
190
191                // Get current history
192                let messages = {
193                    let h = self.history.read().await;
194                    h.clone()
195                };
196
197                // Build system prompt + messages
198                let mut full_messages = Vec::new();
199                if let Some(ref prompt) = self.config.system_prompt
200                    && step == 0 {
201                        full_messages.push(Message::system(prompt));
202                    }
203                full_messages.extend(messages);
204
205                // Build tool definitions
206                let tool_defs: Vec<ToolDefinition> = self.tools.iter()
207                    .map(|t| t.definition())
208                    .collect();
209
210                // Call LLM with retry
211                let completion = match Self::call_llm_with_retry(
212                    self.llm.as_ref(),
213                    full_messages.clone(),
214                    if tool_defs.is_empty() { None } else { Some(tool_defs) },
215                    Some(self.config.tool_choice.clone()),
216                ).await {
217                    Ok(c) => c,
218                    Err(e) => {
219                        yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
220                        break;
221                    }
222                };
223
224                // Track usage
225                if let Some(ref u) = completion.usage {
226                    let mut us = self.usage.write().await;
227                    us.add_usage(self.llm.model(), u);
228                }
229
230                // Yield thinking content
231                if let Some(ref thinking) = completion.thinking {
232                    yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
233                }
234
235                // Yield text content
236                if let Some(ref content) = completion.content {
237                    yield AgentEvent::Text(crate::agent::TextEvent::new(content));
238                }
239
240                // Handle tool calls
241                if completion.has_tool_calls() {
242                    // Add assistant message to history
243                    {
244                        let mut h = self.history.write().await;
245                        h.push(Message::Assistant(AssistantMessage {
246                            role: "assistant".to_string(),
247                            content: completion.content.clone(),
248                            thinking: completion.thinking.clone(),
249                            redacted_thinking: None,
250                            tool_calls: completion.tool_calls.clone(),
251                            refusal: None,
252                        }));
253                    }
254
255                    // Execute tools
256                    for tool_call in &completion.tool_calls {
257                        yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
258
259                        // Find tool
260                        let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
261
262                        let result = if let Some(t) = tool {
263                            let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
264                                .unwrap_or(serde_json::json!({}));
265                            t.execute(args).await
266                        } else {
267                            Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
268                        };
269
270                        match result {
271                            Ok(tool_result) => {
272                                yield AgentEvent::ToolResult(
273                                    crate::agent::ToolResultEvent::new(
274                                        &tool_call.id,
275                                        &tool_call.function.name,
276                                        &tool_result.content,
277                                        step,
278                                    ).with_ephemeral(tool_result.ephemeral)
279                                );
280
281                                // Add tool result to history with ephemeral metadata
282                                {
283                                    let mut h = self.history.write().await;
284                                    let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
285                                    msg.tool_name = Some(tool_call.function.name.clone());
286                                    msg.ephemeral = tool_result.ephemeral;
287                                    h.push(Message::Tool(msg));
288                                }
289                            }
290                            Err(e) => {
291                                yield AgentEvent::Error(ErrorEvent::new(format!(
292                                    "Tool execution failed: {}",
293                                    e
294                                )));
295                            }
296                        }
297                    }
298
299                    step += 1;
300                    yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
301                    continue;
302                }
303
304                // No tool calls - we're done
305                let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
306                    .with_steps(step);
307
308                yield AgentEvent::FinalResponse(final_response);
309                yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
310                break;
311            }
312        }
313    }
314
315    /// Call LLM with exponential backoff retry
316    async fn call_llm_with_retry(
317        llm: &dyn BaseChatModel,
318        messages: Vec<Message>,
319        tools: Option<Vec<ToolDefinition>>,
320        tool_choice: Option<crate::llm::ToolChoice>,
321    ) -> Result<ChatCompletion> {
322        let max_retries = 3;
323        let mut delay = std::time::Duration::from_millis(100);
324
325        for attempt in 0..=max_retries {
326            match llm
327                .invoke(messages.clone(), tools.clone(), tool_choice.clone())
328                .await
329            {
330                Ok(completion) => return Ok(completion),
331                Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
332                    tokio::time::sleep(delay).await;
333                    delay *= 2;
334                }
335                Err(e) => return Err(Error::Llm(e)),
336            }
337        }
338
339        Err(Error::Agent("Max retries exceeded".into()))
340    }
341
342    /// Get usage summary
343    pub async fn get_usage(&self) -> UsageSummary {
344        self.usage.read().await.clone()
345    }
346
347    /// Destroy old ephemeral message content, keeping the last N per tool.
348    fn destroy_ephemeral_messages(
349        history: &mut [Message],
350        ephemeral_config: &HashMap<String, EphemeralConfig>,
351    ) {
352        // First pass: collect indices of ephemeral messages by tool name
353        let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
354
355        for (idx, msg) in history.iter().enumerate() {
356            let tool_msg = match msg {
357                Message::Tool(t) => t,
358                _ => continue,
359            };
360
361            if !tool_msg.ephemeral || tool_msg.destroyed {
362                continue;
363            }
364
365            let tool_name = match &tool_msg.tool_name {
366                Some(name) => name.clone(),
367                None => continue,
368            };
369
370            ephemeral_indices_by_tool
371                .entry(tool_name)
372                .or_default()
373                .push(idx);
374        }
375
376        // Collect all indices to destroy
377        let mut indices_to_destroy: Vec<usize> = Vec::new();
378
379        for (tool_name, indices) in ephemeral_indices_by_tool {
380            let keep_count = ephemeral_config
381                .get(&tool_name)
382                .map(|c| c.keep_count)
383                .unwrap_or(1);
384
385            // Destroy messages beyond the keep limit (older ones first)
386            let to_destroy = if keep_count > 0 && indices.len() > keep_count {
387                &indices[..indices.len() - keep_count]
388            } else {
389                &indices[..]
390            };
391
392            indices_to_destroy.extend(to_destroy.iter().copied());
393        }
394
395        // Second pass: destroy the messages
396        for idx in indices_to_destroy {
397            if let Message::Tool(tool_msg) = &mut history[idx] {
398                debug!("Destroying ephemeral message at index {}", idx);
399                tool_msg.destroy();
400            }
401        }
402    }
403
404    /// Clear history
405    pub async fn clear_history(&self) {
406        let mut history = self.history.write().await;
407        history.clear();
408    }
409
410    /// Load history
411    pub async fn load_history(&self, messages: Vec<Message>) {
412        let mut history = self.history.write().await;
413        *history = messages;
414    }
415
416    /// Get current history
417    pub async fn get_history(&self) -> Vec<Message> {
418        self.history.read().await.clone()
419    }
420
421    /// Check if memory is enabled
422    pub fn has_memory(&self) -> bool {
423        self.memory.is_some()
424    }
425
426    /// Get memory manager reference
427    pub fn get_memory(&self) -> Option<&Arc<RwLock<MemoryManager>>> {
428        self.memory.as_ref()
429    }
430}