hehe-agent 0.0.1

Agent runtime for hehe AI Agent framework
Documentation
use crate::config::AgentConfig;
use crate::error::{AgentError, Result};
use crate::event::AgentEvent;
use crate::executor::Executor;
use crate::response::AgentResponse;
use crate::session::Session;
use hehe_llm::LlmProvider;
use hehe_tools::{ToolExecutor, ToolRegistry};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::Stream;

pub struct Agent {
    config: AgentConfig,
    llm: Arc<dyn LlmProvider>,
    tools: Option<Arc<ToolExecutor>>,
}

impl Agent {
    pub fn builder() -> AgentBuilder {
        AgentBuilder::new()
    }

    pub fn config(&self) -> &AgentConfig {
        &self.config
    }

    pub fn llm(&self) -> &Arc<dyn LlmProvider> {
        &self.llm
    }

    pub fn create_session(&self) -> Session {
        Session::new()
    }

    pub async fn chat(&self, session: &Session, message: &str) -> Result<String> {
        let response = self.process(session, message).await?;
        Ok(response.text)
    }

    pub async fn process(&self, session: &Session, message: &str) -> Result<AgentResponse> {
        let executor = Executor::new(self.config.clone(), self.llm.clone(), self.tools.clone());
        executor.execute(session, message).await
    }

    pub fn chat_stream(
        &self,
        session: &Session,
        message: &str,
    ) -> impl Stream<Item = AgentEvent> + Send {
        let (tx, rx) = mpsc::channel(100);
        let executor = Executor::new(self.config.clone(), self.llm.clone(), self.tools.clone());
        let session = session.clone();
        let message = message.to_string();

        tokio::spawn(async move {
            let _ = executor.execute_stream(&session, &message, tx).await;
        });

        ReceiverStream::new(rx)
    }
}

#[derive(Default)]
pub struct AgentBuilder {
    config: Option<AgentConfig>,
    name: Option<String>,
    system_prompt: Option<String>,
    model: Option<String>,
    temperature: Option<f32>,
    max_tokens: Option<usize>,
    max_iterations: Option<usize>,
    tools_enabled: Option<bool>,
    llm: Option<Arc<dyn LlmProvider>>,
    tool_registry: Option<Arc<ToolRegistry>>,
}

impl AgentBuilder {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn config(mut self, config: AgentConfig) -> Self {
        self.config = Some(config);
        self
    }

    pub fn name(mut self, name: impl Into<String>) -> Self {
        self.name = Some(name.into());
        self
    }

    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.system_prompt = Some(prompt.into());
        self
    }

    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.model = Some(model.into());
        self
    }

    pub fn temperature(mut self, temperature: f32) -> Self {
        self.temperature = Some(temperature);
        self
    }

    pub fn max_tokens(mut self, max_tokens: usize) -> Self {
        self.max_tokens = Some(max_tokens);
        self
    }

    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
        self.max_iterations = Some(max_iterations);
        self
    }

    pub fn tools_enabled(mut self, enabled: bool) -> Self {
        self.tools_enabled = Some(enabled);
        self
    }

    pub fn llm(mut self, llm: Arc<dyn LlmProvider>) -> Self {
        self.llm = Some(llm);
        self
    }

    pub fn tool_registry(mut self, registry: Arc<ToolRegistry>) -> Self {
        self.tool_registry = Some(registry);
        self
    }

    pub fn build(self) -> Result<Agent> {
        let llm = self.llm.ok_or_else(|| AgentError::config("LLM provider is required"))?;

        let mut config = self.config.unwrap_or_default();

        if let Some(name) = self.name {
            config.name = name;
        }
        if let Some(prompt) = self.system_prompt {
            config.system_prompt = prompt;
        }
        if let Some(model) = self.model {
            config.model = model;
        }
        if let Some(temp) = self.temperature {
            config.temperature = temp;
        }
        if let Some(max) = self.max_tokens {
            config.max_tokens = Some(max);
        }
        if let Some(max) = self.max_iterations {
            config.max_iterations = max;
        }
        if let Some(enabled) = self.tools_enabled {
            config.tools_enabled = enabled;
        }

        let tools = self.tool_registry.map(|registry| {
            Arc::new(ToolExecutor::new(registry))
        });

        Ok(Agent { config, llm, tools })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use async_trait::async_trait;
    use hehe_core::capability::Capabilities;
    use hehe_core::stream::StreamChunk;
    use hehe_core::Message;
    use hehe_llm::{BoxStream, CompletionRequest, CompletionResponse, LlmError, ModelInfo};

    struct MockLlm;

    #[async_trait]
    impl LlmProvider for MockLlm {
        fn name(&self) -> &str {
            "mock"
        }

        fn capabilities(&self) -> &Capabilities {
            static CAPS: std::sync::OnceLock<Capabilities> = std::sync::OnceLock::new();
            CAPS.get_or_init(Capabilities::text_basic)
        }

        async fn complete(&self, _request: CompletionRequest) -> std::result::Result<CompletionResponse, LlmError> {
            Ok(CompletionResponse::new("id", "mock", Message::assistant("Hello from mock!")))
        }

        async fn complete_stream(
            &self,
            _request: CompletionRequest,
        ) -> std::result::Result<BoxStream<StreamChunk>, LlmError> {
            use futures::stream;
            Ok(Box::pin(stream::empty()))
        }

        async fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, LlmError> {
            Ok(vec![])
        }

        fn default_model(&self) -> &str {
            "mock"
        }
    }

    #[test]
    fn test_builder_missing_llm() {
        let result = Agent::builder()
            .system_prompt("You are helpful")
            .build();

        assert!(result.is_err());
    }

    #[tokio::test]
    async fn test_agent_chat() {
        let agent = Agent::builder()
            .system_prompt("You are helpful.")
            .model("mock")
            .llm(Arc::new(MockLlm))
            .build()
            .unwrap();

        let session = agent.create_session();
        let response = agent.chat(&session, "Hi").await.unwrap();

        assert_eq!(response, "Hello from mock!");
        assert_eq!(session.message_count(), 2);
    }

    #[tokio::test]
    async fn test_agent_process() {
        let agent = Agent::builder()
            .name("test-agent")
            .system_prompt("You are helpful.")
            .model("mock")
            .temperature(0.5)
            .max_iterations(5)
            .llm(Arc::new(MockLlm))
            .build()
            .unwrap();

        assert_eq!(agent.config().name, "test-agent");
        assert_eq!(agent.config().temperature, 0.5);
        assert_eq!(agent.config().max_iterations, 5);

        let session = agent.create_session();
        let response = agent.process(&session, "Hi").await.unwrap();

        assert_eq!(response.text(), "Hello from mock!");
        assert_eq!(response.iterations, 1);
    }

    #[tokio::test]
    async fn test_session_persistence() {
        let agent = Agent::builder()
            .system_prompt("You are helpful.")
            .llm(Arc::new(MockLlm))
            .build()
            .unwrap();

        let session = agent.create_session();

        agent.chat(&session, "First message").await.unwrap();
        agent.chat(&session, "Second message").await.unwrap();

        assert_eq!(session.message_count(), 4);
    }
}