Skip to main content

purple_ssh/
mcp.rs

1use std::io::{BufRead, Write};
2use std::path::Path;
3
4use log::{debug, error, info, warn};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::ssh_config::model::{SshConfigFile, is_host_pattern};
9
10/// A JSON-RPC 2.0 request.
11#[derive(Debug, Deserialize)]
12pub struct JsonRpcRequest {
13    #[allow(dead_code)]
14    pub jsonrpc: String,
15    #[serde(default)]
16    pub id: Option<Value>,
17    pub method: String,
18    #[serde(default)]
19    pub params: Option<Value>,
20}
21
22/// A JSON-RPC 2.0 response.
23#[derive(Debug, Serialize)]
24pub struct JsonRpcResponse {
25    pub jsonrpc: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub id: Option<Value>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub result: Option<Value>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub error: Option<JsonRpcError>,
32}
33
34/// A JSON-RPC 2.0 error object.
35#[derive(Debug, Serialize)]
36pub struct JsonRpcError {
37    pub code: i64,
38    pub message: String,
39}
40
41impl JsonRpcResponse {
42    fn success(id: Option<Value>, result: Value) -> Self {
43        Self {
44            jsonrpc: "2.0".to_string(),
45            id,
46            result: Some(result),
47            error: None,
48        }
49    }
50
51    fn error(id: Option<Value>, code: i64, message: String) -> Self {
52        Self {
53            jsonrpc: "2.0".to_string(),
54            id,
55            result: None,
56            error: Some(JsonRpcError { code, message }),
57        }
58    }
59}
60
61/// Helper to build an MCP tool result (success).
62fn mcp_tool_result(text: &str) -> Value {
63    serde_json::json!({
64        "content": [{"type": "text", "text": text}]
65    })
66}
67
68/// Helper to build an MCP tool error result.
69fn mcp_tool_error(text: &str) -> Value {
70    serde_json::json!({
71        "content": [{"type": "text", "text": text}],
72        "isError": true
73    })
74}
75
76/// Verify that an alias exists in the SSH config. Returns error Value if not found.
77fn verify_alias_exists(alias: &str, config_path: &Path) -> Result<(), Value> {
78    let config = match SshConfigFile::parse(config_path) {
79        Ok(c) => c,
80        Err(e) => return Err(mcp_tool_error(&format!("Failed to parse SSH config: {e}"))),
81    };
82    let exists = config.host_entries().iter().any(|h| h.alias == alias);
83    if !exists {
84        return Err(mcp_tool_error(&format!("Host not found: {alias}")));
85    }
86    Ok(())
87}
88
89/// Run an SSH command with a timeout. Returns (exit_code, stdout, stderr).
90fn ssh_exec(
91    alias: &str,
92    config_path: &Path,
93    command: &str,
94    timeout_secs: u64,
95) -> Result<(i32, String, String), Value> {
96    let config_str = config_path.to_string_lossy();
97    let mut child = match std::process::Command::new("ssh")
98        .args([
99            "-F",
100            &config_str,
101            "-o",
102            "ConnectTimeout=10",
103            "-o",
104            "BatchMode=yes",
105            "--",
106            alias,
107            command,
108        ])
109        .stdin(std::process::Stdio::null())
110        .stdout(std::process::Stdio::piped())
111        .stderr(std::process::Stdio::piped())
112        .spawn()
113    {
114        Ok(c) => c,
115        Err(e) => return Err(mcp_tool_error(&format!("Failed to spawn ssh: {e}"))),
116    };
117
118    let timeout = std::time::Duration::from_secs(timeout_secs);
119    let start = std::time::Instant::now();
120    loop {
121        match child.try_wait() {
122            Ok(Some(status)) => {
123                let stdout = child
124                    .stdout
125                    .take()
126                    .map(|mut s| {
127                        let mut buf = String::new();
128                        if let Err(e) = std::io::Read::read_to_string(&mut s, &mut buf) {
129                            warn!("[external] Failed to read SSH stdout pipe: {e}");
130                        }
131                        buf
132                    })
133                    .unwrap_or_default();
134                let stderr = child
135                    .stderr
136                    .take()
137                    .map(|mut s| {
138                        let mut buf = String::new();
139                        if let Err(e) = std::io::Read::read_to_string(&mut s, &mut buf) {
140                            warn!("[external] Failed to read SSH stderr pipe: {e}");
141                        }
142                        buf
143                    })
144                    .unwrap_or_default();
145                return Ok((status.code().unwrap_or(-1), stdout, stderr));
146            }
147            Ok(None) => {
148                if start.elapsed() > timeout {
149                    if let Err(e) = child.kill() {
150                        warn!("[external] Failed to kill timed-out SSH process: {e}");
151                    }
152                    let _ = child.wait();
153                    warn!("[external] MCP SSH command timed out after {timeout_secs}s");
154                    return Err(mcp_tool_error(&format!(
155                        "SSH command timed out after {timeout_secs} seconds"
156                    )));
157                }
158                std::thread::sleep(std::time::Duration::from_millis(50));
159            }
160            Err(e) => return Err(mcp_tool_error(&format!("Failed to wait for ssh: {e}"))),
161        }
162    }
163}
164
165/// Dispatch a JSON-RPC method to the appropriate handler.
166pub(crate) fn dispatch(method: &str, params: Option<Value>, config_path: &Path) -> JsonRpcResponse {
167    match method {
168        "initialize" => handle_initialize(),
169        "tools/list" => handle_tools_list(),
170        "tools/call" => handle_tools_call(params, config_path),
171        _ => JsonRpcResponse::error(None, -32601, format!("Method not found: {method}")),
172    }
173}
174
175fn handle_initialize() -> JsonRpcResponse {
176    JsonRpcResponse::success(
177        None,
178        serde_json::json!({
179            "protocolVersion": "2024-11-05",
180            "capabilities": {
181                "tools": {}
182            },
183            "serverInfo": {
184                "name": "purple",
185                "version": env!("CARGO_PKG_VERSION")
186            }
187        }),
188    )
189}
190
191fn handle_tools_list() -> JsonRpcResponse {
192    let tools = serde_json::json!({
193        "tools": [
194            {
195                "name": "list_hosts",
196                "description": "List all SSH hosts available to connect to. Returns alias, hostname, user, port, tags and provider for each host. Use the tag parameter to filter by tag, provider tag or provider name (fuzzy match). Call this first to discover available hosts.",
197                "inputSchema": {
198                    "type": "object",
199                    "properties": {
200                        "tag": {
201                            "type": "string",
202                            "description": "Filter hosts by tag (fuzzy match against tags, provider_tags and provider name)"
203                        }
204                    }
205                }
206            },
207            {
208                "name": "get_host",
209                "description": "Get detailed information for a single SSH host including identity file, proxy jump, provider metadata, password source and tunnel count.",
210                "inputSchema": {
211                    "type": "object",
212                    "properties": {
213                        "alias": {
214                            "type": "string",
215                            "description": "The host alias to look up"
216                        }
217                    },
218                    "required": ["alias"]
219                }
220            },
221            {
222                "name": "run_command",
223                "description": "Run a shell command on a remote host via SSH. Non-interactive (BatchMode). Returns exit code, stdout and stderr. Suitable for diagnostic commands, not interactive programs.",
224                "inputSchema": {
225                    "type": "object",
226                    "properties": {
227                        "alias": {
228                            "type": "string",
229                            "description": "The host alias to connect to"
230                        },
231                        "command": {
232                            "type": "string",
233                            "description": "The command to execute"
234                        },
235                        "timeout": {
236                            "type": "integer",
237                            "description": "Timeout in seconds (default 30)",
238                            "default": 30,
239                            "minimum": 1,
240                            "maximum": 300
241                        }
242                    },
243                    "required": ["alias", "command"]
244                }
245            },
246            {
247                "name": "list_containers",
248                "description": "List all Docker or Podman containers on a remote host via SSH. Auto-detects the container runtime. Returns container ID, name, image, state, status and ports.",
249                "inputSchema": {
250                    "type": "object",
251                    "properties": {
252                        "alias": {
253                            "type": "string",
254                            "description": "The host alias to list containers for"
255                        }
256                    },
257                    "required": ["alias"]
258                }
259            },
260            {
261                "name": "container_action",
262                "description": "Start, stop or restart a Docker or Podman container on a remote host via SSH. Auto-detects the container runtime.",
263                "inputSchema": {
264                    "type": "object",
265                    "properties": {
266                        "alias": {
267                            "type": "string",
268                            "description": "The host alias"
269                        },
270                        "container_id": {
271                            "type": "string",
272                            "description": "The container ID or name"
273                        },
274                        "action": {
275                            "type": "string",
276                            "description": "The action to perform",
277                            "enum": ["start", "stop", "restart"]
278                        }
279                    },
280                    "required": ["alias", "container_id", "action"]
281                }
282            }
283        ]
284    });
285    JsonRpcResponse::success(None, tools)
286}
287
288fn handle_tools_call(params: Option<Value>, config_path: &Path) -> JsonRpcResponse {
289    let params = match params {
290        Some(p) => p,
291        None => {
292            return JsonRpcResponse::error(
293                None,
294                -32602,
295                "Invalid params: missing params object".to_string(),
296            );
297        }
298    };
299
300    let tool_name = match params.get("name").and_then(|n| n.as_str()) {
301        Some(n) => n,
302        None => {
303            return JsonRpcResponse::error(
304                None,
305                -32602,
306                "Invalid params: missing tool name".to_string(),
307            );
308        }
309    };
310
311    let args = params
312        .get("arguments")
313        .cloned()
314        .unwrap_or(serde_json::json!({}));
315
316    let result = match tool_name {
317        "list_hosts" => tool_list_hosts(&args, config_path),
318        "get_host" => tool_get_host(&args, config_path),
319        "run_command" => tool_run_command(&args, config_path),
320        "list_containers" => tool_list_containers(&args, config_path),
321        "container_action" => tool_container_action(&args, config_path),
322        _ => mcp_tool_error(&format!("Unknown tool: {tool_name}")),
323    };
324
325    JsonRpcResponse::success(None, result)
326}
327
328fn tool_list_hosts(args: &Value, config_path: &Path) -> Value {
329    let config = match SshConfigFile::parse(config_path) {
330        Ok(c) => c,
331        Err(e) => return mcp_tool_error(&format!("Failed to parse SSH config: {e}")),
332    };
333
334    let entries = config.host_entries();
335    let tag_filter = args.get("tag").and_then(|t| t.as_str());
336
337    let hosts: Vec<Value> = entries
338        .iter()
339        .filter(|entry| {
340            // Skip host patterns (already filtered by host_entries, but be safe)
341            if is_host_pattern(&entry.alias) {
342                return false;
343            }
344
345            // Apply tag filter (fuzzy: substring match on tags, provider_tags, provider name)
346            if let Some(tag) = tag_filter {
347                let tag_lower = tag.to_lowercase();
348                let matches_tags = entry
349                    .tags
350                    .iter()
351                    .any(|t| t.to_lowercase().contains(&tag_lower));
352                let matches_provider_tags = entry
353                    .provider_tags
354                    .iter()
355                    .any(|t| t.to_lowercase().contains(&tag_lower));
356                let matches_provider = entry
357                    .provider
358                    .as_ref()
359                    .is_some_and(|p| p.to_lowercase().contains(&tag_lower));
360                if !matches_tags && !matches_provider_tags && !matches_provider {
361                    return false;
362                }
363            }
364
365            true
366        })
367        .map(|entry| {
368            serde_json::json!({
369                "alias": entry.alias,
370                "hostname": entry.hostname,
371                "user": entry.user,
372                "port": entry.port,
373                "tags": entry.tags,
374                "provider": entry.provider,
375                "stale": entry.stale.is_some(),
376            })
377        })
378        .collect();
379
380    let json_str = serde_json::to_string_pretty(&hosts).unwrap_or_default();
381    mcp_tool_result(&json_str)
382}
383
384fn tool_get_host(args: &Value, config_path: &Path) -> Value {
385    let alias = match args.get("alias").and_then(|a| a.as_str()) {
386        Some(a) => a,
387        None => return mcp_tool_error("Missing required parameter: alias"),
388    };
389
390    let config = match SshConfigFile::parse(config_path) {
391        Ok(c) => c,
392        Err(e) => return mcp_tool_error(&format!("Failed to parse SSH config: {e}")),
393    };
394
395    let entries = config.host_entries();
396    let entry = entries.iter().find(|e| e.alias == alias);
397
398    match entry {
399        Some(entry) => {
400            let meta: serde_json::Map<String, Value> = entry
401                .provider_meta
402                .iter()
403                .map(|(k, v)| (k.clone(), Value::String(v.clone())))
404                .collect();
405
406            let host = serde_json::json!({
407                "alias": entry.alias,
408                "hostname": entry.hostname,
409                "user": entry.user,
410                "port": entry.port,
411                "identity_file": entry.identity_file,
412                "proxy_jump": entry.proxy_jump,
413                "tags": entry.tags,
414                "provider_tags": entry.provider_tags,
415                "provider": entry.provider,
416                "provider_meta": meta,
417                "askpass": entry.askpass,
418                "tunnel_count": entry.tunnel_count,
419                "stale": entry.stale.is_some(),
420            });
421
422            let json_str = serde_json::to_string_pretty(&host).unwrap_or_default();
423            mcp_tool_result(&json_str)
424        }
425        None => mcp_tool_error(&format!("Host not found: {alias}")),
426    }
427}
428
429fn tool_run_command(args: &Value, config_path: &Path) -> Value {
430    let alias = match args.get("alias").and_then(|a| a.as_str()) {
431        Some(a) if !a.is_empty() => a,
432        _ => return mcp_tool_error("Missing required parameter: alias"),
433    };
434    let command = match args.get("command").and_then(|c| c.as_str()) {
435        Some(c) if !c.is_empty() => c,
436        _ => return mcp_tool_error("Missing required parameter: command"),
437    };
438    let timeout_secs = args.get("timeout").and_then(|t| t.as_u64()).unwrap_or(30);
439
440    if let Err(e) = verify_alias_exists(alias, config_path) {
441        return e;
442    }
443
444    info!("MCP tool: ssh_exec alias={alias} command={command}");
445    match ssh_exec(alias, config_path, command, timeout_secs) {
446        Ok((exit_code, stdout, stderr)) => {
447            if exit_code != 0 {
448                error!("[external] MCP ssh_exec failed: alias={alias} exit={exit_code}");
449            }
450            let result = serde_json::json!({
451                "exit_code": exit_code,
452                "stdout": stdout,
453                "stderr": stderr
454            });
455            let json_str = serde_json::to_string_pretty(&result).unwrap_or_default();
456            mcp_tool_result(&json_str)
457        }
458        Err(e) => e,
459    }
460}
461
462fn tool_list_containers(args: &Value, config_path: &Path) -> Value {
463    let alias = match args.get("alias").and_then(|a| a.as_str()) {
464        Some(a) if !a.is_empty() => a,
465        _ => return mcp_tool_error("Missing required parameter: alias"),
466    };
467
468    if let Err(e) = verify_alias_exists(alias, config_path) {
469        return e;
470    }
471
472    // Build the combined detection + listing command
473    let command = crate::containers::container_list_command(None);
474
475    let (exit_code, stdout, stderr) = match ssh_exec(alias, config_path, &command, 30) {
476        Ok(r) => r,
477        Err(e) => return e,
478    };
479
480    if exit_code != 0 {
481        return mcp_tool_error(&format!("SSH command failed: {}", stderr.trim()));
482    }
483
484    match crate::containers::parse_container_output(&stdout, None) {
485        Ok((runtime, containers)) => {
486            let containers_json: Vec<Value> = containers
487                .iter()
488                .map(|c| {
489                    serde_json::json!({
490                        "id": c.id,
491                        "name": c.names,
492                        "image": c.image,
493                        "state": c.state,
494                        "status": c.status,
495                        "ports": c.ports,
496                    })
497                })
498                .collect();
499            let result = serde_json::json!({
500                "runtime": runtime.as_str(),
501                "containers": containers_json,
502            });
503            let json_str = serde_json::to_string_pretty(&result).unwrap_or_default();
504            mcp_tool_result(&json_str)
505        }
506        Err(e) => mcp_tool_error(&e),
507    }
508}
509
510fn tool_container_action(args: &Value, config_path: &Path) -> Value {
511    let alias = match args.get("alias").and_then(|a| a.as_str()) {
512        Some(a) if !a.is_empty() => a,
513        _ => return mcp_tool_error("Missing required parameter: alias"),
514    };
515    let container_id = match args.get("container_id").and_then(|c| c.as_str()) {
516        Some(c) if !c.is_empty() => c,
517        _ => return mcp_tool_error("Missing required parameter: container_id"),
518    };
519    let action_str = match args.get("action").and_then(|a| a.as_str()) {
520        Some(a) => a,
521        None => return mcp_tool_error("Missing required parameter: action"),
522    };
523
524    // Validate container ID (injection prevention)
525    if let Err(e) = crate::containers::validate_container_id(container_id) {
526        return mcp_tool_error(&e);
527    }
528
529    let action = match action_str {
530        "start" => crate::containers::ContainerAction::Start,
531        "stop" => crate::containers::ContainerAction::Stop,
532        "restart" => crate::containers::ContainerAction::Restart,
533        _ => {
534            return mcp_tool_error(&format!(
535                "Invalid action: {action_str}. Must be start, stop or restart"
536            ));
537        }
538    };
539
540    if let Err(e) = verify_alias_exists(alias, config_path) {
541        return e;
542    }
543
544    // First detect runtime
545    let detect_cmd = crate::containers::container_list_command(None);
546
547    let (detect_exit, detect_stdout, _detect_stderr) =
548        match ssh_exec(alias, config_path, &detect_cmd, 30) {
549            Ok(r) => r,
550            Err(e) => return e,
551        };
552
553    if detect_exit != 0 {
554        return mcp_tool_error("Failed to detect container runtime");
555    }
556
557    let runtime = match crate::containers::parse_container_output(&detect_stdout, None) {
558        Ok((rt, _)) => rt,
559        Err(e) => return mcp_tool_error(&format!("Failed to detect container runtime: {e}")),
560    };
561
562    let action_command = crate::containers::container_action_command(runtime, action, container_id);
563
564    let (action_exit, _action_stdout, action_stderr) =
565        match ssh_exec(alias, config_path, &action_command, 30) {
566            Ok(r) => r,
567            Err(e) => return e,
568        };
569
570    if action_exit == 0 {
571        let result = serde_json::json!({
572            "success": true,
573            "message": format!("Container {container_id} {}ed", action_str),
574        });
575        let json_str = serde_json::to_string_pretty(&result).unwrap_or_default();
576        mcp_tool_result(&json_str)
577    } else {
578        mcp_tool_error(&format!(
579            "Container action failed: {}",
580            action_stderr.trim()
581        ))
582    }
583}
584
585/// Run the MCP server, reading JSON-RPC requests from stdin and writing
586/// responses to stdout. Blocks until stdin is closed.
587pub fn run(config_path: &Path) -> anyhow::Result<()> {
588    let stdin = std::io::stdin();
589    let stdout = std::io::stdout();
590    let reader = stdin.lock();
591    let mut writer = stdout.lock();
592
593    for line in reader.lines() {
594        let line = match line {
595            Ok(l) => l,
596            Err(_) => break,
597        };
598        let trimmed = line.trim();
599        if trimmed.is_empty() {
600            continue;
601        }
602
603        let request: JsonRpcRequest = match serde_json::from_str(trimmed) {
604            Ok(r) => r,
605            Err(_) => {
606                let resp = JsonRpcResponse::error(None, -32700, "Parse error".to_string());
607                let json = serde_json::to_string(&resp)?;
608                writeln!(writer, "{json}")?;
609                writer.flush()?;
610                continue;
611            }
612        };
613
614        // Notifications (no id) don't get responses
615        if request.id.is_none() {
616            debug!("MCP notification: {}", request.method);
617            continue;
618        }
619
620        debug!("MCP request: method={}", request.method);
621        let mut response = dispatch(&request.method, request.params, config_path);
622        debug!(
623            "MCP response: method={} success={}",
624            request.method,
625            response.error.is_none()
626        );
627        response.id = request.id;
628
629        let json = serde_json::to_string(&response)?;
630        writeln!(writer, "{json}")?;
631        writer.flush()?;
632    }
633
634    Ok(())
635}
636
637#[cfg(test)]
638#[path = "mcp_tests.rs"]
639mod tests;