use std::collections::HashMap;
use std::sync::Arc;
use futures::{Stream, StreamExt};
use tokio::sync::RwLock;
use tracing::debug;
use crate::agent::{
AgentEvent, ErrorEvent, FinalResponseEvent, StepCompleteEvent, StepStartEvent, UsageSummary,
};
use crate::llm::{
AssistantMessage, BaseChatModel, ChatCompletion, Message, ToolDefinition, ToolMessage,
};
use crate::memory::{MemoryManager, MemoryType};
use crate::tools::Tool;
use crate::{Error, Result};
use super::builder::AgentBuilder;
use super::config::{AgentConfig, EphemeralConfig, build_ephemeral_config};
pub struct Agent {
llm: Arc<dyn BaseChatModel>,
tools: Vec<Arc<dyn Tool>>,
config: AgentConfig,
history: Arc<RwLock<Vec<Message>>>,
usage: Arc<RwLock<UsageSummary>>,
ephemeral_config: HashMap<String, EphemeralConfig>,
memory: Option<Arc<RwLock<MemoryManager>>>,
}
impl Agent {
pub fn new(llm: Arc<dyn BaseChatModel>, tools: Vec<Arc<dyn Tool>>) -> Self {
let ephemeral_config = build_ephemeral_config(&tools);
Self {
llm,
tools,
config: AgentConfig::default(),
history: Arc::new(RwLock::new(Vec::new())),
usage: Arc::new(RwLock::new(UsageSummary::new())),
ephemeral_config,
memory: None,
}
}
pub fn builder() -> AgentBuilder {
AgentBuilder::default()
}
pub fn with_config(mut self, config: AgentConfig) -> Self {
self.config = config;
self
}
pub(super) fn new_with_config(
llm: Arc<dyn BaseChatModel>,
tools: Vec<Arc<dyn Tool>>,
config: AgentConfig,
ephemeral_config: HashMap<String, EphemeralConfig>,
memory: Option<Arc<RwLock<MemoryManager>>>,
) -> Self {
Self {
llm,
tools,
config,
history: Arc::new(RwLock::new(Vec::new())),
usage: Arc::new(RwLock::new(UsageSummary::new())),
ephemeral_config,
memory,
}
}
pub async fn query(&self, message: impl Into<String>) -> Result<String> {
{
let mut history = self.history.write().await;
history.push(Message::user(message.into()));
}
let stream = self.execute_loop();
futures::pin_mut!(stream);
while let Some(event) = stream.next().await {
if let AgentEvent::FinalResponse(response) = event {
return Ok(response.content);
}
}
Err(Error::Agent("No final response received".into()))
}
pub async fn query_with_memory(&self, message: impl Into<String>) -> Result<String> {
let message = message.into();
let context = self.recall_memory_context(&message).await?;
{
let mut history = self.history.write().await;
if let Some(context_message) = self.memory_context_message(context) {
history.push(context_message);
}
history.push(Message::user(message.clone()));
}
let stream = self.execute_loop();
futures::pin_mut!(stream);
while let Some(event) = stream.next().await {
if let AgentEvent::FinalResponse(response) = event {
self.remember_short_term(&message).await?;
return Ok(response.content);
}
}
Err(Error::Agent("No final response received".into()))
}
pub async fn query_stream<'a, M: Into<String> + 'a>(
&'a self,
message: M,
) -> Result<impl Stream<Item = AgentEvent> + 'a> {
{
let mut history = self.history.write().await;
history.push(Message::user(message.into()));
}
Ok(self.execute_loop())
}
async fn recall_memory_context(&self, message: &str) -> Result<String> {
if let Some(memory) = &self.memory {
let mem = memory.read().await;
mem.recall_context(message).await
} else {
Ok(String::new())
}
}
fn memory_context_message(&self, context: String) -> Option<Message> {
if context.is_empty() {
None
} else {
Some(Message::developer(format!(
"Relevant memory context:\n{}",
context
)))
}
}
async fn remember_short_term(&self, message: &str) -> Result<()> {
if let Some(memory) = &self.memory {
let mut mem = memory.write().await;
mem.remember(message, MemoryType::ShortTerm).await?;
}
Ok(())
}
async fn build_request_messages(&self) -> Vec<Message> {
let history = self.history.read().await;
let mut messages =
Vec::with_capacity(history.len() + usize::from(self.config.system_prompt.is_some()));
if let Some(ref prompt) = self.config.system_prompt {
messages.push(Message::system(prompt));
}
messages.extend(history.iter().cloned());
messages
}
fn execute_loop(&self) -> impl Stream<Item = AgentEvent> + '_ {
async_stream::stream! {
let mut step = 0;
loop {
if step >= self.config.max_iterations {
yield AgentEvent::Error(ErrorEvent::new("Max iterations exceeded"));
break;
}
yield AgentEvent::StepStart(StepStartEvent::new(step));
{
let mut h = self.history.write().await;
Self::destroy_ephemeral_messages(&mut h, &self.ephemeral_config);
}
let full_messages = self.build_request_messages().await;
let tool_defs: Vec<ToolDefinition> = self.tools.iter()
.map(|t| t.definition())
.collect();
let tool_defs = if tool_defs.is_empty() { None } else { Some(tool_defs) };
let tool_choice = self.config.tool_choice.clone();
let completion = match Self::call_llm_with_retry(
self.llm.as_ref(),
&full_messages,
tool_defs.as_deref(),
Some(&tool_choice),
).await {
Ok(c) => c,
Err(e) => {
yield AgentEvent::Error(ErrorEvent::new(e.to_string()));
break;
}
};
if let Some(ref u) = completion.usage {
let mut us = self.usage.write().await;
us.add_usage(self.llm.model(), u);
}
if let Some(ref thinking) = completion.thinking {
yield AgentEvent::Thinking(crate::agent::ThinkingEvent::new(thinking));
}
if let Some(ref content) = completion.content {
yield AgentEvent::Text(crate::agent::TextEvent::new(content));
}
if completion.has_tool_calls() {
{
let mut h = self.history.write().await;
h.push(Message::Assistant(AssistantMessage {
role: "assistant".to_string(),
content: completion.content.clone(),
thinking: completion.thinking.clone(),
redacted_thinking: None,
tool_calls: completion.tool_calls.clone(),
refusal: None,
}));
}
for tool_call in &completion.tool_calls {
yield AgentEvent::ToolCall(crate::agent::ToolCallEvent::new(tool_call, step));
let tool = self.tools.iter().find(|t| t.name() == tool_call.function.name);
let result = if let Some(t) = tool {
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
.unwrap_or(serde_json::json!({}));
t.execute(args).await
} else {
Ok(crate::tools::ToolResult::new(&tool_call.id, format!("Unknown tool: {}", tool_call.function.name)))
};
match result {
Ok(tool_result) => {
yield AgentEvent::ToolResult(
crate::agent::ToolResultEvent::new(
&tool_call.id,
&tool_call.function.name,
&tool_result.content,
step,
).with_ephemeral(tool_result.ephemeral)
);
{
let mut h = self.history.write().await;
let mut msg = ToolMessage::new(&tool_call.id, tool_result.content);
msg.tool_name = Some(tool_call.function.name.clone());
msg.ephemeral = tool_result.ephemeral;
h.push(Message::Tool(msg));
}
}
Err(e) => {
yield AgentEvent::Error(ErrorEvent::new(format!(
"Tool execution failed: {}",
e
)));
}
}
}
step += 1;
yield AgentEvent::StepComplete(StepCompleteEvent::new(step - 1));
continue;
}
let final_response = FinalResponseEvent::new(completion.content.clone().unwrap_or_default())
.with_steps(step);
yield AgentEvent::FinalResponse(final_response);
yield AgentEvent::StepComplete(StepCompleteEvent::new(step));
break;
}
}
}
async fn call_llm_with_retry(
llm: &dyn BaseChatModel,
messages: &[Message],
tools: Option<&[ToolDefinition]>,
tool_choice: Option<&crate::llm::ToolChoice>,
) -> Result<ChatCompletion> {
let max_retries = 10;
let mut delay = std::time::Duration::from_millis(500);
for attempt in 0..=max_retries {
let request_messages = messages.to_vec();
let request_tools = tools.map(|value| value.to_vec());
let request_tool_choice = tool_choice.cloned();
match llm
.invoke(request_messages, request_tools, request_tool_choice)
.await
{
Ok(completion) => return Ok(completion),
Err(crate::llm::LlmError::RateLimit) if attempt < max_retries => {
tracing::warn!(
"Rate limit or empty response, retrying in {:?} (attempt {}/{})",
delay,
attempt + 1,
max_retries
);
tokio::time::sleep(delay).await;
delay *= 2;
}
Err(e) => return Err(Error::Llm(e)),
}
}
Err(Error::Agent("Max retries exceeded".into()))
}
pub async fn get_usage(&self) -> UsageSummary {
self.usage.read().await.clone()
}
fn destroy_ephemeral_messages(
history: &mut [Message],
ephemeral_config: &HashMap<String, EphemeralConfig>,
) {
let mut ephemeral_indices_by_tool: HashMap<String, Vec<usize>> = HashMap::new();
for (idx, msg) in history.iter().enumerate() {
let tool_msg = match msg {
Message::Tool(t) => t,
_ => continue,
};
if !tool_msg.ephemeral || tool_msg.destroyed {
continue;
}
let tool_name = match &tool_msg.tool_name {
Some(name) => name.clone(),
None => continue,
};
ephemeral_indices_by_tool
.entry(tool_name)
.or_default()
.push(idx);
}
let mut indices_to_destroy: Vec<usize> = Vec::new();
for (tool_name, indices) in ephemeral_indices_by_tool {
let keep_count = ephemeral_config
.get(&tool_name)
.map(|c| c.keep_count)
.unwrap_or(1);
let to_destroy = if keep_count > 0 && indices.len() > keep_count {
&indices[..indices.len() - keep_count]
} else {
&indices[..]
};
indices_to_destroy.extend(to_destroy.iter().copied());
}
for idx in indices_to_destroy {
if let Message::Tool(tool_msg) = &mut history[idx] {
debug!("Destroying ephemeral message at index {}", idx);
tool_msg.destroy();
}
}
}
pub async fn clear_history(&self) {
let mut history = self.history.write().await;
history.clear();
}
pub async fn load_history(&self, messages: Vec<Message>) {
let mut history = self.history.write().await;
*history = messages;
}
pub async fn get_history(&self) -> Vec<Message> {
self.history.read().await.clone()
}
pub fn has_memory(&self) -> bool {
self.memory.is_some()
}
pub fn get_memory(&self) -> Option<&Arc<RwLock<MemoryManager>>> {
self.memory.as_ref()
}
}