chasm_cli/agency/
executor.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: Apache-2.0
3//! Agent Execution
4//!
5//! Handles the execution of individual agents with tool calling.
6
7#![allow(dead_code)]
8
9use crate::agency::agent::{Agent, AgentStatus};
10use crate::agency::error::{AgencyError, AgencyResult};
11use crate::agency::models::{
12    AgencyEvent, AgencyMessage, EventType, MessageRole, TokenUsage, ToolCall, ToolResult,
13};
14use crate::agency::session::{generate_message_id, Session};
15use crate::agency::tools::ToolRegistry;
16use chrono::Utc;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20use tokio::sync::mpsc;
21
22/// Execution context passed to tools
23#[derive(Debug, Clone)]
24pub struct ExecutionContext {
25    /// Session ID
26    pub session_id: String,
27    /// Agent name
28    pub agent_name: String,
29    /// User ID
30    pub user_id: Option<String>,
31    /// Current state
32    pub state: HashMap<String, serde_json::Value>,
33    /// Whether to allow tool execution
34    pub allow_tools: bool,
35    /// Maximum tool calls per turn
36    pub max_tool_calls: u32,
37    /// Event sender for streaming
38    pub event_sender: Option<mpsc::Sender<AgencyEvent>>,
39}
40
41impl ExecutionContext {
42    pub fn new(session: &Session) -> Self {
43        Self {
44            session_id: session.id.clone(),
45            agent_name: session.agent_name.clone(),
46            user_id: session.user_id.clone(),
47            state: session.state.data.clone(),
48            allow_tools: true,
49            max_tool_calls: 10,
50            event_sender: None,
51        }
52    }
53
54    /// Send an event to listeners
55    pub async fn emit(&self, event: AgencyEvent) {
56        if let Some(sender) = &self.event_sender {
57            let _ = sender.send(event).await;
58        }
59    }
60}
61
62/// Result of agent execution
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ExecutionResult {
65    /// Final response text
66    pub response: String,
67    /// Messages generated during execution
68    pub messages: Vec<AgencyMessage>,
69    /// Events emitted
70    pub events: Vec<AgencyEvent>,
71    /// Token usage
72    pub token_usage: TokenUsage,
73    /// Execution duration in milliseconds
74    pub duration_ms: u64,
75    /// Whether execution completed successfully
76    pub success: bool,
77    /// Error message if failed
78    pub error: Option<String>,
79}
80
81/// Agent executor handles running an agent with optional tool calling
82pub struct Executor {
83    tool_registry: Arc<ToolRegistry>,
84}
85
86impl Executor {
87    /// Create a new executor with the given tool registry
88    pub fn new(tool_registry: Arc<ToolRegistry>) -> Self {
89        Self { tool_registry }
90    }
91
92    /// Execute an agent with a user message
93    pub async fn execute(
94        &self,
95        agent: &Agent,
96        session: &mut Session,
97        user_message: &str,
98        ctx: &mut ExecutionContext,
99    ) -> AgencyResult<ExecutionResult> {
100        let start_time = std::time::Instant::now();
101        let mut messages = Vec::new();
102        let mut events = Vec::new();
103        let mut token_usage = TokenUsage::default();
104
105        // Set agent status
106        agent.set_status(AgentStatus::Thinking);
107
108        // Emit start event
109        let start_event = AgencyEvent {
110            event_type: EventType::AgentStarted,
111            agent_name: agent.name().to_string(),
112            data: serde_json::json!({ "message": user_message }),
113            timestamp: Utc::now(),
114            session_id: Some(session.id.clone()),
115        };
116        events.push(start_event.clone());
117        ctx.emit(start_event).await;
118
119        // Add user message
120        let user_msg = AgencyMessage {
121            id: generate_message_id(),
122            role: MessageRole::User,
123            content: user_message.to_string(),
124            tool_calls: vec![],
125            tool_result: None,
126            timestamp: Utc::now(),
127            tokens: None,
128            agent_name: Some(agent.name().to_string()),
129            metadata: HashMap::new(),
130        };
131        session.add_message(user_msg.clone());
132        messages.push(user_msg);
133
134        // Execute with tool loop
135        let mut tool_call_count = 0;
136        #[allow(unused_assignments)]
137        let mut final_response = String::new();
138
139        loop {
140            // Call the model
141            agent.set_status(AgentStatus::Thinking);
142            let thinking_event = AgencyEvent {
143                event_type: EventType::AgentThinking,
144                agent_name: agent.name().to_string(),
145                data: serde_json::json!({}),
146                timestamp: Utc::now(),
147                session_id: Some(session.id.clone()),
148            };
149            events.push(thinking_event.clone());
150            ctx.emit(thinking_event).await;
151
152            // Call the model with the current session context
153            let model_response = self.call_model(agent, session).await?;
154
155            token_usage.add(&model_response.usage);
156
157            // Check for tool calls
158            if !model_response.tool_calls.is_empty() && ctx.allow_tools {
159                agent.set_status(AgentStatus::WaitingForTool);
160
161                for tool_call in &model_response.tool_calls {
162                    tool_call_count += 1;
163                    if tool_call_count > ctx.max_tool_calls {
164                        return Err(AgencyError::MaxIterationsExceeded(ctx.max_tool_calls));
165                    }
166
167                    // Emit tool call event
168                    let call_event = AgencyEvent {
169                        event_type: EventType::ToolCallStarted,
170                        agent_name: agent.name().to_string(),
171                        data: serde_json::json!({
172                            "tool": tool_call.name,
173                            "arguments": tool_call.arguments
174                        }),
175                        timestamp: Utc::now(),
176                        session_id: Some(session.id.clone()),
177                    };
178                    events.push(call_event.clone());
179                    ctx.emit(call_event).await;
180
181                    // Execute tool
182                    agent.set_status(AgentStatus::Executing);
183                    let tool_result = self.execute_tool(tool_call).await;
184
185                    // Emit tool result event
186                    let result_event = AgencyEvent {
187                        event_type: EventType::ToolCallCompleted,
188                        agent_name: agent.name().to_string(),
189                        data: serde_json::json!({
190                            "tool": tool_call.name,
191                            "success": tool_result.success,
192                            "content": tool_result.content
193                        }),
194                        timestamp: Utc::now(),
195                        session_id: Some(session.id.clone()),
196                    };
197                    events.push(result_event.clone());
198                    ctx.emit(result_event).await;
199
200                    // Add tool result message
201                    let tool_msg = AgencyMessage {
202                        id: generate_message_id(),
203                        role: MessageRole::Tool,
204                        content: tool_result.content.clone(),
205                        tool_calls: vec![],
206                        tool_result: Some(tool_result),
207                        timestamp: Utc::now(),
208                        tokens: None,
209                        agent_name: Some(agent.name().to_string()),
210                        metadata: HashMap::new(),
211                    };
212                    session.add_message(tool_msg.clone());
213                    messages.push(tool_msg);
214                }
215
216                // Continue loop to get model response after tool results
217                continue;
218            }
219
220            // No tool calls - we have the final response
221            final_response = model_response.content.clone();
222
223            // Add assistant message
224            let assistant_msg = AgencyMessage {
225                id: generate_message_id(),
226                role: MessageRole::Assistant,
227                content: model_response.content,
228                tool_calls: model_response.tool_calls,
229                tool_result: None,
230                timestamp: Utc::now(),
231                tokens: Some(model_response.usage.completion_tokens),
232                agent_name: Some(agent.name().to_string()),
233                metadata: HashMap::new(),
234            };
235            session.add_message(assistant_msg.clone());
236            messages.push(assistant_msg);
237
238            break;
239        }
240
241        // Emit end event
242        agent.set_status(AgentStatus::Completed);
243        let end_event = AgencyEvent {
244            event_type: EventType::AgentCompleted,
245            agent_name: agent.name().to_string(),
246            data: serde_json::json!({ "response": final_response }),
247            timestamp: Utc::now(),
248            session_id: Some(session.id.clone()),
249        };
250        events.push(end_event.clone());
251        ctx.emit(end_event).await;
252
253        Ok(ExecutionResult {
254            response: final_response,
255            messages,
256            events,
257            token_usage,
258            duration_ms: start_time.elapsed().as_millis() as u64,
259            success: true,
260            error: None,
261        })
262    }
263
264    /// Call the model using the appropriate provider
265    async fn call_model(&self, agent: &Agent, session: &Session) -> AgencyResult<ModelResponse> {
266        use crate::agency::models::ModelProvider;
267
268        let messages = session.to_api_messages();
269        let tools = agent.tool_definitions();
270        let model_config = agent.model();
271
272        // Build request body
273        let mut request_body = serde_json::json!({
274            "model": model_config.model,
275            "messages": messages,
276            "temperature": model_config.temperature,
277        });
278
279        if let Some(max_tokens) = model_config.max_tokens {
280            request_body["max_tokens"] = serde_json::json!(max_tokens);
281        }
282
283        if !tools.is_empty() {
284            request_body["tools"] = serde_json::json!(tools);
285        }
286
287        // Determine endpoint based on provider
288        let endpoint = match model_config.provider {
289            // Cloud Providers
290            ModelProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
291            ModelProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
292            ModelProvider::Google => format!(
293                "https://generativelanguage.googleapis.com/v1/models/{}:generateContent",
294                model_config.model
295            ),
296            ModelProvider::Groq => "https://api.groq.com/openai/v1/chat/completions".to_string(),
297            ModelProvider::Together => "https://api.together.xyz/v1/chat/completions".to_string(),
298            ModelProvider::Fireworks => {
299                "https://api.fireworks.ai/inference/v1/chat/completions".to_string()
300            }
301            ModelProvider::DeepSeek => "https://api.deepseek.com/v1/chat/completions".to_string(),
302            ModelProvider::Mistral => "https://api.mistral.ai/v1/chat/completions".to_string(),
303            ModelProvider::Cohere => "https://api.cohere.ai/v1/chat".to_string(),
304            ModelProvider::Perplexity => "https://api.perplexity.ai/chat/completions".to_string(),
305            ModelProvider::Azure => model_config.endpoint.clone().unwrap_or_default(),
306
307            // Local Providers (OpenAI-compatible)
308            ModelProvider::Ollama => model_config
309                .endpoint
310                .clone()
311                .unwrap_or_else(|| "http://localhost:11434/api/chat".to_string()),
312            ModelProvider::LMStudio => model_config
313                .endpoint
314                .clone()
315                .unwrap_or_else(|| "http://localhost:1234/v1/chat/completions".to_string()),
316            ModelProvider::Jan => model_config
317                .endpoint
318                .clone()
319                .unwrap_or_else(|| "http://localhost:1337/v1/chat/completions".to_string()),
320            ModelProvider::GPT4All => model_config
321                .endpoint
322                .clone()
323                .unwrap_or_else(|| "http://localhost:4891/v1/chat/completions".to_string()),
324            ModelProvider::LocalAI => model_config
325                .endpoint
326                .clone()
327                .unwrap_or_else(|| "http://localhost:8080/v1/chat/completions".to_string()),
328            ModelProvider::Llamafile => model_config
329                .endpoint
330                .clone()
331                .unwrap_or_else(|| "http://localhost:8080/v1/chat/completions".to_string()),
332            ModelProvider::TextGenWebUI => model_config
333                .endpoint
334                .clone()
335                .unwrap_or_else(|| "http://localhost:5000/v1/chat/completions".to_string()),
336            ModelProvider::VLLM => model_config
337                .endpoint
338                .clone()
339                .unwrap_or_else(|| "http://localhost:8000/v1/chat/completions".to_string()),
340            ModelProvider::KoboldCpp => model_config
341                .endpoint
342                .clone()
343                .unwrap_or_else(|| "http://localhost:5001/v1/chat/completions".to_string()),
344            ModelProvider::TabbyML => model_config
345                .endpoint
346                .clone()
347                .unwrap_or_else(|| "http://localhost:8080/v1/chat/completions".to_string()),
348            ModelProvider::Exo => model_config
349                .endpoint
350                .clone()
351                .unwrap_or_else(|| "http://localhost:52415/v1/chat/completions".to_string()),
352
353            // Generic
354            ModelProvider::OpenAICompatible | ModelProvider::Custom => model_config
355                .endpoint
356                .clone()
357                .unwrap_or_else(|| "http://localhost:8080/v1/chat/completions".to_string()),
358        };
359
360        if endpoint.is_empty() {
361            return Err(AgencyError::ConfigError(
362                "No endpoint configured for model provider".to_string(),
363            ));
364        }
365
366        // Make HTTP request
367        let client = reqwest::Client::new();
368        let mut request = client.post(&endpoint).json(&request_body);
369
370        // Add authentication
371        if let Some(api_key) = &model_config.api_key {
372            match model_config.provider {
373                ModelProvider::Anthropic => {
374                    request = request.header("x-api-key", api_key);
375                    request = request.header("anthropic-version", "2023-06-01");
376                }
377                ModelProvider::Google => {
378                    // Google uses query parameter for API key
379                    request = client
380                        .post(format!("{}?key={}", endpoint, api_key))
381                        .json(&request_body);
382                }
383                _ => {
384                    request = request.header("Authorization", format!("Bearer {}", api_key));
385                }
386            }
387        }
388
389        let response = request
390            .send()
391            .await
392            .map_err(|e| AgencyError::NetworkError(format!("HTTP request failed: {}", e)))?;
393
394        if !response.status().is_success() {
395            let status = response.status();
396            let body: String = response.text().await.unwrap_or_default();
397            return Err(AgencyError::ModelError(format!(
398                "Model API error ({}): {}",
399                status, body
400            )));
401        }
402
403        let response_body: serde_json::Value = response
404            .json()
405            .await
406            .map_err(|e| AgencyError::ModelError(format!("Failed to parse response: {}", e)))?;
407
408        // Parse response based on provider format
409        let (content, tool_calls, usage) =
410            Self::parse_model_response(&response_body, &model_config.provider)?;
411
412        Ok(ModelResponse {
413            content,
414            tool_calls,
415            usage,
416        })
417    }
418
419    /// Parse model response based on provider format
420    fn parse_model_response(
421        response: &serde_json::Value,
422        provider: &crate::agency::models::ModelProvider,
423    ) -> AgencyResult<(String, Vec<ToolCall>, TokenUsage)> {
424        use crate::agency::models::ModelProvider;
425
426        match provider {
427            ModelProvider::Anthropic => {
428                // Anthropic format
429                let content = response["content"][0]["text"]
430                    .as_str()
431                    .unwrap_or("")
432                    .to_string();
433                let usage = TokenUsage::new(
434                    response["usage"]["input_tokens"].as_u64().unwrap_or(0) as u32,
435                    response["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32,
436                );
437                // Parse tool_use blocks for Anthropic
438                let mut tool_calls = vec![];
439                if let Some(content_blocks) = response["content"].as_array() {
440                    for block in content_blocks {
441                        if block["type"].as_str() == Some("tool_use") {
442                            tool_calls.push(ToolCall {
443                                id: block["id"].as_str().unwrap_or("").to_string(),
444                                name: block["name"].as_str().unwrap_or("").to_string(),
445                                arguments: block["input"].clone(),
446                                timestamp: Utc::now(),
447                            });
448                        }
449                    }
450                }
451                Ok((content, tool_calls, usage))
452            }
453            ModelProvider::Google => {
454                // Google Gemini format
455                let content = response["candidates"][0]["content"]["parts"][0]["text"]
456                    .as_str()
457                    .unwrap_or("")
458                    .to_string();
459                let usage = TokenUsage::new(
460                    response["usageMetadata"]["promptTokenCount"]
461                        .as_u64()
462                        .unwrap_or(0) as u32,
463                    response["usageMetadata"]["candidatesTokenCount"]
464                        .as_u64()
465                        .unwrap_or(0) as u32,
466                );
467                // Parse function calls for Google
468                let mut tool_calls = vec![];
469                if let Some(parts) = response["candidates"][0]["content"]["parts"].as_array() {
470                    for part in parts {
471                        if let Some(fn_call) = part.get("functionCall") {
472                            tool_calls.push(ToolCall {
473                                id: uuid::Uuid::new_v4().to_string(),
474                                name: fn_call["name"].as_str().unwrap_or("").to_string(),
475                                arguments: fn_call["args"].clone(),
476                                timestamp: Utc::now(),
477                            });
478                        }
479                    }
480                }
481                Ok((content, tool_calls, usage))
482            }
483            _ => {
484                // OpenAI-compatible format (OpenAI, Ollama, Azure, OpenAICompatible, Custom)
485                let choice = &response["choices"][0];
486                let content = choice["message"]["content"]
487                    .as_str()
488                    .unwrap_or("")
489                    .to_string();
490
491                let mut tool_calls = vec![];
492                if let Some(calls) = choice["message"]["tool_calls"].as_array() {
493                    for call in calls {
494                        tool_calls.push(ToolCall {
495                            id: call["id"].as_str().unwrap_or("").to_string(),
496                            name: call["function"]["name"].as_str().unwrap_or("").to_string(),
497                            arguments: serde_json::from_str(
498                                call["function"]["arguments"].as_str().unwrap_or("{}"),
499                            )
500                            .unwrap_or_default(),
501                            timestamp: Utc::now(),
502                        });
503                    }
504                }
505
506                let usage = TokenUsage::new(
507                    response["usage"]["prompt_tokens"].as_u64().unwrap_or(0) as u32,
508                    response["usage"]["completion_tokens"].as_u64().unwrap_or(0) as u32,
509                );
510
511                Ok((content, tool_calls, usage))
512            }
513        }
514    }
515
516    /// Execute a tool
517    async fn execute_tool(&self, tool_call: &ToolCall) -> ToolResult {
518        let start = std::time::Instant::now();
519
520        // Check if tool exists
521        if let Some(executor) = self.tool_registry.get_executor(&tool_call.name) {
522            match executor.execute(tool_call.arguments.clone()).await {
523                Ok(result) => result,
524                Err(e) => ToolResult {
525                    call_id: tool_call.id.clone(),
526                    name: tool_call.name.clone(),
527                    success: false,
528                    content: format!("Tool execution failed: {}", e),
529                    duration_ms: start.elapsed().as_millis() as u64,
530                    data: None,
531                },
532            }
533        } else {
534            // Tool not found - return error result
535            ToolResult {
536                call_id: tool_call.id.clone(),
537                name: tool_call.name.clone(),
538                success: false,
539                content: format!("Tool '{}' not found in registry", tool_call.name),
540                duration_ms: start.elapsed().as_millis() as u64,
541                data: None,
542            }
543        }
544    }
545}
546
547/// Response from model API
548struct ModelResponse {
549    content: String,
550    tool_calls: Vec<ToolCall>,
551    usage: TokenUsage,
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use crate::agency::agent::AgentBuilder;
558
559    #[tokio::test]
560    #[ignore = "Integration test - requires API credentials"]
561    async fn test_executor() {
562        let tool_registry = Arc::new(ToolRegistry::new());
563        let executor = Executor::new(tool_registry);
564
565        let mut agent = AgentBuilder::new("test_agent")
566            .description("Test agent")
567            .instruction("You are a helpful assistant.")
568            .model("gemini-2.5-flash")
569            .build();
570
571        let mut session = Session::new("test_agent", None);
572        let mut ctx = ExecutionContext::new(&session);
573
574        let result = executor
575            .execute(&mut agent, &mut session, "Hello!", &mut ctx)
576            .await
577            .unwrap();
578
579        assert!(result.success);
580        assert!(!result.response.is_empty());
581        assert!(!result.messages.is_empty());
582    }
583}