agent_runtime/agent.rs
1use crate::event::{EventStream, EventType};
2use crate::llm::types::ToolCall;
3use crate::llm::{ChatClient, ChatMessage, ChatRequest};
4use crate::tool::ToolRegistry;
5use crate::tool_loop_detection::{ToolCallTracker, ToolLoopDetectionConfig};
6use crate::types::{AgentError, AgentInput, AgentOutput, AgentOutputMetadata, AgentResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[cfg(test)]
12#[path = "agent_test.rs"]
13mod agent_test;
14
15/// Agent configuration
16#[derive(Clone, Serialize, Deserialize)]
17pub struct AgentConfig {
18 pub name: String,
19 pub system_prompt: String,
20
21 #[serde(skip)]
22 pub tools: Option<Arc<ToolRegistry>>,
23
24 pub max_tool_iterations: usize,
25
26 /// Tool loop detection configuration
27 #[serde(skip)]
28 pub tool_loop_detection: Option<ToolLoopDetectionConfig>,
29}
30
31impl std::fmt::Debug for AgentConfig {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("AgentConfig")
34 .field("name", &self.name)
35 .field("system_prompt", &self.system_prompt)
36 .field(
37 "tools",
38 &self.tools.as_ref().map(|t| format!("{} tools", t.len())),
39 )
40 .field("max_tool_iterations", &self.max_tool_iterations)
41 .field(
42 "tool_loop_detection",
43 &self.tool_loop_detection.as_ref().map(|c| c.enabled),
44 )
45 .finish()
46 }
47}
48
49impl AgentConfig {
50 pub fn builder(name: impl Into<String>) -> AgentConfigBuilder {
51 AgentConfigBuilder {
52 name: name.into(),
53 system_prompt: String::new(),
54 tools: None,
55 max_tool_iterations: 10,
56 tool_loop_detection: Some(ToolLoopDetectionConfig::default()),
57 }
58 }
59}
60
61/// Builder for AgentConfig
62pub struct AgentConfigBuilder {
63 name: String,
64 system_prompt: String,
65 tools: Option<Arc<ToolRegistry>>,
66 max_tool_iterations: usize,
67 tool_loop_detection: Option<ToolLoopDetectionConfig>,
68}
69
70impl AgentConfigBuilder {
71 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
72 self.system_prompt = prompt.into();
73 self
74 }
75
76 pub fn tools(mut self, tools: Arc<ToolRegistry>) -> Self {
77 self.tools = Some(tools);
78 self
79 }
80
81 pub fn max_tool_iterations(mut self, max: usize) -> Self {
82 self.max_tool_iterations = max;
83 self
84 }
85
86 pub fn tool_loop_detection(mut self, config: ToolLoopDetectionConfig) -> Self {
87 self.tool_loop_detection = Some(config);
88 self
89 }
90
91 pub fn disable_tool_loop_detection(mut self) -> Self {
92 self.tool_loop_detection = None;
93 self
94 }
95
96 pub fn build(self) -> AgentConfig {
97 AgentConfig {
98 name: self.name,
99 system_prompt: self.system_prompt,
100 tools: self.tools,
101 max_tool_iterations: self.max_tool_iterations,
102 tool_loop_detection: self.tool_loop_detection,
103 }
104 }
105}
106
107/// Agent execution unit
108pub struct Agent {
109 config: AgentConfig,
110 llm_client: Option<Arc<dyn ChatClient>>,
111}
112
113impl Agent {
114 pub fn new(config: AgentConfig) -> Self {
115 Self {
116 config,
117 llm_client: None,
118 }
119 }
120
121 pub fn with_llm_client(mut self, client: Arc<dyn ChatClient>) -> Self {
122 self.llm_client = Some(client);
123 self
124 }
125
126 pub fn name(&self) -> &str {
127 &self.config.name
128 }
129
130 pub fn config(&self) -> &AgentConfig {
131 &self.config
132 }
133
134 /// Execute the agent with the given input
135 pub async fn execute(&self, input: &AgentInput) -> AgentResult {
136 self.execute_with_events(input.clone(), None).await
137 }
138
139 /// Execute the agent with event stream for observability
140 pub async fn execute_with_events(
141 &self,
142 input: AgentInput,
143 event_stream: Option<&EventStream>,
144 ) -> AgentResult {
145 let start = std::time::Instant::now();
146
147 // Emit agent processing event
148 if let Some(stream) = event_stream {
149 stream.append(
150 EventType::AgentProcessing,
151 input
152 .metadata
153 .previous_agent
154 .clone()
155 .unwrap_or_else(|| "workflow".to_string()),
156 serde_json::json!({
157 "agent": self.config.name,
158 "input": input.data,
159 }),
160 );
161 }
162
163 // If we have an LLM client, use it
164 if let Some(client) = &self.llm_client {
165 // Build messages from chat_history OR from input data
166 let messages = if let Some(history) = &input.chat_history {
167 // Use provided chat history as-is
168 // Outer layer is managing the conversation context
169 history.clone()
170 } else {
171 // Build messages from scratch (legacy behavior)
172 let user_message = if let Some(s) = input.data.as_str() {
173 s.to_string()
174 } else {
175 serde_json::to_string_pretty(&input.data).unwrap_or_default()
176 };
177
178 vec![
179 ChatMessage::system(&self.config.system_prompt),
180 ChatMessage::user(&user_message),
181 ]
182 };
183
184 let mut request = ChatRequest::new(messages.clone())
185 .with_temperature(0.7)
186 .with_max_tokens(8192);
187
188 // Get tool schemas if available
189 let tool_schemas = self
190 .config
191 .tools
192 .as_ref()
193 .map(|registry| registry.list_tools())
194 .filter(|tools| !tools.is_empty());
195
196 // Tool calling loop
197 let mut iteration = 0;
198 let mut total_tool_calls = 0;
199
200 // Initialize tool call tracker for loop detection
201 let mut tool_tracker = if self.config.tool_loop_detection.is_some() {
202 Some(ToolCallTracker::new())
203 } else {
204 None
205 };
206
207 loop {
208 iteration += 1;
209
210 // Check iteration limit
211 if iteration > self.config.max_tool_iterations {
212 return Err(AgentError::ExecutionError(format!(
213 "Maximum tool iterations ({}) exceeded",
214 self.config.max_tool_iterations
215 )));
216 }
217
218 // Add tools to request if available
219 if let Some(ref schemas) = tool_schemas {
220 request = request.with_tools(schemas.clone());
221 }
222
223 // Emit LLM request started event
224 if let Some(stream) = event_stream {
225 stream.append(
226 EventType::AgentLlmRequestStarted,
227 input
228 .metadata
229 .previous_agent
230 .clone()
231 .unwrap_or_else(|| "workflow".to_string()),
232 serde_json::json!({
233 "agent": self.config.name,
234 "iteration": iteration,
235 }),
236 );
237 }
238
239 // Call LLM with streaming + full response (for tool calls)
240 let event_stream_for_streaming = event_stream.cloned();
241 let agent_name = self.config.name.clone();
242 let previous_agent = input
243 .metadata
244 .previous_agent
245 .clone()
246 .unwrap_or_else(|| "workflow".to_string());
247
248 // Create channel for streaming chunks
249 let (chunk_tx, mut chunk_rx) = tokio::sync::mpsc::channel(100);
250
251 // Spawn task to receive chunks and emit events
252 let _chunk_event_task = tokio::spawn(async move {
253 while let Some(chunk) = chunk_rx.recv().await {
254 if let Some(stream) = &event_stream_for_streaming {
255 stream.append(
256 EventType::AgentLlmStreamChunk,
257 previous_agent.clone(),
258 serde_json::json!({
259 "agent": &agent_name,
260 "chunk": chunk,
261 }),
262 );
263 }
264 }
265 });
266
267 match client.chat_stream(request.clone(), chunk_tx).await {
268 Ok(response) => {
269 // Emit LLM request completed event
270 if let Some(stream) = event_stream {
271 stream.append(
272 EventType::AgentLlmRequestCompleted,
273 input
274 .metadata
275 .previous_agent
276 .clone()
277 .unwrap_or_else(|| "workflow".to_string()),
278 serde_json::json!({
279 "agent": self.config.name,
280 }),
281 );
282 }
283
284 // Check if we have tool calls (and they're not empty)
285 if let Some(tool_calls) = response.tool_calls.clone() {
286 if tool_calls.is_empty() {
287 // Empty tool calls array - treat as final response
288 } else {
289 total_tool_calls += tool_calls.len();
290
291 // Add assistant message with tool calls to conversation
292 let assistant_msg = ChatMessage::assistant_with_tool_calls(
293 response.content.clone(),
294 tool_calls.clone(),
295 );
296 request.messages.push(assistant_msg);
297
298 // Execute each tool call
299 for tool_call in tool_calls {
300 // Check for duplicate tool call (loop detection)
301 if let (Some(tracker), Some(loop_config)) =
302 (&tool_tracker, &self.config.tool_loop_detection)
303 {
304 if loop_config.enabled {
305 // Parse tool arguments from JSON string
306 let args_value: serde_json::Value =
307 serde_json::from_str(&tool_call.function.arguments)
308 .unwrap_or(serde_json::json!({}));
309
310 // Convert to HashMap for comparison
311 let args_map: HashMap<String, serde_json::Value> =
312 args_value
313 .as_object()
314 .map(|obj| {
315 obj.iter()
316 .map(|(k, v)| (k.clone(), v.clone()))
317 .collect()
318 })
319 .unwrap_or_default();
320
321 if let Some(previous_result) = tracker
322 .check_for_loop(&tool_call.function.name, &args_map)
323 {
324 // Loop detected! Inject message instead of calling tool
325 let loop_message = loop_config.get_message(
326 &tool_call.function.name,
327 &previous_result,
328 );
329
330 // Emit loop detected event
331 if let Some(stream) = event_stream {
332 stream.append(
333 EventType::AgentToolLoopDetected,
334 input
335 .metadata
336 .previous_agent
337 .clone()
338 .unwrap_or_else(|| {
339 "workflow".to_string()
340 }),
341 serde_json::json!({
342 "agent": self.config.name,
343 "tool": tool_call.function.name,
344 "message": loop_message,
345 }),
346 );
347 }
348
349 // Add system message explaining the loop
350 let tool_msg = ChatMessage::tool_result(
351 &tool_call.id,
352 &loop_message,
353 );
354 request.messages.push(tool_msg);
355
356 // Skip actual tool execution
357 continue;
358 }
359 }
360 }
361
362 // No loop detected - execute the tool normally
363 let tool_result = self
364 .execute_tool_call(
365 &tool_call,
366 &input
367 .metadata
368 .previous_agent
369 .clone()
370 .unwrap_or_else(|| "workflow".to_string()),
371 event_stream,
372 )
373 .await;
374
375 // Record this call in the tracker
376 if let Some(tracker) = &mut tool_tracker {
377 // Parse tool arguments from JSON string
378 let args_value: serde_json::Value =
379 serde_json::from_str(&tool_call.function.arguments)
380 .unwrap_or(serde_json::json!({}));
381
382 // Convert to HashMap
383 let args_map: HashMap<String, serde_json::Value> =
384 args_value
385 .as_object()
386 .map(|obj| {
387 obj.iter()
388 .map(|(k, v)| (k.clone(), v.clone()))
389 .collect()
390 })
391 .unwrap_or_default();
392
393 let result_json = serde_json::to_value(&tool_result)
394 .unwrap_or(serde_json::json!({}));
395 tracker.record_call(
396 &tool_call.function.name,
397 &args_map,
398 &result_json,
399 );
400 }
401
402 // Add tool result to conversation
403 let tool_msg =
404 ChatMessage::tool_result(&tool_call.id, &tool_result);
405 request.messages.push(tool_msg);
406 }
407
408 // Continue loop to get next response
409 continue;
410 }
411 }
412
413 // No tool calls (or empty array), we have the final response
414 let response_text = response.content.trim();
415 let token_count = response
416 .usage
417 .map(|u| u.total_tokens)
418 .unwrap_or_else(|| (response_text.len() as f32 / 4.0).ceil() as u32);
419
420 let output_data = serde_json::json!({
421 "response": response_text,
422 "content_type": "text/plain",
423 "token_count": token_count,
424 });
425
426 // Add final assistant response to chat history
427 request.messages.push(ChatMessage::assistant(response_text));
428
429 // Emit agent completed event
430 if let Some(stream) = event_stream {
431 stream.append(
432 EventType::AgentCompleted,
433 input
434 .metadata
435 .previous_agent
436 .clone()
437 .unwrap_or_else(|| "workflow".to_string()),
438 serde_json::json!({
439 "agent": self.config.name,
440 "execution_time_ms": start.elapsed().as_millis() as u64,
441 }),
442 );
443 }
444
445 return Ok(AgentOutput {
446 data: output_data,
447 metadata: AgentOutputMetadata {
448 agent_name: self.config.name.clone(),
449 execution_time_ms: start.elapsed().as_millis() as u64,
450 tool_calls_count: total_tool_calls,
451 },
452 chat_history: Some(request.messages),
453 });
454 }
455 Err(e) => {
456 // Emit LLM request failed event
457 if let Some(stream) = event_stream {
458 stream.append(
459 EventType::AgentLlmRequestFailed,
460 input
461 .metadata
462 .previous_agent
463 .clone()
464 .unwrap_or_else(|| "workflow".to_string()),
465 serde_json::json!({
466 "agent": self.config.name,
467 "error": e.to_string(),
468 }),
469 );
470 }
471
472 // Emit agent failed event
473 if let Some(stream) = event_stream {
474 stream.append(
475 EventType::AgentFailed,
476 input
477 .metadata
478 .previous_agent
479 .clone()
480 .unwrap_or_else(|| "workflow".to_string()),
481 serde_json::json!({
482 "agent": self.config.name,
483 "error": e.to_string(),
484 }),
485 );
486 }
487
488 return Err(AgentError::ExecutionError(format!(
489 "LLM call failed: {}",
490 e
491 )));
492 }
493 }
494 }
495 } else {
496 // Mock execution fallback
497 let output_data = serde_json::json!({
498 "agent": self.config.name,
499 "processed": input.data,
500 "system_prompt": self.config.system_prompt,
501 "note": "Mock execution - no LLM client configured"
502 });
503
504 if let Some(stream) = event_stream {
505 stream.append(
506 EventType::AgentCompleted,
507 input
508 .metadata
509 .previous_agent
510 .clone()
511 .unwrap_or_else(|| "workflow".to_string()),
512 serde_json::json!({
513 "agent": self.config.name,
514 "execution_time_ms": start.elapsed().as_millis() as u64,
515 "mock": true,
516 }),
517 );
518 }
519
520 Ok(AgentOutput {
521 data: output_data,
522 metadata: AgentOutputMetadata {
523 agent_name: self.config.name.clone(),
524 execution_time_ms: start.elapsed().as_millis() as u64,
525 tool_calls_count: 0,
526 },
527 chat_history: None, // No LLM client means no chat history
528 })
529 }
530 }
531
532 /// Execute a single tool call
533 async fn execute_tool_call(
534 &self,
535 tool_call: &ToolCall,
536 previous_agent: &str,
537 event_stream: Option<&EventStream>,
538 ) -> String {
539 let tool_name = &tool_call.function.name;
540
541 // Emit tool call started event
542 if let Some(stream) = event_stream {
543 stream.append(
544 EventType::ToolCallStarted,
545 previous_agent.to_string(),
546 serde_json::json!({
547 "agent": self.config.name,
548 "tool": tool_name,
549 "tool_call_id": tool_call.id,
550 "arguments": tool_call.function.arguments,
551 }),
552 );
553 }
554
555 // Get the tool registry
556 let registry = match &self.config.tools {
557 Some(reg) => reg,
558 None => {
559 let error_msg = "No tool registry configured".to_string();
560 if let Some(stream) = event_stream {
561 stream.append(
562 EventType::ToolCallFailed,
563 previous_agent.to_string(),
564 serde_json::json!({
565 "agent": self.config.name,
566 "tool": tool_name,
567 "tool_call_id": tool_call.id,
568 "arguments": tool_call.function.arguments,
569 "error": error_msg,
570 "duration_ms": 0,
571 }),
572 );
573 }
574 return format!("Error: {}", error_msg);
575 }
576 };
577
578 // Parse arguments from JSON string
579 let params: HashMap<String, serde_json::Value> =
580 match serde_json::from_str(&tool_call.function.arguments) {
581 Ok(p) => p,
582 Err(e) => {
583 let error_msg = format!("Failed to parse tool arguments: {}", e);
584 if let Some(stream) = event_stream {
585 stream.append(
586 EventType::ToolCallFailed,
587 previous_agent.to_string(),
588 serde_json::json!({
589 "agent": self.config.name,
590 "tool": tool_name,
591 "tool_call_id": tool_call.id,
592 "arguments": tool_call.function.arguments,
593 "error": error_msg,
594 "duration_ms": 0,
595 }),
596 );
597 }
598 return format!("Error: {}", error_msg);
599 }
600 };
601
602 // Execute the tool
603 let start_time = std::time::Instant::now();
604 match registry.call_tool(tool_name, params.clone()).await {
605 Ok(result) => {
606 // Emit tool call completed event
607 if let Some(stream) = event_stream {
608 stream.append(
609 EventType::ToolCallCompleted,
610 previous_agent.to_string(),
611 serde_json::json!({
612 "agent": self.config.name,
613 "tool": tool_name,
614 "tool_call_id": tool_call.id,
615 "arguments": params,
616 "result": result.output,
617 "duration_ms": (result.duration_ms * 1000.0).round() / 1000.0,
618 }),
619 );
620 }
621
622 // Convert result to string for LLM
623 serde_json::to_string(&result.output).unwrap_or_else(|_| result.output.to_string())
624 }
625 Err(e) => {
626 let error_msg = format!("Tool execution failed: {}", e);
627 if let Some(stream) = event_stream {
628 stream.append(
629 EventType::ToolCallFailed,
630 previous_agent.to_string(),
631 serde_json::json!({
632 "agent": self.config.name,
633 "tool": tool_name,
634 "tool_call_id": tool_call.id,
635 "arguments": params,
636 "error": error_msg,
637 "duration_ms": start_time.elapsed().as_secs_f64() * 1000.0,
638 }),
639 );
640 }
641 format!("Error: {}", error_msg)
642 }
643 }
644 }
645}