Skip to main content

merlion_core/
agent.rs

1use std::sync::Arc;
2
3use futures::StreamExt;
4use tokio::sync::mpsc;
5
6use crate::approval::{AllowAllApprover, ApprovalDecision, ToolApprover};
7use crate::error::{Error, Result};
8use crate::llm::{LlmClient, LlmRequest, LlmStreamEvent, Usage};
9use crate::message::{Message, Role, ToolCall, ToolResult};
10use crate::tool::ToolRegistry;
11
12#[derive(Debug, Clone)]
13pub struct AgentOptions {
14    pub model: String,
15    pub temperature: Option<f32>,
16    pub max_tokens: Option<u32>,
17    /// Hard cap on assistant ↔ tool round-trips inside a single
18    /// [`Agent::run`] call. Hermes calls this the iteration budget; matching
19    /// that vocabulary makes the port read like the original.
20    pub max_iterations: u32,
21    /// Cap on tool-result `content` length in characters. Anything longer is
22    /// truncated with a `…[truncated tool output]` suffix before being fed
23    /// back to the model. Individual tools may further restrict themselves;
24    /// this is a backstop. `0` disables truncation.
25    pub max_tool_result_chars: usize,
26}
27
28impl Default for AgentOptions {
29    fn default() -> Self {
30        Self {
31            model: "gpt-4o-mini".into(),
32            temperature: None,
33            max_tokens: None,
34            max_iterations: 32,
35            max_tool_result_chars: 16 * 1024,
36        }
37    }
38}
39
40/// Events emitted while the agent runs — consumed by the CLI to render
41/// streaming output and tool activity.
42#[derive(Debug, Clone)]
43pub enum AgentEvent {
44    AssistantDelta(String),
45    AssistantMessage(Message),
46    ToolCallStart {
47        id: String,
48        name: String,
49        arguments: serde_json::Value,
50    },
51    ToolCallFinish {
52        id: String,
53        name: String,
54        content: String,
55        is_error: bool,
56    },
57    Usage(Usage),
58    IterationBudgetExhausted,
59    Done,
60}
61
62/// Drives the assistant ↔ tool loop. One [`Agent`] is reusable across many
63/// `run` calls; the conversation lives in the caller-owned `messages` Vec.
64pub struct Agent {
65    llm: Arc<dyn LlmClient>,
66    tools: ToolRegistry,
67    options: AgentOptions,
68    approver: Arc<dyn ToolApprover>,
69}
70
71impl Agent {
72    pub fn new(llm: Arc<dyn LlmClient>, tools: ToolRegistry, options: AgentOptions) -> Self {
73        Self {
74            llm,
75            tools,
76            options,
77            approver: Arc::new(AllowAllApprover),
78        }
79    }
80
81    /// Install a tool approver. Defaults to [`AllowAllApprover`] — replace
82    /// with a real implementation (e.g. the CLI's console prompter) to gate
83    /// sensitive tools.
84    pub fn with_approver(mut self, approver: Arc<dyn ToolApprover>) -> Self {
85        self.approver = approver;
86        self
87    }
88
89    pub fn options(&self) -> &AgentOptions {
90        &self.options
91    }
92
93    pub fn options_mut(&mut self) -> &mut AgentOptions {
94        &mut self.options
95    }
96
97    pub fn tools(&self) -> &ToolRegistry {
98        &self.tools
99    }
100
101    /// Run the agent until the model returns a turn with no tool calls or the
102    /// iteration budget is exhausted. `messages` is mutated in place to
103    /// reflect the new turns. Streaming events are pushed to `events`.
104    pub async fn run(
105        &self,
106        messages: &mut Vec<Message>,
107        events: mpsc::Sender<AgentEvent>,
108    ) -> Result<()> {
109        for _ in 0..self.options.max_iterations {
110            let req = LlmRequest {
111                model: self.options.model.clone(),
112                messages: messages.clone(),
113                tools: self.tools.schemas(),
114                temperature: self.options.temperature,
115                max_tokens: self.options.max_tokens,
116            };
117
118            let mut stream = self.llm.stream(req).await?;
119            let mut text_buf = String::new();
120            let mut tool_calls: Vec<ToolCall> = Vec::new();
121
122            while let Some(ev) = stream.next().await {
123                match ev? {
124                    LlmStreamEvent::Delta(s) => {
125                        text_buf.push_str(&s);
126                        let _ = events.send(AgentEvent::AssistantDelta(s)).await;
127                    }
128                    LlmStreamEvent::ToolCalls(calls) => {
129                        tool_calls = calls;
130                    }
131                    LlmStreamEvent::Usage(u) => {
132                        let _ = events.send(AgentEvent::Usage(u)).await;
133                    }
134                    LlmStreamEvent::Done(_) => break,
135                }
136            }
137
138            let assistant_msg = if tool_calls.is_empty() {
139                Message::assistant_text(text_buf.clone())
140            } else if text_buf.is_empty() {
141                Message::assistant_tool_calls(tool_calls.clone())
142            } else {
143                // Both content and tool calls — OpenAI allows this; keep both.
144                Message {
145                    role: Role::Assistant,
146                    content: Some(text_buf.clone()),
147                    tool_calls: tool_calls.clone(),
148                    tool_call_id: None,
149                    name: None,
150                }
151            };
152            messages.push(assistant_msg.clone());
153            let _ = events
154                .send(AgentEvent::AssistantMessage(assistant_msg))
155                .await;
156
157            if tool_calls.is_empty() {
158                let _ = events.send(AgentEvent::Done).await;
159                return Ok(());
160            }
161
162            for call in tool_calls {
163                let decision = self.approver.approve(&call.name, &call.arguments).await;
164
165                // Emit Start *after* approval so the CLI's render output and
166                // the approver's stdin prompt don't fight for the terminal.
167                let _ = events
168                    .send(AgentEvent::ToolCallStart {
169                        id: call.id.clone(),
170                        name: call.name.clone(),
171                        arguments: call.arguments.clone(),
172                    })
173                    .await;
174
175                let mut result = match decision {
176                    ApprovalDecision::Deny { reason } => ToolResult {
177                        tool_call_id: call.id.clone(),
178                        name: call.name.clone(),
179                        content: format!("tool rejected by user: {reason}"),
180                        is_error: true,
181                    },
182                    ApprovalDecision::Allow => match self.tools.get(&call.name) {
183                        Ok(tool) => tool.call(&call.id, call.arguments.clone()).await,
184                        Err(e) => ToolResult {
185                            tool_call_id: call.id.clone(),
186                            name: call.name.clone(),
187                            content: format!("error: {e}"),
188                            is_error: true,
189                        },
190                    },
191                };
192                if self.options.max_tool_result_chars > 0
193                    && result.content.len() > self.options.max_tool_result_chars
194                {
195                    result.content.truncate(self.options.max_tool_result_chars);
196                    result.content.push_str("\n…[truncated tool output]");
197                }
198
199                let _ = events
200                    .send(AgentEvent::ToolCallFinish {
201                        id: result.tool_call_id.clone(),
202                        name: result.name.clone(),
203                        content: result.content.clone(),
204                        is_error: result.is_error,
205                    })
206                    .await;
207
208                messages.push(Message::tool_response(result));
209            }
210        }
211
212        let _ = events.send(AgentEvent::IterationBudgetExhausted).await;
213        Err(Error::BudgetExhausted)
214    }
215}