1use std::sync::Arc;
2
3use futures::StreamExt;
4use tokio::sync::mpsc;
5
6use crate::approval::{AllowAllApprover, ApprovalDecision, ToolApprover};
7use crate::error::{Error, Result};
8use crate::llm::{LlmClient, LlmRequest, LlmStreamEvent, Usage};
9use crate::message::{Message, Role, ToolCall, ToolResult};
10use crate::tool::ToolRegistry;
11
12#[derive(Debug, Clone)]
13pub struct AgentOptions {
14 pub model: String,
15 pub temperature: Option<f32>,
16 pub max_tokens: Option<u32>,
17 pub max_iterations: u32,
21 pub max_tool_result_chars: usize,
26}
27
28impl Default for AgentOptions {
29 fn default() -> Self {
30 Self {
31 model: "gpt-4o-mini".into(),
32 temperature: None,
33 max_tokens: None,
34 max_iterations: 32,
35 max_tool_result_chars: 16 * 1024,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
43pub enum AgentEvent {
44 AssistantDelta(String),
45 AssistantMessage(Message),
46 ToolCallStart {
47 id: String,
48 name: String,
49 arguments: serde_json::Value,
50 },
51 ToolCallFinish {
52 id: String,
53 name: String,
54 content: String,
55 is_error: bool,
56 },
57 Usage(Usage),
58 IterationBudgetExhausted,
59 Done,
60}
61
62pub struct Agent {
65 llm: Arc<dyn LlmClient>,
66 tools: ToolRegistry,
67 options: AgentOptions,
68 approver: Arc<dyn ToolApprover>,
69}
70
71impl Agent {
72 pub fn new(llm: Arc<dyn LlmClient>, tools: ToolRegistry, options: AgentOptions) -> Self {
73 Self {
74 llm,
75 tools,
76 options,
77 approver: Arc::new(AllowAllApprover),
78 }
79 }
80
81 pub fn with_approver(mut self, approver: Arc<dyn ToolApprover>) -> Self {
85 self.approver = approver;
86 self
87 }
88
89 pub fn options(&self) -> &AgentOptions {
90 &self.options
91 }
92
93 pub fn options_mut(&mut self) -> &mut AgentOptions {
94 &mut self.options
95 }
96
97 pub fn tools(&self) -> &ToolRegistry {
98 &self.tools
99 }
100
101 pub async fn run(
105 &self,
106 messages: &mut Vec<Message>,
107 events: mpsc::Sender<AgentEvent>,
108 ) -> Result<()> {
109 for _ in 0..self.options.max_iterations {
110 let req = LlmRequest {
111 model: self.options.model.clone(),
112 messages: messages.clone(),
113 tools: self.tools.schemas(),
114 temperature: self.options.temperature,
115 max_tokens: self.options.max_tokens,
116 };
117
118 let mut stream = self.llm.stream(req).await?;
119 let mut text_buf = String::new();
120 let mut tool_calls: Vec<ToolCall> = Vec::new();
121
122 while let Some(ev) = stream.next().await {
123 match ev? {
124 LlmStreamEvent::Delta(s) => {
125 text_buf.push_str(&s);
126 let _ = events.send(AgentEvent::AssistantDelta(s)).await;
127 }
128 LlmStreamEvent::ToolCalls(calls) => {
129 tool_calls = calls;
130 }
131 LlmStreamEvent::Usage(u) => {
132 let _ = events.send(AgentEvent::Usage(u)).await;
133 }
134 LlmStreamEvent::Done(_) => break,
135 }
136 }
137
138 let assistant_msg = if tool_calls.is_empty() {
139 Message::assistant_text(text_buf.clone())
140 } else if text_buf.is_empty() {
141 Message::assistant_tool_calls(tool_calls.clone())
142 } else {
143 Message {
145 role: Role::Assistant,
146 content: Some(text_buf.clone()),
147 tool_calls: tool_calls.clone(),
148 tool_call_id: None,
149 name: None,
150 }
151 };
152 messages.push(assistant_msg.clone());
153 let _ = events
154 .send(AgentEvent::AssistantMessage(assistant_msg))
155 .await;
156
157 if tool_calls.is_empty() {
158 let _ = events.send(AgentEvent::Done).await;
159 return Ok(());
160 }
161
162 for call in tool_calls {
163 let decision = self.approver.approve(&call.name, &call.arguments).await;
164
165 let _ = events
168 .send(AgentEvent::ToolCallStart {
169 id: call.id.clone(),
170 name: call.name.clone(),
171 arguments: call.arguments.clone(),
172 })
173 .await;
174
175 let mut result = match decision {
176 ApprovalDecision::Deny { reason } => ToolResult {
177 tool_call_id: call.id.clone(),
178 name: call.name.clone(),
179 content: format!("tool rejected by user: {reason}"),
180 is_error: true,
181 },
182 ApprovalDecision::Allow => match self.tools.get(&call.name) {
183 Ok(tool) => tool.call(&call.id, call.arguments.clone()).await,
184 Err(e) => ToolResult {
185 tool_call_id: call.id.clone(),
186 name: call.name.clone(),
187 content: format!("error: {e}"),
188 is_error: true,
189 },
190 },
191 };
192 if self.options.max_tool_result_chars > 0
193 && result.content.len() > self.options.max_tool_result_chars
194 {
195 result.content.truncate(self.options.max_tool_result_chars);
196 result.content.push_str("\n…[truncated tool output]");
197 }
198
199 let _ = events
200 .send(AgentEvent::ToolCallFinish {
201 id: result.tool_call_id.clone(),
202 name: result.name.clone(),
203 content: result.content.clone(),
204 is_error: result.is_error,
205 })
206 .await;
207
208 messages.push(Message::tool_response(result));
209 }
210 }
211
212 let _ = events.send(AgentEvent::IterationBudgetExhausted).await;
213 Err(Error::BudgetExhausted)
214 }
215}