Skip to main content

agent_runtime/
agent.rs

1use crate::event::EventStream;
2use crate::llm::types::ToolCall;
3use crate::llm::{ChatClient, ChatMessage, ChatRequest};
4use crate::tool::ToolRegistry;
5use crate::tool_loop_detection::{ToolCallTracker, ToolLoopDetectionConfig};
6use crate::types::{AgentError, AgentInput, AgentOutput, AgentOutputMetadata, AgentResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[cfg(test)]
12#[path = "agent_test.rs"]
13mod agent_test;
14
15/// Agent configuration
16#[derive(Clone, Serialize, Deserialize)]
17pub struct AgentConfig {
18    pub name: String,
19    pub system_prompt: String,
20
21    #[serde(skip)]
22    pub tools: Option<Arc<ToolRegistry>>,
23
24    pub max_tool_iterations: usize,
25
26    /// Tool loop detection configuration
27    #[serde(skip)]
28    pub tool_loop_detection: Option<ToolLoopDetectionConfig>,
29}
30
31impl std::fmt::Debug for AgentConfig {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("AgentConfig")
34            .field("name", &self.name)
35            .field("system_prompt", &self.system_prompt)
36            .field(
37                "tools",
38                &self.tools.as_ref().map(|t| format!("{} tools", t.len())),
39            )
40            .field("max_tool_iterations", &self.max_tool_iterations)
41            .field(
42                "tool_loop_detection",
43                &self.tool_loop_detection.as_ref().map(|c| c.enabled),
44            )
45            .finish()
46    }
47}
48
49impl AgentConfig {
50    pub fn builder(name: impl Into<String>) -> AgentConfigBuilder {
51        AgentConfigBuilder {
52            name: name.into(),
53            system_prompt: String::new(),
54            tools: None,
55            max_tool_iterations: 10,
56            tool_loop_detection: Some(ToolLoopDetectionConfig::default()),
57        }
58    }
59}
60
61/// Builder for AgentConfig
62pub struct AgentConfigBuilder {
63    name: String,
64    system_prompt: String,
65    tools: Option<Arc<ToolRegistry>>,
66    max_tool_iterations: usize,
67    tool_loop_detection: Option<ToolLoopDetectionConfig>,
68}
69
70impl AgentConfigBuilder {
71    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
72        self.system_prompt = prompt.into();
73        self
74    }
75
76    pub fn tools(mut self, tools: Arc<ToolRegistry>) -> Self {
77        self.tools = Some(tools);
78        self
79    }
80
81    pub fn max_tool_iterations(mut self, max: usize) -> Self {
82        self.max_tool_iterations = max;
83        self
84    }
85
86    pub fn tool_loop_detection(mut self, config: ToolLoopDetectionConfig) -> Self {
87        self.tool_loop_detection = Some(config);
88        self
89    }
90
91    pub fn disable_tool_loop_detection(mut self) -> Self {
92        self.tool_loop_detection = None;
93        self
94    }
95
96    pub fn build(self) -> AgentConfig {
97        AgentConfig {
98            name: self.name,
99            system_prompt: self.system_prompt,
100            tools: self.tools,
101            max_tool_iterations: self.max_tool_iterations,
102            tool_loop_detection: self.tool_loop_detection,
103        }
104    }
105}
106
107/// Agent execution unit
108pub struct Agent {
109    config: AgentConfig,
110    llm_client: Option<Arc<dyn ChatClient>>,
111}
112
113impl Agent {
114    pub fn new(config: AgentConfig) -> Self {
115        Self {
116            config,
117            llm_client: None,
118        }
119    }
120
121    pub fn with_llm_client(mut self, client: Arc<dyn ChatClient>) -> Self {
122        self.llm_client = Some(client);
123        self
124    }
125
126    pub fn name(&self) -> &str {
127        &self.config.name
128    }
129
130    pub fn config(&self) -> &AgentConfig {
131        &self.config
132    }
133
134    /// Execute the agent with the given input
135    pub async fn execute(&self, input: &AgentInput) -> AgentResult {
136        self.execute_with_events(input.clone(), None).await
137    }
138
139    /// Execute the agent with event stream for observability
140    pub async fn execute_with_events(
141        &self,
142        input: AgentInput,
143        event_stream: Option<&EventStream>,
144    ) -> AgentResult {
145        let start = std::time::Instant::now();
146
147        let workflow_id = input
148            .metadata
149            .previous_agent
150            .clone()
151            .unwrap_or_else(|| "workflow".to_string());
152
153        // Emit Agent::Started event
154        if let Some(stream) = event_stream {
155            stream.agent_started(
156                &self.config.name,
157                workflow_id.clone(),
158                serde_json::json!({
159                    "input": input.data,
160                }),
161            );
162        }
163
164        // If we have an LLM client, use it
165        if let Some(client) = &self.llm_client {
166            // Build messages from chat_history OR from input data
167            let messages = if let Some(history) = &input.chat_history {
168                // Use provided chat history as-is
169                // Outer layer is managing the conversation context
170                history.clone()
171            } else {
172                // Build messages from scratch (legacy behavior)
173                let user_message = if let Some(s) = input.data.as_str() {
174                    s.to_string()
175                } else {
176                    serde_json::to_string_pretty(&input.data).unwrap_or_default()
177                };
178
179                vec![
180                    ChatMessage::system(&self.config.system_prompt),
181                    ChatMessage::user(&user_message),
182                ]
183            };
184
185            let mut request = ChatRequest::new(messages.clone())
186                .with_temperature(0.7)
187                .with_max_tokens(8192);
188
189            // Get tool schemas if available
190            let tool_schemas = self
191                .config
192                .tools
193                .as_ref()
194                .map(|registry| registry.list_tools())
195                .filter(|tools| !tools.is_empty());
196
197            // Tool calling loop
198            let mut iteration = 0;
199            let mut total_tool_calls = 0;
200
201            // Initialize tool call tracker for loop detection
202            let mut tool_tracker = if self.config.tool_loop_detection.is_some() {
203                Some(ToolCallTracker::new())
204            } else {
205                None
206            };
207
208            loop {
209                iteration += 1;
210
211                // Check iteration limit
212                if iteration > self.config.max_tool_iterations {
213                    return Err(AgentError::ExecutionError(format!(
214                        "Maximum tool iterations ({}) exceeded",
215                        self.config.max_tool_iterations
216                    )));
217                }
218
219                // Add tools to request if available
220                if let Some(ref schemas) = tool_schemas {
221                    request = request.with_tools(schemas.clone());
222                }
223
224                // Emit LlmRequest::Started event
225                if let Some(stream) = event_stream {
226                    stream.llm_started(
227                        &self.config.name,
228                        iteration,
229                        workflow_id.clone(),
230                        serde_json::json!({
231                            "messages": request.messages.len(),
232                        }),
233                    );
234                }
235
236                // Call LLM with streaming + full response (for tool calls)
237                let event_stream_for_streaming = event_stream.cloned();
238                let agent_name = self.config.name.clone();
239                let workflow_id_for_streaming = workflow_id.clone();
240
241                // Create channel for streaming chunks
242                let (chunk_tx, mut chunk_rx) = tokio::sync::mpsc::channel(100);
243
244                // Spawn task to receive chunks and emit events
245                let _chunk_event_task = tokio::spawn(async move {
246                    while let Some(chunk) = chunk_rx.recv().await {
247                        if let Some(stream) = &event_stream_for_streaming {
248                            stream.llm_progress(
249                                &agent_name,
250                                iteration,
251                                workflow_id_for_streaming.clone(),
252                                chunk,
253                            );
254                        }
255                    }
256                });
257
258                match client.chat_stream(request.clone(), chunk_tx).await {
259                    Ok(response) => {
260                        // Emit LlmRequest::Completed event
261                        if let Some(stream) = event_stream {
262                            stream.llm_completed(
263                                &self.config.name,
264                                iteration,
265                                workflow_id.clone(),
266                                serde_json::json!({
267                                    "content": response.content.chars().take(100).collect::<String>(),
268                                    "has_tool_calls": response.tool_calls.is_some(),
269                                }),
270                            );
271                        }
272
273                        // Check if we have tool calls (and they're not empty)
274                        if let Some(tool_calls) = response.tool_calls.clone() {
275                            if tool_calls.is_empty() {
276                                // Empty tool calls array - treat as final response
277                            } else {
278                                total_tool_calls += tool_calls.len();
279
280                                // Add assistant message with tool calls to conversation
281                                let assistant_msg = ChatMessage::assistant_with_tool_calls(
282                                    response.content.clone(),
283                                    tool_calls.clone(),
284                                );
285                                request.messages.push(assistant_msg);
286
287                                // Execute each tool call
288                                for tool_call in tool_calls {
289                                    // Check for duplicate tool call (loop detection)
290                                    if let (Some(tracker), Some(loop_config)) =
291                                        (&tool_tracker, &self.config.tool_loop_detection)
292                                    {
293                                        if loop_config.enabled {
294                                            // Parse tool arguments from JSON string
295                                            let args_value: serde_json::Value =
296                                                serde_json::from_str(&tool_call.function.arguments)
297                                                    .unwrap_or(serde_json::json!({}));
298
299                                            // Convert to HashMap for comparison
300                                            let args_map: HashMap<String, serde_json::Value> =
301                                                args_value
302                                                    .as_object()
303                                                    .map(|obj| {
304                                                        obj.iter()
305                                                            .map(|(k, v)| (k.clone(), v.clone()))
306                                                            .collect()
307                                                    })
308                                                    .unwrap_or_default();
309
310                                            if let Some(previous_result) = tracker
311                                                .check_for_loop(&tool_call.function.name, &args_map)
312                                            {
313                                                // Loop detected! Inject message instead of calling tool
314                                                let loop_message = loop_config.get_message(
315                                                    &tool_call.function.name,
316                                                    &previous_result,
317                                                );
318
319                                                // Emit tool loop detected event (System scope)
320                                                if let Some(stream) = event_stream {
321                                                    stream.append(
322                                                        crate::event::EventScope::System,
323                                                        crate::event::EventType::Progress,
324                                                        "system:tool_loop_detection".to_string(),
325                                                        crate::event::ComponentStatus::Running,
326                                                        workflow_id.clone(),
327                                                        Some(format!(
328                                                            "Tool loop detected: {}",
329                                                            tool_call.function.name
330                                                        )),
331                                                        serde_json::json!({
332                                                            "agent": self.config.name,
333                                                            "tool": tool_call.function.name,
334                                                            "message": loop_message,
335                                                        }),
336                                                    );
337                                                }
338
339                                                // Add system message explaining the loop
340                                                let tool_msg = ChatMessage::tool_result(
341                                                    &tool_call.id,
342                                                    &loop_message,
343                                                );
344                                                request.messages.push(tool_msg);
345
346                                                // Skip actual tool execution
347                                                continue;
348                                            }
349                                        }
350                                    }
351
352                                    // No loop detected - execute the tool normally
353                                    let tool_result = self
354                                        .execute_tool_call(
355                                            &tool_call,
356                                            &input
357                                                .metadata
358                                                .previous_agent
359                                                .clone()
360                                                .unwrap_or_else(|| "workflow".to_string()),
361                                            event_stream,
362                                        )
363                                        .await;
364
365                                    // Record this call in the tracker
366                                    if let Some(tracker) = &mut tool_tracker {
367                                        // Parse tool arguments from JSON string
368                                        let args_value: serde_json::Value =
369                                            serde_json::from_str(&tool_call.function.arguments)
370                                                .unwrap_or(serde_json::json!({}));
371
372                                        // Convert to HashMap
373                                        let args_map: HashMap<String, serde_json::Value> =
374                                            args_value
375                                                .as_object()
376                                                .map(|obj| {
377                                                    obj.iter()
378                                                        .map(|(k, v)| (k.clone(), v.clone()))
379                                                        .collect()
380                                                })
381                                                .unwrap_or_default();
382
383                                        let result_json = serde_json::to_value(&tool_result)
384                                            .unwrap_or(serde_json::json!({}));
385                                        tracker.record_call(
386                                            &tool_call.function.name,
387                                            &args_map,
388                                            &result_json,
389                                        );
390                                    }
391
392                                    // Add tool result to conversation
393                                    let tool_msg =
394                                        ChatMessage::tool_result(&tool_call.id, &tool_result);
395                                    request.messages.push(tool_msg);
396                                }
397
398                                // Continue loop to get next response
399                                continue;
400                            }
401                        }
402
403                        // No tool calls (or empty array), we have the final response
404                        let response_text = response.content.trim();
405                        let token_count = response
406                            .usage
407                            .map(|u| u.total_tokens)
408                            .unwrap_or_else(|| (response_text.len() as f32 / 4.0).ceil() as u32);
409
410                        let output_data = serde_json::json!({
411                            "response": response_text,
412                            "content_type": "text/plain",
413                            "token_count": token_count,
414                        });
415
416                        // Add final assistant response to chat history
417                        request.messages.push(ChatMessage::assistant(response_text));
418
419                        // Emit Agent::Completed event
420                        if let Some(stream) = event_stream {
421                            stream.agent_completed(
422                                &self.config.name,
423                                workflow_id.clone(),
424                                Some(format!(
425                                    "Agent completed in {}ms",
426                                    start.elapsed().as_millis()
427                                )),
428                                serde_json::json!({
429                                    "execution_time_ms": start.elapsed().as_millis() as u64,
430                                    "tool_calls": total_tool_calls,
431                                    "iterations": iteration,
432                                }),
433                            );
434                        }
435
436                        return Ok(AgentOutput {
437                            data: output_data,
438                            metadata: AgentOutputMetadata {
439                                agent_name: self.config.name.clone(),
440                                execution_time_ms: start.elapsed().as_millis() as u64,
441                                tool_calls_count: total_tool_calls,
442                            },
443                            chat_history: Some(request.messages),
444                        });
445                    }
446                    Err(e) => {
447                        // Emit LlmRequest::Failed event
448                        if let Some(stream) = event_stream {
449                            stream.llm_failed(
450                                &self.config.name,
451                                iteration,
452                                workflow_id.clone(),
453                                &e.to_string(),
454                            );
455                        }
456
457                        // Emit Agent::Failed event
458                        if let Some(stream) = event_stream {
459                            stream.agent_failed(
460                                &self.config.name,
461                                workflow_id.clone(),
462                                &e.to_string(),
463                                serde_json::json!({}),
464                            );
465                        }
466
467                        return Err(AgentError::ExecutionError(format!(
468                            "LLM call failed: {}",
469                            e
470                        )));
471                    }
472                }
473            }
474        } else {
475            // Mock execution fallback
476            let output_data = serde_json::json!({
477                "agent": self.config.name,
478                "processed": input.data,
479                "system_prompt": self.config.system_prompt,
480                "note": "Mock execution - no LLM client configured"
481            });
482
483            if let Some(stream) = event_stream {
484                stream.agent_completed(
485                    &self.config.name,
486                    workflow_id.clone(),
487                    Some("Agent completed (no LLM)".to_string()),
488                    serde_json::json!({
489                        "execution_time_ms": start.elapsed().as_millis() as u64,
490                        "mock": true,
491                    }),
492                );
493            }
494
495            Ok(AgentOutput {
496                data: output_data,
497                metadata: AgentOutputMetadata {
498                    agent_name: self.config.name.clone(),
499                    execution_time_ms: start.elapsed().as_millis() as u64,
500                    tool_calls_count: 0,
501                },
502                chat_history: None, // No LLM client means no chat history
503            })
504        }
505    }
506
507    /// Execute a single tool call
508    async fn execute_tool_call(
509        &self,
510        tool_call: &ToolCall,
511        previous_agent: &str,
512        event_stream: Option<&EventStream>,
513    ) -> String {
514        let tool_name = &tool_call.function.name;
515
516        // Emit Tool::Started event
517        if let Some(stream) = event_stream {
518            stream.tool_started(
519                tool_name,
520                previous_agent.to_string(),
521                serde_json::json!({
522                    "agent": self.config.name,
523                    "tool_call_id": tool_call.id,
524                    "arguments": tool_call.function.arguments,
525                }),
526            );
527        }
528
529        // Get the tool registry
530        let registry = match &self.config.tools {
531            Some(reg) => reg,
532            None => {
533                let error_msg = "No tool registry configured".to_string();
534                if let Some(stream) = event_stream {
535                    stream.tool_failed(
536                        tool_name,
537                        previous_agent.to_string(),
538                        &error_msg,
539                        serde_json::json!({
540                            "agent": self.config.name,
541                            "tool_call_id": tool_call.id,
542                            "duration_ms": 0,
543                        }),
544                    );
545                }
546                return format!("Error: {}", error_msg);
547            }
548        };
549
550        // Parse arguments from JSON string
551        let params: HashMap<String, serde_json::Value> =
552            match serde_json::from_str(&tool_call.function.arguments) {
553                Ok(p) => p,
554                Err(e) => {
555                    let error_msg = format!("Failed to parse tool arguments: {}", e);
556                    if let Some(stream) = event_stream {
557                        stream.tool_failed(
558                            tool_name,
559                            previous_agent.to_string(),
560                            &error_msg,
561                            serde_json::json!({
562                                "agent": self.config.name,
563                                "tool_call_id": tool_call.id,
564                                "duration_ms": 0,
565                                "duration_ms": 0,
566                            }),
567                        );
568                    }
569                    return format!("Error: {}", error_msg);
570                }
571            };
572
573        // Execute the tool
574        let start_time = std::time::Instant::now();
575        match registry.call_tool(tool_name, params.clone()).await {
576            Ok(result) => {
577                // Emit Tool::Completed event
578                if let Some(stream) = event_stream {
579                    stream.tool_completed(
580                        tool_name,
581                        previous_agent.to_string(),
582                        serde_json::json!({
583                            "agent": self.config.name,
584                            "tool_call_id": tool_call.id,
585                            "result": result.output,
586                            "duration_ms": (result.duration_ms * 1000.0).round() / 1000.0,
587                        }),
588                    );
589                }
590
591                // Convert result to string for LLM
592                serde_json::to_string(&result.output).unwrap_or_else(|_| result.output.to_string())
593            }
594            Err(e) => {
595                let error_msg = format!("Tool execution failed: {}", e);
596                if let Some(stream) = event_stream {
597                    stream.tool_failed(
598                        tool_name,
599                        previous_agent.to_string(),
600                        &error_msg,
601                        serde_json::json!({
602                            "agent": self.config.name,
603                            "tool_call_id": tool_call.id,
604                            "duration_ms": start_time.elapsed().as_secs_f64() * 1000.0,
605                        }),
606                    );
607                }
608                format!("Error: {}", error_msg)
609            }
610        }
611    }
612}