hehe_agent/
agent.rs

1use crate::config::AgentConfig;
2use crate::error::{AgentError, Result};
3use crate::event::AgentEvent;
4use crate::executor::Executor;
5use crate::response::AgentResponse;
6use crate::session::Session;
7use hehe_llm::LlmProvider;
8use hehe_tools::{ToolExecutor, ToolRegistry};
9use std::sync::Arc;
10use tokio::sync::mpsc;
11use tokio_stream::wrappers::ReceiverStream;
12use tokio_stream::Stream;
13
14pub struct Agent {
15    config: AgentConfig,
16    llm: Arc<dyn LlmProvider>,
17    tools: Option<Arc<ToolExecutor>>,
18}
19
20impl Agent {
21    pub fn builder() -> AgentBuilder {
22        AgentBuilder::new()
23    }
24
25    pub fn config(&self) -> &AgentConfig {
26        &self.config
27    }
28
29    pub fn llm(&self) -> &Arc<dyn LlmProvider> {
30        &self.llm
31    }
32
33    pub fn create_session(&self) -> Session {
34        Session::new()
35    }
36
37    pub async fn chat(&self, session: &Session, message: &str) -> Result<String> {
38        let response = self.process(session, message).await?;
39        Ok(response.text)
40    }
41
42    pub async fn process(&self, session: &Session, message: &str) -> Result<AgentResponse> {
43        let executor = Executor::new(self.config.clone(), self.llm.clone(), self.tools.clone());
44        executor.execute(session, message).await
45    }
46
47    pub fn chat_stream(
48        &self,
49        session: &Session,
50        message: &str,
51    ) -> impl Stream<Item = AgentEvent> + Send {
52        let (tx, rx) = mpsc::channel(100);
53        let executor = Executor::new(self.config.clone(), self.llm.clone(), self.tools.clone());
54        let session = session.clone();
55        let message = message.to_string();
56
57        tokio::spawn(async move {
58            let _ = executor.execute_stream(&session, &message, tx).await;
59        });
60
61        ReceiverStream::new(rx)
62    }
63}
64
65#[derive(Default)]
66pub struct AgentBuilder {
67    config: Option<AgentConfig>,
68    name: Option<String>,
69    system_prompt: Option<String>,
70    model: Option<String>,
71    temperature: Option<f32>,
72    max_tokens: Option<usize>,
73    max_iterations: Option<usize>,
74    tools_enabled: Option<bool>,
75    llm: Option<Arc<dyn LlmProvider>>,
76    tool_registry: Option<Arc<ToolRegistry>>,
77}
78
79impl AgentBuilder {
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    pub fn config(mut self, config: AgentConfig) -> Self {
85        self.config = Some(config);
86        self
87    }
88
89    pub fn name(mut self, name: impl Into<String>) -> Self {
90        self.name = Some(name.into());
91        self
92    }
93
94    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
95        self.system_prompt = Some(prompt.into());
96        self
97    }
98
99    pub fn model(mut self, model: impl Into<String>) -> Self {
100        self.model = Some(model.into());
101        self
102    }
103
104    pub fn temperature(mut self, temperature: f32) -> Self {
105        self.temperature = Some(temperature);
106        self
107    }
108
109    pub fn max_tokens(mut self, max_tokens: usize) -> Self {
110        self.max_tokens = Some(max_tokens);
111        self
112    }
113
114    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
115        self.max_iterations = Some(max_iterations);
116        self
117    }
118
119    pub fn tools_enabled(mut self, enabled: bool) -> Self {
120        self.tools_enabled = Some(enabled);
121        self
122    }
123
124    pub fn llm(mut self, llm: Arc<dyn LlmProvider>) -> Self {
125        self.llm = Some(llm);
126        self
127    }
128
129    pub fn tool_registry(mut self, registry: Arc<ToolRegistry>) -> Self {
130        self.tool_registry = Some(registry);
131        self
132    }
133
134    pub fn build(self) -> Result<Agent> {
135        let llm = self.llm.ok_or_else(|| AgentError::config("LLM provider is required"))?;
136
137        let mut config = self.config.unwrap_or_default();
138
139        if let Some(name) = self.name {
140            config.name = name;
141        }
142        if let Some(prompt) = self.system_prompt {
143            config.system_prompt = prompt;
144        }
145        if let Some(model) = self.model {
146            config.model = model;
147        }
148        if let Some(temp) = self.temperature {
149            config.temperature = temp;
150        }
151        if let Some(max) = self.max_tokens {
152            config.max_tokens = Some(max);
153        }
154        if let Some(max) = self.max_iterations {
155            config.max_iterations = max;
156        }
157        if let Some(enabled) = self.tools_enabled {
158            config.tools_enabled = enabled;
159        }
160
161        let tools = self.tool_registry.map(|registry| {
162            Arc::new(ToolExecutor::new(registry))
163        });
164
165        Ok(Agent { config, llm, tools })
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use async_trait::async_trait;
173    use hehe_core::capability::Capabilities;
174    use hehe_core::stream::StreamChunk;
175    use hehe_core::Message;
176    use hehe_llm::{BoxStream, CompletionRequest, CompletionResponse, LlmError, ModelInfo};
177
178    struct MockLlm;
179
180    #[async_trait]
181    impl LlmProvider for MockLlm {
182        fn name(&self) -> &str {
183            "mock"
184        }
185
186        fn capabilities(&self) -> &Capabilities {
187            static CAPS: std::sync::OnceLock<Capabilities> = std::sync::OnceLock::new();
188            CAPS.get_or_init(Capabilities::text_basic)
189        }
190
191        async fn complete(&self, _request: CompletionRequest) -> std::result::Result<CompletionResponse, LlmError> {
192            Ok(CompletionResponse::new("id", "mock", Message::assistant("Hello from mock!")))
193        }
194
195        async fn complete_stream(
196            &self,
197            _request: CompletionRequest,
198        ) -> std::result::Result<BoxStream<StreamChunk>, LlmError> {
199            use futures::stream;
200            Ok(Box::pin(stream::empty()))
201        }
202
203        async fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, LlmError> {
204            Ok(vec![])
205        }
206
207        fn default_model(&self) -> &str {
208            "mock"
209        }
210    }
211
212    #[test]
213    fn test_builder_missing_llm() {
214        let result = Agent::builder()
215            .system_prompt("You are helpful")
216            .build();
217
218        assert!(result.is_err());
219    }
220
221    #[tokio::test]
222    async fn test_agent_chat() {
223        let agent = Agent::builder()
224            .system_prompt("You are helpful.")
225            .model("mock")
226            .llm(Arc::new(MockLlm))
227            .build()
228            .unwrap();
229
230        let session = agent.create_session();
231        let response = agent.chat(&session, "Hi").await.unwrap();
232
233        assert_eq!(response, "Hello from mock!");
234        assert_eq!(session.message_count(), 2);
235    }
236
237    #[tokio::test]
238    async fn test_agent_process() {
239        let agent = Agent::builder()
240            .name("test-agent")
241            .system_prompt("You are helpful.")
242            .model("mock")
243            .temperature(0.5)
244            .max_iterations(5)
245            .llm(Arc::new(MockLlm))
246            .build()
247            .unwrap();
248
249        assert_eq!(agent.config().name, "test-agent");
250        assert_eq!(agent.config().temperature, 0.5);
251        assert_eq!(agent.config().max_iterations, 5);
252
253        let session = agent.create_session();
254        let response = agent.process(&session, "Hi").await.unwrap();
255
256        assert_eq!(response.text(), "Hello from mock!");
257        assert_eq!(response.iterations, 1);
258    }
259
260    #[tokio::test]
261    async fn test_session_persistence() {
262        let agent = Agent::builder()
263            .system_prompt("You are helpful.")
264            .llm(Arc::new(MockLlm))
265            .build()
266            .unwrap();
267
268        let session = agent.create_session();
269
270        agent.chat(&session, "First message").await.unwrap();
271        agent.chat(&session, "Second message").await.unwrap();
272
273        assert_eq!(session.message_count(), 4);
274    }
275}