agent 0.0.1

A flexible AI Agent SDK for building intelligent agents
Documentation
use crate::error::{AgentError, Result};
use crate::memory::{Memory, MemoryStore};
use crate::message::{Message, MessageRole};
use crate::tool::ToolRegistry;
use crate::provider::{ModelConfig, ModelProvider};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
    pub name: String,
    pub description: String,
    pub system_prompt: Option<String>,
    pub max_iterations: usize,
    pub model_config: ModelConfig,
}

impl Default for AgentConfig {
    fn default() -> Self {
        Self {
            name: "Agent".to_string(),
            description: "An AI agent".to_string(),
            system_prompt: None,
            max_iterations: 10,
            model_config: ModelConfig::default(),
        }
    }
}

#[async_trait]
pub trait AgentExecutor: Send + Sync {
    async fn execute(&self, messages: Vec<Message>) -> Result<String>;
}

pub struct Agent {
    config: AgentConfig,
    memory: Arc<RwLock<Box<dyn Memory>>>,
    tools: ToolRegistry,
    executor: Option<Arc<dyn AgentExecutor>>,
    provider: Option<Arc<dyn ModelProvider>>,
}

impl Agent {
    pub fn builder() -> AgentBuilder {
        AgentBuilder::default()
    }

    pub async fn run(&mut self, input: impl Into<String>) -> Result<String> {
        let user_message = Message::user(input);
        self.memory.write().await.add(user_message).await?;

        // 优先使用 provider,如果没有则使用 executor
        if let Some(provider) = &self.provider {
            for iteration in 0..self.config.max_iterations {
                let mut messages = self.memory.read().await.get_all().await?;
                
                // 如果有 system_prompt,添加到消息开头
                if let Some(system_prompt) = &self.config.system_prompt {
                    if messages.is_empty() || messages[0].role != MessageRole::System {
                        messages.insert(0, Message::system(system_prompt));
                    }
                }
                
                let response = provider.complete(messages, &self.config.model_config).await?;
                
                let assistant_message = Message::assistant(&response.content);
                self.memory.write().await.add(assistant_message).await?;

                if !self.should_continue(&response.content) {
                    return Ok(response.content);
                }

                if iteration == self.config.max_iterations - 1 {
                    return Err(AgentError::ExecutionError(
                        "Max iterations reached".to_string()
                    ));
                }
            }
        } else if let Some(executor) = &self.executor {
            for iteration in 0..self.config.max_iterations {
                let messages = self.memory.read().await.get_all().await?;
                
                let response = executor.execute(messages).await?;
                
                let assistant_message = Message::assistant(&response);
                self.memory.write().await.add(assistant_message).await?;

                if !self.should_continue(&response) {
                    return Ok(response);
                }

                if iteration == self.config.max_iterations - 1 {
                    return Err(AgentError::ExecutionError(
                        "Max iterations reached".to_string()
                    ));
                }
            }
        } else {
            return Err(AgentError::InvalidConfig(
                "No provider or executor configured".to_string()
            ));
        }

        Err(AgentError::ExecutionError("Unexpected termination".to_string()))
    }

    pub async fn add_message(&mut self, message: Message) -> Result<()> {
        self.memory.write().await.add(message).await
    }

    pub async fn get_history(&self) -> Result<Vec<Message>> {
        self.memory.read().await.get_all().await
    }

    pub async fn clear_history(&mut self) -> Result<()> {
        self.memory.write().await.clear().await
    }

    pub fn config(&self) -> &AgentConfig {
        &self.config
    }

    pub fn tools(&self) -> &ToolRegistry {
        &self.tools
    }

    fn should_continue(&self, response: &str) -> bool {
        !response.contains("[DONE]") && !response.contains("[FINAL]")
    }
}

pub struct AgentBuilder {
    config: AgentConfig,
    memory: Option<Box<dyn Memory>>,
    tools: ToolRegistry,
    executor: Option<Arc<dyn AgentExecutor>>,
    provider: Option<Arc<dyn ModelProvider>>,
}

impl Default for AgentBuilder {
    fn default() -> Self {
        Self {
            config: AgentConfig::default(),
            memory: None,
            tools: ToolRegistry::new(),
            executor: None,
            provider: None,
        }
    }
}

impl AgentBuilder {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn config(mut self, config: AgentConfig) -> Self {
        self.config = config;
        self
    }

    pub fn name(mut self, name: impl Into<String>) -> Self {
        self.config.name = name.into();
        self
    }

    pub fn description(mut self, description: impl Into<String>) -> Self {
        self.config.description = description.into();
        self
    }

    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.config.system_prompt = Some(prompt.into());
        self
    }

    pub fn max_iterations(mut self, max: usize) -> Self {
        self.config.max_iterations = max;
        self
    }

    pub fn temperature(mut self, temp: f32) -> Self {
        self.config.model_config.temperature = temp;
        self
    }

    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.config.model_config.model = model.into();
        self
    }

    pub fn max_tokens(mut self, max_tokens: usize) -> Self {
        self.config.model_config.max_tokens = Some(max_tokens);
        self
    }

    pub fn model_config(mut self, config: ModelConfig) -> Self {
        self.config.model_config = config;
        self
    }

    pub fn provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
        self.provider = Some(provider);
        self
    }

    pub fn memory(mut self, memory: Box<dyn Memory>) -> Self {
        self.memory = Some(memory);
        self
    }

    pub fn tools(mut self, tools: ToolRegistry) -> Self {
        self.tools = tools;
        self
    }

    pub fn executor(mut self, executor: Arc<dyn AgentExecutor>) -> Self {
        self.executor = Some(executor);
        self
    }

    pub fn build(self) -> Agent {
        let memory = self.memory.unwrap_or_else(|| Box::new(MemoryStore::new()));
        
        Agent {
            config: self.config,
            memory: Arc::new(RwLock::new(memory)),
            tools: self.tools,
            executor: self.executor,
            provider: self.provider,
        }
    }
}