#[cfg(feature = "engine")]
use std::sync::Arc;
#[cfg(feature = "engine")]
use std::time::Duration;
#[cfg(feature = "engine")]
use async_trait::async_trait;
#[cfg(feature = "engine")]
use cortexai_agents::{AgentEngine, ExecutionTrace};
#[cfg(feature = "engine")]
use cortexai_core::tool::ToolRegistry;
#[cfg(feature = "engine")]
use cortexai_core::types::{AgentConfig, AgentRole, Task};
#[cfg(feature = "engine")]
use cortexai_crew::{Crew, CrewConfig};
#[cfg(feature = "engine")]
use cortexai_crew::process::Process;
#[cfg(feature = "engine")]
use cortexai_providers::LLMBackend;
#[cfg(feature = "engine")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "engine")]
use serde_json::json;
#[cfg(feature = "engine")]
use tracing::info;
#[cfg(feature = "engine")]
use crate::error::McpError;
#[cfg(feature = "engine")]
use crate::protocol::{CallToolResult, McpTool, ToolContent};
#[cfg(feature = "engine")]
use crate::server::ToolHandler;
#[cfg(feature = "engine")]
const DEFAULT_CREW_TIMEOUT_SECS: u64 = 60;
#[cfg(feature = "engine")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrewTaskInput {
pub description: String,
pub agent_role: String,
#[serde(default)]
pub expected_output: Option<String>,
}
#[cfg(feature = "engine")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunCrewInput {
pub crew_name: String,
pub tasks: Vec<CrewTaskInput>,
#[serde(default)]
pub process: Option<String>,
#[serde(default)]
pub max_iterations: Option<usize>,
#[serde(default)]
pub trace: bool,
}
#[cfg(feature = "engine")]
pub struct CrewEngineHandler {
engine: Arc<AgentEngine>,
backend: Arc<dyn LLMBackend>,
tool_registry: Arc<ToolRegistry>,
timeout: Duration,
}
#[cfg(feature = "engine")]
impl CrewEngineHandler {
pub fn new(
engine: Arc<AgentEngine>,
backend: Arc<dyn LLMBackend>,
tool_registry: Arc<ToolRegistry>,
) -> Self {
Self {
engine,
backend,
tool_registry,
timeout: Duration::from_secs(DEFAULT_CREW_TIMEOUT_SECS),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
fn parse_process(process_str: &str) -> Process {
match process_str.to_lowercase().as_str() {
"parallel" => Process::Parallel,
"hierarchical" => Process::Hierarchical,
_ => Process::Sequential,
}
}
async fn run_crew(&self, input: RunCrewInput) -> Result<CallToolResult, McpError> {
let trace_collector = if input.trace {
Some(Arc::new(ExecutionTrace::new()))
} else {
None
};
let start = std::time::Instant::now();
let process = input
.process
.as_deref()
.map(Self::parse_process)
.unwrap_or(Process::Sequential);
let crew_config = CrewConfig::new(&input.crew_name)
.with_description(format!("MCP crew: {}", input.crew_name))
.with_process(process);
let mut crew = Crew::new(crew_config, self.engine.clone());
let mut agent_ids = Vec::new();
for task_input in &input.tasks {
let max_iter = input.max_iterations.unwrap_or(10);
let agent_config = AgentConfig::new(
&task_input.agent_role,
AgentRole::Custom(task_input.agent_role.clone()),
)
.with_system_prompt(format!(
"You are a {} agent. Complete the assigned task thoroughly.",
task_input.agent_role
))
.with_max_iterations(max_iter)
.with_timeout(self.timeout.as_secs());
let agent_id = agent_config.id.clone();
self.engine
.spawn_agent(
agent_config.clone(),
self.tool_registry.clone(),
self.backend.clone(),
)
.await
.map_err(|e| McpError::Internal(format!("Failed to spawn agent: {}", e)))?;
crew.add_agent(agent_config.clone());
let task = Task::new(&task_input.description)
.with_agent(agent_id.clone())
.with_expected_output(
task_input
.expected_output
.as_deref()
.unwrap_or("Complete the task"),
);
crew.add_task(task)
.map_err(|e| McpError::Internal(format!("Failed to add task: {}", e)))?;
agent_ids.push(agent_id);
}
let crew_result = tokio::time::timeout(self.timeout, crew.kickoff()).await;
for agent_id in &agent_ids {
let _ = self.engine.stop_agent(agent_id).await;
}
let elapsed_ms = start.elapsed().as_millis() as u64;
match crew_result {
Ok(Ok(task_results)) => {
let result_text = format_crew_results(&input, &task_results, elapsed_ms);
let mut content = vec![ToolContent::text(result_text.clone())];
if let Some(trace) = trace_collector {
let finalized = trace.finalize(result_text);
content.push(ToolContent::text(
serde_json::to_string_pretty(&finalized.to_json())
.unwrap_or_else(|_| "{}".to_string()),
));
}
Ok(CallToolResult {
content,
is_error: false,
})
}
Ok(Err(e)) => Ok(CallToolResult {
content: vec![ToolContent::text(format!("Crew execution error: {}", e))],
is_error: true,
}),
Err(_) => Ok(CallToolResult {
content: vec![ToolContent::text(format!(
"Crew execution timed out after {}s",
self.timeout.as_secs()
))],
is_error: true,
}),
}
}
}
#[cfg(feature = "engine")]
fn format_crew_results(
input: &RunCrewInput,
task_results: &[cortexai_core::types::TaskResult],
elapsed_ms: u64,
) -> String {
let mut parts = vec![format!("Crew '{}' completed in {}ms\n", input.crew_name, elapsed_ms)];
for (i, result) in task_results.iter().enumerate() {
let task_desc = input
.tasks
.get(i)
.map(|t| t.description.as_str())
.unwrap_or("unknown");
let status = if result.success { "OK" } else { "FAIL" };
let output = if result.success {
serde_json::to_string_pretty(&result.output).unwrap_or_default()
} else {
result.error.clone().unwrap_or_default()
};
parts.push(format!("Task {} [{}]: {}\n {}\n", i + 1, status, task_desc, output));
}
parts.join("")
}
#[cfg(feature = "engine")]
#[async_trait]
impl ToolHandler for CrewEngineHandler {
fn definition(&self) -> McpTool {
McpTool {
name: "run_crew".to_string(),
description: Some(
"Run a Cortex multi-agent crew. Creates agents for each task, \
executes tasks according to the specified process, and returns \
per-task results."
.to_string(),
),
input_schema: json!({
"type": "object",
"properties": {
"crew_name": {
"type": "string",
"description": "Name for the crew"
},
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"description": {
"type": "string",
"description": "What this task should accomplish"
},
"agent_role": {
"type": "string",
"description": "Role of the agent for this task"
},
"expected_output": {
"type": "string",
"description": "Optional expected output description"
}
},
"required": ["description", "agent_role"]
},
"description": "Tasks to execute"
},
"process": {
"type": "string",
"enum": ["sequential", "parallel", "hierarchical"],
"description": "Execution process (default: sequential)"
},
"max_iterations": {
"type": "integer",
"description": "Max iterations per agent"
},
"trace": {
"type": "boolean",
"description": "When true, include an execution trace in the response (default false)"
}
},
"required": ["crew_name", "tasks"]
}),
}
}
async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
let input: RunCrewInput = serde_json::from_value(arguments)
.map_err(|e| McpError::InvalidParams(format!("Invalid input: {}", e)))?;
info!(crew_name = %input.crew_name, task_count = input.tasks.len(), "Running crew via MCP");
self.run_crew(input).await
}
}
#[cfg(all(test, feature = "engine"))]
mod tests {
use std::sync::Arc;
use cortexai_agents::AgentEngine;
use cortexai_core::tool::ToolRegistry;
use cortexai_providers::{LLMBackend, MockBackend, MockResponse};
use serde_json::json;
use crate::server::ToolHandler;
use super::*;
#[tokio::test]
async fn test_crew_engine_handler_definition() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> =
Arc::new(MockBackend::new().with_response(MockResponse::text("done")));
let registry = Arc::new(ToolRegistry::new());
let handler = CrewEngineHandler::new(engine, backend, registry);
let def = handler.definition();
assert_eq!(def.name, "run_crew");
assert!(def.description.is_some());
let schema = &def.input_schema;
assert!(schema["properties"]["crew_name"].is_object());
assert!(schema["properties"]["tasks"].is_object());
assert!(schema["properties"]["process"].is_object());
}
#[tokio::test]
async fn test_crew_engine_handler_trace_false_returns_normal_response() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> = Arc::new(
MockBackend::new()
.with_response(MockResponse::text("Result 1").with_latency(200)),
);
let registry = Arc::new(ToolRegistry::new());
let handler = CrewEngineHandler::new(engine.clone(), backend, registry)
.with_timeout(Duration::from_secs(30));
let result = handler
.execute(json!({
"crew_name": "test-crew",
"tasks": [{"description": "Do task", "agent_role": "worker"}],
"trace": false
}))
.await
.unwrap();
assert_eq!(result.content.len(), 1);
}
#[tokio::test]
async fn test_crew_engine_handler_trace_true_returns_response_and_trace() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> = Arc::new(
MockBackend::new()
.with_response(MockResponse::text("Task result").with_latency(200)),
);
let registry = Arc::new(ToolRegistry::new());
let handler = CrewEngineHandler::new(engine.clone(), backend, registry)
.with_timeout(Duration::from_secs(30));
let result = handler
.execute(json!({
"crew_name": "traced-crew",
"tasks": [{"description": "Do task", "agent_role": "worker"}],
"trace": true
}))
.await
.unwrap();
assert_eq!(result.content.len(), 2, "Expected 2 content blocks (response + trace)");
let trace_text = result.content[1].as_text().unwrap();
let trace_json: serde_json::Value = serde_json::from_str(trace_text)
.expect("trace content should be valid JSON");
assert!(trace_json["trace_id"].is_string());
assert!(trace_json["total_duration_ms"].is_u64());
}
#[tokio::test]
async fn test_crew_engine_handler_runs_two_tasks() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> = Arc::new(
MockBackend::new()
.with_response(MockResponse::text("Research result").with_latency(200))
.with_response(MockResponse::text("Analysis result").with_latency(200)),
);
let registry = Arc::new(ToolRegistry::new());
let handler = CrewEngineHandler::new(engine.clone(), backend, registry)
.with_timeout(Duration::from_secs(30));
let result = handler
.execute(json!({
"crew_name": "test-crew",
"tasks": [
{
"description": "Research the topic",
"agent_role": "researcher"
},
{
"description": "Analyze findings",
"agent_role": "analyst"
}
],
"process": "sequential"
}))
.await
.unwrap();
let text = result.content[0].as_text().unwrap();
assert!(
text.contains("test-crew"),
"Expected crew name in output, got: {}",
text
);
assert_eq!(engine.agent_count(), 0);
}
}