1use std::collections::HashMap;
4use std::sync::Arc;
5
6use futures::{Stream, StreamExt};
7use tokio::sync::RwLock;
8use tracing::debug;
9
10use crate::agent::{
11 AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
12};
13use crate::llm::{
14 AssistantMessage, BaseChatModel, ChatCompletion, Message, ToolDefinition, ToolMessage,
15};
16use crate::memory::{MemoryManager, MemoryType};
17use crate::tools::Tool;
18use crate::{Error, Result};
19
20use super::builder::AgentBuilder;
21use super::config::{AgentConfig, EphemeralConfig};
22
23pub struct Agent {
25 llm: Arc<dyn BaseChatModel>,
27 tools: Vec<Arc<dyn Tool>>,
29 config: AgentConfig,
31 history: Arc<RwLock<Vec<Message>>>,
33 usage: Arc<RwLock<UsageSummary>>,
35 ephemeral_config: HashMap<String, EphemeralConfig>,
37 memory: Option<Arc<RwLock<MemoryManager>>>,
39}
40
41impl Agent {
42 pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
44 let ephemeral_config = tools
46 .iter()
47 .filter_map(|t| {
48 let cfg = t.ephemeral();
49 if cfg != crate::tools::EphemeralConfig::None {
50 let keep_count = match cfg {
51 crate::tools::EphemeralConfig::Single => 1,
52 crate::tools::EphemeralConfig::Count(n) => n,
53 crate::tools::EphemeralConfig::None => 0,
54 };
55 Some((t.name().to_string(), EphemeralConfig { keep_count }))
56 } else {
57 None
58 }
59 })
60 .collect();
61
62 Self {
63 llm,
64 tools,
65 config: AgentConfig::default(),
66 history: Arc::new(RwLock::new(Vec::new())),
67 usage: Arc::new(RwLock::new(UsageSummary::new())),
68 ephemeral_config,
69 memory: None,
70 }
71 }
72
73 pub fn builder() -> AgentBuilder {
75 AgentBuilder::default()
76 }
77
78 pub fn with_config(mut self, config: AgentConfig) -> Self {
80 self.config = config;
81 self
82 }
83
84 pub(super) fn new_with_config(
86 llm: Arc<dyn BaseChatModel>,
87 tools: Vec<Arc<dyn Tool>>,
88 config: AgentConfig,
89 ephemeral_config: HashMap<String, EphemeralConfig>,
90 memory: Option<Arc<RwLock<MemoryManager>>>,
91 ) -> Self {
92 Self {
93 llm,
94 tools,
95 config,
96 history: Arc::new(RwLock::new(Vec::new())),
97 usage: Arc::new(RwLock::new(UsageSummary::new())),
98 ephemeral_config,
99 memory,
100 }
101 }
102
103 pub async fn query(&self, message: impl Into<String>) -> Result<String> {
105 {
107 let mut history = self.history.write().await;
108 history.push(Message::user(message.into()));
109 }
110
111 let stream = self.execute_loop();
113 futures::pin_mut!(stream);
114
115 while let Some(event) = stream.next().await {
116 if let AgentEvent::FinalResponse(response) = event {
117 return Ok(response.content);
118 }
119 }
120
121 Err(Error::Agent("No final response received".into()))
122 }
123
124 pub async fn query_with_memory(&self, message: impl Into<String>) -> Result<String> {
126 let message = message.into();
127
128 let context = if let Some(memory) = &self.memory {
130 let mem = memory.read().await;
131 mem.recall_context(&message).await?
132 } else {
133 String::new()
134 };
135
136 let enhanced_message = if context.is_empty() {
138 message.clone()
139 } else {
140 format!(
141 "Relevant context from memory:\n{}\n\nUser query: {}",
142 context, message
143 )
144 };
145
146 let result = self.query(enhanced_message).await?;
148
149 if let Some(memory) = &self.memory {
151 let mut mem = memory.write().await;
152 mem.remember(&message, MemoryType::ShortTerm).await?;
153 }
154
155 Ok(result)
156 }
157
158 pub async fn query_stream<'a, M: Into<String> + 'a>(
160 &'a self,
161 message: M,
162 ) -> Result<impl Stream<Item = AgentEvent> + 'a> {
163 {
165 let mut history = self.history.write().await;
166 history.push(Message::user(message.into()));
167 }
168
169 Ok(self.execute_loop())
170 }
171
172 fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
174 async_stream::stream! {
175 let mut step = 0;
176
177 loop {
178 if step >= self.config.max_iterations {
179 yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
180 break;
181 }
182
183 yield AgentEvent::StepStart(StepStartEvent::new(step));
184
185 {
187 let mut h = self.history.write().await;
188 Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
189 }
190
191 let messages = {
193 let h = self.history.read().await;
194 h.clone()
195 };
196
197 let mut full_messages = Vec::new();
199 if let Some(ref prompt) = self.config.system_prompt
200 && step == 0 {
201 full_messages.push(Message::system(prompt));
202 }
203 full_messages.extend(messages);
204
205 let tool_defs: Vec<ToolDefinition> = self.tools.iter()
207 .map(|t| t.definition())
208 .collect();
209
210 let completion = match Self::call_llm_with_retry(
212 self.llm.as_ref(),
213 full_messages.clone(),
214 if tool_defs.is_empty() { None } else { Some(tool_defs) },
215 Some(self.config.tool_choice.clone()),
216 ).await {
217 Ok(c) => c,
218 Err(e) => {
219 yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
220 break;
221 }
222 };
223
224 if let Some(ref u) = completion.usage {
226 let mut us = self.usage.write().await;
227 us.add_usage(self.llm.model(), u);
228 }
229
230 if let Some(ref thinking) = completion.thinking {
232 yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
233 }
234
235 if let Some(ref content) = completion.content {
237 yield AgentEvent::Text(crate::agent::TextEvent::new(content));
238 }
239
240 if completion.has_tool_calls() {
242 {
244 let mut h = self.history.write().await;
245 h.push(Message::Assistant(AssistantMessage {
246 role: "assistant".to_string(),
247 content: completion.content.clone(),
248 thinking: completion.thinking.clone(),
249 redacted_thinking: None,
250 tool_calls: completion.tool_calls.clone(),
251 refusal: None,
252 }));
253 }
254
255 for tool_call in &completion.tool_calls {
257 yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
258
259 let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
261
262 let result = if let Some(t) = tool {
263 let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
264 .unwrap_or(serde_json::json!({}));
265 t.execute(args).await
266 } else {
267 Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
268 };
269
270 match result {
271 Ok(tool_result) => {
272 yield AgentEvent::ToolResult(
273 crate::agent::ToolResultEvent::new(
274 &tool_call.id,
275 &tool_call.function.name,
276 &tool_result.content,
277 step,
278 ).with_ephemeral(tool_result.ephemeral)
279 );
280
281 {
283 let mut h = self.history.write().await;
284 let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
285 msg.tool_name = Some(tool_call.function.name.clone());
286 msg.ephemeral = tool_result.ephemeral;
287 h.push(Message::Tool(msg));
288 }
289 }
290 Err(e) => {
291 yield AgentEvent::Error(ErrorEvent::new(format!(
292 "Tool execution failed: {}",
293 e
294 )));
295 }
296 }
297 }
298
299 step += 1;
300 yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
301 continue;
302 }
303
304 let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
306 .with_steps(step);
307
308 yield AgentEvent::FinalResponse(final_response);
309 yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
310 break;
311 }
312 }
313 }
314
315 async fn call_llm_with_retry(
317 llm: &dyn BaseChatModel,
318 messages: Vec<Message>,
319 tools: Option<Vec<ToolDefinition>>,
320 tool_choice: Option<crate::llm::ToolChoice>,
321 ) -> Result<ChatCompletion> {
322 let max_retries = 3;
323 let mut delay = std::time::Duration::from_millis(100);
324
325 for attempt in 0..=max_retries {
326 match llm
327 .invoke(messages.clone(), tools.clone(), tool_choice.clone())
328 .await
329 {
330 Ok(completion) => return Ok(completion),
331 Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
332 tokio::time::sleep(delay).await;
333 delay *= 2;
334 }
335 Err(e) => return Err(Error::Llm(e)),
336 }
337 }
338
339 Err(Error::Agent("Max retries exceeded".into()))
340 }
341
342 pub async fn get_usage(&self) -> UsageSummary {
344 self.usage.read().await.clone()
345 }
346
347 fn destroy_ephemeral_messages(
349 history: &mut [Message],
350 ephemeral_config: &HashMap<String, EphemeralConfig>,
351 ) {
352 let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
354
355 for (idx, msg) in history.iter().enumerate() {
356 let tool_msg = match msg {
357 Message::Tool(t) => t,
358 _ => continue,
359 };
360
361 if !tool_msg.ephemeral || tool_msg.destroyed {
362 continue;
363 }
364
365 let tool_name = match &tool_msg.tool_name {
366 Some(name) => name.clone(),
367 None => continue,
368 };
369
370 ephemeral_indices_by_tool
371 .entry(tool_name)
372 .or_default()
373 .push(idx);
374 }
375
376 let mut indices_to_destroy: Vec<usize> = Vec::new();
378
379 for (tool_name, indices) in ephemeral_indices_by_tool {
380 let keep_count = ephemeral_config
381 .get(&tool_name)
382 .map(|c| c.keep_count)
383 .unwrap_or(1);
384
385 let to_destroy = if keep_count > 0 && indices.len() > keep_count {
387 &indices[..indices.len() - keep_count]
388 } else {
389 &indices[..]
390 };
391
392 indices_to_destroy.extend(to_destroy.iter().copied());
393 }
394
395 for idx in indices_to_destroy {
397 if let Message::Tool(tool_msg) = &mut history[idx] {
398 debug!("Destroying ephemeral message at index {}", idx);
399 tool_msg.destroy();
400 }
401 }
402 }
403
404 pub async fn clear_history(&self) {
406 let mut history = self.history.write().await;
407 history.clear();
408 }
409
410 pub async fn load_history(&self, messages: Vec<Message>) {
412 let mut history = self.history.write().await;
413 *history = messages;
414 }
415
416 pub async fn get_history(&self) -> Vec<Message> {
418 self.history.read().await.clone()
419 }
420
421 pub fn has_memory(&self) -> bool {
423 self.memory.is_some()
424 }
425
426 pub fn get_memory(&self) -> Option<&Arc<RwLock<MemoryManager>>> {
428 self.memory.as_ref()
429 }
430}