use crate::{AgentHistory, Result};
use rig::{
agent::Agent,
completion::{Chat, CompletionModel, Message},
};
use serde_json::json;
use std::collections::HashMap;
use std::time::Instant;
pub struct SmartAgent<M: CompletionModel> {
agent: Agent<M>,
history: AgentHistory,
recall_top_k: usize,
summarize_every: usize,
turn_count: usize,
}
impl<M: CompletionModel + 'static> SmartAgent<M> {
pub fn new(agent: Agent<M>, history: AgentHistory) -> Self {
Self {
agent,
history,
recall_top_k: 4,
summarize_every: 20,
turn_count: 0,
}
}
pub fn with_recall_top_k(mut self, k: usize) -> Self {
self.recall_top_k = k;
self
}
pub fn with_summarize_every(mut self, n: usize) -> Self {
self.summarize_every = n;
self
}
pub async fn chat(&mut self, user_input: &str) -> Result<String> {
let start = Instant::now();
let relevant_traces =
self.history.search(user_input, self.recall_top_k, false).await?;
let mut context_messages = Vec::new();
if !relevant_traces.is_empty() {
let mut recall_context =
String::from("Relevant past experiences:\n\n");
for (i, trace) in relevant_traces.iter().enumerate() {
recall_context.push_str(&format!(
"{}. [{}] {}: {}\n",
i + 1,
trace.created_at.format("%Y-%m-%d %H:%M"),
trace.role,
trace.content
));
}
context_messages.push(Message {
role: "system".to_string(),
content: recall_context,
});
}
let user_message = Message {
role: "user".to_string(),
content: user_input.to_string(),
};
let mut user_metadata = HashMap::new();
user_metadata.insert(
"recalled_traces".to_string(),
json!(relevant_traces.len()),
);
self.history.log_turn(&user_message, user_metadata).await?;
let response = self
.agent
.chat(user_input, context_messages)
.await
.map_err(|e| crate::Error::Rig(e.to_string()))?;
let duration = start.elapsed();
let assistant_message = Message {
role: "assistant".to_string(),
content: response.clone(),
};
let mut metadata = HashMap::new();
metadata
.insert("duration_ms".to_string(), json!(duration.as_millis()));
metadata.insert("success".to_string(), json!(true));
metadata.insert("tokens_used".to_string(), json!(null));
self.history.log_turn(&assistant_message, metadata).await?;
self.turn_count += 1;
if self.turn_count.is_multiple_of(self.summarize_every) {
let _ = self.history.summarize_session(&self.agent).await;
}
Ok(response)
}
pub fn agent(&self) -> &Agent<M> {
&self.agent
}
pub fn agent_mut(&mut self) -> &mut Agent<M> {
&mut self.agent
}
pub fn history(&self) -> &AgentHistory {
&self.history
}
pub fn turn_count(&self) -> usize {
self.turn_count
}
pub async fn summarize(&self) -> Result<String> {
self.history.summarize_session(&self.agent).await
}
}