Skip to main content

code_baseline/
mcp.rs

1use crate::cli::toml_config::TomlConfig;
2use crate::presets;
3use crate::scan;
4use serde_json::json;
5use std::io::{self, BufRead, Write};
6use std::path::{Path, PathBuf};
7
8/// Run a simple MCP-compatible server over stdio.
9///
10/// Reads JSON-RPC requests from stdin, processes them, and writes
11/// JSON-RPC responses to stdout. Supports the MCP protocol for
12/// tool discovery and execution.
13pub fn run_mcp_server(config_path: &Path) {
14    let stdin = io::stdin();
15    let mut stdout = io::stdout();
16
17    // Read line-delimited JSON-RPC messages
18    for line in stdin.lock().lines() {
19        let line = match line {
20            Ok(l) => l,
21            Err(_) => break,
22        };
23
24        if line.trim().is_empty() {
25            continue;
26        }
27
28        let request: serde_json::Value = match serde_json::from_str(&line) {
29            Ok(v) => v,
30            Err(e) => {
31                let error_response = json!({
32                    "jsonrpc": "2.0",
33                    "id": null,
34                    "error": { "code": -32700, "message": format!("Parse error: {}", e) }
35                });
36                let _ = writeln!(stdout, "{}", error_response);
37                let _ = stdout.flush();
38                continue;
39            }
40        };
41
42        let id = request.get("id").cloned();
43        let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
44        let params = request.get("params").cloned().unwrap_or(json!({}));
45
46        let response = match method {
47            "initialize" => handle_initialize(id.clone()),
48            "tools/list" => handle_tools_list(id.clone()),
49            "tools/call" => handle_tools_call(id.clone(), &params, config_path),
50            "notifications/initialized" | "notifications/cancelled" => continue,
51            _ => json!({
52                "jsonrpc": "2.0",
53                "id": id,
54                "error": { "code": -32601, "message": format!("Unknown method: {}", method) }
55            }),
56        };
57
58        let _ = writeln!(stdout, "{}", response);
59        let _ = stdout.flush();
60    }
61}
62
63fn handle_initialize(id: Option<serde_json::Value>) -> serde_json::Value {
64    json!({
65        "jsonrpc": "2.0",
66        "id": id,
67        "result": {
68            "protocolVersion": "2024-11-05",
69            "capabilities": {
70                "tools": {}
71            },
72            "serverInfo": {
73                "name": "baseline",
74                "version": env!("CARGO_PKG_VERSION")
75            }
76        }
77    })
78}
79
80fn handle_tools_list(id: Option<serde_json::Value>) -> serde_json::Value {
81    json!({
82        "jsonrpc": "2.0",
83        "id": id,
84        "result": {
85            "tools": [
86                {
87                    "name": "baseline_scan",
88                    "description": "Scan files for rule violations. Returns structured violations with fix suggestions.",
89                    "inputSchema": {
90                        "type": "object",
91                        "properties": {
92                            "paths": {
93                                "type": "array",
94                                "items": { "type": "string" },
95                                "description": "File or directory paths to scan"
96                            },
97                            "content": {
98                                "type": "string",
99                                "description": "Inline file content to scan (alternative to paths)"
100                            },
101                            "filename": {
102                                "type": "string",
103                                "description": "Virtual filename for glob matching when using content"
104                            }
105                        }
106                    }
107                },
108                {
109                    "name": "baseline_list_rules",
110                    "description": "List all configured rules and their descriptions.",
111                    "inputSchema": {
112                        "type": "object",
113                        "properties": {}
114                    }
115                }
116            ]
117        }
118    })
119}
120
121fn handle_tools_call(
122    id: Option<serde_json::Value>,
123    params: &serde_json::Value,
124    config_path: &Path,
125) -> serde_json::Value {
126    let tool_name = params
127        .get("name")
128        .and_then(|n| n.as_str())
129        .unwrap_or("");
130
131    let arguments = params.get("arguments").cloned().unwrap_or(json!({}));
132
133    match tool_name {
134        "baseline_scan" => handle_scan(&id, &arguments, config_path),
135        "baseline_list_rules" => handle_list_rules(&id, config_path),
136        _ => json!({
137            "jsonrpc": "2.0",
138            "id": id,
139            "error": { "code": -32602, "message": format!("Unknown tool: {}", tool_name) }
140        }),
141    }
142}
143
144fn handle_scan(
145    id: &Option<serde_json::Value>,
146    arguments: &serde_json::Value,
147    config_path: &Path,
148) -> serde_json::Value {
149    // Check for inline content mode
150    if let Some(content) = arguments.get("content").and_then(|c| c.as_str()) {
151        let filename = arguments
152            .get("filename")
153            .and_then(|f| f.as_str())
154            .unwrap_or("stdin.tsx");
155
156        match scan::run_scan_stdin(config_path, content, filename) {
157            Ok(result) => {
158                let violations = format_violations_json(&result);
159                json!({
160                    "jsonrpc": "2.0",
161                    "id": id,
162                    "result": {
163                        "content": [{ "type": "text", "text": violations.to_string() }]
164                    }
165                })
166            }
167            Err(e) => json!({
168                "jsonrpc": "2.0",
169                "id": id,
170                "result": {
171                    "content": [{ "type": "text", "text": format!("Error: {}", e) }],
172                    "isError": true
173                }
174            }),
175        }
176    } else {
177        // File paths mode
178        let paths: Vec<PathBuf> = arguments
179            .get("paths")
180            .and_then(|p| p.as_array())
181            .map(|arr| {
182                arr.iter()
183                    .filter_map(|v| v.as_str().map(PathBuf::from))
184                    .collect()
185            })
186            .unwrap_or_else(|| vec![PathBuf::from(".")]);
187
188        match scan::run_scan(config_path, &paths) {
189            Ok(result) => {
190                let violations = format_violations_json(&result);
191                json!({
192                    "jsonrpc": "2.0",
193                    "id": id,
194                    "result": {
195                        "content": [{ "type": "text", "text": violations.to_string() }]
196                    }
197                })
198            }
199            Err(e) => json!({
200                "jsonrpc": "2.0",
201                "id": id,
202                "result": {
203                    "content": [{ "type": "text", "text": format!("Error: {}", e) }],
204                    "isError": true
205                }
206            }),
207        }
208    }
209}
210
211fn handle_list_rules(
212    id: &Option<serde_json::Value>,
213    config_path: &Path,
214) -> serde_json::Value {
215    let config_text = match std::fs::read_to_string(config_path) {
216        Ok(c) => c,
217        Err(e) => {
218            return json!({
219                "jsonrpc": "2.0",
220                "id": id,
221                "result": {
222                    "content": [{ "type": "text", "text": format!("Error reading config: {}", e) }],
223                    "isError": true
224                }
225            });
226        }
227    };
228
229    let toml_config: TomlConfig = match toml::from_str(&config_text) {
230        Ok(c) => c,
231        Err(e) => {
232            return json!({
233                "jsonrpc": "2.0",
234                "id": id,
235                "result": {
236                    "content": [{ "type": "text", "text": format!("Error parsing config: {}", e) }],
237                    "isError": true
238                }
239            });
240        }
241    };
242
243    let resolved = match presets::resolve_rules(&toml_config.baseline.extends, &toml_config.rule) {
244        Ok(r) => r,
245        Err(e) => {
246            return json!({
247                "jsonrpc": "2.0",
248                "id": id,
249                "result": {
250                    "content": [{ "type": "text", "text": format!("Error resolving rules: {}", e) }],
251                    "isError": true
252                }
253            });
254        }
255    };
256
257    let rules: Vec<serde_json::Value> = resolved
258        .iter()
259        .map(|r| {
260            json!({
261                "id": r.id,
262                "type": r.rule_type,
263                "severity": r.severity,
264                "glob": r.glob,
265                "message": r.message,
266            })
267        })
268        .collect();
269
270    let text = serde_json::to_string_pretty(&json!({ "rules": rules })).unwrap();
271
272    json!({
273        "jsonrpc": "2.0",
274        "id": id,
275        "result": {
276            "content": [{ "type": "text", "text": text }]
277        }
278    })
279}
280
281fn format_violations_json(result: &scan::ScanResult) -> serde_json::Value {
282    use crate::config::Severity;
283
284    let violations: Vec<serde_json::Value> = result
285        .violations
286        .iter()
287        .map(|v| {
288            let mut obj = json!({
289                "rule_id": v.rule_id,
290                "severity": match v.severity {
291                    Severity::Error => "error",
292                    Severity::Warning => "warning",
293                },
294                "file": v.file.display().to_string(),
295                "line": v.line,
296                "column": v.column,
297                "message": v.message,
298                "suggest": v.suggest,
299            });
300
301            if let Some(ref fix) = v.fix {
302                obj["fix"] = json!({ "old": fix.old, "new": fix.new });
303            }
304
305            obj
306        })
307        .collect();
308
309    json!({
310        "violations": violations,
311        "summary": {
312            "total": result.violations.len(),
313            "errors": result.violations.iter().filter(|v| v.severity == Severity::Error).count(),
314            "warnings": result.violations.iter().filter(|v| v.severity == Severity::Warning).count(),
315            "files_scanned": result.files_scanned,
316        }
317    })
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::config::Severity;
324    use crate::rules::Violation;
325    use std::collections::HashMap;
326    use std::path::PathBuf;
327
328    #[test]
329    fn initialize_returns_protocol_version() {
330        let resp = handle_initialize(Some(json!(1)));
331        assert_eq!(resp["jsonrpc"], "2.0");
332        assert_eq!(resp["id"], 1);
333        assert_eq!(resp["result"]["protocolVersion"], "2024-11-05");
334        assert_eq!(resp["result"]["serverInfo"]["name"], "baseline");
335    }
336
337    #[test]
338    fn tools_list_returns_both_tools() {
339        let resp = handle_tools_list(Some(json!(2)));
340        assert_eq!(resp["jsonrpc"], "2.0");
341        let tools = resp["result"]["tools"].as_array().unwrap();
342        assert_eq!(tools.len(), 2);
343        assert_eq!(tools[0]["name"], "baseline_scan");
344        assert_eq!(tools[1]["name"], "baseline_list_rules");
345    }
346
347    #[test]
348    fn format_violations_empty() {
349        let result = scan::ScanResult {
350            violations: vec![],
351            files_scanned: 3,
352            rules_loaded: 2,
353            ratchet_counts: HashMap::new(),
354            changed_files_count: None,
355            base_ref: None,
356        };
357        let json = format_violations_json(&result);
358        assert_eq!(json["summary"]["total"], 0);
359        assert_eq!(json["summary"]["files_scanned"], 3);
360        assert!(json["violations"].as_array().unwrap().is_empty());
361    }
362
363    #[test]
364    fn format_violations_with_fix() {
365        let result = scan::ScanResult {
366            violations: vec![Violation {
367                rule_id: "test-rule".into(),
368                severity: Severity::Error,
369                file: PathBuf::from("test.tsx"),
370                line: Some(5),
371                column: Some(10),
372                message: "bad class".into(),
373                suggest: Some("use good class".into()),
374                source_line: None,
375                fix: Some(crate::rules::Fix {
376                    old: "bg-red-500".into(),
377                    new: "bg-destructive".into(),
378                }),
379            }],
380            files_scanned: 1,
381            rules_loaded: 1,
382            ratchet_counts: HashMap::new(),
383            changed_files_count: None,
384            base_ref: None,
385        };
386        let json = format_violations_json(&result);
387        assert_eq!(json["summary"]["total"], 1);
388        assert_eq!(json["summary"]["errors"], 1);
389        let v = &json["violations"][0];
390        assert_eq!(v["rule_id"], "test-rule");
391        assert_eq!(v["fix"]["old"], "bg-red-500");
392        assert_eq!(v["fix"]["new"], "bg-destructive");
393    }
394
395    #[test]
396    fn format_violations_counts_severities() {
397        let result = scan::ScanResult {
398            violations: vec![
399                Violation {
400                    rule_id: "r1".into(),
401                    severity: Severity::Error,
402                    file: PathBuf::from("a.ts"),
403                    line: Some(1),
404                    column: None,
405                    message: "err".into(),
406                    suggest: None,
407                    source_line: None,
408                    fix: None,
409                },
410                Violation {
411                    rule_id: "r2".into(),
412                    severity: Severity::Warning,
413                    file: PathBuf::from("b.ts"),
414                    line: Some(2),
415                    column: None,
416                    message: "warn".into(),
417                    suggest: None,
418                    source_line: None,
419                    fix: None,
420                },
421            ],
422            files_scanned: 2,
423            rules_loaded: 2,
424            ratchet_counts: HashMap::new(),
425            changed_files_count: None,
426            base_ref: None,
427        };
428        let json = format_violations_json(&result);
429        assert_eq!(json["summary"]["errors"], 1);
430        assert_eq!(json["summary"]["warnings"], 1);
431        assert_eq!(json["summary"]["total"], 2);
432    }
433
434    #[test]
435    fn unknown_tool_returns_error() {
436        let resp = handle_tools_call(
437            Some(json!(3)),
438            &json!({ "name": "nonexistent_tool", "arguments": {} }),
439            std::path::Path::new("baseline.toml"),
440        );
441        assert!(resp["error"].is_object());
442        assert_eq!(resp["error"]["code"], -32602);
443    }
444}