1#![allow(dead_code)]
8
9use crate::agency::agent::{Agent, AgentStatus};
10use crate::agency::error::{AgencyError, AgencyResult};
11use crate::agency::models::{
12 AgencyEvent, AgencyMessage, EventType, MessageRole, TokenUsage, ToolCall, ToolResult,
13};
14use crate::agency::session::{generate_message_id, Session};
15use crate::agency::tools::ToolRegistry;
16use chrono::Utc;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20use tokio::sync::mpsc;
21
22#[derive(Debug, Clone)]
24pub struct ExecutionContext {
25 pub session_id: String,
27 pub agent_name: String,
29 pub user_id: Option<String>,
31 pub state: HashMap<String, serde_json::Value>,
33 pub allow_tools: bool,
35 pub max_tool_calls: u32,
37 pub event_sender: Option<mpsc::Sender<AgencyEvent>>,
39}
40
41impl ExecutionContext {
42 pub fn new(session: &Session) -> Self {
43 Self {
44 session_id: session.id.clone(),
45 agent_name: session.agent_name.clone(),
46 user_id: session.user_id.clone(),
47 state: session.state.data.clone(),
48 allow_tools: true,
49 max_tool_calls: 10,
50 event_sender: None,
51 }
52 }
53
54 pub async fn emit(&self, event: AgencyEvent) {
56 if let Some(sender) = &self.event_sender {
57 let _ = sender.send(event).await;
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ExecutionResult {
65 pub response: String,
67 pub messages: Vec<AgencyMessage>,
69 pub events: Vec<AgencyEvent>,
71 pub token_usage: TokenUsage,
73 pub duration_ms: u64,
75 pub success: bool,
77 pub error: Option<String>,
79}
80
81pub struct Executor {
83 tool_registry: Arc<ToolRegistry>,
84}
85
86impl Executor {
87 pub fn new(tool_registry: Arc<ToolRegistry>) -> Self {
89 Self { tool_registry }
90 }
91
92 pub async fn execute(
94 &self,
95 agent: &Agent,
96 session: &mut Session,
97 user_message: &str,
98 ctx: &mut ExecutionContext,
99 ) -> AgencyResult<ExecutionResult> {
100 let start_time = std::time::Instant::now();
101 let mut messages = Vec::new();
102 let mut events = Vec::new();
103 let mut token_usage = TokenUsage::default();
104
105 agent.set_status(AgentStatus::Thinking);
107
108 let start_event = AgencyEvent {
110 event_type: EventType::AgentStarted,
111 agent_name: agent.name().to_string(),
112 data: serde_json::json!({ "message": user_message }),
113 timestamp: Utc::now(),
114 session_id: Some(session.id.clone()),
115 };
116 events.push(start_event.clone());
117 ctx.emit(start_event).await;
118
119 let user_msg = AgencyMessage {
121 id: generate_message_id(),
122 role: MessageRole::User,
123 content: user_message.to_string(),
124 tool_calls: vec![],
125 tool_result: None,
126 timestamp: Utc::now(),
127 tokens: None,
128 agent_name: Some(agent.name().to_string()),
129 metadata: HashMap::new(),
130 };
131 session.add_message(user_msg.clone());
132 messages.push(user_msg);
133
134 let mut tool_call_count = 0;
136 #[allow(unused_assignments)]
137 let mut final_response = String::new();
138
139 loop {
140 agent.set_status(AgentStatus::Thinking);
142 let thinking_event = AgencyEvent {
143 event_type: EventType::AgentThinking,
144 agent_name: agent.name().to_string(),
145 data: serde_json::json!({}),
146 timestamp: Utc::now(),
147 session_id: Some(session.id.clone()),
148 };
149 events.push(thinking_event.clone());
150 ctx.emit(thinking_event).await;
151
152 let model_response = self.call_model(agent, session).await?;
154
155 token_usage.add(&model_response.usage);
156
157 if !model_response.tool_calls.is_empty() && ctx.allow_tools {
159 agent.set_status(AgentStatus::WaitingForTool);
160
161 for tool_call in &model_response.tool_calls {
162 tool_call_count += 1;
163 if tool_call_count > ctx.max_tool_calls {
164 return Err(AgencyError::MaxIterationsExceeded(ctx.max_tool_calls));
165 }
166
167 let call_event = AgencyEvent {
169 event_type: EventType::ToolCallStarted,
170 agent_name: agent.name().to_string(),
171 data: serde_json::json!({
172 "tool": tool_call.name,
173 "arguments": tool_call.arguments
174 }),
175 timestamp: Utc::now(),
176 session_id: Some(session.id.clone()),
177 };
178 events.push(call_event.clone());
179 ctx.emit(call_event).await;
180
181 agent.set_status(AgentStatus::Executing);
183 let tool_result = self.execute_tool(tool_call).await;
184
185 let result_event = AgencyEvent {
187 event_type: EventType::ToolCallCompleted,
188 agent_name: agent.name().to_string(),
189 data: serde_json::json!({
190 "tool": tool_call.name,
191 "success": tool_result.success,
192 "content": tool_result.content
193 }),
194 timestamp: Utc::now(),
195 session_id: Some(session.id.clone()),
196 };
197 events.push(result_event.clone());
198 ctx.emit(result_event).await;
199
200 let tool_msg = AgencyMessage {
202 id: generate_message_id(),
203 role: MessageRole::Tool,
204 content: tool_result.content.clone(),
205 tool_calls: vec![],
206 tool_result: Some(tool_result),
207 timestamp: Utc::now(),
208 tokens: None,
209 agent_name: Some(agent.name().to_string()),
210 metadata: HashMap::new(),
211 };
212 session.add_message(tool_msg.clone());
213 messages.push(tool_msg);
214 }
215
216 continue;
218 }
219
220 final_response = model_response.content.clone();
222
223 let assistant_msg = AgencyMessage {
225 id: generate_message_id(),
226 role: MessageRole::Assistant,
227 content: model_response.content,
228 tool_calls: model_response.tool_calls,
229 tool_result: None,
230 timestamp: Utc::now(),
231 tokens: Some(model_response.usage.completion_tokens),
232 agent_name: Some(agent.name().to_string()),
233 metadata: HashMap::new(),
234 };
235 session.add_message(assistant_msg.clone());
236 messages.push(assistant_msg);
237
238 break;
239 }
240
241 agent.set_status(AgentStatus::Completed);
243 let end_event = AgencyEvent {
244 event_type: EventType::AgentCompleted,
245 agent_name: agent.name().to_string(),
246 data: serde_json::json!({ "response": final_response }),
247 timestamp: Utc::now(),
248 session_id: Some(session.id.clone()),
249 };
250 events.push(end_event.clone());
251 ctx.emit(end_event).await;
252
253 Ok(ExecutionResult {
254 response: final_response,
255 messages,
256 events,
257 token_usage,
258 duration_ms: start_time.elapsed().as_millis() as u64,
259 success: true,
260 error: None,
261 })
262 }
263
264 async fn call_model(&self, agent: &Agent, session: &Session) -> AgencyResult<ModelResponse> {
266 use crate::agency::models::ModelProvider;
267
268 let messages = session.to_api_messages();
269 let tools = agent.tool_definitions();
270 let model_config = agent.model();
271
272 let mut request_body = serde_json::json!({
274 "model": model_config.model,
275 "messages": messages,
276 "temperature": model_config.temperature,
277 });
278
279 if let Some(max_tokens) = model_config.max_tokens {
280 request_body["max_tokens"] = serde_json::json!(max_tokens);
281 }
282
283 if !tools.is_empty() {
284 request_body["tools"] = serde_json::json!(tools);
285 }
286
287 let endpoint = match model_config.provider {
289 ModelProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
291 ModelProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
292 ModelProvider::Google => format!(
293 "https://generativelanguage.googleapis.com/v1/models/{}:generateContent",
294 model_config.model
295 ),
296 ModelProvider::Groq => "https://api.groq.com/openai/v1/chat/completions".to_string(),
297 ModelProvider::Together => "https://api.together.xyz/v1/chat/completions".to_string(),
298 ModelProvider::Fireworks => {
299 "https://api.fireworks.ai/inference/v1/chat/completions".to_string()
300 }
301 ModelProvider::DeepSeek => "https://api.deepseek.com/v1/chat/completions".to_string(),
302 ModelProvider::Mistral => "https://api.mistral.ai/v1/chat/completions".to_string(),
303 ModelProvider::Cohere => "https://api.cohere.ai/v1/chat".to_string(),
304 ModelProvider::Perplexity => "https://api.perplexity.ai/chat/completions".to_string(),
305 ModelProvider::Azure => model_config.endpoint.clone().unwrap_or_default(),
306
307 ModelProvider::Ollama => model_config
309 .endpoint
310 .clone()
311 .unwrap_or_else(|| "http://localhost:11434/api/chat".to_string()),
312 ModelProvider::LMStudio => model_config
313 .endpoint
314 .clone()
315 .unwrap_or_else(|| "http://localhost:1234/v1/chat/completions".to_string()),
316 ModelProvider::Jan => model_config
317 .endpoint
318 .clone()
319 .unwrap_or_else(|| "http://localhost:1337/v1/chat/completions".to_string()),
320 ModelProvider::GPT4All => model_config
321 .endpoint
322 .clone()
323 .unwrap_or_else(|| "http://localhost:4891/v1/chat/completions".to_string()),
324 ModelProvider::LocalAI => model_config
325 .endpoint
326 .clone()
327 .unwrap_or_else(|| "http://localhost:8080/v1/chat/completions".to_string()),
328 ModelProvider::Llamafile => model_config
329 .endpoint
330 .clone()
331 .unwrap_or_else(|| "http://localhost:8080/v1/chat/completions".to_string()),
332 ModelProvider::TextGenWebUI => model_config
333 .endpoint
334 .clone()
335 .unwrap_or_else(|| "http://localhost:5000/v1/chat/completions".to_string()),
336 ModelProvider::VLLM => model_config
337 .endpoint
338 .clone()
339 .unwrap_or_else(|| "http://localhost:8000/v1/chat/completions".to_string()),
340 ModelProvider::KoboldCpp => model_config
341 .endpoint
342 .clone()
343 .unwrap_or_else(|| "http://localhost:5001/v1/chat/completions".to_string()),
344 ModelProvider::TabbyML => model_config
345 .endpoint
346 .clone()
347 .unwrap_or_else(|| "http://localhost:8080/v1/chat/completions".to_string()),
348 ModelProvider::Exo => model_config
349 .endpoint
350 .clone()
351 .unwrap_or_else(|| "http://localhost:52415/v1/chat/completions".to_string()),
352
353 ModelProvider::OpenAICompatible | ModelProvider::Custom => model_config
355 .endpoint
356 .clone()
357 .unwrap_or_else(|| "http://localhost:8080/v1/chat/completions".to_string()),
358 };
359
360 if endpoint.is_empty() {
361 return Err(AgencyError::ConfigError(
362 "No endpoint configured for model provider".to_string(),
363 ));
364 }
365
366 let client = reqwest::Client::new();
368 let mut request = client.post(&endpoint).json(&request_body);
369
370 if let Some(api_key) = &model_config.api_key {
372 match model_config.provider {
373 ModelProvider::Anthropic => {
374 request = request.header("x-api-key", api_key);
375 request = request.header("anthropic-version", "2023-06-01");
376 }
377 ModelProvider::Google => {
378 request = client
380 .post(format!("{}?key={}", endpoint, api_key))
381 .json(&request_body);
382 }
383 _ => {
384 request = request.header("Authorization", format!("Bearer {}", api_key));
385 }
386 }
387 }
388
389 let response = request
390 .send()
391 .await
392 .map_err(|e| AgencyError::NetworkError(format!("HTTP request failed: {}", e)))?;
393
394 if !response.status().is_success() {
395 let status = response.status();
396 let body: String = response.text().await.unwrap_or_default();
397 return Err(AgencyError::ModelError(format!(
398 "Model API error ({}): {}",
399 status, body
400 )));
401 }
402
403 let response_body: serde_json::Value = response
404 .json()
405 .await
406 .map_err(|e| AgencyError::ModelError(format!("Failed to parse response: {}", e)))?;
407
408 let (content, tool_calls, usage) =
410 Self::parse_model_response(&response_body, &model_config.provider)?;
411
412 Ok(ModelResponse {
413 content,
414 tool_calls,
415 usage,
416 })
417 }
418
419 fn parse_model_response(
421 response: &serde_json::Value,
422 provider: &crate::agency::models::ModelProvider,
423 ) -> AgencyResult<(String, Vec<ToolCall>, TokenUsage)> {
424 use crate::agency::models::ModelProvider;
425
426 match provider {
427 ModelProvider::Anthropic => {
428 let content = response["content"][0]["text"]
430 .as_str()
431 .unwrap_or("")
432 .to_string();
433 let usage = TokenUsage::new(
434 response["usage"]["input_tokens"].as_u64().unwrap_or(0) as u32,
435 response["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32,
436 );
437 let mut tool_calls = vec![];
439 if let Some(content_blocks) = response["content"].as_array() {
440 for block in content_blocks {
441 if block["type"].as_str() == Some("tool_use") {
442 tool_calls.push(ToolCall {
443 id: block["id"].as_str().unwrap_or("").to_string(),
444 name: block["name"].as_str().unwrap_or("").to_string(),
445 arguments: block["input"].clone(),
446 timestamp: Utc::now(),
447 });
448 }
449 }
450 }
451 Ok((content, tool_calls, usage))
452 }
453 ModelProvider::Google => {
454 let content = response["candidates"][0]["content"]["parts"][0]["text"]
456 .as_str()
457 .unwrap_or("")
458 .to_string();
459 let usage = TokenUsage::new(
460 response["usageMetadata"]["promptTokenCount"]
461 .as_u64()
462 .unwrap_or(0) as u32,
463 response["usageMetadata"]["candidatesTokenCount"]
464 .as_u64()
465 .unwrap_or(0) as u32,
466 );
467 let mut tool_calls = vec![];
469 if let Some(parts) = response["candidates"][0]["content"]["parts"].as_array() {
470 for part in parts {
471 if let Some(fn_call) = part.get("functionCall") {
472 tool_calls.push(ToolCall {
473 id: uuid::Uuid::new_v4().to_string(),
474 name: fn_call["name"].as_str().unwrap_or("").to_string(),
475 arguments: fn_call["args"].clone(),
476 timestamp: Utc::now(),
477 });
478 }
479 }
480 }
481 Ok((content, tool_calls, usage))
482 }
483 _ => {
484 let choice = &response["choices"][0];
486 let content = choice["message"]["content"]
487 .as_str()
488 .unwrap_or("")
489 .to_string();
490
491 let mut tool_calls = vec![];
492 if let Some(calls) = choice["message"]["tool_calls"].as_array() {
493 for call in calls {
494 tool_calls.push(ToolCall {
495 id: call["id"].as_str().unwrap_or("").to_string(),
496 name: call["function"]["name"].as_str().unwrap_or("").to_string(),
497 arguments: serde_json::from_str(
498 call["function"]["arguments"].as_str().unwrap_or("{}"),
499 )
500 .unwrap_or_default(),
501 timestamp: Utc::now(),
502 });
503 }
504 }
505
506 let usage = TokenUsage::new(
507 response["usage"]["prompt_tokens"].as_u64().unwrap_or(0) as u32,
508 response["usage"]["completion_tokens"].as_u64().unwrap_or(0) as u32,
509 );
510
511 Ok((content, tool_calls, usage))
512 }
513 }
514 }
515
516 async fn execute_tool(&self, tool_call: &ToolCall) -> ToolResult {
518 let start = std::time::Instant::now();
519
520 if let Some(executor) = self.tool_registry.get_executor(&tool_call.name) {
522 match executor.execute(tool_call.arguments.clone()).await {
523 Ok(result) => result,
524 Err(e) => ToolResult {
525 call_id: tool_call.id.clone(),
526 name: tool_call.name.clone(),
527 success: false,
528 content: format!("Tool execution failed: {}", e),
529 duration_ms: start.elapsed().as_millis() as u64,
530 data: None,
531 },
532 }
533 } else {
534 ToolResult {
536 call_id: tool_call.id.clone(),
537 name: tool_call.name.clone(),
538 success: false,
539 content: format!("Tool '{}' not found in registry", tool_call.name),
540 duration_ms: start.elapsed().as_millis() as u64,
541 data: None,
542 }
543 }
544 }
545}
546
547struct ModelResponse {
549 content: String,
550 tool_calls: Vec<ToolCall>,
551 usage: TokenUsage,
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use crate::agency::agent::AgentBuilder;
558
559 #[tokio::test]
560 #[ignore = "Integration test - requires API credentials"]
561 async fn test_executor() {
562 let tool_registry = Arc::new(ToolRegistry::new());
563 let executor = Executor::new(tool_registry);
564
565 let mut agent = AgentBuilder::new("test_agent")
566 .description("Test agent")
567 .instruction("You are a helpful assistant.")
568 .model("gemini-2.5-flash")
569 .build();
570
571 let mut session = Session::new("test_agent", None);
572 let mut ctx = ExecutionContext::new(&session);
573
574 let result = executor
575 .execute(&mut agent, &mut session, "Hello!", &mut ctx)
576 .await
577 .unwrap();
578
579 assert!(result.success);
580 assert!(!result.response.is_empty());
581 assert!(!result.messages.is_empty());
582 }
583}