Skip to main content

cortexai_mcp/
crew_handler.rs

1//! Crew-as-MCP-Tool
2//!
3//! Exposes cortex crew workflows as MCP tools, allowing external MCP clients
4//! to invoke entire multi-agent crew executions as a single tool call.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tracing::{debug, info};
12
13use crate::error::McpError;
14use crate::protocol::{CallToolResult, McpTool, ToolContent};
15use crate::server::ToolHandler;
16
17// =============================================================================
18// Input / Output types
19// =============================================================================
20
21/// Input schema for crew MCP tools
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CrewMcpInput {
24    /// The main task description for the crew
25    pub task: String,
26    /// Additional context as key-value pairs
27    #[serde(default)]
28    pub context: HashMap<String, String>,
29    /// Execution mode: "sequential", "parallel", or "hierarchical"
30    #[serde(default)]
31    pub mode: Option<String>,
32    /// Maximum number of iterations for the crew execution
33    #[serde(default)]
34    pub max_iterations: Option<u32>,
35}
36
37/// Per-task outcome within a crew execution
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct TaskResult {
40    /// Name or identifier of the task
41    pub name: String,
42    /// Output produced by this task
43    pub output: String,
44    /// Whether this task succeeded
45    pub success: bool,
46}
47
48/// Output structure for crew responses
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct CrewMcpOutput {
51    /// Final aggregated result
52    pub result: String,
53    /// Per-task outcomes
54    pub task_results: Vec<TaskResult>,
55    /// Total execution duration in milliseconds
56    pub duration_ms: u64,
57    /// Names of agents that participated
58    pub agents_used: Vec<String>,
59}
60
61// =============================================================================
62// Configuration
63// =============================================================================
64
65/// Configuration for a crew MCP handler
66#[derive(Debug, Clone)]
67pub struct CrewMcpConfig {
68    /// Name prefix for the MCP tool (e.g., "crew_")
69    pub name_prefix: String,
70    /// Whether to include per-task results in the response
71    pub include_task_results: bool,
72}
73
74impl Default for CrewMcpConfig {
75    fn default() -> Self {
76        Self {
77            name_prefix: "crew_".to_string(),
78            include_task_results: true,
79        }
80    }
81}
82
83// =============================================================================
84// Handler
85// =============================================================================
86
87/// Handler type for crew execution
88pub type CrewHandlerFn = Arc<
89    dyn Fn(
90            CrewMcpInput,
91        ) -> std::pin::Pin<
92            Box<dyn std::future::Future<Output = Result<CrewMcpOutput, String>> + Send>,
93        > + Send
94        + Sync,
95>;
96
97/// MCP ToolHandler that wraps a crew workflow
98pub struct CrewMcpHandler {
99    name: String,
100    description: String,
101    capabilities: Vec<String>,
102    handler: CrewHandlerFn,
103    config: CrewMcpConfig,
104}
105
106impl CrewMcpHandler {
107    /// Create a builder for fluent construction
108    pub fn builder(name: impl Into<String>) -> CrewMcpHandlerBuilder {
109        CrewMcpHandlerBuilder::new(name)
110    }
111
112    /// Get the tool name
113    pub fn name(&self) -> &str {
114        &self.name
115    }
116
117    /// Get the capabilities
118    pub fn capabilities(&self) -> &[String] {
119        &self.capabilities
120    }
121}
122
123#[async_trait]
124impl ToolHandler for CrewMcpHandler {
125    fn definition(&self) -> McpTool {
126        let schema = json!({
127            "type": "object",
128            "properties": {
129                "task": {
130                    "type": "string",
131                    "description": "The main task description for the crew"
132                },
133                "context": {
134                    "type": "object",
135                    "description": "Additional context as key-value pairs",
136                    "additionalProperties": { "type": "string" }
137                },
138                "mode": {
139                    "type": "string",
140                    "description": "Execution mode: sequential, parallel, or hierarchical",
141                    "enum": ["sequential", "parallel", "hierarchical"]
142                },
143                "max_iterations": {
144                    "type": "integer",
145                    "description": "Maximum number of iterations for crew execution"
146                }
147            },
148            "required": ["task"]
149        });
150
151        let description = if self.capabilities.is_empty() {
152            self.description.clone()
153        } else {
154            format!(
155                "{}\n\nCapabilities: {}",
156                self.description,
157                self.capabilities.join(", ")
158            )
159        };
160
161        McpTool {
162            name: self.name.clone(),
163            description: Some(description),
164            input_schema: schema,
165        }
166    }
167
168    async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
169        debug!(tool = %self.name, "Executing crew MCP handler");
170
171        let input: CrewMcpInput = serde_json::from_value(arguments)
172            .map_err(|e| McpError::InvalidParams(format!("Invalid input: {}", e)))?;
173
174        info!(
175            tool = %self.name,
176            task = %input.task,
177            mode = ?input.mode,
178            "Crew executing task"
179        );
180
181        let result = (self.handler)(input).await;
182
183        match result {
184            Ok(output) => {
185                let response_text = build_success_response(&output, &self.config);
186
187                let structured = json!({
188                    "duration_ms": output.duration_ms,
189                    "agents_used": output.agents_used,
190                    "task_count": output.task_results.len(),
191                });
192
193                Ok(CallToolResult {
194                    content: vec![
195                        ToolContent::text(response_text),
196                        ToolContent::text(format!(
197                            "\n---\nStructured output: {}",
198                            serde_json::to_string_pretty(&structured).unwrap_or_default()
199                        )),
200                    ],
201                    is_error: false,
202                })
203            }
204            Err(e) => Ok(CallToolResult {
205                content: vec![ToolContent::text(format!("Crew error: {}", e))],
206                is_error: true,
207            }),
208        }
209    }
210}
211
212/// Build the human-readable success response text
213fn build_success_response(output: &CrewMcpOutput, config: &CrewMcpConfig) -> String {
214    let mut parts = vec![output.result.clone()];
215
216    if config.include_task_results && !output.task_results.is_empty() {
217        let tasks_str = output
218            .task_results
219            .iter()
220            .map(|t| format!("  - {} [{}]: {}", t.name, if t.success { "OK" } else { "FAIL" }, t.output))
221            .collect::<Vec<_>>()
222            .join("\n");
223        parts.push(format!("\n\nTask results:\n{}", tasks_str));
224    }
225
226    if !output.agents_used.is_empty() {
227        parts.push(format!("\n\nAgents used: {}", output.agents_used.join(", ")));
228    }
229
230    parts.join("")
231}
232
233// =============================================================================
234// Builder
235// =============================================================================
236
237/// Builder for CrewMcpHandler
238pub struct CrewMcpHandlerBuilder {
239    name: String,
240    description: String,
241    capabilities: Vec<String>,
242    config: CrewMcpConfig,
243}
244
245impl CrewMcpHandlerBuilder {
246    pub fn new(name: impl Into<String>) -> Self {
247        Self {
248            name: name.into(),
249            description: String::new(),
250            capabilities: Vec::new(),
251            config: CrewMcpConfig::default(),
252        }
253    }
254
255    pub fn description(mut self, description: impl Into<String>) -> Self {
256        self.description = description.into();
257        self
258    }
259
260    pub fn capability(mut self, capability: impl Into<String>) -> Self {
261        self.capabilities.push(capability.into());
262        self
263    }
264
265    pub fn capabilities(mut self, capabilities: Vec<String>) -> Self {
266        self.capabilities.extend(capabilities);
267        self
268    }
269
270    pub fn name_prefix(mut self, prefix: impl Into<String>) -> Self {
271        self.config.name_prefix = prefix.into();
272        self
273    }
274
275    pub fn include_task_results(mut self, include: bool) -> Self {
276        self.config.include_task_results = include;
277        self
278    }
279
280    pub fn config(mut self, config: CrewMcpConfig) -> Self {
281        self.config = config;
282        self
283    }
284
285    /// Build with a handler function
286    pub fn handler<F, Fut>(self, handler: F) -> CrewMcpHandler
287    where
288        F: Fn(CrewMcpInput) -> Fut + Send + Sync + 'static,
289        Fut: std::future::Future<Output = Result<CrewMcpOutput, String>> + Send + 'static,
290    {
291        let tool_name = format!("{}{}", self.config.name_prefix, self.name);
292
293        CrewMcpHandler {
294            name: tool_name,
295            description: self.description,
296            capabilities: self.capabilities,
297            handler: Arc::new(move |input| Box::pin(handler(input))),
298            config: self.config,
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use serde_json::json;
306
307    use super::*;
308
309    #[test]
310    fn test_crew_mcp_input_full_deserialization() {
311        let json_val = json!({
312            "task": "Research quantum computing advances",
313            "context": {"domain": "physics", "depth": "detailed"},
314            "mode": "parallel",
315            "max_iterations": 5
316        });
317
318        let input: CrewMcpInput = serde_json::from_value(json_val).unwrap();
319        assert_eq!(input.task, "Research quantum computing advances");
320        assert_eq!(input.context.get("domain").unwrap(), "physics");
321        assert_eq!(input.context.get("depth").unwrap(), "detailed");
322        assert_eq!(input.mode, Some("parallel".to_string()));
323        assert_eq!(input.max_iterations, Some(5));
324    }
325
326    #[test]
327    fn test_crew_handler_definition_and_schema() {
328        let handler = CrewMcpHandler::builder("research")
329            .description("Research crew workflow")
330            .capability("web_search")
331            .capability("summarization")
332            .handler(|_input: CrewMcpInput| async move {
333                Ok(CrewMcpOutput {
334                    result: "done".to_string(),
335                    task_results: Vec::new(),
336                    duration_ms: 0,
337                    agents_used: Vec::new(),
338                })
339            });
340
341        let def = handler.definition();
342        assert_eq!(def.name, "crew_research");
343        let desc = def.description.unwrap();
344        assert!(desc.contains("Research crew workflow"));
345        assert!(desc.contains("web_search"));
346        assert!(desc.contains("summarization"));
347
348        // Verify schema has the right properties
349        let schema = &def.input_schema;
350        assert_eq!(schema["type"], "object");
351        assert!(schema["properties"]["task"].is_object());
352        assert!(schema["properties"]["context"].is_object());
353        assert!(schema["properties"]["mode"].is_object());
354        assert!(schema["properties"]["max_iterations"].is_object());
355        assert_eq!(schema["required"][0], "task");
356    }
357
358    #[test]
359    fn test_crew_handler_custom_prefix() {
360        let handler = CrewMcpHandler::builder("analysis")
361            .description("Analysis crew")
362            .name_prefix("workflow_")
363            .handler(|_input: CrewMcpInput| async move {
364                Ok(CrewMcpOutput {
365                    result: "done".to_string(),
366                    task_results: Vec::new(),
367                    duration_ms: 0,
368                    agents_used: Vec::new(),
369                })
370            });
371
372        let def = handler.definition();
373        assert_eq!(def.name, "workflow_analysis");
374    }
375
376    #[tokio::test]
377    async fn test_crew_handler_execution_with_mock() {
378        let handler = CrewMcpHandler::builder("research")
379            .description("Research crew")
380            .handler(|input: CrewMcpInput| async move {
381                let topic = input.context.get("topic").cloned().unwrap_or_default();
382                Ok(CrewMcpOutput {
383                    result: format!("Researched: {} - {}", input.task, topic),
384                    task_results: vec![
385                        TaskResult {
386                            name: "gather".to_string(),
387                            output: "Gathered data".to_string(),
388                            success: true,
389                        },
390                        TaskResult {
391                            name: "analyze".to_string(),
392                            output: "Analysis complete".to_string(),
393                            success: true,
394                        },
395                    ],
396                    duration_ms: 1500,
397                    agents_used: vec!["researcher".to_string(), "analyst".to_string()],
398                })
399            });
400
401        let result = handler
402            .execute(json!({
403                "task": "Find trends",
404                "context": {"topic": "AI"},
405                "mode": "sequential"
406            }))
407            .await
408            .unwrap();
409
410        assert!(!result.is_error);
411
412        let text = result.content[0].as_text().unwrap();
413        assert!(text.contains("Researched: Find trends - AI"));
414        assert!(text.contains("gather [OK]"));
415        assert!(text.contains("analyze [OK]"));
416        assert!(text.contains("researcher"));
417        assert!(text.contains("analyst"));
418
419        // Verify structured output
420        let structured_text = result.content[1].as_text().unwrap();
421        assert!(structured_text.contains("duration_ms"));
422        assert!(structured_text.contains("1500"));
423    }
424
425    #[tokio::test]
426    async fn test_crew_handler_error_returns_is_error() {
427        let handler = CrewMcpHandler::builder("failing_crew")
428            .description("A crew that fails")
429            .handler(|_: CrewMcpInput| async move {
430                Err("Agent timeout: researcher did not respond".to_string())
431            });
432
433        let result = handler
434            .execute(json!({"task": "do something"}))
435            .await
436            .unwrap();
437
438        assert!(result.is_error);
439        let text = result.content[0].as_text().unwrap();
440        assert!(text.contains("Crew error"));
441        assert!(text.contains("Agent timeout"));
442    }
443
444    #[tokio::test]
445    async fn test_crew_handler_invalid_input_returns_error() {
446        let handler = CrewMcpHandler::builder("strict_crew")
447            .description("Crew with strict input")
448            .handler(|_: CrewMcpInput| async move {
449                Ok(CrewMcpOutput {
450                    result: "ok".to_string(),
451                    task_results: Vec::new(),
452                    duration_ms: 0,
453                    agents_used: Vec::new(),
454                })
455            });
456
457        // Missing required "task" field
458        let result = handler.execute(json!({"context": {"a": "b"}})).await;
459        assert!(result.is_err());
460    }
461
462    #[test]
463    fn test_crew_mcp_input_minimal_deserialization() {
464        let json_val = json!({"task": "simple task"});
465        let input: CrewMcpInput = serde_json::from_value(json_val).unwrap();
466
467        assert_eq!(input.task, "simple task");
468        assert!(input.context.is_empty());
469        assert!(input.mode.is_none());
470        assert!(input.max_iterations.is_none());
471    }
472}