1use std::collections::HashMap;
4use std::sync::Arc;
5
6use futures::{Stream, StreamExt};
7use tokio::sync::RwLock;
8use tracing::debug;
9
10use crate::agent::{
11 AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
12};
13use crate::llm::{
14 AssistantMessage, BaseChatModel, ChatCompletion, Message, ToolDefinition, ToolMessage,
15};
16use crate::memory::{MemoryManager, MemoryType};
17use crate::tools::Tool;
18use crate::{Error, Result};
19
20use super::builder::AgentBuilder;
21use super::config::{AgentConfig, EphemeralConfig};
22
23pub struct Agent {
25 llm: Arc<dyn BaseChatModel>,
27 tools: Vec<Arc<dyn Tool>>,
29 config: AgentConfig,
31 history: Arc<RwLock<Vec<Message>>>,
33 usage: Arc<RwLock<UsageSummary>>,
35 ephemeral_config: HashMap<String, EphemeralConfig>,
37 memory: Option<Arc<RwLock<MemoryManager>>>,
39}
40
41impl Agent {
42 pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
44 let ephemeral_config = tools
46 .iter()
47 .filter_map(|t| {
48 let cfg = t.ephemeral();
49 if cfg != crate::tools::EphemeralConfig::None {
50 let keep_count = match cfg {
51 crate::tools::EphemeralConfig::Single => 1,
52 crate::tools::EphemeralConfig::Count(n) => n,
53 crate::tools::EphemeralConfig::None => 0,
54 };
55 Some((t.name().to_string(), EphemeralConfig { keep_count }))
56 } else {
57 None
58 }
59 })
60 .collect();
61
62 Self {
63 llm,
64 tools,
65 config: AgentConfig::default(),
66 history: Arc::new(RwLock::new(Vec::new())),
67 usage: Arc::new(RwLock::new(UsageSummary::new())),
68 ephemeral_config,
69 memory: None,
70 }
71 }
72
73 pub fn builder() -> AgentBuilder {
75 AgentBuilder::default()
76 }
77
78 pub fn with_config(mut self, config: AgentConfig) -> Self {
80 self.config = config;
81 self
82 }
83
84 pub(super) fn new_with_config(
86 llm: Arc<dyn BaseChatModel>,
87 tools: Vec<Arc<dyn Tool>>,
88 config: AgentConfig,
89 ephemeral_config: HashMap<String, EphemeralConfig>,
90 memory: Option<Arc<RwLock<MemoryManager>>>,
91 ) -> Self {
92 Self {
93 llm,
94 tools,
95 config,
96 history: Arc::new(RwLock::new(Vec::new())),
97 usage: Arc::new(RwLock::new(UsageSummary::new())),
98 ephemeral_config,
99 memory,
100 }
101 }
102
103 pub async fn query(&self, message: impl Into<String>) -> Result<String> {
105 {
107 let mut history = self.history.write().await;
108 history.push(Message::user(message.into()));
109 }
110
111 let stream = self.execute_loop();
113 futures::pin_mut!(stream);
114
115 while let Some(event) = stream.next().await {
116 if let AgentEvent::FinalResponse(response) = event {
117 return Ok(response.content);
118 }
119 }
120
121 Err(Error::Agent("No final response received".into()))
122 }
123
124 pub async fn query_with_memory(&self, message: impl Into<String>) -> Result<String> {
126 let message = message.into();
127
128 let context = if let Some(memory) = &self.memory {
130 let mem = memory.read().await;
131 mem.recall_context(&message).await?
132 } else {
133 String::new()
134 };
135
136 let enhanced_message = if context.is_empty() {
138 message.clone()
139 } else {
140 format!(
141 "Relevant context from memory:\n{}\n\nUser query: {}",
142 context, message
143 )
144 };
145
146 let result = self.query(enhanced_message).await?;
148
149 if let Some(memory) = &self.memory {
151 let mut mem = memory.write().await;
152 mem.remember(&message, MemoryType::ShortTerm).await?;
153 }
154
155 Ok(result)
156 }
157
158 pub async fn query_stream<'a, M: Into<String> + 'a>(
160 &'a self,
161 message: M,
162 ) -> Result<impl Stream<Item = AgentEvent> + 'a> {
163 {
165 let mut history = self.history.write().await;
166 history.push(Message::user(message.into()));
167 }
168
169 Ok(self.execute_loop())
170 }
171
172 fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
174 async_stream::stream! {
175 let mut step = 0;
176
177 loop {
178 if step >= self.config.max_iterations {
179 yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
180 break;
181 }
182
183 yield AgentEvent::StepStart(StepStartEvent::new(step));
184
185 {
187 let mut h = self.history.write().await;
188 Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
189 }
190
191 let messages = {
193 let h = self.history.read().await;
194 h.clone()
195 };
196
197 let mut full_messages = Vec::new();
199 if let Some(ref prompt) = self.config.system_prompt {
200 full_messages.push(Message::system(prompt));
201 }
202 full_messages.extend(messages);
203
204 let tool_defs: Vec<ToolDefinition> = self.tools.iter()
206 .map(|t| t.definition())
207 .collect();
208
209 let completion = match Self::call_llm_with_retry(
211 self.llm.as_ref(),
212 full_messages.clone(),
213 if tool_defs.is_empty() { None } else { Some(tool_defs) },
214 Some(self.config.tool_choice.clone()),
215 ).await {
216 Ok(c) => c,
217 Err(e) => {
218 yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
219 break;
220 }
221 };
222
223 if let Some(ref u) = completion.usage {
225 let mut us = self.usage.write().await;
226 us.add_usage(self.llm.model(), u);
227 }
228
229 if let Some(ref thinking) = completion.thinking {
231 yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
232 }
233
234 if let Some(ref content) = completion.content {
236 yield AgentEvent::Text(crate::agent::TextEvent::new(content));
237 }
238
239 if completion.has_tool_calls() {
241 {
243 let mut h = self.history.write().await;
244 h.push(Message::Assistant(AssistantMessage {
245 role: "assistant".to_string(),
246 content: completion.content.clone(),
247 thinking: completion.thinking.clone(),
248 redacted_thinking: None,
249 tool_calls: completion.tool_calls.clone(),
250 refusal: None,
251 }));
252 }
253
254 for tool_call in &completion.tool_calls {
256 yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
257
258 let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
260
261 let result = if let Some(t) = tool {
262 let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
263 .unwrap_or(serde_json::json!({}));
264 t.execute(args).await
265 } else {
266 Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
267 };
268
269 match result {
270 Ok(tool_result) => {
271 yield AgentEvent::ToolResult(
272 crate::agent::ToolResultEvent::new(
273 &tool_call.id,
274 &tool_call.function.name,
275 &tool_result.content,
276 step,
277 ).with_ephemeral(tool_result.ephemeral)
278 );
279
280 {
282 let mut h = self.history.write().await;
283 let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
284 msg.tool_name = Some(tool_call.function.name.clone());
285 msg.ephemeral = tool_result.ephemeral;
286 h.push(Message::Tool(msg));
287 }
288 }
289 Err(e) => {
290 yield AgentEvent::Error(ErrorEvent::new(format!(
291 "Tool execution failed: {}",
292 e
293 )));
294 }
295 }
296 }
297
298 step += 1;
299 yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
300 continue;
301 }
302
303 let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
305 .with_steps(step);
306
307 yield AgentEvent::FinalResponse(final_response);
308 yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
309 break;
310 }
311 }
312 }
313
314 async fn call_llm_with_retry(
316 llm: &dyn BaseChatModel,
317 messages: Vec<Message>,
318 tools: Option<Vec<ToolDefinition>>,
319 tool_choice: Option<crate::llm::ToolChoice>,
320 ) -> Result<ChatCompletion> {
321 let max_retries = 10;
322 let mut delay = std::time::Duration::from_millis(500);
323
324 for attempt in 0..=max_retries {
325 match llm
326 .invoke(messages.clone(), tools.clone(), tool_choice.clone())
327 .await
328 {
329 Ok(completion) => return Ok(completion),
330 Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
331 tracing::warn!(
332 "Rate limit or empty response, retrying in {:?} (attempt {}/{})",
333 delay,
334 attempt + 1,
335 max_retries
336 );
337 tokio::time::sleep(delay).await;
338 delay *= 2;
339 }
340 Err(e) => return Err(Error::Llm(e)),
341 }
342 }
343
344 Err(Error::Agent("Max retries exceeded".into()))
345 }
346
347 pub async fn get_usage(&self) -> UsageSummary {
349 self.usage.read().await.clone()
350 }
351
352 fn destroy_ephemeral_messages(
354 history: &mut [Message],
355 ephemeral_config: &HashMap<String, EphemeralConfig>,
356 ) {
357 let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
359
360 for (idx, msg) in history.iter().enumerate() {
361 let tool_msg = match msg {
362 Message::Tool(t) => t,
363 _ => continue,
364 };
365
366 if !tool_msg.ephemeral || tool_msg.destroyed {
367 continue;
368 }
369
370 let tool_name = match &tool_msg.tool_name {
371 Some(name) => name.clone(),
372 None => continue,
373 };
374
375 ephemeral_indices_by_tool
376 .entry(tool_name)
377 .or_default()
378 .push(idx);
379 }
380
381 let mut indices_to_destroy: Vec<usize> = Vec::new();
383
384 for (tool_name, indices) in ephemeral_indices_by_tool {
385 let keep_count = ephemeral_config
386 .get(&tool_name)
387 .map(|c| c.keep_count)
388 .unwrap_or(1);
389
390 let to_destroy = if keep_count > 0 && indices.len() > keep_count {
392 &indices[..indices.len() - keep_count]
393 } else {
394 &indices[..]
395 };
396
397 indices_to_destroy.extend(to_destroy.iter().copied());
398 }
399
400 for idx in indices_to_destroy {
402 if let Message::Tool(tool_msg) = &mut history[idx] {
403 debug!("Destroying ephemeral message at index {}", idx);
404 tool_msg.destroy();
405 }
406 }
407 }
408
409 pub async fn clear_history(&self) {
411 let mut history = self.history.write().await;
412 history.clear();
413 }
414
415 pub async fn load_history(&self, messages: Vec<Message>) {
417 let mut history = self.history.write().await;
418 *history = messages;
419 }
420
421 pub async fn get_history(&self) -> Vec<Message> {
423 self.history.read().await.clone()
424 }
425
426 pub fn has_memory(&self) -> bool {
428 self.memory.is_some()
429 }
430
431 pub fn get_memory(&self) -> Option<&Arc<RwLock<MemoryManager>>> {
433 self.memory.as_ref()
434 }
435}