1use crate::config::AgentConfig;
2use crate::error::{AgentError, Result};
3use crate::event::AgentEvent;
4use crate::executor::Executor;
5use crate::response::AgentResponse;
6use crate::session::Session;
7use hehe_llm::LlmProvider;
8use hehe_tools::{ToolExecutor, ToolRegistry};
9use std::sync::Arc;
10use tokio::sync::mpsc;
11use tokio_stream::wrappers::ReceiverStream;
12use tokio_stream::Stream;
13
14pub struct Agent {
15 config: AgentConfig,
16 llm: Arc<dyn LlmProvider>,
17 tools: Option<Arc<ToolExecutor>>,
18}
19
20impl Agent {
21 pub fn builder() -> AgentBuilder {
22 AgentBuilder::new()
23 }
24
25 pub fn config(&self) -> &AgentConfig {
26 &self.config
27 }
28
29 pub fn llm(&self) -> &Arc<dyn LlmProvider> {
30 &self.llm
31 }
32
33 pub fn create_session(&self) -> Session {
34 Session::new()
35 }
36
37 pub async fn chat(&self, session: &Session, message: &str) -> Result<String> {
38 let response = self.process(session, message).await?;
39 Ok(response.text)
40 }
41
42 pub async fn process(&self, session: &Session, message: &str) -> Result<AgentResponse> {
43 let executor = Executor::new(self.config.clone(), self.llm.clone(), self.tools.clone());
44 executor.execute(session, message).await
45 }
46
47 pub fn chat_stream(
48 &self,
49 session: &Session,
50 message: &str,
51 ) -> impl Stream<Item = AgentEvent> + Send {
52 let (tx, rx) = mpsc::channel(100);
53 let executor = Executor::new(self.config.clone(), self.llm.clone(), self.tools.clone());
54 let session = session.clone();
55 let message = message.to_string();
56
57 tokio::spawn(async move {
58 let _ = executor.execute_stream(&session, &message, tx).await;
59 });
60
61 ReceiverStream::new(rx)
62 }
63}
64
65#[derive(Default)]
66pub struct AgentBuilder {
67 config: Option<AgentConfig>,
68 name: Option<String>,
69 system_prompt: Option<String>,
70 model: Option<String>,
71 temperature: Option<f32>,
72 max_tokens: Option<usize>,
73 max_iterations: Option<usize>,
74 tools_enabled: Option<bool>,
75 llm: Option<Arc<dyn LlmProvider>>,
76 tool_registry: Option<Arc<ToolRegistry>>,
77}
78
79impl AgentBuilder {
80 pub fn new() -> Self {
81 Self::default()
82 }
83
84 pub fn config(mut self, config: AgentConfig) -> Self {
85 self.config = Some(config);
86 self
87 }
88
89 pub fn name(mut self, name: impl Into<String>) -> Self {
90 self.name = Some(name.into());
91 self
92 }
93
94 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
95 self.system_prompt = Some(prompt.into());
96 self
97 }
98
99 pub fn model(mut self, model: impl Into<String>) -> Self {
100 self.model = Some(model.into());
101 self
102 }
103
104 pub fn temperature(mut self, temperature: f32) -> Self {
105 self.temperature = Some(temperature);
106 self
107 }
108
109 pub fn max_tokens(mut self, max_tokens: usize) -> Self {
110 self.max_tokens = Some(max_tokens);
111 self
112 }
113
114 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
115 self.max_iterations = Some(max_iterations);
116 self
117 }
118
119 pub fn tools_enabled(mut self, enabled: bool) -> Self {
120 self.tools_enabled = Some(enabled);
121 self
122 }
123
124 pub fn llm(mut self, llm: Arc<dyn LlmProvider>) -> Self {
125 self.llm = Some(llm);
126 self
127 }
128
129 pub fn tool_registry(mut self, registry: Arc<ToolRegistry>) -> Self {
130 self.tool_registry = Some(registry);
131 self
132 }
133
134 pub fn build(self) -> Result<Agent> {
135 let llm = self.llm.ok_or_else(|| AgentError::config("LLM provider is required"))?;
136
137 let mut config = self.config.unwrap_or_default();
138
139 if let Some(name) = self.name {
140 config.name = name;
141 }
142 if let Some(prompt) = self.system_prompt {
143 config.system_prompt = prompt;
144 }
145 if let Some(model) = self.model {
146 config.model = model;
147 }
148 if let Some(temp) = self.temperature {
149 config.temperature = temp;
150 }
151 if let Some(max) = self.max_tokens {
152 config.max_tokens = Some(max);
153 }
154 if let Some(max) = self.max_iterations {
155 config.max_iterations = max;
156 }
157 if let Some(enabled) = self.tools_enabled {
158 config.tools_enabled = enabled;
159 }
160
161 let tools = self.tool_registry.map(|registry| {
162 Arc::new(ToolExecutor::new(registry))
163 });
164
165 Ok(Agent { config, llm, tools })
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use async_trait::async_trait;
173 use hehe_core::capability::Capabilities;
174 use hehe_core::stream::StreamChunk;
175 use hehe_core::Message;
176 use hehe_llm::{BoxStream, CompletionRequest, CompletionResponse, LlmError, ModelInfo};
177
178 struct MockLlm;
179
180 #[async_trait]
181 impl LlmProvider for MockLlm {
182 fn name(&self) -> &str {
183 "mock"
184 }
185
186 fn capabilities(&self) -> &Capabilities {
187 static CAPS: std::sync::OnceLock<Capabilities> = std::sync::OnceLock::new();
188 CAPS.get_or_init(Capabilities::text_basic)
189 }
190
191 async fn complete(&self, _request: CompletionRequest) -> std::result::Result<CompletionResponse, LlmError> {
192 Ok(CompletionResponse::new("id", "mock", Message::assistant("Hello from mock!")))
193 }
194
195 async fn complete_stream(
196 &self,
197 _request: CompletionRequest,
198 ) -> std::result::Result<BoxStream<StreamChunk>, LlmError> {
199 use futures::stream;
200 Ok(Box::pin(stream::empty()))
201 }
202
203 async fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, LlmError> {
204 Ok(vec![])
205 }
206
207 fn default_model(&self) -> &str {
208 "mock"
209 }
210 }
211
212 #[test]
213 fn test_builder_missing_llm() {
214 let result = Agent::builder()
215 .system_prompt("You are helpful")
216 .build();
217
218 assert!(result.is_err());
219 }
220
221 #[tokio::test]
222 async fn test_agent_chat() {
223 let agent = Agent::builder()
224 .system_prompt("You are helpful.")
225 .model("mock")
226 .llm(Arc::new(MockLlm))
227 .build()
228 .unwrap();
229
230 let session = agent.create_session();
231 let response = agent.chat(&session, "Hi").await.unwrap();
232
233 assert_eq!(response, "Hello from mock!");
234 assert_eq!(session.message_count(), 2);
235 }
236
237 #[tokio::test]
238 async fn test_agent_process() {
239 let agent = Agent::builder()
240 .name("test-agent")
241 .system_prompt("You are helpful.")
242 .model("mock")
243 .temperature(0.5)
244 .max_iterations(5)
245 .llm(Arc::new(MockLlm))
246 .build()
247 .unwrap();
248
249 assert_eq!(agent.config().name, "test-agent");
250 assert_eq!(agent.config().temperature, 0.5);
251 assert_eq!(agent.config().max_iterations, 5);
252
253 let session = agent.create_session();
254 let response = agent.process(&session, "Hi").await.unwrap();
255
256 assert_eq!(response.text(), "Hello from mock!");
257 assert_eq!(response.iterations, 1);
258 }
259
260 #[tokio::test]
261 async fn test_session_persistence() {
262 let agent = Agent::builder()
263 .system_prompt("You are helpful.")
264 .llm(Arc::new(MockLlm))
265 .build()
266 .unwrap();
267
268 let session = agent.create_session();
269
270 agent.chat(&session, "First message").await.unwrap();
271 agent.chat(&session, "Second message").await.unwrap();
272
273 assert_eq!(session.message_count(), 4);
274 }
275}