mod builder;
pub use builder::{AgentBuilder, AgentConfig};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::error::{Error, Result};
use crate::llm::{LlmProvider, Message};
use crate::memory::Memory;
use crate::tools::{Tool, ToolRegistry, ToolResult};
pub struct Agent {
id: String,
name: String,
instructions: String,
llm: Arc<dyn LlmProvider>,
tools: Arc<RwLock<ToolRegistry>>,
memory: Arc<RwLock<Memory>>,
config: AgentConfig,
}
impl Agent {
#[allow(clippy::new_ret_no_self)]
pub fn new() -> AgentBuilder {
AgentBuilder::new()
}
pub fn simple(instructions: impl Into<String>) -> Result<Self> {
AgentBuilder::new().instructions(instructions).build()
}
pub fn id(&self) -> &str {
&self.id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn instructions(&self) -> &str {
&self.instructions
}
pub fn model(&self) -> &str {
self.llm.model()
}
pub async fn chat(&self, prompt: &str) -> Result<String> {
let mut messages = vec![Message::system(&self.instructions)];
{
let memory = self.memory.read().await;
let history = memory.history().await?;
messages.extend(history);
}
let user_msg = Message::user(prompt);
messages.push(user_msg.clone());
{
let mut memory = self.memory.write().await;
memory.store(user_msg).await?;
}
let tool_defs = {
let tools = self.tools.read().await;
if tools.is_empty() {
None
} else {
Some(tools.definitions())
}
};
let mut iterations = 0;
let max_iterations = self.config.max_iterations;
loop {
iterations += 1;
if iterations > max_iterations {
return Err(Error::agent(format!(
"Max iterations ({}) exceeded",
max_iterations
)));
}
let response = self.llm.chat(&messages, tool_defs.as_deref()).await?;
if response.tool_calls.is_empty() {
let assistant_msg = Message::assistant(&response.content);
{
let mut memory = self.memory.write().await;
memory.store(assistant_msg).await?;
}
return Ok(response.content);
}
let mut assistant_msg = Message::assistant(&response.content);
assistant_msg.tool_calls = Some(response.tool_calls.clone());
messages.push(assistant_msg);
for tool_call in &response.tool_calls {
let args: serde_json::Value = serde_json::from_str(tool_call.arguments())
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
let result = {
let tools = self.tools.read().await;
tools.execute(tool_call.name(), args).await
};
let tool_result = match result {
Ok(r) => r,
Err(e) => ToolResult::failure(tool_call.name(), e.to_string()),
};
let result_str = if tool_result.success {
serde_json::to_string(&tool_result.value).unwrap_or_default()
} else {
format!("Error: {}", tool_result.error.unwrap_or_default())
};
messages.push(Message::tool(&tool_call.id, result_str));
}
}
}
pub async fn start(&self, prompt: &str) -> Result<String> {
self.chat(prompt).await
}
pub async fn run(&self, task: &str) -> Result<String> {
self.chat(task).await
}
pub async fn add_tool(&self, tool: impl Tool + 'static) {
let mut tools = self.tools.write().await;
tools.register(tool);
}
pub async fn tool_count(&self) -> usize {
let tools = self.tools.read().await;
tools.len()
}
pub async fn clear_memory(&self) -> Result<()> {
let mut memory = self.memory.write().await;
memory.clear().await
}
pub async fn history(&self) -> Result<Vec<Message>> {
let memory = self.memory.read().await;
memory.history().await
}
}
impl Default for Agent {
fn default() -> Self {
AgentBuilder::new()
.build()
.expect("Default agent should build")
}
}
impl std::fmt::Debug for Agent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Agent")
.field("id", &self.id)
.field("name", &self.name)
.field("model", &self.llm.model())
.finish()
}
}
impl Agent {
pub(crate) fn from_builder(
name: String,
instructions: String,
llm: Arc<dyn LlmProvider>,
tools: ToolRegistry,
memory: Memory,
config: AgentConfig,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
name,
instructions,
llm,
tools: Arc::new(RwLock::new(tools)),
memory: Arc::new(RwLock::new(memory)),
config,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_builder() {
let agent = Agent::new()
.name("test")
.instructions("Be helpful")
.build()
.unwrap();
assert_eq!(agent.name(), "test");
assert_eq!(agent.instructions(), "Be helpful");
}
#[test]
fn test_default_agent() {
let agent = Agent::default();
assert!(!agent.id().is_empty());
}
}