Skip to main content

agent_runtime/
agent.rs

1use crate::event::{EventStream, EventType};
2use crate::llm::types::ToolCall;
3use crate::llm::{ChatClient, ChatMessage, ChatRequest};
4use crate::tool::ToolRegistry;
5use crate::tool_loop_detection::{ToolCallTracker, ToolLoopDetectionConfig};
6use crate::types::{AgentError, AgentInput, AgentOutput, AgentOutputMetadata, AgentResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[cfg(test)]
12#[path = "agent_test.rs"]
13mod agent_test;
14
15/// Agent configuration
16#[derive(Clone, Serialize, Deserialize)]
17pub struct AgentConfig {
18    pub name: String,
19    pub system_prompt: String,
20
21    #[serde(skip)]
22    pub tools: Option<Arc<ToolRegistry>>,
23
24    pub max_tool_iterations: usize,
25
26    /// Tool loop detection configuration
27    #[serde(skip)]
28    pub tool_loop_detection: Option<ToolLoopDetectionConfig>,
29}
30
31impl std::fmt::Debug for AgentConfig {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("AgentConfig")
34            .field("name", &self.name)
35            .field("system_prompt", &self.system_prompt)
36            .field(
37                "tools",
38                &self.tools.as_ref().map(|t| format!("{} tools", t.len())),
39            )
40            .field("max_tool_iterations", &self.max_tool_iterations)
41            .field(
42                "tool_loop_detection",
43                &self.tool_loop_detection.as_ref().map(|c| c.enabled),
44            )
45            .finish()
46    }
47}
48
49impl AgentConfig {
50    pub fn builder(name: impl Into<String>) -> AgentConfigBuilder {
51        AgentConfigBuilder {
52            name: name.into(),
53            system_prompt: String::new(),
54            tools: None,
55            max_tool_iterations: 10,
56            tool_loop_detection: Some(ToolLoopDetectionConfig::default()),
57        }
58    }
59}
60
61/// Builder for AgentConfig
62pub struct AgentConfigBuilder {
63    name: String,
64    system_prompt: String,
65    tools: Option<Arc<ToolRegistry>>,
66    max_tool_iterations: usize,
67    tool_loop_detection: Option<ToolLoopDetectionConfig>,
68}
69
70impl AgentConfigBuilder {
71    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
72        self.system_prompt = prompt.into();
73        self
74    }
75
76    pub fn tools(mut self, tools: Arc<ToolRegistry>) -> Self {
77        self.tools = Some(tools);
78        self
79    }
80
81    pub fn max_tool_iterations(mut self, max: usize) -> Self {
82        self.max_tool_iterations = max;
83        self
84    }
85
86    pub fn tool_loop_detection(mut self, config: ToolLoopDetectionConfig) -> Self {
87        self.tool_loop_detection = Some(config);
88        self
89    }
90
91    pub fn disable_tool_loop_detection(mut self) -> Self {
92        self.tool_loop_detection = None;
93        self
94    }
95
96    pub fn build(self) -> AgentConfig {
97        AgentConfig {
98            name: self.name,
99            system_prompt: self.system_prompt,
100            tools: self.tools,
101            max_tool_iterations: self.max_tool_iterations,
102            tool_loop_detection: self.tool_loop_detection,
103        }
104    }
105}
106
107/// Agent execution unit
108pub struct Agent {
109    config: AgentConfig,
110    llm_client: Option<Arc<dyn ChatClient>>,
111}
112
113impl Agent {
114    pub fn new(config: AgentConfig) -> Self {
115        Self {
116            config,
117            llm_client: None,
118        }
119    }
120
121    pub fn with_llm_client(mut self, client: Arc<dyn ChatClient>) -> Self {
122        self.llm_client = Some(client);
123        self
124    }
125
126    pub fn name(&self) -> &str {
127        &self.config.name
128    }
129
130    pub fn config(&self) -> &AgentConfig {
131        &self.config
132    }
133
134    /// Execute the agent with the given input
135    pub async fn execute(&self, input: &AgentInput) -> AgentResult {
136        self.execute_with_events(input.clone(), None).await
137    }
138
139    /// Execute the agent with event stream for observability
140    pub async fn execute_with_events(
141        &self,
142        input: AgentInput,
143        event_stream: Option<&EventStream>,
144    ) -> AgentResult {
145        let start = std::time::Instant::now();
146
147        // Emit agent processing event
148        if let Some(stream) = event_stream {
149            stream.append(
150                EventType::AgentProcessing,
151                input
152                    .metadata
153                    .previous_agent
154                    .clone()
155                    .unwrap_or_else(|| "workflow".to_string()),
156                serde_json::json!({
157                    "agent": self.config.name,
158                    "input": input.data,
159                }),
160            );
161        }
162
163        // If we have an LLM client, use it
164        if let Some(client) = &self.llm_client {
165            // Convert input to user message
166            let user_message = if let Some(s) = input.data.as_str() {
167                s.to_string()
168            } else {
169                serde_json::to_string_pretty(&input.data).unwrap_or_default()
170            };
171
172            // Build messages with system prompt
173            let mut messages = vec![ChatMessage::system(&self.config.system_prompt)];
174            messages.push(ChatMessage::user(&user_message));
175
176            let mut request = ChatRequest::new(messages.clone())
177                .with_temperature(0.7)
178                .with_max_tokens(8192);
179
180            // Get tool schemas if available
181            let tool_schemas = self
182                .config
183                .tools
184                .as_ref()
185                .map(|registry| registry.list_tools())
186                .filter(|tools| !tools.is_empty());
187
188            // Tool calling loop
189            let mut iteration = 0;
190            let mut total_tool_calls = 0;
191
192            // Initialize tool call tracker for loop detection
193            let mut tool_tracker = if self.config.tool_loop_detection.is_some() {
194                Some(ToolCallTracker::new())
195            } else {
196                None
197            };
198
199            loop {
200                iteration += 1;
201
202                // Check iteration limit
203                if iteration > self.config.max_tool_iterations {
204                    return Err(AgentError::ExecutionError(format!(
205                        "Maximum tool iterations ({}) exceeded",
206                        self.config.max_tool_iterations
207                    )));
208                }
209
210                // Add tools to request if available
211                if let Some(ref schemas) = tool_schemas {
212                    request = request.with_tools(schemas.clone());
213                }
214
215                // Emit LLM request started event
216                if let Some(stream) = event_stream {
217                    stream.append(
218                        EventType::AgentLlmRequestStarted,
219                        input
220                            .metadata
221                            .previous_agent
222                            .clone()
223                            .unwrap_or_else(|| "workflow".to_string()),
224                        serde_json::json!({
225                            "agent": self.config.name,
226                            "iteration": iteration,
227                        }),
228                    );
229                }
230
231                // Call LLM with streaming + full response (for tool calls)
232                let event_stream_for_streaming = event_stream.cloned();
233                let agent_name = self.config.name.clone();
234                let previous_agent = input
235                    .metadata
236                    .previous_agent
237                    .clone()
238                    .unwrap_or_else(|| "workflow".to_string());
239
240                // Create channel for streaming chunks
241                let (chunk_tx, mut chunk_rx) = tokio::sync::mpsc::channel(100);
242
243                // Spawn task to receive chunks and emit events
244                let _chunk_event_task = tokio::spawn(async move {
245                    while let Some(chunk) = chunk_rx.recv().await {
246                        if let Some(stream) = &event_stream_for_streaming {
247                            stream.append(
248                                EventType::AgentLlmStreamChunk,
249                                previous_agent.clone(),
250                                serde_json::json!({
251                                    "agent": &agent_name,
252                                    "chunk": chunk,
253                                }),
254                            );
255                        }
256                    }
257                });
258
259                match client.chat_stream(request.clone(), chunk_tx).await {
260                    Ok(response) => {
261                        // Emit LLM request completed event
262                        if let Some(stream) = event_stream {
263                            stream.append(
264                                EventType::AgentLlmRequestCompleted,
265                                input
266                                    .metadata
267                                    .previous_agent
268                                    .clone()
269                                    .unwrap_or_else(|| "workflow".to_string()),
270                                serde_json::json!({
271                                    "agent": self.config.name,
272                                }),
273                            );
274                        }
275
276                        // Check if we have tool calls (and they're not empty)
277                        if let Some(tool_calls) = response.tool_calls.clone() {
278                            if tool_calls.is_empty() {
279                                // Empty tool calls array - treat as final response
280                            } else {
281                                total_tool_calls += tool_calls.len();
282
283                                // Add assistant message with tool calls to conversation
284                                let assistant_msg = ChatMessage::assistant_with_tool_calls(
285                                    response.content.clone(),
286                                    tool_calls.clone(),
287                                );
288                                request.messages.push(assistant_msg);
289
290                                // Execute each tool call
291                                for tool_call in tool_calls {
292                                    // Check for duplicate tool call (loop detection)
293                                    if let (Some(tracker), Some(loop_config)) =
294                                        (&tool_tracker, &self.config.tool_loop_detection)
295                                    {
296                                        if loop_config.enabled {
297                                            // Parse tool arguments from JSON string
298                                            let args_value: serde_json::Value =
299                                                serde_json::from_str(&tool_call.function.arguments)
300                                                    .unwrap_or(serde_json::json!({}));
301
302                                            // Convert to HashMap for comparison
303                                            let args_map: HashMap<String, serde_json::Value> =
304                                                args_value
305                                                    .as_object()
306                                                    .map(|obj| {
307                                                        obj.iter()
308                                                            .map(|(k, v)| (k.clone(), v.clone()))
309                                                            .collect()
310                                                    })
311                                                    .unwrap_or_default();
312
313                                            if let Some(previous_result) = tracker
314                                                .check_for_loop(&tool_call.function.name, &args_map)
315                                            {
316                                                // Loop detected! Inject message instead of calling tool
317                                                let loop_message = loop_config.get_message(
318                                                    &tool_call.function.name,
319                                                    &previous_result,
320                                                );
321
322                                                // Emit loop detected event
323                                                if let Some(stream) = event_stream {
324                                                    stream.append(
325                                                        EventType::AgentToolLoopDetected,
326                                                        input
327                                                            .metadata
328                                                            .previous_agent
329                                                            .clone()
330                                                            .unwrap_or_else(|| {
331                                                                "workflow".to_string()
332                                                            }),
333                                                        serde_json::json!({
334                                                            "agent": self.config.name,
335                                                            "tool": tool_call.function.name,
336                                                            "message": loop_message,
337                                                        }),
338                                                    );
339                                                }
340
341                                                // Add system message explaining the loop
342                                                let tool_msg = ChatMessage::tool_result(
343                                                    &tool_call.id,
344                                                    &loop_message,
345                                                );
346                                                request.messages.push(tool_msg);
347
348                                                // Skip actual tool execution
349                                                continue;
350                                            }
351                                        }
352                                    }
353
354                                    // No loop detected - execute the tool normally
355                                    let tool_result = self
356                                        .execute_tool_call(
357                                            &tool_call,
358                                            &input
359                                                .metadata
360                                                .previous_agent
361                                                .clone()
362                                                .unwrap_or_else(|| "workflow".to_string()),
363                                            event_stream,
364                                        )
365                                        .await;
366
367                                    // Record this call in the tracker
368                                    if let Some(tracker) = &mut tool_tracker {
369                                        // Parse tool arguments from JSON string
370                                        let args_value: serde_json::Value =
371                                            serde_json::from_str(&tool_call.function.arguments)
372                                                .unwrap_or(serde_json::json!({}));
373
374                                        // Convert to HashMap
375                                        let args_map: HashMap<String, serde_json::Value> =
376                                            args_value
377                                                .as_object()
378                                                .map(|obj| {
379                                                    obj.iter()
380                                                        .map(|(k, v)| (k.clone(), v.clone()))
381                                                        .collect()
382                                                })
383                                                .unwrap_or_default();
384
385                                        let result_json = serde_json::to_value(&tool_result)
386                                            .unwrap_or(serde_json::json!({}));
387                                        tracker.record_call(
388                                            &tool_call.function.name,
389                                            &args_map,
390                                            &result_json,
391                                        );
392                                    }
393
394                                    // Add tool result to conversation
395                                    let tool_msg =
396                                        ChatMessage::tool_result(&tool_call.id, &tool_result);
397                                    request.messages.push(tool_msg);
398                                }
399
400                                // Continue loop to get next response
401                                continue;
402                            }
403                        }
404
405                        // No tool calls (or empty array), we have the final response
406                        let response_text = response.content.trim();
407                        let token_count = response
408                            .usage
409                            .map(|u| u.total_tokens)
410                            .unwrap_or_else(|| (response_text.len() as f32 / 4.0).ceil() as u32);
411
412                        let output_data = serde_json::json!({
413                            "response": response_text,
414                            "content_type": "text/plain",
415                            "token_count": token_count,
416                        });
417
418                        // Emit agent completed event
419                        if let Some(stream) = event_stream {
420                            stream.append(
421                                EventType::AgentCompleted,
422                                input
423                                    .metadata
424                                    .previous_agent
425                                    .clone()
426                                    .unwrap_or_else(|| "workflow".to_string()),
427                                serde_json::json!({
428                                    "agent": self.config.name,
429                                    "execution_time_ms": start.elapsed().as_millis() as u64,
430                                }),
431                            );
432                        }
433
434                        return Ok(AgentOutput {
435                            data: output_data,
436                            metadata: AgentOutputMetadata {
437                                agent_name: self.config.name.clone(),
438                                execution_time_ms: start.elapsed().as_millis() as u64,
439                                tool_calls_count: total_tool_calls,
440                            },
441                        });
442                    }
443                    Err(e) => {
444                        // Emit LLM request failed event
445                        if let Some(stream) = event_stream {
446                            stream.append(
447                                EventType::AgentLlmRequestFailed,
448                                input
449                                    .metadata
450                                    .previous_agent
451                                    .clone()
452                                    .unwrap_or_else(|| "workflow".to_string()),
453                                serde_json::json!({
454                                    "agent": self.config.name,
455                                    "error": e.to_string(),
456                                }),
457                            );
458                        }
459
460                        // Emit agent failed event
461                        if let Some(stream) = event_stream {
462                            stream.append(
463                                EventType::AgentFailed,
464                                input
465                                    .metadata
466                                    .previous_agent
467                                    .clone()
468                                    .unwrap_or_else(|| "workflow".to_string()),
469                                serde_json::json!({
470                                    "agent": self.config.name,
471                                    "error": e.to_string(),
472                                }),
473                            );
474                        }
475
476                        return Err(AgentError::ExecutionError(format!(
477                            "LLM call failed: {}",
478                            e
479                        )));
480                    }
481                }
482            }
483        } else {
484            // Mock execution fallback
485            let output_data = serde_json::json!({
486                "agent": self.config.name,
487                "processed": input.data,
488                "system_prompt": self.config.system_prompt,
489                "note": "Mock execution - no LLM client configured"
490            });
491
492            if let Some(stream) = event_stream {
493                stream.append(
494                    EventType::AgentCompleted,
495                    input
496                        .metadata
497                        .previous_agent
498                        .clone()
499                        .unwrap_or_else(|| "workflow".to_string()),
500                    serde_json::json!({
501                        "agent": self.config.name,
502                        "execution_time_ms": start.elapsed().as_millis() as u64,
503                        "mock": true,
504                    }),
505                );
506            }
507
508            Ok(AgentOutput {
509                data: output_data,
510                metadata: AgentOutputMetadata {
511                    agent_name: self.config.name.clone(),
512                    execution_time_ms: start.elapsed().as_millis() as u64,
513                    tool_calls_count: 0,
514                },
515            })
516        }
517    }
518
519    /// Execute a single tool call
520    async fn execute_tool_call(
521        &self,
522        tool_call: &ToolCall,
523        previous_agent: &str,
524        event_stream: Option<&EventStream>,
525    ) -> String {
526        let tool_name = &tool_call.function.name;
527
528        // Emit tool call started event
529        if let Some(stream) = event_stream {
530            stream.append(
531                EventType::ToolCallStarted,
532                previous_agent.to_string(),
533                serde_json::json!({
534                    "agent": self.config.name,
535                    "tool": tool_name,
536                    "tool_call_id": tool_call.id,
537                    "arguments": tool_call.function.arguments,
538                }),
539            );
540        }
541
542        // Get the tool registry
543        let registry = match &self.config.tools {
544            Some(reg) => reg,
545            None => {
546                let error_msg = "No tool registry configured".to_string();
547                if let Some(stream) = event_stream {
548                    stream.append(
549                        EventType::ToolCallFailed,
550                        previous_agent.to_string(),
551                        serde_json::json!({
552                            "agent": self.config.name,
553                            "tool": tool_name,
554                            "tool_call_id": tool_call.id,
555                            "arguments": tool_call.function.arguments,
556                            "error": error_msg,
557                            "duration_ms": 0,
558                        }),
559                    );
560                }
561                return format!("Error: {}", error_msg);
562            }
563        };
564
565        // Parse arguments from JSON string
566        let params: HashMap<String, serde_json::Value> =
567            match serde_json::from_str(&tool_call.function.arguments) {
568                Ok(p) => p,
569                Err(e) => {
570                    let error_msg = format!("Failed to parse tool arguments: {}", e);
571                    if let Some(stream) = event_stream {
572                        stream.append(
573                            EventType::ToolCallFailed,
574                            previous_agent.to_string(),
575                            serde_json::json!({
576                                "agent": self.config.name,
577                                "tool": tool_name,
578                                "tool_call_id": tool_call.id,
579                                "arguments": tool_call.function.arguments,
580                                "error": error_msg,
581                                "duration_ms": 0,
582                            }),
583                        );
584                    }
585                    return format!("Error: {}", error_msg);
586                }
587            };
588
589        // Execute the tool
590        let start_time = std::time::Instant::now();
591        match registry.call_tool(tool_name, params.clone()).await {
592            Ok(result) => {
593                // Emit tool call completed event
594                if let Some(stream) = event_stream {
595                    stream.append(
596                        EventType::ToolCallCompleted,
597                        previous_agent.to_string(),
598                        serde_json::json!({
599                            "agent": self.config.name,
600                            "tool": tool_name,
601                            "tool_call_id": tool_call.id,
602                            "arguments": params,
603                            "result": result.output,
604                            "duration_ms": (result.duration_ms * 1000.0).round() / 1000.0,
605                        }),
606                    );
607                }
608
609                // Convert result to string for LLM
610                serde_json::to_string(&result.output).unwrap_or_else(|_| result.output.to_string())
611            }
612            Err(e) => {
613                let error_msg = format!("Tool execution failed: {}", e);
614                if let Some(stream) = event_stream {
615                    stream.append(
616                        EventType::ToolCallFailed,
617                        previous_agent.to_string(),
618                        serde_json::json!({
619                            "agent": self.config.name,
620                            "tool": tool_name,
621                            "tool_call_id": tool_call.id,
622                            "arguments": params,
623                            "error": error_msg,
624                            "duration_ms": start_time.elapsed().as_secs_f64() * 1000.0,
625                        }),
626                    );
627                }
628                format!("Error: {}", error_msg)
629            }
630        }
631    }
632}