1use std::collections::HashMap;
4use std::sync::Arc;
5
6use futures::{Stream, StreamExt};
7use tokio::sync::RwLock;
8use tracing::debug;
9
10use crate::agent::{
11 AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
12};
13use crate::llm::{
14 AssistantMessage, BaseChatModel, ChatCompletion, Message, ToolDefinition, ToolMessage,
15};
16use crate::memory::{MemoryManager, MemoryType};
17use crate::tools::Tool;
18use crate::{Error, Result};
19
20use super::builder::AgentBuilder;
21use super::config::{AgentConfig, EphemeralConfig, build_ephemeral_config};
22
23pub struct Agent {
25 llm: Arc<dyn BaseChatModel>,
27 tools: Vec<Arc<dyn Tool>>,
29 config: AgentConfig,
31 history: Arc<RwLock<Vec<Message>>>,
33 usage: Arc<RwLock<UsageSummary>>,
35 ephemeral_config: HashMap<String, EphemeralConfig>,
37 memory: Option<Arc<RwLock<MemoryManager>>>,
39}
40
41impl Agent {
42 pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
44 let ephemeral_config = build_ephemeral_config(&tools);
45
46 Self {
47 llm,
48 tools,
49 config: AgentConfig::default(),
50 history: Arc::new(RwLock::new(Vec::new())),
51 usage: Arc::new(RwLock::new(UsageSummary::new())),
52 ephemeral_config,
53 memory: None,
54 }
55 }
56
57 pub fn builder() -> AgentBuilder {
59 AgentBuilder::default()
60 }
61
62 pub fn with_config(mut self, config: AgentConfig) -> Self {
64 self.config = config;
65 self
66 }
67
68 pub(super) fn new_with_config(
70 llm: Arc<dyn BaseChatModel>,
71 tools: Vec<Arc<dyn Tool>>,
72 config: AgentConfig,
73 ephemeral_config: HashMap<String, EphemeralConfig>,
74 memory: Option<Arc<RwLock<MemoryManager>>>,
75 ) -> Self {
76 Self {
77 llm,
78 tools,
79 config,
80 history: Arc::new(RwLock::new(Vec::new())),
81 usage: Arc::new(RwLock::new(UsageSummary::new())),
82 ephemeral_config,
83 memory,
84 }
85 }
86
87 pub async fn query(&self, message: impl Into<String>) -> Result<String> {
89 {
91 let mut history = self.history.write().await;
92 history.push(Message::user(message.into()));
93 }
94
95 let stream = self.execute_loop();
97 futures::pin_mut!(stream);
98
99 while let Some(event) = stream.next().await {
100 if let AgentEvent::FinalResponse(response) = event {
101 return Ok(response.content);
102 }
103 }
104
105 Err(Error::Agent("No final response received".into()))
106 }
107
108 pub async fn query_with_memory(&self, message: impl Into<String>) -> Result<String> {
110 let message = message.into();
111 let context = self.recall_memory_context(&message).await?;
112
113 {
114 let mut history = self.history.write().await;
115 if let Some(context_message) = self.memory_context_message(context) {
116 history.push(context_message);
117 }
118 history.push(Message::user(message.clone()));
119 }
120
121 let stream = self.execute_loop();
122 futures::pin_mut!(stream);
123
124 while let Some(event) = stream.next().await {
125 if let AgentEvent::FinalResponse(response) = event {
126 self.remember_short_term(&message).await?;
127 return Ok(response.content);
128 }
129 }
130
131 Err(Error::Agent("No final response received".into()))
132 }
133
134 pub async fn query_stream<'a, M: Into<String> + 'a>(
136 &'a self,
137 message: M,
138 ) -> Result<impl Stream<Item = AgentEvent> + 'a> {
139 {
141 let mut history = self.history.write().await;
142 history.push(Message::user(message.into()));
143 }
144
145 Ok(self.execute_loop())
146 }
147
148 async fn recall_memory_context(&self, message: &str) -> Result<String> {
149 if let Some(memory) = &self.memory {
150 let mem = memory.read().await;
151 mem.recall_context(message).await
152 } else {
153 Ok(String::new())
154 }
155 }
156
157 fn memory_context_message(&self, context: String) -> Option<Message> {
158 if context.is_empty() {
159 None
160 } else {
161 Some(Message::developer(format!(
162 "Relevant memory context:\n{}",
163 context
164 )))
165 }
166 }
167
168 async fn remember_short_term(&self, message: &str) -> Result<()> {
169 if let Some(memory) = &self.memory {
170 let mut mem = memory.write().await;
171 mem.remember(message, MemoryType::ShortTerm).await?;
172 }
173 Ok(())
174 }
175
176 async fn build_request_messages(&self) -> Vec<Message> {
177 let history = self.history.read().await;
178 let mut messages =
179 Vec::with_capacity(history.len() + usize::from(self.config.system_prompt.is_some()));
180 if let Some(ref prompt) = self.config.system_prompt {
181 messages.push(Message::system(prompt));
182 }
183 messages.extend(history.iter().cloned());
184 messages
185 }
186
187 fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
189 async_stream::stream! {
190 let mut step = 0;
191
192 loop {
193 if step >= self.config.max_iterations {
194 yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
195 break;
196 }
197
198 yield AgentEvent::StepStart(StepStartEvent::new(step));
199
200 {
202 let mut h = self.history.write().await;
203 Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
204 }
205
206 let full_messages = self.build_request_messages().await;
207
208 let tool_defs: Vec<ToolDefinition> = self.tools.iter()
210 .map(|t| t.definition())
211 .collect();
212 let tool_defs = if tool_defs.is_empty() { None } else { Some(tool_defs) };
213 let tool_choice = self.config.tool_choice.clone();
214
215 let completion = match Self::call_llm_with_retry(
217 self.llm.as_ref(),
218 &full_messages,
219 tool_defs.as_deref(),
220 Some(&tool_choice),
221 ).await {
222 Ok(c) => c,
223 Err(e) => {
224 yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
225 break;
226 }
227 };
228
229 if let Some(ref u) = completion.usage {
231 let mut us = self.usage.write().await;
232 us.add_usage(self.llm.model(), u);
233 }
234
235 if let Some(ref thinking) = completion.thinking {
237 yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
238 }
239
240 if let Some(ref content) = completion.content {
242 yield AgentEvent::Text(crate::agent::TextEvent::new(content));
243 }
244
245 if completion.has_tool_calls() {
247 {
249 let mut h = self.history.write().await;
250 h.push(Message::Assistant(AssistantMessage {
251 role: "assistant".to_string(),
252 content: completion.content.clone(),
253 thinking: completion.thinking.clone(),
254 redacted_thinking: None,
255 tool_calls: completion.tool_calls.clone(),
256 refusal: None,
257 }));
258 }
259
260 for tool_call in &completion.tool_calls {
262 yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
263
264 let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
266
267 let result = if let Some(t) = tool {
268 let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
269 .unwrap_or(serde_json::json!({}));
270 t.execute(args).await
271 } else {
272 Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
273 };
274
275 match result {
276 Ok(tool_result) => {
277 yield AgentEvent::ToolResult(
278 crate::agent::ToolResultEvent::new(
279 &tool_call.id,
280 &tool_call.function.name,
281 &tool_result.content,
282 step,
283 ).with_ephemeral(tool_result.ephemeral)
284 );
285
286 {
288 let mut h = self.history.write().await;
289 let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
290 msg.tool_name = Some(tool_call.function.name.clone());
291 msg.ephemeral = tool_result.ephemeral;
292 h.push(Message::Tool(msg));
293 }
294 }
295 Err(e) => {
296 yield AgentEvent::Error(ErrorEvent::new(format!(
297 "Tool execution failed: {}",
298 e
299 )));
300 }
301 }
302 }
303
304 step += 1;
305 yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
306 continue;
307 }
308
309 let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
311 .with_steps(step);
312
313 yield AgentEvent::FinalResponse(final_response);
314 yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
315 break;
316 }
317 }
318 }
319
320 async fn call_llm_with_retry(
322 llm: &dyn BaseChatModel,
323 messages: &[Message],
324 tools: Option<&[ToolDefinition]>,
325 tool_choice: Option<&crate::llm::ToolChoice>,
326 ) -> Result<ChatCompletion> {
327 let max_retries = 10;
328 let mut delay = std::time::Duration::from_millis(500);
329
330 for attempt in 0..=max_retries {
331 let request_messages = messages.to_vec();
332 let request_tools = tools.map(|value| value.to_vec());
333 let request_tool_choice = tool_choice.cloned();
334
335 match llm
336 .invoke(request_messages, request_tools, request_tool_choice)
337 .await
338 {
339 Ok(completion) => return Ok(completion),
340 Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
341 tracing::warn!(
342 "Rate limit or empty response, retrying in {:?} (attempt {}/{})",
343 delay,
344 attempt + 1,
345 max_retries
346 );
347 tokio::time::sleep(delay).await;
348 delay *= 2;
349 }
350 Err(e) => return Err(Error::Llm(e)),
351 }
352 }
353
354 Err(Error::Agent("Max retries exceeded".into()))
355 }
356
357 pub async fn get_usage(&self) -> UsageSummary {
359 self.usage.read().await.clone()
360 }
361
362 fn destroy_ephemeral_messages(
364 history: &mut [Message],
365 ephemeral_config: &HashMap<String, EphemeralConfig>,
366 ) {
367 let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
369
370 for (idx, msg) in history.iter().enumerate() {
371 let tool_msg = match msg {
372 Message::Tool(t) => t,
373 _ => continue,
374 };
375
376 if !tool_msg.ephemeral || tool_msg.destroyed {
377 continue;
378 }
379
380 let tool_name = match &tool_msg.tool_name {
381 Some(name) => name.clone(),
382 None => continue,
383 };
384
385 ephemeral_indices_by_tool
386 .entry(tool_name)
387 .or_default()
388 .push(idx);
389 }
390
391 let mut indices_to_destroy: Vec<usize> = Vec::new();
393
394 for (tool_name, indices) in ephemeral_indices_by_tool {
395 let keep_count = ephemeral_config
396 .get(&tool_name)
397 .map(|c| c.keep_count)
398 .unwrap_or(1);
399
400 let to_destroy = if keep_count > 0 && indices.len() > keep_count {
402 &indices[..indices.len() - keep_count]
403 } else {
404 &indices[..]
405 };
406
407 indices_to_destroy.extend(to_destroy.iter().copied());
408 }
409
410 for idx in indices_to_destroy {
412 if let Message::Tool(tool_msg) = &mut history[idx] {
413 debug!("Destroying ephemeral message at index {}", idx);
414 tool_msg.destroy();
415 }
416 }
417 }
418
419 pub async fn clear_history(&self) {
421 let mut history = self.history.write().await;
422 history.clear();
423 }
424
425 pub async fn load_history(&self, messages: Vec<Message>) {
427 let mut history = self.history.write().await;
428 *history = messages;
429 }
430
431 pub async fn get_history(&self) -> Vec<Message> {
433 self.history.read().await.clone()
434 }
435
436 pub fn has_memory(&self) -> bool {
438 self.memory.is_some()
439 }
440
441 pub fn get_memory(&self) -> Option<&Arc<RwLock<MemoryManager>>> {
443 self.memory.as_ref()
444 }
445}