Skip to main content

agent_io/agent/
service.rs

1//! Agent service - core execution loop
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use futures::{Stream, StreamExt};
7use tokio::sync::RwLock;
8use tracing::debug;
9
10use crate::agent::{
11    AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
12};
13use crate::llm::{
14    AssistantMessage, BaseChatModel, ChatCompletion, Message, ToolDefinition, ToolMessage,
15};
16use crate::memory::{MemoryManager, MemoryType};
17use crate::tools::Tool;
18use crate::{Error, Result};
19
20use super::builder::AgentBuilder;
21use super::config::{AgentConfig, EphemeralConfig};
22
23/// Agent - the main orchestrator for LLM interactions
24pub struct Agent {
25    /// LLM provider
26    llm: Arc<dyn BaseChatModel>,
27    /// Available tools
28    tools: Vec<Arc<dyn Tool>>,
29    /// Configuration
30    config: AgentConfig,
31    /// Message history
32    history: Arc<RwLock<Vec<Message>>>,
33    /// Usage tracking
34    usage: Arc<RwLock<UsageSummary>>,
35    /// Ephemeral config per tool name
36    ephemeral_config: HashMap<String, EphemeralConfig>,
37    /// Memory manager (optional)
38    memory: Option<Arc<RwLock<MemoryManager>>>,
39}
40
41impl Agent {
42    /// Create a new agent
43    pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
44        // Build ephemeral config from tools
45        let ephemeral_config = tools
46            .iter()
47            .filter_map(|t| {
48                let cfg = t.ephemeral();
49                if cfg != crate::tools::EphemeralConfig::None {
50                    let keep_count = match cfg {
51                        crate::tools::EphemeralConfig::Single => 1,
52                        crate::tools::EphemeralConfig::Count(n) => n,
53                        crate::tools::EphemeralConfig::None => 0,
54                    };
55                    Some((t.name().to_string(), EphemeralConfig { keep_count }))
56                } else {
57                    None
58                }
59            })
60            .collect();
61
62        Self {
63            llm,
64            tools,
65            config: AgentConfig::default(),
66            history: Arc::new(RwLock::new(Vec::new())),
67            usage: Arc::new(RwLock::new(UsageSummary::new())),
68            ephemeral_config,
69            memory: None,
70        }
71    }
72
73    /// Create an agent builder
74    pub fn builder() -> AgentBuilder {
75        AgentBuilder::default()
76    }
77
78    /// Set configuration
79    pub fn with_config(mut self, config: AgentConfig) -> Self {
80        self.config = config;
81        self
82    }
83
84    /// Create agent with all components (used by builder)
85    pub(super) fn new_with_config(
86        llm: Arc<dyn BaseChatModel>,
87        tools: Vec<Arc<dyn Tool>>,
88        config: AgentConfig,
89        ephemeral_config: HashMap<String, EphemeralConfig>,
90        memory: Option<Arc<RwLock<MemoryManager>>>,
91    ) -> Self {
92        Self {
93            llm,
94            tools,
95            config,
96            history: Arc::new(RwLock::new(Vec::new())),
97            usage: Arc::new(RwLock::new(UsageSummary::new())),
98            ephemeral_config,
99            memory,
100        }
101    }
102
103    /// Query the agent synchronously (returns final response)
104    pub async fn query(&self, message: impl Into<String>) -> Result<String> {
105        // Add user message to history
106        {
107            let mut history = self.history.write().await;
108            history.push(Message::user(message.into()));
109        }
110
111        // Execute and collect result
112        let stream = self.execute_loop();
113        futures::pin_mut!(stream);
114
115        while let Some(event) = stream.next().await {
116            if let AgentEvent::FinalResponse(response) = event {
117                return Ok(response.content);
118            }
119        }
120
121        Err(Error::Agent("No final response received".into()))
122    }
123
124    /// Query with memory context
125    pub async fn query_with_memory(&self, message: impl Into<String>) -> Result<String> {
126        let message = message.into();
127
128        // Recall relevant memories
129        let context = if let Some(memory) = &self.memory {
130            let mem = memory.read().await;
131            mem.recall_context(&message).await?
132        } else {
133            String::new()
134        };
135
136        // Build enhanced prompt with memory context
137        let enhanced_message = if context.is_empty() {
138            message.clone()
139        } else {
140            format!(
141                "Relevant context from memory:\n{}\n\nUser query: {}",
142                context, message
143            )
144        };
145
146        // Execute query
147        let result = self.query(enhanced_message).await?;
148
149        // Store this interaction in memory
150        if let Some(memory) = &self.memory {
151            let mut mem = memory.write().await;
152            mem.remember(&message, MemoryType::ShortTerm).await?;
153        }
154
155        Ok(result)
156    }
157
158    /// Query the agent with streaming events
159    pub async fn query_stream<'a, M: Into<String> + 'a>(
160        &'a self,
161        message: M,
162    ) -> Result<impl Stream<Item = AgentEvent> + 'a> {
163        // Add user message to history
164        {
165            let mut history = self.history.write().await;
166            history.push(Message::user(message.into()));
167        }
168
169        Ok(self.execute_loop())
170    }
171
172    /// Main execution loop
173    fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
174        async_stream::stream! {
175            let mut step = 0;
176
177            loop {
178                if step >= self.config.max_iterations {
179                    yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
180                    break;
181                }
182
183                yield AgentEvent::StepStart(StepStartEvent::new(step));
184
185                // Destroy ephemeral messages from previous iteration
186                {
187                    let mut h = self.history.write().await;
188                    Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
189                }
190
191                // Get current history
192                let messages = {
193                    let h = self.history.read().await;
194                    h.clone()
195                };
196
197                // Build system prompt + messages
198                let mut full_messages = Vec::new();
199                if let Some(ref prompt) = self.config.system_prompt {
200                    full_messages.push(Message::system(prompt));
201                }
202                full_messages.extend(messages);
203
204                // Build tool definitions
205                let tool_defs: Vec<ToolDefinition> = self.tools.iter()
206                    .map(|t| t.definition())
207                    .collect();
208
209                // Call LLM with retry
210                let completion = match Self::call_llm_with_retry(
211                    self.llm.as_ref(),
212                    full_messages.clone(),
213                    if tool_defs.is_empty() { None } else { Some(tool_defs) },
214                    Some(self.config.tool_choice.clone()),
215                ).await {
216                    Ok(c) => c,
217                    Err(e) => {
218                        yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
219                        break;
220                    }
221                };
222
223                // Track usage
224                if let Some(ref u) = completion.usage {
225                    let mut us = self.usage.write().await;
226                    us.add_usage(self.llm.model(), u);
227                }
228
229                // Yield thinking content
230                if let Some(ref thinking) = completion.thinking {
231                    yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
232                }
233
234                // Yield text content
235                if let Some(ref content) = completion.content {
236                    yield AgentEvent::Text(crate::agent::TextEvent::new(content));
237                }
238
239                // Handle tool calls
240                if completion.has_tool_calls() {
241                    // Add assistant message to history
242                    {
243                        let mut h = self.history.write().await;
244                        h.push(Message::Assistant(AssistantMessage {
245                            role: "assistant".to_string(),
246                            content: completion.content.clone(),
247                            thinking: completion.thinking.clone(),
248                            redacted_thinking: None,
249                            tool_calls: completion.tool_calls.clone(),
250                            refusal: None,
251                        }));
252                    }
253
254                    // Execute tools
255                    for tool_call in &completion.tool_calls {
256                        yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
257
258                        // Find tool
259                        let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
260
261                        let result = if let Some(t) = tool {
262                            let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
263                                .unwrap_or(serde_json::json!({}));
264                            t.execute(args).await
265                        } else {
266                            Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
267                        };
268
269                        match result {
270                            Ok(tool_result) => {
271                                yield AgentEvent::ToolResult(
272                                    crate::agent::ToolResultEvent::new(
273                                        &tool_call.id,
274                                        &tool_call.function.name,
275                                        &tool_result.content,
276                                        step,
277                                    ).with_ephemeral(tool_result.ephemeral)
278                                );
279
280                                // Add tool result to history with ephemeral metadata
281                                {
282                                    let mut h = self.history.write().await;
283                                    let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
284                                    msg.tool_name = Some(tool_call.function.name.clone());
285                                    msg.ephemeral = tool_result.ephemeral;
286                                    h.push(Message::Tool(msg));
287                                }
288                            }
289                            Err(e) => {
290                                yield AgentEvent::Error(ErrorEvent::new(format!(
291                                    "Tool execution failed: {}",
292                                    e
293                                )));
294                            }
295                        }
296                    }
297
298                    step += 1;
299                    yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
300                    continue;
301                }
302
303                // No tool calls - we're done
304                let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
305                    .with_steps(step);
306
307                yield AgentEvent::FinalResponse(final_response);
308                yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
309                break;
310            }
311        }
312    }
313
314    /// Call LLM with exponential backoff retry
315    async fn call_llm_with_retry(
316        llm: &dyn BaseChatModel,
317        messages: Vec<Message>,
318        tools: Option<Vec<ToolDefinition>>,
319        tool_choice: Option<crate::llm::ToolChoice>,
320    ) -> Result<ChatCompletion> {
321        let max_retries = 10;
322        let mut delay = std::time::Duration::from_millis(500);
323
324        for attempt in 0..=max_retries {
325            match llm
326                .invoke(messages.clone(), tools.clone(), tool_choice.clone())
327                .await
328            {
329                Ok(completion) => return Ok(completion),
330                Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
331                    tracing::warn!(
332                        "Rate limit or empty response, retrying in {:?} (attempt {}/{})",
333                        delay,
334                        attempt + 1,
335                        max_retries
336                    );
337                    tokio::time::sleep(delay).await;
338                    delay *= 2;
339                }
340                Err(e) => return Err(Error::Llm(e)),
341            }
342        }
343
344        Err(Error::Agent("Max retries exceeded".into()))
345    }
346
347    /// Get usage summary
348    pub async fn get_usage(&self) -> UsageSummary {
349        self.usage.read().await.clone()
350    }
351
352    /// Destroy old ephemeral message content, keeping the last N per tool.
353    fn destroy_ephemeral_messages(
354        history: &mut [Message],
355        ephemeral_config: &HashMap<String, EphemeralConfig>,
356    ) {
357        // First pass: collect indices of ephemeral messages by tool name
358        let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
359
360        for (idx, msg) in history.iter().enumerate() {
361            let tool_msg = match msg {
362                Message::Tool(t) => t,
363                _ => continue,
364            };
365
366            if !tool_msg.ephemeral || tool_msg.destroyed {
367                continue;
368            }
369
370            let tool_name = match &tool_msg.tool_name {
371                Some(name) => name.clone(),
372                None => continue,
373            };
374
375            ephemeral_indices_by_tool
376                .entry(tool_name)
377                .or_default()
378                .push(idx);
379        }
380
381        // Collect all indices to destroy
382        let mut indices_to_destroy: Vec<usize> = Vec::new();
383
384        for (tool_name, indices) in ephemeral_indices_by_tool {
385            let keep_count = ephemeral_config
386                .get(&tool_name)
387                .map(|c| c.keep_count)
388                .unwrap_or(1);
389
390            // Destroy messages beyond the keep limit (older ones first)
391            let to_destroy = if keep_count > 0 && indices.len() > keep_count {
392                &indices[..indices.len() - keep_count]
393            } else {
394                &indices[..]
395            };
396
397            indices_to_destroy.extend(to_destroy.iter().copied());
398        }
399
400        // Second pass: destroy the messages
401        for idx in indices_to_destroy {
402            if let Message::Tool(tool_msg) = &mut history[idx] {
403                debug!("Destroying ephemeral message at index {}", idx);
404                tool_msg.destroy();
405            }
406        }
407    }
408
409    /// Clear history
410    pub async fn clear_history(&self) {
411        let mut history = self.history.write().await;
412        history.clear();
413    }
414
415    /// Load history
416    pub async fn load_history(&self, messages: Vec<Message>) {
417        let mut history = self.history.write().await;
418        *history = messages;
419    }
420
421    /// Get current history
422    pub async fn get_history(&self) -> Vec<Message> {
423        self.history.read().await.clone()
424    }
425
426    /// Check if memory is enabled
427    pub fn has_memory(&self) -> bool {
428        self.memory.is_some()
429    }
430
431    /// Get memory manager reference
432    pub fn get_memory(&self) -> Option<&Arc<RwLock<MemoryManager>>> {
433        self.memory.as_ref()
434    }
435}