1use crate::event::EventStream;
2use crate::llm::types::ToolCall;
3use crate::llm::{ChatClient, ChatMessage, ChatRequest};
4use crate::tool::ToolRegistry;
5use crate::tool_loop_detection::{ToolCallTracker, ToolLoopDetectionConfig};
6use crate::types::{AgentError, AgentInput, AgentOutput, AgentOutputMetadata, AgentResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[cfg(test)]
12#[path = "agent_test.rs"]
13mod agent_test;
14
15#[derive(Clone, Serialize, Deserialize)]
17pub struct AgentConfig {
18 pub name: String,
19 pub system_prompt: String,
20
21 #[serde(skip)]
22 pub tools: Option<Arc<ToolRegistry>>,
23
24 pub max_tool_iterations: usize,
25
26 #[serde(skip)]
28 pub tool_loop_detection: Option<ToolLoopDetectionConfig>,
29}
30
31impl std::fmt::Debug for AgentConfig {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("AgentConfig")
34 .field("name", &self.name)
35 .field("system_prompt", &self.system_prompt)
36 .field(
37 "tools",
38 &self.tools.as_ref().map(|t| format!("{} tools", t.len())),
39 )
40 .field("max_tool_iterations", &self.max_tool_iterations)
41 .field(
42 "tool_loop_detection",
43 &self.tool_loop_detection.as_ref().map(|c| c.enabled),
44 )
45 .finish()
46 }
47}
48
49impl AgentConfig {
50 pub fn builder(name: impl Into<String>) -> AgentConfigBuilder {
51 AgentConfigBuilder {
52 name: name.into(),
53 system_prompt: String::new(),
54 tools: None,
55 max_tool_iterations: 10,
56 tool_loop_detection: Some(ToolLoopDetectionConfig::default()),
57 }
58 }
59}
60
61pub struct AgentConfigBuilder {
63 name: String,
64 system_prompt: String,
65 tools: Option<Arc<ToolRegistry>>,
66 max_tool_iterations: usize,
67 tool_loop_detection: Option<ToolLoopDetectionConfig>,
68}
69
70impl AgentConfigBuilder {
71 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
72 self.system_prompt = prompt.into();
73 self
74 }
75
76 pub fn tools(mut self, tools: Arc<ToolRegistry>) -> Self {
77 self.tools = Some(tools);
78 self
79 }
80
81 pub fn max_tool_iterations(mut self, max: usize) -> Self {
82 self.max_tool_iterations = max;
83 self
84 }
85
86 pub fn tool_loop_detection(mut self, config: ToolLoopDetectionConfig) -> Self {
87 self.tool_loop_detection = Some(config);
88 self
89 }
90
91 pub fn disable_tool_loop_detection(mut self) -> Self {
92 self.tool_loop_detection = None;
93 self
94 }
95
96 pub fn build(self) -> AgentConfig {
97 AgentConfig {
98 name: self.name,
99 system_prompt: self.system_prompt,
100 tools: self.tools,
101 max_tool_iterations: self.max_tool_iterations,
102 tool_loop_detection: self.tool_loop_detection,
103 }
104 }
105}
106
107pub struct Agent {
109 config: AgentConfig,
110 llm_client: Option<Arc<dyn ChatClient>>,
111}
112
113impl Agent {
114 pub fn new(config: AgentConfig) -> Self {
115 Self {
116 config,
117 llm_client: None,
118 }
119 }
120
121 pub fn with_llm_client(mut self, client: Arc<dyn ChatClient>) -> Self {
122 self.llm_client = Some(client);
123 self
124 }
125
126 pub fn name(&self) -> &str {
127 &self.config.name
128 }
129
130 pub fn config(&self) -> &AgentConfig {
131 &self.config
132 }
133
134 pub async fn execute(&self, input: &AgentInput) -> AgentResult {
136 self.execute_with_events(input.clone(), None).await
137 }
138
139 pub async fn execute_with_events(
141 &self,
142 input: AgentInput,
143 event_stream: Option<&EventStream>,
144 ) -> AgentResult {
145 let start = std::time::Instant::now();
146
147 let workflow_id = input
148 .metadata
149 .previous_agent
150 .clone()
151 .unwrap_or_else(|| "workflow".to_string());
152
153 if let Some(stream) = event_stream {
155 stream.agent_started(
156 &self.config.name,
157 workflow_id.clone(),
158 serde_json::json!({
159 "input": input.data,
160 }),
161 );
162 }
163
164 if let Some(client) = &self.llm_client {
166 let messages = if let Some(history) = &input.chat_history {
168 history.clone()
171 } else {
172 let user_message = if let Some(s) = input.data.as_str() {
174 s.to_string()
175 } else {
176 serde_json::to_string_pretty(&input.data).unwrap_or_default()
177 };
178
179 vec![
180 ChatMessage::system(&self.config.system_prompt),
181 ChatMessage::user(&user_message),
182 ]
183 };
184
185 let mut request = ChatRequest::new(messages.clone())
186 .with_temperature(0.7)
187 .with_max_tokens(8192);
188
189 let tool_schemas = self
191 .config
192 .tools
193 .as_ref()
194 .map(|registry| registry.list_tools())
195 .filter(|tools| !tools.is_empty());
196
197 let mut iteration = 0;
199 let mut total_tool_calls = 0;
200
201 let mut tool_tracker = if self.config.tool_loop_detection.is_some() {
203 Some(ToolCallTracker::new())
204 } else {
205 None
206 };
207
208 loop {
209 iteration += 1;
210
211 if iteration > self.config.max_tool_iterations {
213 return Err(AgentError::ExecutionError(format!(
214 "Maximum tool iterations ({}) exceeded",
215 self.config.max_tool_iterations
216 )));
217 }
218
219 if let Some(ref schemas) = tool_schemas {
221 request = request.with_tools(schemas.clone());
222 }
223
224 if let Some(stream) = event_stream {
226 stream.llm_started(
227 &self.config.name,
228 iteration,
229 workflow_id.clone(),
230 serde_json::json!({
231 "messages": request.messages.len(),
232 }),
233 );
234 }
235
236 let event_stream_for_streaming = event_stream.cloned();
238 let agent_name = self.config.name.clone();
239 let workflow_id_for_streaming = workflow_id.clone();
240
241 let (chunk_tx, mut chunk_rx) = tokio::sync::mpsc::channel(100);
243
244 let _chunk_event_task = tokio::spawn(async move {
246 while let Some(chunk) = chunk_rx.recv().await {
247 if let Some(stream) = &event_stream_for_streaming {
248 stream.llm_progress(
249 &agent_name,
250 iteration,
251 workflow_id_for_streaming.clone(),
252 chunk,
253 );
254 }
255 }
256 });
257
258 match client.chat_stream(request.clone(), chunk_tx).await {
259 Ok(response) => {
260 if let Some(stream) = event_stream {
262 stream.llm_completed(
263 &self.config.name,
264 iteration,
265 workflow_id.clone(),
266 serde_json::json!({
267 "content": response.content.chars().take(100).collect::<String>(),
268 "has_tool_calls": response.tool_calls.is_some(),
269 }),
270 );
271 }
272
273 if let Some(tool_calls) = response.tool_calls.clone() {
275 if tool_calls.is_empty() {
276 } else {
278 total_tool_calls += tool_calls.len();
279
280 let assistant_msg = ChatMessage::assistant_with_tool_calls(
282 response.content.clone(),
283 tool_calls.clone(),
284 );
285 request.messages.push(assistant_msg);
286
287 for tool_call in tool_calls {
289 if let (Some(tracker), Some(loop_config)) =
291 (&tool_tracker, &self.config.tool_loop_detection)
292 {
293 if loop_config.enabled {
294 let args_value: serde_json::Value =
296 serde_json::from_str(&tool_call.function.arguments)
297 .unwrap_or(serde_json::json!({}));
298
299 let args_map: HashMap<String, serde_json::Value> =
301 args_value
302 .as_object()
303 .map(|obj| {
304 obj.iter()
305 .map(|(k, v)| (k.clone(), v.clone()))
306 .collect()
307 })
308 .unwrap_or_default();
309
310 if let Some(previous_result) = tracker
311 .check_for_loop(&tool_call.function.name, &args_map)
312 {
313 let loop_message = loop_config.get_message(
315 &tool_call.function.name,
316 &previous_result,
317 );
318
319 if let Some(stream) = event_stream {
321 stream.append(
322 crate::event::EventScope::System,
323 crate::event::EventType::Progress,
324 "system:tool_loop_detection".to_string(),
325 crate::event::ComponentStatus::Running,
326 workflow_id.clone(),
327 Some(format!(
328 "Tool loop detected: {}",
329 tool_call.function.name
330 )),
331 serde_json::json!({
332 "agent": self.config.name,
333 "tool": tool_call.function.name,
334 "message": loop_message,
335 }),
336 );
337 }
338
339 let tool_msg = ChatMessage::tool_result(
341 &tool_call.id,
342 &loop_message,
343 );
344 request.messages.push(tool_msg);
345
346 continue;
348 }
349 }
350 }
351
352 let tool_result = self
354 .execute_tool_call(
355 &tool_call,
356 &input
357 .metadata
358 .previous_agent
359 .clone()
360 .unwrap_or_else(|| "workflow".to_string()),
361 event_stream,
362 )
363 .await;
364
365 if let Some(tracker) = &mut tool_tracker {
367 let args_value: serde_json::Value =
369 serde_json::from_str(&tool_call.function.arguments)
370 .unwrap_or(serde_json::json!({}));
371
372 let args_map: HashMap<String, serde_json::Value> =
374 args_value
375 .as_object()
376 .map(|obj| {
377 obj.iter()
378 .map(|(k, v)| (k.clone(), v.clone()))
379 .collect()
380 })
381 .unwrap_or_default();
382
383 let result_json = serde_json::to_value(&tool_result)
384 .unwrap_or(serde_json::json!({}));
385 tracker.record_call(
386 &tool_call.function.name,
387 &args_map,
388 &result_json,
389 );
390 }
391
392 let tool_msg =
394 ChatMessage::tool_result(&tool_call.id, &tool_result);
395 request.messages.push(tool_msg);
396 }
397
398 continue;
400 }
401 }
402
403 let response_text = response.content.trim();
405 let token_count = response
406 .usage
407 .map(|u| u.total_tokens)
408 .unwrap_or_else(|| (response_text.len() as f32 / 4.0).ceil() as u32);
409
410 let output_data = serde_json::json!({
411 "response": response_text,
412 "content_type": "text/plain",
413 "token_count": token_count,
414 });
415
416 request.messages.push(ChatMessage::assistant(response_text));
418
419 if let Some(stream) = event_stream {
421 stream.agent_completed(
422 &self.config.name,
423 workflow_id.clone(),
424 Some(format!(
425 "Agent completed in {}ms",
426 start.elapsed().as_millis()
427 )),
428 serde_json::json!({
429 "execution_time_ms": start.elapsed().as_millis() as u64,
430 "tool_calls": total_tool_calls,
431 "iterations": iteration,
432 }),
433 );
434 }
435
436 return Ok(AgentOutput {
437 data: output_data,
438 metadata: AgentOutputMetadata {
439 agent_name: self.config.name.clone(),
440 execution_time_ms: start.elapsed().as_millis() as u64,
441 tool_calls_count: total_tool_calls,
442 },
443 chat_history: Some(request.messages),
444 });
445 }
446 Err(e) => {
447 if let Some(stream) = event_stream {
449 stream.llm_failed(
450 &self.config.name,
451 iteration,
452 workflow_id.clone(),
453 &e.to_string(),
454 );
455 }
456
457 if let Some(stream) = event_stream {
459 stream.agent_failed(
460 &self.config.name,
461 workflow_id.clone(),
462 &e.to_string(),
463 serde_json::json!({}),
464 );
465 }
466
467 return Err(AgentError::ExecutionError(format!(
468 "LLM call failed: {}",
469 e
470 )));
471 }
472 }
473 }
474 } else {
475 let output_data = serde_json::json!({
477 "agent": self.config.name,
478 "processed": input.data,
479 "system_prompt": self.config.system_prompt,
480 "note": "Mock execution - no LLM client configured"
481 });
482
483 if let Some(stream) = event_stream {
484 stream.agent_completed(
485 &self.config.name,
486 workflow_id.clone(),
487 Some("Agent completed (no LLM)".to_string()),
488 serde_json::json!({
489 "execution_time_ms": start.elapsed().as_millis() as u64,
490 "mock": true,
491 }),
492 );
493 }
494
495 Ok(AgentOutput {
496 data: output_data,
497 metadata: AgentOutputMetadata {
498 agent_name: self.config.name.clone(),
499 execution_time_ms: start.elapsed().as_millis() as u64,
500 tool_calls_count: 0,
501 },
502 chat_history: None, })
504 }
505 }
506
507 async fn execute_tool_call(
509 &self,
510 tool_call: &ToolCall,
511 previous_agent: &str,
512 event_stream: Option<&EventStream>,
513 ) -> String {
514 let tool_name = &tool_call.function.name;
515
516 if let Some(stream) = event_stream {
518 stream.tool_started(
519 tool_name,
520 previous_agent.to_string(),
521 serde_json::json!({
522 "agent": self.config.name,
523 "tool_call_id": tool_call.id,
524 "arguments": tool_call.function.arguments,
525 }),
526 );
527 }
528
529 let registry = match &self.config.tools {
531 Some(reg) => reg,
532 None => {
533 let error_msg = "No tool registry configured".to_string();
534 if let Some(stream) = event_stream {
535 stream.tool_failed(
536 tool_name,
537 previous_agent.to_string(),
538 &error_msg,
539 serde_json::json!({
540 "agent": self.config.name,
541 "tool_call_id": tool_call.id,
542 "duration_ms": 0,
543 }),
544 );
545 }
546 return format!("Error: {}", error_msg);
547 }
548 };
549
550 let params: HashMap<String, serde_json::Value> =
552 match serde_json::from_str(&tool_call.function.arguments) {
553 Ok(p) => p,
554 Err(e) => {
555 let error_msg = format!("Failed to parse tool arguments: {}", e);
556 if let Some(stream) = event_stream {
557 stream.tool_failed(
558 tool_name,
559 previous_agent.to_string(),
560 &error_msg,
561 serde_json::json!({
562 "agent": self.config.name,
563 "tool_call_id": tool_call.id,
564 "duration_ms": 0,
565 "duration_ms": 0,
566 }),
567 );
568 }
569 return format!("Error: {}", error_msg);
570 }
571 };
572
573 let start_time = std::time::Instant::now();
575 match registry.call_tool(tool_name, params.clone()).await {
576 Ok(result) => {
577 if let Some(stream) = event_stream {
579 stream.tool_completed(
580 tool_name,
581 previous_agent.to_string(),
582 serde_json::json!({
583 "agent": self.config.name,
584 "tool_call_id": tool_call.id,
585 "result": result.output,
586 "duration_ms": (result.duration_ms * 1000.0).round() / 1000.0,
587 }),
588 );
589 }
590
591 serde_json::to_string(&result.output).unwrap_or_else(|_| result.output.to_string())
593 }
594 Err(e) => {
595 let error_msg = format!("Tool execution failed: {}", e);
596 if let Some(stream) = event_stream {
597 stream.tool_failed(
598 tool_name,
599 previous_agent.to_string(),
600 &error_msg,
601 serde_json::json!({
602 "agent": self.config.name,
603 "tool_call_id": tool_call.id,
604 "duration_ms": start_time.elapsed().as_secs_f64() * 1000.0,
605 }),
606 );
607 }
608 format!("Error: {}", error_msg)
609 }
610 }
611 }
612}