1use std::collections::HashMap;
4use std::sync::Arc;
5
6use derive_builder::Builder;
7use futures::{Stream, StreamExt};
8use tokio::sync::RwLock;
9use tracing::debug;
10
11use crate::agent::{
12 AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
13};
14use crate::llm::{BaseChatModel, ChatCompletion, Message, ToolChoice, ToolDefinition, ToolMessage};
15use crate::tools::Tool;
16use crate::{Error, Result};
17
18pub const DEFAULT_MAX_ITERATIONS: usize = 200;
20
21#[derive(Builder, Clone)]
23#[builder(pattern = "owned")]
24pub struct AgentConfig {
25 #[builder(setter(into, strip_option), default = "None")]
27 pub system_prompt: Option<String>,
28
29 #[builder(default = "DEFAULT_MAX_ITERATIONS")]
31 pub max_iterations: usize,
32
33 #[builder(default = "ToolChoice::Auto")]
35 pub tool_choice: ToolChoice,
36
37 #[builder(default = "false")]
39 pub enable_compaction: bool,
40
41 #[builder(default = "0.80")]
43 pub compaction_threshold: f32,
44
45 #[builder(default = "false")]
47 pub include_cost: bool,
48}
49
50impl Default for AgentConfig {
51 fn default() -> Self {
52 Self {
53 system_prompt: None,
54 max_iterations: DEFAULT_MAX_ITERATIONS,
55 tool_choice: ToolChoice::Auto,
56 enable_compaction: false,
57 compaction_threshold: 0.80,
58 include_cost: false,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy)]
65pub struct EphemeralConfig {
66 pub keep_count: usize,
68}
69
70impl Default for EphemeralConfig {
71 fn default() -> Self {
72 Self { keep_count: 1 }
73 }
74}
75
76pub struct Agent {
78 llm: Arc<dyn BaseChatModel>,
80 tools: Vec<Arc<dyn Tool>>,
82 config: AgentConfig,
84 history: Arc<RwLock<Vec<Message>>>,
86 usage: Arc<RwLock<UsageSummary>>,
88 ephemeral_config: HashMap<String, EphemeralConfig>,
90}
91
92impl Agent {
93 pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
95 let ephemeral_config = tools
97 .iter()
98 .filter_map(|t| {
99 let cfg = t.ephemeral();
100 if cfg != crate::tools::EphemeralConfig::None {
101 let keep_count = match cfg {
102 crate::tools::EphemeralConfig::Single => 1,
103 crate::tools::EphemeralConfig::Count(n) => n,
104 crate::tools::EphemeralConfig::None => 0,
105 };
106 Some((t.name().to_string(), EphemeralConfig { keep_count }))
107 } else {
108 None
109 }
110 })
111 .collect();
112
113 Self {
114 llm,
115 tools,
116 config: AgentConfig::default(),
117 history: Arc::new(RwLock::new(Vec::new())),
118 usage: Arc::new(RwLock::new(UsageSummary::new())),
119 ephemeral_config,
120 }
121 }
122
123 pub fn builder() -> AgentBuilder {
125 AgentBuilder::default()
126 }
127
128 pub fn with_config(mut self, config: AgentConfig) -> Self {
130 self.config = config;
131 self
132 }
133
134 pub async fn query(&self, message: impl Into<String>) -> Result<String> {
136 {
138 let mut history = self.history.write().await;
139 history.push(Message::user(message.into()));
140 }
141
142 let stream = self.execute_loop();
144 futures::pin_mut!(stream);
145
146 while let Some(event) = stream.next().await {
147 if let AgentEvent::FinalResponse(response) = event {
148 return Ok(response.content);
149 }
150 }
151
152 Err(Error::Agent("No final response received".into()))
153 }
154
155 pub async fn query_stream<'a, M: Into<String> + 'a>(
157 &'a self,
158 message: M,
159 ) -> Result<impl Stream<Item = AgentEvent> + 'a> {
160 {
162 let mut history = self.history.write().await;
163 history.push(Message::user(message.into()));
164 }
165
166 Ok(self.execute_loop())
167 }
168
169 fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
171 async_stream::stream! {
172 let mut step = 0;
173
174 loop {
175 if step >= self.config.max_iterations {
176 yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
177 break;
178 }
179
180 yield AgentEvent::StepStart(StepStartEvent::new(step));
181
182 {
184 let mut h = self.history.write().await;
185 Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
186 }
187
188 let messages = {
190 let h = self.history.read().await;
191 h.clone()
192 };
193
194 let mut full_messages = Vec::new();
196 if let Some(ref prompt) = self.config.system_prompt
197 && step == 0 {
198 full_messages.push(Message::system(prompt));
199 }
200 full_messages.extend(messages);
201
202 let tool_defs: Vec<ToolDefinition> = self.tools.iter()
204 .map(|t| t.definition())
205 .collect();
206
207 let completion = match Self::call_llm_with_retry(
209 self.llm.as_ref(),
210 full_messages.clone(),
211 if tool_defs.is_empty() { None } else { Some(tool_defs) },
212 Some(self.config.tool_choice.clone()),
213 ).await {
214 Ok(c) => c,
215 Err(e) => {
216 yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
217 break;
218 }
219 };
220
221 if let Some(ref u) = completion.usage {
223 let mut us = self.usage.write().await;
224 us.add_usage(self.llm.model(), u);
225 }
226
227 if let Some(ref thinking) = completion.thinking {
229 yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
230 }
231
232 if let Some(ref content) = completion.content {
234 yield AgentEvent::Text(crate::agent::TextEvent::new(content));
235 }
236
237 if completion.has_tool_calls() {
239 {
241 let mut h = self.history.write().await;
242 h.push(Message::Assistant(crate::llm::AssistantMessage {
243 role: "assistant".to_string(),
244 content: completion.content.clone(),
245 thinking: completion.thinking.clone(),
246 redacted_thinking: None,
247 tool_calls: completion.tool_calls.clone(),
248 refusal: None,
249 }));
250 }
251
252 for tool_call in &completion.tool_calls {
254 yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
255
256 let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
258
259 let result = if let Some(t) = tool {
260 let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
261 .unwrap_or(serde_json::json!({}));
262 t.execute(args, None).await
263 } else {
264 Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
265 };
266
267 match result {
268 Ok(tool_result) => {
269 yield AgentEvent::ToolResult(
270 crate::agent::ToolResultEvent::new(
271 &tool_call.id,
272 &tool_call.function.name,
273 &tool_result.content,
274 step,
275 ).with_ephemeral(tool_result.ephemeral)
276 );
277
278 {
280 let mut h = self.history.write().await;
281 let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
282 msg.tool_name = Some(tool_call.function.name.clone());
283 msg.ephemeral = tool_result.ephemeral;
284 h.push(Message::Tool(msg));
285 }
286 }
287 Err(e) => {
288 yield AgentEvent::Error(ErrorEvent::new(format!(
289 "Tool execution failed: {}",
290 e
291 )));
292 }
293 }
294 }
295
296 step += 1;
297 yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
298 continue;
299 }
300
301 let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
303 .with_steps(step);
304
305 yield AgentEvent::FinalResponse(final_response);
306 yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
307 break;
308 }
309 }
310 }
311
312 async fn call_llm_with_retry(
314 llm: &dyn BaseChatModel,
315 messages: Vec<Message>,
316 tools: Option<Vec<ToolDefinition>>,
317 tool_choice: Option<ToolChoice>,
318 ) -> Result<ChatCompletion> {
319 let max_retries = 3;
320 let mut delay = std::time::Duration::from_millis(100);
321
322 for attempt in 0..=max_retries {
323 match llm
324 .invoke(messages.clone(), tools.clone(), tool_choice.clone())
325 .await
326 {
327 Ok(completion) => return Ok(completion),
328 Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
329 tokio::time::sleep(delay).await;
330 delay *= 2;
331 }
332 Err(e) => return Err(Error::Llm(e)),
333 }
334 }
335
336 Err(Error::Agent("Max retries exceeded".into()))
337 }
338
339 pub async fn get_usage(&self) -> UsageSummary {
341 self.usage.read().await.clone()
342 }
343
344 fn destroy_ephemeral_messages(
348 history: &mut [Message],
349 ephemeral_config: &HashMap<String, EphemeralConfig>,
350 ) {
351 let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
353
354 for (idx, msg) in history.iter().enumerate() {
355 let tool_msg = match msg {
356 Message::Tool(t) => t,
357 _ => continue,
358 };
359
360 if !tool_msg.ephemeral || tool_msg.destroyed {
361 continue;
362 }
363
364 let tool_name = match &tool_msg.tool_name {
365 Some(name) => name.clone(),
366 None => continue,
367 };
368
369 ephemeral_indices_by_tool
370 .entry(tool_name)
371 .or_default()
372 .push(idx);
373 }
374
375 let mut indices_to_destroy: Vec<usize> = Vec::new();
377
378 for (tool_name, indices) in ephemeral_indices_by_tool {
379 let keep_count = ephemeral_config
380 .get(&tool_name)
381 .map(|c| c.keep_count)
382 .unwrap_or(1);
383
384 let to_destroy = if keep_count > 0 && indices.len() > keep_count {
386 &indices[..indices.len() - keep_count]
387 } else {
388 &indices[..]
389 };
390
391 indices_to_destroy.extend(to_destroy.iter().copied());
392 }
393
394 for idx in indices_to_destroy {
396 if let Message::Tool(tool_msg) = &mut history[idx] {
397 debug!("Destroying ephemeral message at index {}", idx);
398 tool_msg.destroy();
399 }
400 }
401 }
402
403 pub async fn clear_history(&self) {
405 let mut history = self.history.write().await;
406 history.clear();
407 }
408
409 pub async fn load_history(&self, messages: Vec<Message>) {
411 let mut history = self.history.write().await;
412 *history = messages;
413 }
414
415 pub async fn get_history(&self) -> Vec<Message> {
417 self.history.read().await.clone()
418 }
419}
420
421#[derive(Default)]
423pub struct AgentBuilder {
424 llm: Option<Arc<dyn BaseChatModel>>,
425 tools: Vec<Arc<dyn Tool>>,
426 config: Option<AgentConfig>,
427}
428
429impl AgentBuilder {
430 pub fn with_llm(mut self, llm: Arc<dyn BaseChatModel>) -> Self {
431 self.llm = Some(llm);
432 self
433 }
434
435 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
436 self.tools.push(tool);
437 self
438 }
439
440 pub fn tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
441 self.tools = tools;
442 self
443 }
444
445 pub fn config(mut self, config: AgentConfig) -> Self {
446 self.config = Some(config);
447 self
448 }
449
450 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
451 let mut config = self.config.unwrap_or_default();
452 config.system_prompt = Some(prompt.into());
453 self.config = Some(config);
454 self
455 }
456
457 pub fn max_iterations(mut self, max: usize) -> Self {
458 let mut config = self.config.unwrap_or_default();
459 config.max_iterations = max;
460 self.config = Some(config);
461 self
462 }
463
464 pub fn build(self) -> Result<Agent> {
465 let llm = self
466 .llm
467 .ok_or_else(|| Error::Config("LLM is required".into()))?;
468
469 let ephemeral_config = self
471 .tools
472 .iter()
473 .filter_map(|t| {
474 let cfg = t.ephemeral();
475 if cfg != crate::tools::EphemeralConfig::None {
476 let keep_count = match cfg {
477 crate::tools::EphemeralConfig::Single => 1,
478 crate::tools::EphemeralConfig::Count(n) => n,
479 crate::tools::EphemeralConfig::None => 0,
480 };
481 Some((t.name().to_string(), EphemeralConfig { keep_count }))
482 } else {
483 None
484 }
485 })
486 .collect();
487
488 Ok(Agent {
489 llm,
490 tools: self.tools,
491 config: self.config.unwrap_or_default(),
492 history: Arc::new(RwLock::new(Vec::new())),
493 usage: Arc::new(RwLock::new(UsageSummary::new())),
494 ephemeral_config,
495 })
496 }
497}