Skip to main content

agent_io/agent/
service.rs

1//! Agent service - core execution loop
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use derive_builder::Builder;
7use futures::{Stream, StreamExt};
8use tokio::sync::RwLock;
9use tracing::debug;
10
11use crate::agent::{
12    AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
13};
14use crate::llm::{BaseChatModel, ChatCompletion, Message, ToolChoice, ToolDefinition, ToolMessage};
15use crate::tools::Tool;
16use crate::{Error, Result};
17
18/// Default maximum iterations
19pub const DEFAULT_MAX_ITERATIONS: usize = 200;
20
21/// Agent configuration
22#[derive(Builder, Clone)]
23#[builder(pattern = "owned")]
24pub struct AgentConfig {
25    /// System prompt
26    #[builder(setter(into, strip_option), default = "None")]
27    pub system_prompt: Option<String>,
28
29    /// Maximum iterations before stopping
30    #[builder(default = "DEFAULT_MAX_ITERATIONS")]
31    pub max_iterations: usize,
32
33    /// Tool choice strategy
34    #[builder(default = "ToolChoice::Auto")]
35    pub tool_choice: ToolChoice,
36
37    /// Enable context compaction
38    #[builder(default = "false")]
39    pub enable_compaction: bool,
40
41    /// Compaction threshold (ratio of context window)
42    #[builder(default = "0.80")]
43    pub compaction_threshold: f32,
44
45    /// Enable cost tracking
46    #[builder(default = "false")]
47    pub include_cost: bool,
48}
49
50impl Default for AgentConfig {
51    fn default() -> Self {
52        Self {
53            system_prompt: None,
54            max_iterations: DEFAULT_MAX_ITERATIONS,
55            tool_choice: ToolChoice::Auto,
56            enable_compaction: false,
57            compaction_threshold: 0.80,
58            include_cost: false,
59        }
60    }
61}
62
63/// Ephemeral message configuration per tool
64#[derive(Debug, Clone, Copy)]
65pub struct EphemeralConfig {
66    /// How many outputs to keep (None = not ephemeral)
67    pub keep_count: usize,
68}
69
70impl Default for EphemeralConfig {
71    fn default() -> Self {
72        Self { keep_count: 1 }
73    }
74}
75
76/// Agent - the main orchestrator for LLM interactions
77pub struct Agent {
78    /// LLM provider
79    llm: Arc<dyn BaseChatModel>,
80    /// Available tools
81    tools: Vec<Arc<dyn Tool>>,
82    /// Configuration
83    config: AgentConfig,
84    /// Message history
85    history: Arc<RwLock<Vec<Message>>>,
86    /// Usage tracking
87    usage: Arc<RwLock<UsageSummary>>,
88    /// Ephemeral config per tool name
89    ephemeral_config: HashMap<String, EphemeralConfig>,
90}
91
92impl Agent {
93    /// Create a new agent
94    pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
95        // Build ephemeral config from tools
96        let ephemeral_config = tools
97            .iter()
98            .filter_map(|t| {
99                let cfg = t.ephemeral();
100                if cfg != crate::tools::EphemeralConfig::None {
101                    let keep_count = match cfg {
102                        crate::tools::EphemeralConfig::Single => 1,
103                        crate::tools::EphemeralConfig::Count(n) => n,
104                        crate::tools::EphemeralConfig::None => 0,
105                    };
106                    Some((t.name().to_string(), EphemeralConfig { keep_count }))
107                } else {
108                    None
109                }
110            })
111            .collect();
112
113        Self {
114            llm,
115            tools,
116            config: AgentConfig::default(),
117            history: Arc::new(RwLock::new(Vec::new())),
118            usage: Arc::new(RwLock::new(UsageSummary::new())),
119            ephemeral_config,
120        }
121    }
122
123    /// Create an agent builder
124    pub fn builder() -> AgentBuilder {
125        AgentBuilder::default()
126    }
127
128    /// Set configuration
129    pub fn with_config(mut self, config: AgentConfig) -> Self {
130        self.config = config;
131        self
132    }
133
134    /// Query the agent synchronously (returns final response)
135    pub async fn query(&self, message: impl Into<String>) -> Result<String> {
136        // Add user message to history
137        {
138            let mut history = self.history.write().await;
139            history.push(Message::user(message.into()));
140        }
141
142        // Execute and collect result
143        let stream = self.execute_loop();
144        futures::pin_mut!(stream);
145
146        while let Some(event) = stream.next().await {
147            if let AgentEvent::FinalResponse(response) = event {
148                return Ok(response.content);
149            }
150        }
151
152        Err(Error::Agent("No final response received".into()))
153    }
154
155    /// Query the agent with streaming events
156    pub async fn query_stream<'a, M: Into<String> + 'a>(
157        &'a self,
158        message: M,
159    ) -> Result<impl Stream<Item = AgentEvent> + 'a> {
160        // Add user message to history
161        {
162            let mut history = self.history.write().await;
163            history.push(Message::user(message.into()));
164        }
165
166        Ok(self.execute_loop())
167    }
168
169    /// Main execution loop
170    fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
171        async_stream::stream! {
172            let mut step = 0;
173
174            loop {
175                if step >= self.config.max_iterations {
176                    yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
177                    break;
178                }
179
180                yield AgentEvent::StepStart(StepStartEvent::new(step));
181
182                // Destroy ephemeral messages from previous iteration
183                {
184                    let mut h = self.history.write().await;
185                    Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
186                }
187
188                // Get current history
189                let messages = {
190                    let h = self.history.read().await;
191                    h.clone()
192                };
193
194                // Build system prompt + messages
195                let mut full_messages = Vec::new();
196                if let Some(ref prompt) = self.config.system_prompt
197                    && step == 0 {
198                        full_messages.push(Message::system(prompt));
199                    }
200                full_messages.extend(messages);
201
202                // Build tool definitions
203                let tool_defs: Vec<ToolDefinition> = self.tools.iter()
204                    .map(|t| t.definition())
205                    .collect();
206
207                // Call LLM with retry
208                let completion = match Self::call_llm_with_retry(
209                    self.llm.as_ref(),
210                    full_messages.clone(),
211                    if tool_defs.is_empty() { None } else { Some(tool_defs) },
212                    Some(self.config.tool_choice.clone()),
213                ).await {
214                    Ok(c) => c,
215                    Err(e) => {
216                        yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
217                        break;
218                    }
219                };
220
221                // Track usage
222                if let Some(ref u) = completion.usage {
223                    let mut us = self.usage.write().await;
224                    us.add_usage(self.llm.model(), u);
225                }
226
227                // Yield thinking content
228                if let Some(ref thinking) = completion.thinking {
229                    yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
230                }
231
232                // Yield text content
233                if let Some(ref content) = completion.content {
234                    yield AgentEvent::Text(crate::agent::TextEvent::new(content));
235                }
236
237                // Handle tool calls
238                if completion.has_tool_calls() {
239                    // Add assistant message to history
240                    {
241                        let mut h = self.history.write().await;
242                        h.push(Message::Assistant(crate::llm::AssistantMessage {
243                            role: "assistant".to_string(),
244                            content: completion.content.clone(),
245                            thinking: completion.thinking.clone(),
246                            redacted_thinking: None,
247                            tool_calls: completion.tool_calls.clone(),
248                            refusal: None,
249                        }));
250                    }
251
252                    // Execute tools
253                    for tool_call in &completion.tool_calls {
254                        yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
255
256                        // Find tool
257                        let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
258
259                        let result = if let Some(t) = tool {
260                            let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
261                                .unwrap_or(serde_json::json!({}));
262                            t.execute(args, None).await
263                        } else {
264                            Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
265                        };
266
267                        match result {
268                            Ok(tool_result) => {
269                                yield AgentEvent::ToolResult(
270                                    crate::agent::ToolResultEvent::new(
271                                        &tool_call.id,
272                                        &tool_call.function.name,
273                                        &tool_result.content,
274                                        step,
275                                    ).with_ephemeral(tool_result.ephemeral)
276                                );
277
278                                // Add tool result to history with ephemeral metadata
279                                {
280                                    let mut h = self.history.write().await;
281                                    let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
282                                    msg.tool_name = Some(tool_call.function.name.clone());
283                                    msg.ephemeral = tool_result.ephemeral;
284                                    h.push(Message::Tool(msg));
285                                }
286                            }
287                            Err(e) => {
288                                yield AgentEvent::Error(ErrorEvent::new(format!(
289                                    "Tool execution failed: {}",
290                                    e
291                                )));
292                            }
293                        }
294                    }
295
296                    step += 1;
297                    yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
298                    continue;
299                }
300
301                // No tool calls - we're done
302                let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
303                    .with_steps(step);
304
305                yield AgentEvent::FinalResponse(final_response);
306                yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
307                break;
308            }
309        }
310    }
311
312    /// Call LLM with exponential backoff retry
313    async fn call_llm_with_retry(
314        llm: &dyn BaseChatModel,
315        messages: Vec<Message>,
316        tools: Option<Vec<ToolDefinition>>,
317        tool_choice: Option<ToolChoice>,
318    ) -> Result<ChatCompletion> {
319        let max_retries = 3;
320        let mut delay = std::time::Duration::from_millis(100);
321
322        for attempt in 0..=max_retries {
323            match llm
324                .invoke(messages.clone(), tools.clone(), tool_choice.clone())
325                .await
326            {
327                Ok(completion) => return Ok(completion),
328                Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
329                    tokio::time::sleep(delay).await;
330                    delay *= 2;
331                }
332                Err(e) => return Err(Error::Llm(e)),
333            }
334        }
335
336        Err(Error::Agent("Max retries exceeded".into()))
337    }
338
339    /// Get usage summary
340    pub async fn get_usage(&self) -> UsageSummary {
341        self.usage.read().await.clone()
342    }
343
344    /// Destroy old ephemeral message content, keeping the last N per tool.
345    ///
346    /// This should be called before each LLM invocation to manage context size.
347    fn destroy_ephemeral_messages(
348        history: &mut [Message],
349        ephemeral_config: &HashMap<String, EphemeralConfig>,
350    ) {
351        // First pass: collect indices of ephemeral messages by tool name
352        let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
353
354        for (idx, msg) in history.iter().enumerate() {
355            let tool_msg = match msg {
356                Message::Tool(t) => t,
357                _ => continue,
358            };
359
360            if !tool_msg.ephemeral || tool_msg.destroyed {
361                continue;
362            }
363
364            let tool_name = match &tool_msg.tool_name {
365                Some(name) => name.clone(),
366                None => continue,
367            };
368
369            ephemeral_indices_by_tool
370                .entry(tool_name)
371                .or_default()
372                .push(idx);
373        }
374
375        // Collect all indices to destroy
376        let mut indices_to_destroy: Vec<usize> = Vec::new();
377
378        for (tool_name, indices) in ephemeral_indices_by_tool {
379            let keep_count = ephemeral_config
380                .get(&tool_name)
381                .map(|c| c.keep_count)
382                .unwrap_or(1);
383
384            // Destroy messages beyond the keep limit (older ones first)
385            let to_destroy = if keep_count > 0 && indices.len() > keep_count {
386                &indices[..indices.len() - keep_count]
387            } else {
388                &indices[..]
389            };
390
391            indices_to_destroy.extend(to_destroy.iter().copied());
392        }
393
394        // Second pass: destroy the messages
395        for idx in indices_to_destroy {
396            if let Message::Tool(tool_msg) = &mut history[idx] {
397                debug!("Destroying ephemeral message at index {}", idx);
398                tool_msg.destroy();
399            }
400        }
401    }
402
403    /// Clear history
404    pub async fn clear_history(&self) {
405        let mut history = self.history.write().await;
406        history.clear();
407    }
408
409    /// Load history
410    pub async fn load_history(&self, messages: Vec<Message>) {
411        let mut history = self.history.write().await;
412        *history = messages;
413    }
414
415    /// Get current history
416    pub async fn get_history(&self) -> Vec<Message> {
417        self.history.read().await.clone()
418    }
419}
420
421/// Agent builder
422#[derive(Default)]
423pub struct AgentBuilder {
424    llm: Option<Arc<dyn BaseChatModel>>,
425    tools: Vec<Arc<dyn Tool>>,
426    config: Option<AgentConfig>,
427}
428
429impl AgentBuilder {
430    pub fn with_llm(mut self, llm: Arc<dyn BaseChatModel>) -> Self {
431        self.llm = Some(llm);
432        self
433    }
434
435    pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
436        self.tools.push(tool);
437        self
438    }
439
440    pub fn tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
441        self.tools = tools;
442        self
443    }
444
445    pub fn config(mut self, config: AgentConfig) -> Self {
446        self.config = Some(config);
447        self
448    }
449
450    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
451        let mut config = self.config.unwrap_or_default();
452        config.system_prompt = Some(prompt.into());
453        self.config = Some(config);
454        self
455    }
456
457    pub fn max_iterations(mut self, max: usize) -> Self {
458        let mut config = self.config.unwrap_or_default();
459        config.max_iterations = max;
460        self.config = Some(config);
461        self
462    }
463
464    pub fn build(self) -> Result<Agent> {
465        let llm = self
466            .llm
467            .ok_or_else(|| Error::Config("LLM is required".into()))?;
468
469        // Build ephemeral config from tools
470        let ephemeral_config = self
471            .tools
472            .iter()
473            .filter_map(|t| {
474                let cfg = t.ephemeral();
475                if cfg != crate::tools::EphemeralConfig::None {
476                    let keep_count = match cfg {
477                        crate::tools::EphemeralConfig::Single => 1,
478                        crate::tools::EphemeralConfig::Count(n) => n,
479                        crate::tools::EphemeralConfig::None => 0,
480                    };
481                    Some((t.name().to_string(), EphemeralConfig { keep_count }))
482                } else {
483                    None
484                }
485            })
486            .collect();
487
488        Ok(Agent {
489            llm,
490            tools: self.tools,
491            config: self.config.unwrap_or_default(),
492            history: Arc::new(RwLock::new(Vec::new())),
493            usage: Arc::new(RwLock::new(UsageSummary::new())),
494            ephemeral_config,
495        })
496    }
497}