use crate::error::{AgentError, Result};
use crate::memory::{Memory, MemoryStore};
use crate::message::{Message, MessageRole};
use crate::tool::ToolRegistry;
use crate::provider::{ModelConfig, ModelProvider};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub name: String,
pub description: String,
pub system_prompt: Option<String>,
pub max_iterations: usize,
pub model_config: ModelConfig,
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
name: "Agent".to_string(),
description: "An AI agent".to_string(),
system_prompt: None,
max_iterations: 10,
model_config: ModelConfig::default(),
}
}
}
#[async_trait]
pub trait AgentExecutor: Send + Sync {
async fn execute(&self, messages: Vec<Message>) -> Result<String>;
}
pub struct Agent {
config: AgentConfig,
memory: Arc<RwLock<Box<dyn Memory>>>,
tools: ToolRegistry,
executor: Option<Arc<dyn AgentExecutor>>,
provider: Option<Arc<dyn ModelProvider>>,
}
impl Agent {
pub fn builder() -> AgentBuilder {
AgentBuilder::default()
}
pub async fn run(&mut self, input: impl Into<String>) -> Result<String> {
let user_message = Message::user(input);
self.memory.write().await.add(user_message).await?;
if let Some(provider) = &self.provider {
for iteration in 0..self.config.max_iterations {
let mut messages = self.memory.read().await.get_all().await?;
if let Some(system_prompt) = &self.config.system_prompt {
if messages.is_empty() || messages[0].role != MessageRole::System {
messages.insert(0, Message::system(system_prompt));
}
}
let response = provider.complete(messages, &self.config.model_config).await?;
let assistant_message = Message::assistant(&response.content);
self.memory.write().await.add(assistant_message).await?;
if !self.should_continue(&response.content) {
return Ok(response.content);
}
if iteration == self.config.max_iterations - 1 {
return Err(AgentError::ExecutionError(
"Max iterations reached".to_string()
));
}
}
} else if let Some(executor) = &self.executor {
for iteration in 0..self.config.max_iterations {
let messages = self.memory.read().await.get_all().await?;
let response = executor.execute(messages).await?;
let assistant_message = Message::assistant(&response);
self.memory.write().await.add(assistant_message).await?;
if !self.should_continue(&response) {
return Ok(response);
}
if iteration == self.config.max_iterations - 1 {
return Err(AgentError::ExecutionError(
"Max iterations reached".to_string()
));
}
}
} else {
return Err(AgentError::InvalidConfig(
"No provider or executor configured".to_string()
));
}
Err(AgentError::ExecutionError("Unexpected termination".to_string()))
}
pub async fn add_message(&mut self, message: Message) -> Result<()> {
self.memory.write().await.add(message).await
}
pub async fn get_history(&self) -> Result<Vec<Message>> {
self.memory.read().await.get_all().await
}
pub async fn clear_history(&mut self) -> Result<()> {
self.memory.write().await.clear().await
}
pub fn config(&self) -> &AgentConfig {
&self.config
}
pub fn tools(&self) -> &ToolRegistry {
&self.tools
}
fn should_continue(&self, response: &str) -> bool {
!response.contains("[DONE]") && !response.contains("[FINAL]")
}
}
pub struct AgentBuilder {
config: AgentConfig,
memory: Option<Box<dyn Memory>>,
tools: ToolRegistry,
executor: Option<Arc<dyn AgentExecutor>>,
provider: Option<Arc<dyn ModelProvider>>,
}
impl Default for AgentBuilder {
fn default() -> Self {
Self {
config: AgentConfig::default(),
memory: None,
tools: ToolRegistry::new(),
executor: None,
provider: None,
}
}
}
impl AgentBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn config(mut self, config: AgentConfig) -> Self {
self.config = config;
self
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.config.name = name.into();
self
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.config.description = description.into();
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.config.system_prompt = Some(prompt.into());
self
}
pub fn max_iterations(mut self, max: usize) -> Self {
self.config.max_iterations = max;
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.config.model_config.temperature = temp;
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.config.model_config.model = model.into();
self
}
pub fn max_tokens(mut self, max_tokens: usize) -> Self {
self.config.model_config.max_tokens = Some(max_tokens);
self
}
pub fn model_config(mut self, config: ModelConfig) -> Self {
self.config.model_config = config;
self
}
pub fn provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
self.provider = Some(provider);
self
}
pub fn memory(mut self, memory: Box<dyn Memory>) -> Self {
self.memory = Some(memory);
self
}
pub fn tools(mut self, tools: ToolRegistry) -> Self {
self.tools = tools;
self
}
pub fn executor(mut self, executor: Arc<dyn AgentExecutor>) -> Self {
self.executor = Some(executor);
self
}
pub fn build(self) -> Agent {
let memory = self.memory.unwrap_or_else(|| Box::new(MemoryStore::new()));
Agent {
config: self.config,
memory: Arc::new(RwLock::new(memory)),
tools: self.tools,
executor: self.executor,
provider: self.provider,
}
}
}