intent_engine/mcp/
server.rs

1//! Intent-Engine MCP Server (Rust Implementation)
2//!
3//! This is a native Rust implementation of the MCP (Model Context Protocol) server
4//! that provides a JSON-RPC 2.0 interface for AI assistants to interact with
5//! intent-engine's task management capabilities.
6//!
7//! Unlike the Python wrapper (mcp-server.py), this implementation directly uses
8//! the Rust library functions, avoiding subprocess overhead and improving performance.
9
10use crate::events::EventManager;
11use crate::project::ProjectContext;
12use crate::report::ReportManager;
13use crate::tasks::TaskManager;
14use crate::workspace::WorkspaceManager;
15use serde::{Deserialize, Serialize};
16use serde_json::{json, Value};
17use std::io::{self, BufRead, Write};
18
19#[derive(Debug, Deserialize)]
20struct JsonRpcRequest {
21    jsonrpc: String,
22    id: Option<Value>,
23    method: String,
24    params: Option<Value>,
25}
26
27#[derive(Debug, Serialize)]
28struct JsonRpcResponse {
29    jsonrpc: String,
30    id: Option<Value>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    result: Option<Value>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    error: Option<JsonRpcError>,
35}
36
37#[derive(Debug, Serialize)]
38struct JsonRpcError {
39    code: i32,
40    message: String,
41}
42
43#[derive(Debug, Deserialize)]
44struct ToolCallParams {
45    name: String,
46    arguments: Value,
47}
48
49/// MCP Tool Schema
50const MCP_TOOLS: &str = include_str!("../../mcp-server.json");
51
52/// Run the MCP server
53/// This is the main entry point for MCP server mode
54pub async fn run() -> io::Result<()> {
55    run_server().await
56}
57
58async fn run_server() -> io::Result<()> {
59    let stdin = io::stdin();
60    let mut stdout = io::stdout();
61    let reader = stdin.lock();
62
63    for line in reader.lines() {
64        let line = line?;
65        if line.trim().is_empty() {
66            continue;
67        }
68
69        let response = match serde_json::from_str::<JsonRpcRequest>(&line) {
70            Ok(request) => {
71                // Handle notifications (no id = no response needed)
72                if request.id.is_none() {
73                    handle_notification(&request).await;
74                    continue; // Skip sending response for notifications
75                }
76                handle_request(request).await
77            },
78            Err(e) => JsonRpcResponse {
79                jsonrpc: "2.0".to_string(),
80                id: None,
81                result: None,
82                error: Some(JsonRpcError {
83                    code: -32700,
84                    message: format!("Parse error: {}", e),
85                }),
86            },
87        };
88
89        let response_json = serde_json::to_string(&response)?;
90        writeln!(stdout, "{}", response_json)?;
91        stdout.flush()?;
92    }
93
94    Ok(())
95}
96
97async fn handle_notification(request: &JsonRpcRequest) {
98    // Handle MCP notifications (no response required)
99    match request.method.as_str() {
100        "initialized" => {
101            eprintln!("✓ MCP client initialized");
102        },
103        "notifications/cancelled" => {
104            eprintln!("⚠ Request cancelled");
105        },
106        _ => {
107            eprintln!("⚠ Unknown notification: {}", request.method);
108        },
109    }
110}
111
112async fn handle_request(request: JsonRpcRequest) -> JsonRpcResponse {
113    // Validate JSON-RPC version
114    if request.jsonrpc != "2.0" {
115        return JsonRpcResponse {
116            jsonrpc: "2.0".to_string(),
117            id: request.id,
118            result: None,
119            error: Some(JsonRpcError {
120                code: -32600,
121                message: format!("Invalid JSON-RPC version: {}", request.jsonrpc),
122            }),
123        };
124    }
125
126    let result = match request.method.as_str() {
127        "initialize" => handle_initialize(request.params),
128        "ping" => Ok(json!({})), // Ping response for connection keep-alive
129        "tools/list" => handle_tools_list(),
130        "tools/call" => handle_tool_call(request.params).await,
131        _ => Err(format!("Method not found: {}", request.method)),
132    };
133
134    match result {
135        Ok(value) => JsonRpcResponse {
136            jsonrpc: "2.0".to_string(),
137            id: request.id,
138            result: Some(value),
139            error: None,
140        },
141        Err(message) => JsonRpcResponse {
142            jsonrpc: "2.0".to_string(),
143            id: request.id,
144            result: None,
145            error: Some(JsonRpcError {
146                code: -32000,
147                message,
148            }),
149        },
150    }
151}
152
153fn handle_initialize(_params: Option<Value>) -> Result<Value, String> {
154    // MCP initialize handshake
155    // Return server capabilities and info per MCP specification
156    Ok(json!({
157        "protocolVersion": "2024-11-05",
158        "capabilities": {
159            "tools": {
160                "listChanged": false  // Static tool list, no dynamic changes
161            }
162        },
163        "serverInfo": {
164            "name": "intent-engine",
165            "version": env!("CARGO_PKG_VERSION")
166        }
167    }))
168}
169
170fn handle_tools_list() -> Result<Value, String> {
171    let config: Value = serde_json::from_str(MCP_TOOLS)
172        .map_err(|e| format!("Failed to parse MCP tools schema: {}", e))?;
173
174    Ok(json!({
175        "tools": config.get("tools").unwrap_or(&json!([]))
176    }))
177}
178
179async fn handle_tool_call(params: Option<Value>) -> Result<Value, String> {
180    let params: ToolCallParams = serde_json::from_value(params.unwrap_or(json!({})))
181        .map_err(|e| format!("Invalid tool call parameters: {}", e))?;
182
183    let result = match params.name.as_str() {
184        "task_add" => handle_task_add(params.arguments).await,
185        "task_add_dependency" => handle_task_add_dependency(params.arguments).await,
186        "task_start" => handle_task_start(params.arguments).await,
187        "task_pick_next" => handle_task_pick_next(params.arguments).await,
188        "task_spawn_subtask" => handle_task_spawn_subtask(params.arguments).await,
189        "task_switch" => handle_task_switch(params.arguments).await,
190        "task_done" => handle_task_done(params.arguments).await,
191        "task_update" => handle_task_update(params.arguments).await,
192        "task_list" => handle_task_list(params.arguments).await,
193        "task_get" => handle_task_get(params.arguments).await,
194        "task_context" => handle_task_context(params.arguments).await,
195        "task_delete" => handle_task_delete(params.arguments).await,
196        "event_add" => handle_event_add(params.arguments).await,
197        "event_list" => handle_event_list(params.arguments).await,
198        "unified_search" => handle_unified_search(params.arguments).await,
199        "current_task_get" => handle_current_task_get(params.arguments).await,
200        "report_generate" => handle_report_generate(params.arguments).await,
201        _ => Err(format!("Unknown tool: {}", params.name)),
202    }?;
203
204    Ok(json!({
205        "content": [{
206            "type": "text",
207            "text": serde_json::to_string_pretty(&result)
208                .unwrap_or_else(|_| "{}".to_string())
209        }]
210    }))
211}
212
213// Tool Handlers
214
215async fn handle_task_add(args: Value) -> Result<Value, String> {
216    // Improved parameter validation with specific error messages
217    let name = match args.get("name") {
218        None => return Err("Missing required parameter: name".to_string()),
219        Some(value) => {
220            if value.is_null() {
221                return Err("Parameter 'name' cannot be null".to_string());
222            }
223            match value.as_str() {
224                Some(s) if s.trim().is_empty() => {
225                    return Err("Parameter 'name' cannot be empty".to_string());
226                },
227                Some(s) => s,
228                None => return Err(format!("Parameter 'name' must be a string, got: {}", value)),
229            }
230        },
231    };
232
233    let spec = args.get("spec").and_then(|v| v.as_str());
234    let parent_id = args.get("parent_id").and_then(|v| v.as_i64());
235
236    let ctx = ProjectContext::load_or_init()
237        .await
238        .map_err(|e| format!("Failed to load project context: {}", e))?;
239
240    let task_mgr = TaskManager::new(&ctx.pool);
241    let task = task_mgr
242        .add_task(name, spec, parent_id)
243        .await
244        .map_err(|e| format!("Failed to add task: {}", e))?;
245
246    serde_json::to_value(&task).map_err(|e| format!("Serialization error: {}", e))
247}
248
249async fn handle_task_add_dependency(args: Value) -> Result<Value, String> {
250    let blocked_task_id = args
251        .get("blocked_task_id")
252        .and_then(|v| v.as_i64())
253        .ok_or("Missing required parameter: blocked_task_id")?;
254
255    let blocking_task_id = args
256        .get("blocking_task_id")
257        .and_then(|v| v.as_i64())
258        .ok_or("Missing required parameter: blocking_task_id")?;
259
260    let ctx = ProjectContext::load_or_init()
261        .await
262        .map_err(|e| format!("Failed to load project context: {}", e))?;
263
264    let dependency =
265        crate::dependencies::add_dependency(&ctx.pool, blocking_task_id, blocked_task_id)
266            .await
267            .map_err(|e| format!("Failed to add dependency: {}", e))?;
268
269    serde_json::to_value(&dependency).map_err(|e| format!("Serialization error: {}", e))
270}
271
272async fn handle_task_start(args: Value) -> Result<Value, String> {
273    let task_id = args
274        .get("task_id")
275        .and_then(|v| v.as_i64())
276        .ok_or("Missing required parameter: task_id")?;
277
278    let with_events = args
279        .get("with_events")
280        .and_then(|v| v.as_bool())
281        .unwrap_or(true);
282
283    let ctx = ProjectContext::load_or_init()
284        .await
285        .map_err(|e| format!("Failed to load project context: {}", e))?;
286
287    let task_mgr = TaskManager::new(&ctx.pool);
288    let task = task_mgr
289        .start_task(task_id, with_events)
290        .await
291        .map_err(|e| format!("Failed to start task: {}", e))?;
292
293    serde_json::to_value(&task).map_err(|e| format!("Serialization error: {}", e))
294}
295
296async fn handle_task_pick_next(args: Value) -> Result<Value, String> {
297    let _max_count = args.get("max_count").and_then(|v| v.as_i64());
298    let _capacity = args.get("capacity").and_then(|v| v.as_i64());
299
300    let ctx = ProjectContext::load_or_init()
301        .await
302        .map_err(|e| format!("Failed to load project context: {}", e))?;
303
304    let task_mgr = TaskManager::new(&ctx.pool);
305    let response = task_mgr
306        .pick_next()
307        .await
308        .map_err(|e| format!("Failed to pick next task: {}", e))?;
309
310    serde_json::to_value(&response).map_err(|e| format!("Serialization error: {}", e))
311}
312
313async fn handle_task_spawn_subtask(args: Value) -> Result<Value, String> {
314    let name = args
315        .get("name")
316        .and_then(|v| v.as_str())
317        .ok_or("Missing required parameter: name")?;
318
319    let spec = args.get("spec").and_then(|v| v.as_str());
320
321    let ctx = ProjectContext::load_or_init()
322        .await
323        .map_err(|e| format!("Failed to load project context: {}", e))?;
324
325    let task_mgr = TaskManager::new(&ctx.pool);
326    let subtask = task_mgr
327        .spawn_subtask(name, spec)
328        .await
329        .map_err(|e| format!("Failed to spawn subtask: {}", e))?;
330
331    serde_json::to_value(&subtask).map_err(|e| format!("Serialization error: {}", e))
332}
333
334async fn handle_task_switch(args: Value) -> Result<Value, String> {
335    let task_id = args
336        .get("task_id")
337        .and_then(|v| v.as_i64())
338        .ok_or("Missing required parameter: task_id")?;
339
340    let ctx = ProjectContext::load_or_init()
341        .await
342        .map_err(|e| format!("Failed to load project context: {}", e))?;
343
344    let task_mgr = TaskManager::new(&ctx.pool);
345    let task = task_mgr
346        .switch_to_task(task_id)
347        .await
348        .map_err(|e| format!("Failed to switch task: {}", e))?;
349
350    serde_json::to_value(&task).map_err(|e| format!("Serialization error: {}", e))
351}
352
353async fn handle_task_done(args: Value) -> Result<Value, String> {
354    let task_id = args.get("task_id").and_then(|v| v.as_i64());
355
356    let ctx = ProjectContext::load_or_init()
357        .await
358        .map_err(|e| format!("Failed to load project context: {}", e))?;
359
360    let task_mgr = TaskManager::new(&ctx.pool);
361
362    // If task_id is provided, set it as current first
363    if let Some(id) = task_id {
364        let workspace_mgr = WorkspaceManager::new(&ctx.pool);
365        workspace_mgr
366            .set_current_task(id)
367            .await
368            .map_err(|e| format!("Failed to set current task: {}", e))?;
369    }
370
371    let task = task_mgr
372        .done_task()
373        .await
374        .map_err(|e| format!("Failed to mark task as done: {}", e))?;
375
376    serde_json::to_value(&task).map_err(|e| format!("Serialization error: {}", e))
377}
378
379async fn handle_task_update(args: Value) -> Result<Value, String> {
380    let task_id = args
381        .get("task_id")
382        .and_then(|v| v.as_i64())
383        .ok_or("Missing required parameter: task_id")?;
384
385    let name = args.get("name").and_then(|v| v.as_str());
386    let spec = args.get("spec").and_then(|v| v.as_str());
387    let status = args.get("status").and_then(|v| v.as_str());
388    let complexity = args
389        .get("complexity")
390        .and_then(|v| v.as_i64())
391        .map(|v| v as i32);
392    let priority = match args.get("priority").and_then(|v| v.as_str()) {
393        Some(p) => Some(
394            crate::priority::PriorityLevel::parse_to_int(p)
395                .map_err(|e| format!("Invalid priority: {}", e))?,
396        ),
397        None => None,
398    };
399    let parent_id = args.get("parent_id").and_then(|v| v.as_i64()).map(Some);
400
401    let ctx = ProjectContext::load_or_init()
402        .await
403        .map_err(|e| format!("Failed to load project context: {}", e))?;
404
405    let task_mgr = TaskManager::new(&ctx.pool);
406    let task = task_mgr
407        .update_task(task_id, name, spec, parent_id, status, complexity, priority)
408        .await
409        .map_err(|e| format!("Failed to update task: {}", e))?;
410
411    serde_json::to_value(&task).map_err(|e| format!("Serialization error: {}", e))
412}
413
414async fn handle_task_list(args: Value) -> Result<Value, String> {
415    let status = args.get("status").and_then(|v| v.as_str());
416    let parent = args.get("parent").and_then(|v| v.as_str());
417
418    let parent_opt = parent.map(|p| {
419        if p == "null" {
420            None
421        } else {
422            p.parse::<i64>().ok()
423        }
424    });
425
426    let ctx = ProjectContext::load()
427        .await
428        .map_err(|e| format!("Failed to load project context: {}", e))?;
429
430    let task_mgr = TaskManager::new(&ctx.pool);
431    let tasks = task_mgr
432        .find_tasks(status, parent_opt)
433        .await
434        .map_err(|e| format!("Failed to list tasks: {}", e))?;
435
436    serde_json::to_value(&tasks).map_err(|e| format!("Serialization error: {}", e))
437}
438
439async fn handle_task_get(args: Value) -> Result<Value, String> {
440    let task_id = args
441        .get("task_id")
442        .and_then(|v| v.as_i64())
443        .ok_or("Missing required parameter: task_id")?;
444
445    let with_events = args
446        .get("with_events")
447        .and_then(|v| v.as_bool())
448        .unwrap_or(false);
449
450    let ctx = ProjectContext::load()
451        .await
452        .map_err(|e| format!("Failed to load project context: {}", e))?;
453
454    let task_mgr = TaskManager::new(&ctx.pool);
455
456    if with_events {
457        let task = task_mgr
458            .get_task_with_events(task_id)
459            .await
460            .map_err(|e| format!("Failed to get task: {}", e))?;
461        serde_json::to_value(&task).map_err(|e| format!("Serialization error: {}", e))
462    } else {
463        let task = task_mgr
464            .get_task(task_id)
465            .await
466            .map_err(|e| format!("Failed to get task: {}", e))?;
467        serde_json::to_value(&task).map_err(|e| format!("Serialization error: {}", e))
468    }
469}
470
471async fn handle_task_context(args: Value) -> Result<Value, String> {
472    // Get task_id from args, or fall back to current task
473    let task_id = if let Some(id) = args.get("task_id").and_then(|v| v.as_i64()) {
474        id
475    } else {
476        // Fall back to current_task_id if no task_id provided
477        let ctx = ProjectContext::load()
478            .await
479            .map_err(|e| format!("Failed to load project context: {}", e))?;
480
481        let current_task_id: Option<String> =
482            sqlx::query_scalar("SELECT value FROM workspace_state WHERE key = 'current_task_id'")
483                .fetch_optional(&ctx.pool)
484                .await
485                .map_err(|e| format!("Database error: {}", e))?;
486
487        current_task_id
488            .and_then(|s| s.parse::<i64>().ok())
489            .ok_or_else(|| {
490                "No current task is set and task_id was not provided. \
491                 Use task_start or task_switch to set a task first, or provide task_id parameter."
492                    .to_string()
493            })?
494    };
495
496    let ctx = ProjectContext::load()
497        .await
498        .map_err(|e| format!("Failed to load project context: {}", e))?;
499
500    let task_mgr = TaskManager::new(&ctx.pool);
501    let context = task_mgr
502        .get_task_context(task_id)
503        .await
504        .map_err(|e| format!("Failed to get task context: {}", e))?;
505
506    serde_json::to_value(&context).map_err(|e| format!("Serialization error: {}", e))
507}
508
509async fn handle_task_delete(args: Value) -> Result<Value, String> {
510    let task_id = args
511        .get("task_id")
512        .and_then(|v| v.as_i64())
513        .ok_or("Missing required parameter: task_id")?;
514
515    let ctx = ProjectContext::load()
516        .await
517        .map_err(|e| format!("Failed to load project context: {}", e))?;
518
519    let task_mgr = TaskManager::new(&ctx.pool);
520    task_mgr
521        .delete_task(task_id)
522        .await
523        .map_err(|e| format!("Failed to delete task: {}", e))?;
524
525    Ok(json!({"success": true, "deleted_task_id": task_id}))
526}
527
528async fn handle_event_add(args: Value) -> Result<Value, String> {
529    let task_id = args.get("task_id").and_then(|v| v.as_i64());
530
531    let event_type = args
532        .get("event_type")
533        .and_then(|v| v.as_str())
534        .ok_or("Missing required parameter: event_type")?;
535
536    let data = args
537        .get("data")
538        .and_then(|v| v.as_str())
539        .ok_or("Missing required parameter: data")?;
540
541    let ctx = ProjectContext::load_or_init()
542        .await
543        .map_err(|e| format!("Failed to load project context: {}", e))?;
544
545    // Determine the target task ID
546    let target_task_id = if let Some(id) = task_id {
547        id
548    } else {
549        // Fall back to current_task_id
550        let current_task_id: Option<String> =
551            sqlx::query_scalar("SELECT value FROM workspace_state WHERE key = 'current_task_id'")
552                .fetch_optional(&ctx.pool)
553                .await
554                .map_err(|e| format!("Database error: {}", e))?;
555
556        current_task_id
557            .and_then(|s| s.parse::<i64>().ok())
558            .ok_or_else(|| {
559                "No current task is set and task_id was not provided. \
560                 Use task_start or task_switch to set a task first."
561                    .to_string()
562            })?
563    };
564
565    let event_mgr = EventManager::new(&ctx.pool);
566    let event = event_mgr
567        .add_event(target_task_id, event_type, data)
568        .await
569        .map_err(|e| format!("Failed to add event: {}", e))?;
570
571    serde_json::to_value(&event).map_err(|e| format!("Serialization error: {}", e))
572}
573
574async fn handle_event_list(args: Value) -> Result<Value, String> {
575    let task_id = args.get("task_id").and_then(|v| v.as_i64());
576
577    let limit = args.get("limit").and_then(|v| v.as_i64());
578    let log_type = args
579        .get("type")
580        .and_then(|v| v.as_str())
581        .map(|s| s.to_string());
582    let since = args
583        .get("since")
584        .and_then(|v| v.as_str())
585        .map(|s| s.to_string());
586
587    let ctx = ProjectContext::load()
588        .await
589        .map_err(|e| format!("Failed to load project context: {}", e))?;
590
591    let event_mgr = EventManager::new(&ctx.pool);
592    let events = event_mgr
593        .list_events(task_id, limit, log_type, since)
594        .await
595        .map_err(|e| format!("Failed to list events: {}", e))?;
596
597    serde_json::to_value(&events).map_err(|e| format!("Serialization error: {}", e))
598}
599
600async fn handle_unified_search(args: Value) -> Result<Value, String> {
601    use crate::search::SearchManager;
602
603    let query = args
604        .get("query")
605        .and_then(|v| v.as_str())
606        .ok_or("Missing required parameter: query")?;
607
608    let include_tasks = args
609        .get("include_tasks")
610        .and_then(|v| v.as_bool())
611        .unwrap_or(true);
612
613    let include_events = args
614        .get("include_events")
615        .and_then(|v| v.as_bool())
616        .unwrap_or(true);
617
618    let limit = args.get("limit").and_then(|v| v.as_i64());
619
620    let ctx = ProjectContext::load()
621        .await
622        .map_err(|e| format!("Failed to load project context: {}", e))?;
623
624    let search_mgr = SearchManager::new(&ctx.pool);
625    let results = search_mgr
626        .unified_search(query, include_tasks, include_events, limit)
627        .await
628        .map_err(|e| format!("Failed to perform unified search: {}", e))?;
629
630    serde_json::to_value(&results).map_err(|e| format!("Serialization error: {}", e))
631}
632
633async fn handle_current_task_get(_args: Value) -> Result<Value, String> {
634    let ctx = ProjectContext::load()
635        .await
636        .map_err(|e| format!("Failed to load project context: {}", e))?;
637
638    let workspace_mgr = WorkspaceManager::new(&ctx.pool);
639    let response = workspace_mgr
640        .get_current_task()
641        .await
642        .map_err(|e| format!("Failed to get current task: {}", e))?;
643
644    serde_json::to_value(&response).map_err(|e| format!("Serialization error: {}", e))
645}
646
647async fn handle_report_generate(args: Value) -> Result<Value, String> {
648    let since = args.get("since").and_then(|v| v.as_str()).map(String::from);
649    let status = args
650        .get("status")
651        .and_then(|v| v.as_str())
652        .map(String::from);
653    let filter_name = args
654        .get("filter_name")
655        .and_then(|v| v.as_str())
656        .map(String::from);
657    let filter_spec = args
658        .get("filter_spec")
659        .and_then(|v| v.as_str())
660        .map(String::from);
661    let summary_only = args
662        .get("summary_only")
663        .and_then(|v| v.as_bool())
664        .unwrap_or(true);
665
666    let ctx = ProjectContext::load()
667        .await
668        .map_err(|e| format!("Failed to load project context: {}", e))?;
669
670    let report_mgr = ReportManager::new(&ctx.pool);
671    let report = report_mgr
672        .generate_report(since, status, filter_name, filter_spec, summary_only)
673        .await
674        .map_err(|e| format!("Failed to generate report: {}", e))?;
675
676    serde_json::to_value(&report).map_err(|e| format!("Serialization error: {}", e))
677}
678
679#[cfg(test)]
680#[path = "server_tests.rs"]
681mod tests;