Skip to main content

agent_runtime/
agent.rs

1use crate::event::{EventStream, EventType};
2use crate::llm::types::ToolCall;
3use crate::llm::{ChatClient, ChatMessage, ChatRequest};
4use crate::tool::ToolRegistry;
5use crate::tool_loop_detection::{ToolCallTracker, ToolLoopDetectionConfig};
6use crate::types::{AgentError, AgentInput, AgentOutput, AgentOutputMetadata, AgentResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[cfg(test)]
12#[path = "agent_test.rs"]
13mod agent_test;
14
15/// Agent configuration
16#[derive(Clone, Serialize, Deserialize)]
17pub struct AgentConfig {
18    pub name: String,
19    pub system_prompt: String,
20
21    #[serde(skip)]
22    pub tools: Option<Arc<ToolRegistry>>,
23
24    pub max_tool_iterations: usize,
25
26    /// Tool loop detection configuration
27    #[serde(skip)]
28    pub tool_loop_detection: Option<ToolLoopDetectionConfig>,
29}
30
31impl std::fmt::Debug for AgentConfig {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("AgentConfig")
34            .field("name", &self.name)
35            .field("system_prompt", &self.system_prompt)
36            .field(
37                "tools",
38                &self.tools.as_ref().map(|t| format!("{} tools", t.len())),
39            )
40            .field("max_tool_iterations", &self.max_tool_iterations)
41            .field(
42                "tool_loop_detection",
43                &self.tool_loop_detection.as_ref().map(|c| c.enabled),
44            )
45            .finish()
46    }
47}
48
49impl AgentConfig {
50    pub fn builder(name: impl Into<String>) -> AgentConfigBuilder {
51        AgentConfigBuilder {
52            name: name.into(),
53            system_prompt: String::new(),
54            tools: None,
55            max_tool_iterations: 10,
56            tool_loop_detection: Some(ToolLoopDetectionConfig::default()),
57        }
58    }
59}
60
61/// Builder for AgentConfig
62pub struct AgentConfigBuilder {
63    name: String,
64    system_prompt: String,
65    tools: Option<Arc<ToolRegistry>>,
66    max_tool_iterations: usize,
67    tool_loop_detection: Option<ToolLoopDetectionConfig>,
68}
69
70impl AgentConfigBuilder {
71    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
72        self.system_prompt = prompt.into();
73        self
74    }
75
76    pub fn tools(mut self, tools: Arc<ToolRegistry>) -> Self {
77        self.tools = Some(tools);
78        self
79    }
80
81    pub fn max_tool_iterations(mut self, max: usize) -> Self {
82        self.max_tool_iterations = max;
83        self
84    }
85
86    pub fn tool_loop_detection(mut self, config: ToolLoopDetectionConfig) -> Self {
87        self.tool_loop_detection = Some(config);
88        self
89    }
90
91    pub fn disable_tool_loop_detection(mut self) -> Self {
92        self.tool_loop_detection = None;
93        self
94    }
95
96    pub fn build(self) -> AgentConfig {
97        AgentConfig {
98            name: self.name,
99            system_prompt: self.system_prompt,
100            tools: self.tools,
101            max_tool_iterations: self.max_tool_iterations,
102            tool_loop_detection: self.tool_loop_detection,
103        }
104    }
105}
106
107/// Agent execution unit
108pub struct Agent {
109    config: AgentConfig,
110    llm_client: Option<Arc<dyn ChatClient>>,
111}
112
113impl Agent {
114    pub fn new(config: AgentConfig) -> Self {
115        Self {
116            config,
117            llm_client: None,
118        }
119    }
120
121    pub fn with_llm_client(mut self, client: Arc<dyn ChatClient>) -> Self {
122        self.llm_client = Some(client);
123        self
124    }
125
126    pub fn name(&self) -> &str {
127        &self.config.name
128    }
129
130    pub fn config(&self) -> &AgentConfig {
131        &self.config
132    }
133
134    /// Execute the agent with the given input
135    pub async fn execute(&self, input: &AgentInput) -> AgentResult {
136        self.execute_with_events(input.clone(), None).await
137    }
138
139    /// Execute the agent with event stream for observability
140    pub async fn execute_with_events(
141        &self,
142        input: AgentInput,
143        event_stream: Option<&EventStream>,
144    ) -> AgentResult {
145        let start = std::time::Instant::now();
146
147        // Emit agent processing event
148        if let Some(stream) = event_stream {
149            stream.append(
150                EventType::AgentProcessing,
151                input
152                    .metadata
153                    .previous_agent
154                    .clone()
155                    .unwrap_or_else(|| "workflow".to_string()),
156                serde_json::json!({
157                    "agent": self.config.name,
158                    "input": input.data,
159                }),
160            );
161        }
162
163        // If we have an LLM client, use it
164        if let Some(client) = &self.llm_client {
165            // Build messages from chat_history OR from input data
166            let messages = if let Some(history) = &input.chat_history {
167                // Use provided chat history as-is
168                // Outer layer is managing the conversation context
169                history.clone()
170            } else {
171                // Build messages from scratch (legacy behavior)
172                let user_message = if let Some(s) = input.data.as_str() {
173                    s.to_string()
174                } else {
175                    serde_json::to_string_pretty(&input.data).unwrap_or_default()
176                };
177
178                vec![
179                    ChatMessage::system(&self.config.system_prompt),
180                    ChatMessage::user(&user_message),
181                ]
182            };
183
184            let mut request = ChatRequest::new(messages.clone())
185                .with_temperature(0.7)
186                .with_max_tokens(8192);
187
188            // Get tool schemas if available
189            let tool_schemas = self
190                .config
191                .tools
192                .as_ref()
193                .map(|registry| registry.list_tools())
194                .filter(|tools| !tools.is_empty());
195
196            // Tool calling loop
197            let mut iteration = 0;
198            let mut total_tool_calls = 0;
199
200            // Initialize tool call tracker for loop detection
201            let mut tool_tracker = if self.config.tool_loop_detection.is_some() {
202                Some(ToolCallTracker::new())
203            } else {
204                None
205            };
206
207            loop {
208                iteration += 1;
209
210                // Check iteration limit
211                if iteration > self.config.max_tool_iterations {
212                    return Err(AgentError::ExecutionError(format!(
213                        "Maximum tool iterations ({}) exceeded",
214                        self.config.max_tool_iterations
215                    )));
216                }
217
218                // Add tools to request if available
219                if let Some(ref schemas) = tool_schemas {
220                    request = request.with_tools(schemas.clone());
221                }
222
223                // Emit LLM request started event
224                if let Some(stream) = event_stream {
225                    stream.append(
226                        EventType::AgentLlmRequestStarted,
227                        input
228                            .metadata
229                            .previous_agent
230                            .clone()
231                            .unwrap_or_else(|| "workflow".to_string()),
232                        serde_json::json!({
233                            "agent": self.config.name,
234                            "iteration": iteration,
235                        }),
236                    );
237                }
238
239                // Call LLM with streaming + full response (for tool calls)
240                let event_stream_for_streaming = event_stream.cloned();
241                let agent_name = self.config.name.clone();
242                let previous_agent = input
243                    .metadata
244                    .previous_agent
245                    .clone()
246                    .unwrap_or_else(|| "workflow".to_string());
247
248                // Create channel for streaming chunks
249                let (chunk_tx, mut chunk_rx) = tokio::sync::mpsc::channel(100);
250
251                // Spawn task to receive chunks and emit events
252                let _chunk_event_task = tokio::spawn(async move {
253                    while let Some(chunk) = chunk_rx.recv().await {
254                        if let Some(stream) = &event_stream_for_streaming {
255                            stream.append(
256                                EventType::AgentLlmStreamChunk,
257                                previous_agent.clone(),
258                                serde_json::json!({
259                                    "agent": &agent_name,
260                                    "chunk": chunk,
261                                }),
262                            );
263                        }
264                    }
265                });
266
267                match client.chat_stream(request.clone(), chunk_tx).await {
268                    Ok(response) => {
269                        // Emit LLM request completed event
270                        if let Some(stream) = event_stream {
271                            stream.append(
272                                EventType::AgentLlmRequestCompleted,
273                                input
274                                    .metadata
275                                    .previous_agent
276                                    .clone()
277                                    .unwrap_or_else(|| "workflow".to_string()),
278                                serde_json::json!({
279                                    "agent": self.config.name,
280                                }),
281                            );
282                        }
283
284                        // Check if we have tool calls (and they're not empty)
285                        if let Some(tool_calls) = response.tool_calls.clone() {
286                            if tool_calls.is_empty() {
287                                // Empty tool calls array - treat as final response
288                            } else {
289                                total_tool_calls += tool_calls.len();
290
291                                // Add assistant message with tool calls to conversation
292                                let assistant_msg = ChatMessage::assistant_with_tool_calls(
293                                    response.content.clone(),
294                                    tool_calls.clone(),
295                                );
296                                request.messages.push(assistant_msg);
297
298                                // Execute each tool call
299                                for tool_call in tool_calls {
300                                    // Check for duplicate tool call (loop detection)
301                                    if let (Some(tracker), Some(loop_config)) =
302                                        (&tool_tracker, &self.config.tool_loop_detection)
303                                    {
304                                        if loop_config.enabled {
305                                            // Parse tool arguments from JSON string
306                                            let args_value: serde_json::Value =
307                                                serde_json::from_str(&tool_call.function.arguments)
308                                                    .unwrap_or(serde_json::json!({}));
309
310                                            // Convert to HashMap for comparison
311                                            let args_map: HashMap<String, serde_json::Value> =
312                                                args_value
313                                                    .as_object()
314                                                    .map(|obj| {
315                                                        obj.iter()
316                                                            .map(|(k, v)| (k.clone(), v.clone()))
317                                                            .collect()
318                                                    })
319                                                    .unwrap_or_default();
320
321                                            if let Some(previous_result) = tracker
322                                                .check_for_loop(&tool_call.function.name, &args_map)
323                                            {
324                                                // Loop detected! Inject message instead of calling tool
325                                                let loop_message = loop_config.get_message(
326                                                    &tool_call.function.name,
327                                                    &previous_result,
328                                                );
329
330                                                // Emit loop detected event
331                                                if let Some(stream) = event_stream {
332                                                    stream.append(
333                                                        EventType::AgentToolLoopDetected,
334                                                        input
335                                                            .metadata
336                                                            .previous_agent
337                                                            .clone()
338                                                            .unwrap_or_else(|| {
339                                                                "workflow".to_string()
340                                                            }),
341                                                        serde_json::json!({
342                                                            "agent": self.config.name,
343                                                            "tool": tool_call.function.name,
344                                                            "message": loop_message,
345                                                        }),
346                                                    );
347                                                }
348
349                                                // Add system message explaining the loop
350                                                let tool_msg = ChatMessage::tool_result(
351                                                    &tool_call.id,
352                                                    &loop_message,
353                                                );
354                                                request.messages.push(tool_msg);
355
356                                                // Skip actual tool execution
357                                                continue;
358                                            }
359                                        }
360                                    }
361
362                                    // No loop detected - execute the tool normally
363                                    let tool_result = self
364                                        .execute_tool_call(
365                                            &tool_call,
366                                            &input
367                                                .metadata
368                                                .previous_agent
369                                                .clone()
370                                                .unwrap_or_else(|| "workflow".to_string()),
371                                            event_stream,
372                                        )
373                                        .await;
374
375                                    // Record this call in the tracker
376                                    if let Some(tracker) = &mut tool_tracker {
377                                        // Parse tool arguments from JSON string
378                                        let args_value: serde_json::Value =
379                                            serde_json::from_str(&tool_call.function.arguments)
380                                                .unwrap_or(serde_json::json!({}));
381
382                                        // Convert to HashMap
383                                        let args_map: HashMap<String, serde_json::Value> =
384                                            args_value
385                                                .as_object()
386                                                .map(|obj| {
387                                                    obj.iter()
388                                                        .map(|(k, v)| (k.clone(), v.clone()))
389                                                        .collect()
390                                                })
391                                                .unwrap_or_default();
392
393                                        let result_json = serde_json::to_value(&tool_result)
394                                            .unwrap_or(serde_json::json!({}));
395                                        tracker.record_call(
396                                            &tool_call.function.name,
397                                            &args_map,
398                                            &result_json,
399                                        );
400                                    }
401
402                                    // Add tool result to conversation
403                                    let tool_msg =
404                                        ChatMessage::tool_result(&tool_call.id, &tool_result);
405                                    request.messages.push(tool_msg);
406                                }
407
408                                // Continue loop to get next response
409                                continue;
410                            }
411                        }
412
413                        // No tool calls (or empty array), we have the final response
414                        let response_text = response.content.trim();
415                        let token_count = response
416                            .usage
417                            .map(|u| u.total_tokens)
418                            .unwrap_or_else(|| (response_text.len() as f32 / 4.0).ceil() as u32);
419
420                        let output_data = serde_json::json!({
421                            "response": response_text,
422                            "content_type": "text/plain",
423                            "token_count": token_count,
424                        });
425
426                        // Add final assistant response to chat history
427                        request.messages.push(ChatMessage::assistant(response_text));
428
429                        // Emit agent completed event
430                        if let Some(stream) = event_stream {
431                            stream.append(
432                                EventType::AgentCompleted,
433                                input
434                                    .metadata
435                                    .previous_agent
436                                    .clone()
437                                    .unwrap_or_else(|| "workflow".to_string()),
438                                serde_json::json!({
439                                    "agent": self.config.name,
440                                    "execution_time_ms": start.elapsed().as_millis() as u64,
441                                }),
442                            );
443                        }
444
445                        return Ok(AgentOutput {
446                            data: output_data,
447                            metadata: AgentOutputMetadata {
448                                agent_name: self.config.name.clone(),
449                                execution_time_ms: start.elapsed().as_millis() as u64,
450                                tool_calls_count: total_tool_calls,
451                            },
452                            chat_history: Some(request.messages),
453                        });
454                    }
455                    Err(e) => {
456                        // Emit LLM request failed event
457                        if let Some(stream) = event_stream {
458                            stream.append(
459                                EventType::AgentLlmRequestFailed,
460                                input
461                                    .metadata
462                                    .previous_agent
463                                    .clone()
464                                    .unwrap_or_else(|| "workflow".to_string()),
465                                serde_json::json!({
466                                    "agent": self.config.name,
467                                    "error": e.to_string(),
468                                }),
469                            );
470                        }
471
472                        // Emit agent failed event
473                        if let Some(stream) = event_stream {
474                            stream.append(
475                                EventType::AgentFailed,
476                                input
477                                    .metadata
478                                    .previous_agent
479                                    .clone()
480                                    .unwrap_or_else(|| "workflow".to_string()),
481                                serde_json::json!({
482                                    "agent": self.config.name,
483                                    "error": e.to_string(),
484                                }),
485                            );
486                        }
487
488                        return Err(AgentError::ExecutionError(format!(
489                            "LLM call failed: {}",
490                            e
491                        )));
492                    }
493                }
494            }
495        } else {
496            // Mock execution fallback
497            let output_data = serde_json::json!({
498                "agent": self.config.name,
499                "processed": input.data,
500                "system_prompt": self.config.system_prompt,
501                "note": "Mock execution - no LLM client configured"
502            });
503
504            if let Some(stream) = event_stream {
505                stream.append(
506                    EventType::AgentCompleted,
507                    input
508                        .metadata
509                        .previous_agent
510                        .clone()
511                        .unwrap_or_else(|| "workflow".to_string()),
512                    serde_json::json!({
513                        "agent": self.config.name,
514                        "execution_time_ms": start.elapsed().as_millis() as u64,
515                        "mock": true,
516                    }),
517                );
518            }
519
520            Ok(AgentOutput {
521                data: output_data,
522                metadata: AgentOutputMetadata {
523                    agent_name: self.config.name.clone(),
524                    execution_time_ms: start.elapsed().as_millis() as u64,
525                    tool_calls_count: 0,
526                },
527                chat_history: None, // No LLM client means no chat history
528            })
529        }
530    }
531
532    /// Execute a single tool call
533    async fn execute_tool_call(
534        &self,
535        tool_call: &ToolCall,
536        previous_agent: &str,
537        event_stream: Option<&EventStream>,
538    ) -> String {
539        let tool_name = &tool_call.function.name;
540
541        // Emit tool call started event
542        if let Some(stream) = event_stream {
543            stream.append(
544                EventType::ToolCallStarted,
545                previous_agent.to_string(),
546                serde_json::json!({
547                    "agent": self.config.name,
548                    "tool": tool_name,
549                    "tool_call_id": tool_call.id,
550                    "arguments": tool_call.function.arguments,
551                }),
552            );
553        }
554
555        // Get the tool registry
556        let registry = match &self.config.tools {
557            Some(reg) => reg,
558            None => {
559                let error_msg = "No tool registry configured".to_string();
560                if let Some(stream) = event_stream {
561                    stream.append(
562                        EventType::ToolCallFailed,
563                        previous_agent.to_string(),
564                        serde_json::json!({
565                            "agent": self.config.name,
566                            "tool": tool_name,
567                            "tool_call_id": tool_call.id,
568                            "arguments": tool_call.function.arguments,
569                            "error": error_msg,
570                            "duration_ms": 0,
571                        }),
572                    );
573                }
574                return format!("Error: {}", error_msg);
575            }
576        };
577
578        // Parse arguments from JSON string
579        let params: HashMap<String, serde_json::Value> =
580            match serde_json::from_str(&tool_call.function.arguments) {
581                Ok(p) => p,
582                Err(e) => {
583                    let error_msg = format!("Failed to parse tool arguments: {}", e);
584                    if let Some(stream) = event_stream {
585                        stream.append(
586                            EventType::ToolCallFailed,
587                            previous_agent.to_string(),
588                            serde_json::json!({
589                                "agent": self.config.name,
590                                "tool": tool_name,
591                                "tool_call_id": tool_call.id,
592                                "arguments": tool_call.function.arguments,
593                                "error": error_msg,
594                                "duration_ms": 0,
595                            }),
596                        );
597                    }
598                    return format!("Error: {}", error_msg);
599                }
600            };
601
602        // Execute the tool
603        let start_time = std::time::Instant::now();
604        match registry.call_tool(tool_name, params.clone()).await {
605            Ok(result) => {
606                // Emit tool call completed event
607                if let Some(stream) = event_stream {
608                    stream.append(
609                        EventType::ToolCallCompleted,
610                        previous_agent.to_string(),
611                        serde_json::json!({
612                            "agent": self.config.name,
613                            "tool": tool_name,
614                            "tool_call_id": tool_call.id,
615                            "arguments": params,
616                            "result": result.output,
617                            "duration_ms": (result.duration_ms * 1000.0).round() / 1000.0,
618                        }),
619                    );
620                }
621
622                // Convert result to string for LLM
623                serde_json::to_string(&result.output).unwrap_or_else(|_| result.output.to_string())
624            }
625            Err(e) => {
626                let error_msg = format!("Tool execution failed: {}", e);
627                if let Some(stream) = event_stream {
628                    stream.append(
629                        EventType::ToolCallFailed,
630                        previous_agent.to_string(),
631                        serde_json::json!({
632                            "agent": self.config.name,
633                            "tool": tool_name,
634                            "tool_call_id": tool_call.id,
635                            "arguments": params,
636                            "error": error_msg,
637                            "duration_ms": start_time.elapsed().as_secs_f64() * 1000.0,
638                        }),
639                    );
640                }
641                format!("Error: {}", error_msg)
642            }
643        }
644    }
645}