hehe-agent 0.0.1

Agent runtime for hehe AI Agent framework
Documentation
use crate::config::AgentConfig;
use crate::error::{AgentError, Result};
use crate::event::AgentEvent;
use crate::response::{AgentResponse, ToolCallRecord};
use crate::session::Session;
use hehe_core::message::{ContentBlock, ToolResult, ToolUse};
use hehe_core::{Context, Message};
use hehe_llm::{CompletionRequest, LlmProvider};
use hehe_tools::ToolExecutor;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tracing::{debug, info, warn};

pub struct Executor {
    config: AgentConfig,
    llm: Arc<dyn LlmProvider>,
    tools: Option<Arc<ToolExecutor>>,
}

impl Executor {
    pub fn new(
        config: AgentConfig,
        llm: Arc<dyn LlmProvider>,
        tools: Option<Arc<ToolExecutor>>,
    ) -> Self {
        Self { config, llm, tools }
    }

    pub async fn execute(&self, session: &Session, user_input: &str) -> Result<AgentResponse> {
        let user_message = Message::user(user_input);
        session.add_message(user_message);

        let mut all_tool_calls = Vec::new();
        let mut iterations = 0;

        loop {
            iterations += 1;
            session.increment_iterations();

            if iterations > self.config.max_iterations {
                return Err(AgentError::MaxIterationsReached(self.config.max_iterations));
            }

            info!(iteration = iterations, "Starting agent loop iteration");

            let request = self.build_request(session);
            let response = self.llm.complete(request).await?;

            let tool_uses = response.message.tool_uses();

            if tool_uses.is_empty() {
                let text = response.text_content();
                session.add_message(Message::assistant(&text));

                return Ok(AgentResponse::new(session.id().clone(), text)
                    .with_tool_calls(all_tool_calls)
                    .with_iterations(iterations));
            }

            let mut assistant_content = Vec::new();
            if !response.text_content().is_empty() {
                assistant_content.push(ContentBlock::text(response.text_content()));
            }
            for tu in &tool_uses {
                assistant_content.push(ContentBlock::tool_use(ToolUse::new(
                    &tu.id,
                    &tu.name,
                    tu.input.clone(),
                )));
            }
            session.add_message(Message::new(hehe_core::Role::Assistant, assistant_content));

            let tool_results = self.execute_tools(&tool_uses).await;

            for (tu, (output, duration_ms, is_error)) in tool_uses.iter().zip(&tool_results) {
                all_tool_calls.push(ToolCallRecord {
                    id: tu.id.clone(),
                    name: tu.name.clone(),
                    input: tu.input.clone(),
                    output: output.clone(),
                    is_error: *is_error,
                    duration_ms: *duration_ms,
                });
            }

            session.increment_tool_calls(tool_results.len());

            let tool_result_content: Vec<ContentBlock> = tool_uses
                .iter()
                .zip(&tool_results)
                .map(|(tu, (output, _, is_error))| {
                    if *is_error {
                        ContentBlock::tool_result(ToolResult::error(&tu.id, output))
                    } else {
                        ContentBlock::tool_result(ToolResult::success(&tu.id, output))
                    }
                })
                .collect();

            session.add_message(Message::tool(tool_result_content));
        }
    }

    pub async fn execute_stream(
        &self,
        session: &Session,
        user_input: &str,
        tx: mpsc::Sender<AgentEvent>,
    ) -> Result<AgentResponse> {
        let _ = tx.send(AgentEvent::message_start(session.id().clone())).await;

        let result = self.execute(session, user_input).await;

        match &result {
            Ok(response) => {
                let _ = tx.send(AgentEvent::text_complete(response.text.clone())).await;
                let _ = tx.send(AgentEvent::message_end(session.id().clone())).await;
            }
            Err(e) => {
                let _ = tx.send(AgentEvent::error(e.to_string())).await;
            }
        }

        result
    }

    fn build_request(&self, session: &Session) -> CompletionRequest {
        let messages = session.last_messages(self.config.max_context_messages);

        let mut request = CompletionRequest::new(&self.config.model, messages)
            .with_system(&self.config.system_prompt)
            .with_temperature(self.config.temperature);

        if let Some(max_tokens) = self.config.max_tokens {
            request = request.with_max_tokens(max_tokens as u32);
        }

        if self.config.tools_enabled {
            if let Some(tools) = &self.tools {
                let definitions = tools.registry().definitions();
                if !definitions.is_empty() {
                    request = request.with_tools(definitions);
                }
            }
        }

        request
    }

