1use crate::event::{EventStream, EventType};
2use crate::llm::{ChatClient, ChatMessage, ChatRequest};
3use crate::tool::Tool;
4use crate::types::{AgentError, AgentInput, AgentOutput, AgentOutputMetadata, AgentResult};
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8
9#[cfg(test)]
10#[path = "agent_test.rs"]
11mod agent_test;
12
13#[derive(Clone, Serialize, Deserialize)]
15pub struct AgentConfig {
16 pub name: String,
17 pub system_prompt: String,
18
19 #[serde(skip)]
20 pub tools: Vec<Arc<dyn Tool>>,
21}
22
23impl std::fmt::Debug for AgentConfig {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("AgentConfig")
26 .field("name", &self.name)
27 .field("system_prompt", &self.system_prompt)
28 .field("tools", &format!("{} tools", self.tools.len()))
29 .finish()
30 }
31}
32
33impl AgentConfig {
34 pub fn builder(name: impl Into<String>) -> AgentConfigBuilder {
35 AgentConfigBuilder {
36 name: name.into(),
37 system_prompt: String::new(),
38 tools: Vec::new(),
39 }
40 }
41}
42
43pub struct AgentConfigBuilder {
45 name: String,
46 system_prompt: String,
47 tools: Vec<Arc<dyn Tool>>,
48}
49
50impl AgentConfigBuilder {
51 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
52 self.system_prompt = prompt.into();
53 self
54 }
55
56 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
57 self.tools.push(tool);
58 self
59 }
60
61 pub fn tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
62 self.tools = tools;
63 self
64 }
65
66 pub fn build(self) -> AgentConfig {
67 AgentConfig {
68 name: self.name,
69 system_prompt: self.system_prompt,
70 tools: self.tools,
71 }
72 }
73}
74
75pub struct Agent {
77 config: AgentConfig,
78 llm_client: Option<Arc<dyn ChatClient>>,
79}
80
81impl Agent {
82 pub fn new(config: AgentConfig) -> Self {
83 Self {
84 config,
85 llm_client: None,
86 }
87 }
88
89 pub fn with_llm_client(mut self, client: Arc<dyn ChatClient>) -> Self {
90 self.llm_client = Some(client);
91 self
92 }
93
94 pub fn name(&self) -> &str {
95 &self.config.name
96 }
97
98 pub fn config(&self) -> &AgentConfig {
99 &self.config
100 }
101
102 pub async fn execute(&self, input: AgentInput) -> AgentResult {
104 self.execute_with_events(input, None).await
105 }
106
107 pub async fn execute_with_events(
109 &self,
110 input: AgentInput,
111 event_stream: Option<&EventStream>,
112 ) -> AgentResult {
113 let start = std::time::Instant::now();
114
115 if let Some(stream) = event_stream {
117 stream.append(
118 EventType::AgentProcessing,
119 input
120 .metadata
121 .previous_agent
122 .clone()
123 .unwrap_or_else(|| "workflow".to_string()),
124 serde_json::json!({
125 "agent": self.config.name,
126 "input": input.data,
127 }),
128 );
129 }
130
131 if let Some(client) = &self.llm_client {
133 let user_message = if let Some(s) = input.data.as_str() {
135 s.to_string()
136 } else {
137 serde_json::to_string_pretty(&input.data).unwrap_or_default()
138 };
139
140 let mut messages = vec![ChatMessage::system(&self.config.system_prompt)];
142 messages.push(ChatMessage::user(&user_message));
143
144 let request = ChatRequest::new(messages.clone())
145 .with_temperature(0.7)
146 .with_max_tokens(500);
147
148 if let Some(stream) = event_stream {
150 stream.append(
151 EventType::AgentLlmRequestStarted,
152 input
153 .metadata
154 .previous_agent
155 .clone()
156 .unwrap_or_else(|| "workflow".to_string()),
157 serde_json::json!({
158 "agent": self.config.name,
159 "provider": client.provider(),
160 }),
161 );
162 }
163
164 match client.chat_stream(request).await {
166 Ok(mut text_stream) => {
167 let mut full_response = String::new();
168
169 while let Some(chunk_result) = text_stream.next().await {
171 match chunk_result {
172 Ok(chunk) => {
173 if !chunk.is_empty() {
174 full_response.push_str(&chunk);
175
176 if let Some(stream) = event_stream {
178 stream.append(
179 EventType::AgentLlmStreamChunk,
180 input
181 .metadata
182 .previous_agent
183 .clone()
184 .unwrap_or_else(|| "workflow".to_string()),
185 serde_json::json!({
186 "agent": self.config.name,
187 "chunk": chunk,
188 }),
189 );
190 }
191 }
192 }
193 Err(e) => {
194 if let Some(stream) = event_stream {
196 stream.append(
197 EventType::AgentLlmRequestFailed,
198 input
199 .metadata
200 .previous_agent
201 .clone()
202 .unwrap_or_else(|| "workflow".to_string()),
203 serde_json::json!({
204 "agent": self.config.name,
205 "error": e.to_string(),
206 }),
207 );
208 }
209 return Err(AgentError::ExecutionError(format!(
210 "LLM streaming failed: {}",
211 e
212 )));
213 }
214 }
215 }
216
217 if let Some(stream) = event_stream {
219 stream.append(
220 EventType::AgentLlmRequestCompleted,
221 input
222 .metadata
223 .previous_agent
224 .clone()
225 .unwrap_or_else(|| "workflow".to_string()),
226 serde_json::json!({
227 "agent": self.config.name,
228 }),
229 );
230 }
231
232 let output_data = serde_json::json!({
233 "response": full_response,
234 });
235
236 if let Some(stream) = event_stream {
238 stream.append(
239 EventType::AgentCompleted,
240 input
241 .metadata
242 .previous_agent
243 .clone()
244 .unwrap_or_else(|| "workflow".to_string()),
245 serde_json::json!({
246 "agent": self.config.name,
247 "execution_time_ms": start.elapsed().as_millis() as u64,
248 }),
249 );
250 }
251
252 Ok(AgentOutput {
253 data: output_data,
254 metadata: AgentOutputMetadata {
255 agent_name: self.config.name.clone(),
256 execution_time_ms: start.elapsed().as_millis() as u64,
257 tool_calls_count: 0,
258 },
259 })
260 }
261 Err(e) => {
262 if let Some(stream) = event_stream {
264 stream.append(
265 EventType::AgentLlmRequestFailed,
266 input
267 .metadata
268 .previous_agent
269 .clone()
270 .unwrap_or_else(|| "workflow".to_string()),
271 serde_json::json!({
272 "agent": self.config.name,
273 "error": e.to_string(),
274 }),
275 );
276 }
277
278 if let Some(stream) = event_stream {
280 stream.append(
281 EventType::AgentFailed,
282 input
283 .metadata
284 .previous_agent
285 .clone()
286 .unwrap_or_else(|| "workflow".to_string()),
287 serde_json::json!({
288 "agent": self.config.name,
289 "error": e.to_string(),
290 }),
291 );
292 }
293
294 Err(AgentError::ExecutionError(format!(
295 "LLM call failed: {}",
296 e
297 )))
298 }
299 }
300 } else {
301 let output_data = serde_json::json!({
303 "agent": self.config.name,
304 "processed": input.data,
305 "system_prompt": self.config.system_prompt,
306 "note": "Mock execution - no LLM client configured"
307 });
308
309 if let Some(stream) = event_stream {
310 stream.append(
311 EventType::AgentCompleted,
312 input
313 .metadata
314 .previous_agent
315 .clone()
316 .unwrap_or_else(|| "workflow".to_string()),
317 serde_json::json!({
318 "agent": self.config.name,
319 "execution_time_ms": start.elapsed().as_millis() as u64,
320 "mock": true,
321 }),
322 );
323 }
324
325 Ok(AgentOutput {
326 data: output_data,
327 metadata: AgentOutputMetadata {
328 agent_name: self.config.name.clone(),
329 execution_time_ms: start.elapsed().as_millis() as u64,
330 tool_calls_count: 0,
331 },
332 })
333 }
334 }
335}