agent_runtime/agent.rs
1use crate::event::{EventStream, EventType};
2use crate::llm::types::ToolCall;
3use crate::llm::{ChatClient, ChatMessage, ChatRequest};
4use crate::tool::ToolRegistry;
5use crate::tool_loop_detection::{ToolCallTracker, ToolLoopDetectionConfig};
6use crate::types::{AgentError, AgentInput, AgentOutput, AgentOutputMetadata, AgentResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[cfg(test)]
12#[path = "agent_test.rs"]
13mod agent_test;
14
15/// Agent configuration
16#[derive(Clone, Serialize, Deserialize)]
17pub struct AgentConfig {
18 pub name: String,
19 pub system_prompt: String,
20
21 #[serde(skip)]
22 pub tools: Option<Arc<ToolRegistry>>,
23
24 pub max_tool_iterations: usize,
25
26 /// Tool loop detection configuration
27 #[serde(skip)]
28 pub tool_loop_detection: Option<ToolLoopDetectionConfig>,
29}
30
31impl std::fmt::Debug for AgentConfig {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("AgentConfig")
34 .field("name", &self.name)
35 .field("system_prompt", &self.system_prompt)
36 .field(
37 "tools",
38 &self.tools.as_ref().map(|t| format!("{} tools", t.len())),
39 )
40 .field("max_tool_iterations", &self.max_tool_iterations)
41 .field(
42 "tool_loop_detection",
43 &self.tool_loop_detection.as_ref().map(|c| c.enabled),
44 )
45 .finish()
46 }
47}
48
49impl AgentConfig {
50 pub fn builder(name: impl Into<String>) -> AgentConfigBuilder {
51 AgentConfigBuilder {
52 name: name.into(),
53 system_prompt: String::new(),
54 tools: None,
55 max_tool_iterations: 10,
56 tool_loop_detection: Some(ToolLoopDetectionConfig::default()),
57 }
58 }
59}
60
61/// Builder for AgentConfig
62pub struct AgentConfigBuilder {
63 name: String,
64 system_prompt: String,
65 tools: Option<Arc<ToolRegistry>>,
66 max_tool_iterations: usize,
67 tool_loop_detection: Option<ToolLoopDetectionConfig>,
68}
69
70impl AgentConfigBuilder {
71 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
72 self.system_prompt = prompt.into();
73 self
74 }
75
76 pub fn tools(mut self, tools: Arc<ToolRegistry>) -> Self {
77 self.tools = Some(tools);
78 self
79 }
80
81 pub fn max_tool_iterations(mut self, max: usize) -> Self {
82 self.max_tool_iterations = max;
83 self
84 }
85
86 pub fn tool_loop_detection(mut self, config: ToolLoopDetectionConfig) -> Self {
87 self.tool_loop_detection = Some(config);
88 self
89 }
90
91 pub fn disable_tool_loop_detection(mut self) -> Self {
92 self.tool_loop_detection = None;
93 self
94 }
95
96 pub fn build(self) -> AgentConfig {
97 AgentConfig {
98 name: self.name,
99 system_prompt: self.system_prompt,
100 tools: self.tools,
101 max_tool_iterations: self.max_tool_iterations,
102 tool_loop_detection: self.tool_loop_detection,
103 }
104 }
105}
106
107/// Agent execution unit
108pub struct Agent {
109 config: AgentConfig,
110 llm_client: Option<Arc<dyn ChatClient>>,
111}
112
113impl Agent {
114 pub fn new(config: AgentConfig) -> Self {
115 Self {
116 config,
117 llm_client: None,
118 }
119 }
120
121 pub fn with_llm_client(mut self, client: Arc<dyn ChatClient>) -> Self {
122 self.llm_client = Some(client);
123 self
124 }
125
126 pub fn name(&self) -> &str {
127 &self.config.name
128 }
129
130 pub fn config(&self) -> &AgentConfig {
131 &self.config
132 }
133
134 /// Execute the agent with the given input
135 pub async fn execute(&self, input: &AgentInput) -> AgentResult {
136 self.execute_with_events(input.clone(), None).await
137 }
138
139 /// Execute the agent with event stream for observability
140 pub async fn execute_with_events(
141 &self,
142 input: AgentInput,
143 event_stream: Option<&EventStream>,
144 ) -> AgentResult {
145 let start = std::time::Instant::now();
146
147 // Emit agent processing event
148 if let Some(stream) = event_stream {
149 stream.append(
150 EventType::AgentProcessing,
151 input
152 .metadata
153 .previous_agent
154 .clone()
155 .unwrap_or_else(|| "workflow".to_string()),
156 serde_json::json!({
157 "agent": self.config.name,
158 "input": input.data,
159 }),
160 );
161 }
162
163 // If we have an LLM client, use it
164 if let Some(client) = &self.llm_client {
165 // Convert input to user message
166 let user_message = if let Some(s) = input.data.as_str() {
167 s.to_string()
168 } else {
169 serde_json::to_string_pretty(&input.data).unwrap_or_default()
170 };
171
172 // Build messages with system prompt
173 let mut messages = vec![ChatMessage::system(&self.config.system_prompt)];
174 messages.push(ChatMessage::user(&user_message));
175
176 let mut request = ChatRequest::new(messages.clone())
177 .with_temperature(0.7)
178 .with_max_tokens(8192);
179
180 // Get tool schemas if available
181 let tool_schemas = self
182 .config
183 .tools
184 .as_ref()
185 .map(|registry| registry.list_tools())
186 .filter(|tools| !tools.is_empty());
187
188 // Tool calling loop
189 let mut iteration = 0;
190 let mut total_tool_calls = 0;
191
192 // Initialize tool call tracker for loop detection
193 let mut tool_tracker = if self.config.tool_loop_detection.is_some() {
194 Some(ToolCallTracker::new())
195 } else {
196 None
197 };
198
199 loop {
200 iteration += 1;
201
202 // Check iteration limit
203 if iteration > self.config.max_tool_iterations {
204 return Err(AgentError::ExecutionError(format!(
205 "Maximum tool iterations ({}) exceeded",
206 self.config.max_tool_iterations
207 )));
208 }
209
210 // Add tools to request if available
211 if let Some(ref schemas) = tool_schemas {
212 request = request.with_tools(schemas.clone());
213 }
214
215 // Emit LLM request started event
216 if let Some(stream) = event_stream {
217 stream.append(
218 EventType::AgentLlmRequestStarted,
219 input
220 .metadata
221 .previous_agent
222 .clone()
223 .unwrap_or_else(|| "workflow".to_string()),
224 serde_json::json!({
225 "agent": self.config.name,
226 "iteration": iteration,
227 }),
228 );
229 }
230
231 // Call LLM with streaming + full response (for tool calls)
232 let event_stream_for_streaming = event_stream.cloned();
233 let agent_name = self.config.name.clone();
234 let previous_agent = input
235 .metadata
236 .previous_agent
237 .clone()
238 .unwrap_or_else(|| "workflow".to_string());
239
240 // Create channel for streaming chunks
241 let (chunk_tx, mut chunk_rx) = tokio::sync::mpsc::channel(100);
242
243 // Spawn task to receive chunks and emit events
244 let _chunk_event_task = tokio::spawn(async move {
245 while let Some(chunk) = chunk_rx.recv().await {
246 if let Some(stream) = &event_stream_for_streaming {
247 stream.append(
248 EventType::AgentLlmStreamChunk,
249 previous_agent.clone(),
250 serde_json::json!({
251 "agent": &agent_name,
252 "chunk": chunk,
253 }),
254 );
255 }
256 }
257 });
258
259 match client.chat_stream(request.clone(), chunk_tx).await {
260 Ok(response) => {
261 // Emit LLM request completed event
262 if let Some(stream) = event_stream {
263 stream.append(
264 EventType::AgentLlmRequestCompleted,
265 input
266 .metadata
267 .previous_agent
268 .clone()
269 .unwrap_or_else(|| "workflow".to_string()),
270 serde_json::json!({
271 "agent": self.config.name,
272 }),
273 );
274 }
275
276 // Check if we have tool calls (and they're not empty)
277 if let Some(tool_calls) = response.tool_calls.clone() {
278 if tool_calls.is_empty() {
279 // Empty tool calls array - treat as final response
280 } else {
281 total_tool_calls += tool_calls.len();
282
283 // Add assistant message with tool calls to conversation
284 let assistant_msg = ChatMessage::assistant_with_tool_calls(
285 response.content.clone(),
286 tool_calls.clone(),
287 );
288 request.messages.push(assistant_msg);
289
290 // Execute each tool call
291 for tool_call in tool_calls {
292 // Check for duplicate tool call (loop detection)
293 if let (Some(tracker), Some(loop_config)) =
294 (&tool_tracker, &self.config.tool_loop_detection)
295 {
296 if loop_config.enabled {
297 // Parse tool arguments from JSON string
298 let args_value: serde_json::Value =
299 serde_json::from_str(&tool_call.function.arguments)
300 .unwrap_or(serde_json::json!({}));
301
302 // Convert to HashMap for comparison
303 let args_map: HashMap<String, serde_json::Value> =
304 args_value
305 .as_object()
306 .map(|obj| {
307 obj.iter()
308 .map(|(k, v)| (k.clone(), v.clone()))
309 .collect()
310 })
311 .unwrap_or_default();
312
313 if let Some(previous_result) = tracker
314 .check_for_loop(&tool_call.function.name, &args_map)
315 {
316 // Loop detected! Inject message instead of calling tool
317 let loop_message = loop_config.get_message(
318 &tool_call.function.name,
319 &previous_result,
320 );
321
322 // Emit loop detected event
323 if let Some(stream) = event_stream {
324 stream.append(
325 EventType::AgentToolLoopDetected,
326 input
327 .metadata
328 .previous_agent
329 .clone()
330 .unwrap_or_else(|| {
331 "workflow".to_string()
332 }),
333 serde_json::json!({
334 "agent": self.config.name,
335 "tool": tool_call.function.name,
336 "message": loop_message,
337 }),
338 );
339 }
340
341 // Add system message explaining the loop
342 let tool_msg = ChatMessage::tool_result(
343 &tool_call.id,
344 &loop_message,
345 );
346 request.messages.push(tool_msg);
347
348 // Skip actual tool execution
349 continue;
350 }
351 }
352 }
353
354 // No loop detected - execute the tool normally
355 let tool_result = self
356 .execute_tool_call(
357 &tool_call,
358 &input
359 .metadata
360 .previous_agent
361 .clone()
362 .unwrap_or_else(|| "workflow".to_string()),
363 event_stream,
364 )
365 .await;
366
367 // Record this call in the tracker
368 if let Some(tracker) = &mut tool_tracker {
369 // Parse tool arguments from JSON string
370 let args_value: serde_json::Value =
371 serde_json::from_str(&tool_call.function.arguments)
372 .unwrap_or(serde_json::json!({}));
373
374 // Convert to HashMap
375 let args_map: HashMap<String, serde_json::Value> =
376 args_value
377 .as_object()
378 .map(|obj| {
379 obj.iter()
380 .map(|(k, v)| (k.clone(), v.clone()))
381 .collect()
382 })
383 .unwrap_or_default();
384
385 let result_json = serde_json::to_value(&tool_result)
386 .unwrap_or(serde_json::json!({}));
387 tracker.record_call(
388 &tool_call.function.name,
389 &args_map,
390 &result_json,
391 );
392 }
393
394 // Add tool result to conversation
395 let tool_msg =
396 ChatMessage::tool_result(&tool_call.id, &tool_result);
397 request.messages.push(tool_msg);
398 }
399
400 // Continue loop to get next response
401 continue;
402 }
403 }
404
405 // No tool calls (or empty array), we have the final response
406 let response_text = response.content.trim();
407 let token_count = response
408 .usage
409 .map(|u| u.total_tokens)
410 .unwrap_or_else(|| (response_text.len() as f32 / 4.0).ceil() as u32);
411
412 let output_data = serde_json::json!({
413 "response": response_text,
414 "content_type": "text/plain",
415 "token_count": token_count,
416 });
417
418 // Emit agent completed event
419 if let Some(stream) = event_stream {
420 stream.append(
421 EventType::AgentCompleted,
422 input
423 .metadata
424 .previous_agent
425 .clone()
426 .unwrap_or_else(|| "workflow".to_string()),
427 serde_json::json!({
428 "agent": self.config.name,
429 "execution_time_ms": start.elapsed().as_millis() as u64,
430 }),
431 );
432 }
433
434 return Ok(AgentOutput {
435 data: output_data,
436 metadata: AgentOutputMetadata {
437 agent_name: self.config.name.clone(),
438 execution_time_ms: start.elapsed().as_millis() as u64,
439 tool_calls_count: total_tool_calls,
440 },
441 });
442 }
443 Err(e) => {
444 // Emit LLM request failed event
445 if let Some(stream) = event_stream {
446 stream.append(
447 EventType::AgentLlmRequestFailed,
448 input
449 .metadata
450 .previous_agent
451 .clone()
452 .unwrap_or_else(|| "workflow".to_string()),
453 serde_json::json!({
454 "agent": self.config.name,
455 "error": e.to_string(),
456 }),
457 );
458 }
459
460 // Emit agent failed event
461 if let Some(stream) = event_stream {
462 stream.append(
463 EventType::AgentFailed,
464 input
465 .metadata
466 .previous_agent
467 .clone()
468 .unwrap_or_else(|| "workflow".to_string()),
469 serde_json::json!({
470 "agent": self.config.name,
471 "error": e.to_string(),
472 }),
473 );
474 }
475
476 return Err(AgentError::ExecutionError(format!(
477 "LLM call failed: {}",
478 e
479 )));
480 }
481 }
482 }
483 } else {
484 // Mock execution fallback
485 let output_data = serde_json::json!({
486 "agent": self.config.name,
487 "processed": input.data,
488 "system_prompt": self.config.system_prompt,
489 "note": "Mock execution - no LLM client configured"
490 });
491
492 if let Some(stream) = event_stream {
493 stream.append(
494 EventType::AgentCompleted,
495 input
496 .metadata
497 .previous_agent
498 .clone()
499 .unwrap_or_else(|| "workflow".to_string()),
500 serde_json::json!({
501 "agent": self.config.name,
502 "execution_time_ms": start.elapsed().as_millis() as u64,
503 "mock": true,
504 }),
505 );
506 }
507
508 Ok(AgentOutput {
509 data: output_data,
510 metadata: AgentOutputMetadata {
511 agent_name: self.config.name.clone(),
512 execution_time_ms: start.elapsed().as_millis() as u64,
513 tool_calls_count: 0,
514 },
515 })
516 }
517 }
518
519 /// Execute a single tool call
520 async fn execute_tool_call(
521 &self,
522 tool_call: &ToolCall,
523 previous_agent: &str,
524 event_stream: Option<&EventStream>,
525 ) -> String {
526 let tool_name = &tool_call.function.name;
527
528 // Emit tool call started event
529 if let Some(stream) = event_stream {
530 stream.append(
531 EventType::ToolCallStarted,
532 previous_agent.to_string(),
533 serde_json::json!({
534 "agent": self.config.name,
535 "tool": tool_name,
536 "tool_call_id": tool_call.id,
537 "arguments": tool_call.function.arguments,
538 }),
539 );
540 }
541
542 // Get the tool registry
543 let registry = match &self.config.tools {
544 Some(reg) => reg,
545 None => {
546 let error_msg = "No tool registry configured".to_string();
547 if let Some(stream) = event_stream {
548 stream.append(
549 EventType::ToolCallFailed,
550 previous_agent.to_string(),
551 serde_json::json!({
552 "agent": self.config.name,
553 "tool": tool_name,
554 "tool_call_id": tool_call.id,
555 "arguments": tool_call.function.arguments,
556 "error": error_msg,
557 "duration_ms": 0,
558 }),
559 );
560 }
561 return format!("Error: {}", error_msg);
562 }
563 };
564
565 // Parse arguments from JSON string
566 let params: HashMap<String, serde_json::Value> =
567 match serde_json::from_str(&tool_call.function.arguments) {
568 Ok(p) => p,
569 Err(e) => {
570 let error_msg = format!("Failed to parse tool arguments: {}", e);
571 if let Some(stream) = event_stream {
572 stream.append(
573 EventType::ToolCallFailed,
574 previous_agent.to_string(),
575 serde_json::json!({
576 "agent": self.config.name,
577 "tool": tool_name,
578 "tool_call_id": tool_call.id,
579 "arguments": tool_call.function.arguments,
580 "error": error_msg,
581 "duration_ms": 0,
582 }),
583 );
584 }
585 return format!("Error: {}", error_msg);
586 }
587 };
588
589 // Execute the tool
590 let start_time = std::time::Instant::now();
591 match registry.call_tool(tool_name, params.clone()).await {
592 Ok(result) => {
593 // Emit tool call completed event
594 if let Some(stream) = event_stream {
595 stream.append(
596 EventType::ToolCallCompleted,
597 previous_agent.to_string(),
598 serde_json::json!({
599 "agent": self.config.name,
600 "tool": tool_name,
601 "tool_call_id": tool_call.id,
602 "arguments": params,
603 "result": result.output,
604 "duration_ms": (result.duration_ms * 1000.0).round() / 1000.0,
605 }),
606 );
607 }
608
609 // Convert result to string for LLM
610 serde_json::to_string(&result.output).unwrap_or_else(|_| result.output.to_string())
611 }
612 Err(e) => {
613 let error_msg = format!("Tool execution failed: {}", e);
614 if let Some(stream) = event_stream {
615 stream.append(
616 EventType::ToolCallFailed,
617 previous_agent.to_string(),
618 serde_json::json!({
619 "agent": self.config.name,
620 "tool": tool_name,
621 "tool_call_id": tool_call.id,
622 "arguments": params,
623 "error": error_msg,
624 "duration_ms": start_time.elapsed().as_secs_f64() * 1000.0,
625 }),
626 );
627 }
628 format!("Error: {}", error_msg)
629 }
630 }
631 }
632}