    async fn execute_tools(
        &self,
        tool_uses: &[&ToolUse],
    ) -> Vec<(String, u64, bool)> {
        let Some(tools) = &self.tools else {
            return tool_uses
                .iter()
                .map(|tu| (format!("Tool execution not available: {}", tu.name), 0, true))
                .collect();
        };

        let ctx = Context::new().with_timeout(self.config.tool_timeout());
        let mut results = Vec::with_capacity(tool_uses.len());

        for tu in tool_uses {
            let start = Instant::now();
            debug!(tool = %tu.name, id = %tu.id, "Executing tool");

            let result = tools.execute(&ctx, &tu.name, tu.input.clone()).await;
            let duration_ms = start.elapsed().as_millis() as u64;

            match result {
                Ok(output) => {
                    info!(tool = %tu.name, duration_ms, is_error = output.is_error, "Tool completed");
                    results.push((output.content, duration_ms, output.is_error));
                }
                Err(e) => {
                    warn!(tool = %tu.name, error = %e, "Tool execution failed");
                    results.push((e.to_string(), duration_ms, true));
                }
            }
        }

        results
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use async_trait::async_trait;
    use hehe_core::capability::Capabilities;
    use hehe_core::stream::StreamChunk;
    use hehe_llm::{BoxStream, CompletionResponse, LlmError, ModelInfo};

    struct MockLlm {
        responses: std::sync::Mutex<Vec<CompletionResponse>>,
    }

    impl MockLlm {
        fn new(responses: Vec<CompletionResponse>) -> Self {
            Self {
                responses: std::sync::Mutex::new(responses),
            }
        }
    }

    #[async_trait]
    impl LlmProvider for MockLlm {
        fn name(&self) -> &str {
            "mock"
        }

        fn capabilities(&self) -> &Capabilities {
            static CAPS: std::sync::OnceLock<Capabilities> = std::sync::OnceLock::new();
            CAPS.get_or_init(Capabilities::text_basic)
        }

        async fn complete(&self, _request: CompletionRequest) -> std::result::Result<CompletionResponse, LlmError> {
            let mut responses = self.responses.lock().unwrap();
            if responses.is_empty() {
                Ok(CompletionResponse::new("id", "model", Message::assistant("Default response")))
            } else {
                Ok(responses.remove(0))
            }
        }

        async fn complete_stream(
            &self,
            _request: CompletionRequest,
        ) -> std::result::Result<BoxStream<StreamChunk>, LlmError> {
            use futures::stream;
            Ok(Box::pin(stream::empty()))
        }

        async fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, LlmError> {
            Ok(vec![])
        }

        fn default_model(&self) -> &str {
            "mock"
        }
    }

    #[tokio::test]
    async fn test_executor_simple_response() {
        let config = AgentConfig::new("mock", "You are helpful.");
        let llm = Arc::new(MockLlm::new(vec![CompletionResponse::new(
            "resp-1",
            "mock",
            Message::assistant("Hello!"),
        )]));

        let executor = Executor::new(config, llm, None);
        let session = Session::new();

        let response = executor.execute(&session, "Hi").await.unwrap();

        assert_eq!(response.text(), "Hello!");
        assert_eq!(response.iterations, 1);
        assert!(!response.has_tool_calls());
    }

    #[tokio::test]
    async fn test_executor_max_iterations() {
        let config = AgentConfig::new("mock", "You are helpful.").with_max_iterations(2);

        let tool_response = Message::new(
            hehe_core::Role::Assistant,
            vec![ContentBlock::tool_use(ToolUse::new(
                "call_1",
                "test_tool",
                serde_json::json!({}),
            ))],
        );

        let llm = Arc::new(MockLlm::new(vec![
            CompletionResponse::new("resp-1", "mock", tool_response.clone()),
            CompletionResponse::new("resp-2", "mock", tool_response.clone()),
            CompletionResponse::new("resp-3", "mock", tool_response),
        ]));

        let executor = Executor::new(config, llm, None);
        let session = Session::new();

        let result = executor.execute(&session, "Hi").await;

        assert!(matches!(result, Err(AgentError::MaxIterationsReached(2))));
    }
}