Skip to main content

scute_test_utils/
mcp.rs

1use std::path::Path;
2
3use rmcp::{
4    ClientHandler, ErrorData, ServiceExt,
5    model::{
6        CallToolRequestParams, CallToolResult, ClientCapabilities, ClientInfo, Implementation,
7        InitializeRequestParams, ListRootsResult, Root, Tool,
8    },
9    service::{RequestContext, RoleClient, RunningService},
10    transport::TokioChildProcess,
11};
12use tempfile::TempDir;
13use tokio::process::Command;
14
15use crate::{Backend, CheckResult, ExitStatus, ListChecksResult, target_bin};
16
17pub(crate) struct McpBackend;
18
19impl Backend for McpBackend {
20    fn check(&self, dir: TempDir, working_dir: &Path, args: &[&str]) -> CheckResult {
21        let check_name = args.get(1).expect("check name required");
22        let tool_name = format!("check_{}", check_name.replace('-', "_"));
23        let tool_args = build_tool_args(check_name, &args[2..]);
24        let project_dir = working_dir.canonicalize().unwrap();
25
26        let client = McpTestClient::connect(&project_dir);
27        let result = client.call_tool(&tool_name, &tool_args);
28
29        let json = result
30            .structured_content
31            .clone()
32            .expect("structuredContent must be present");
33        let is_error = result.is_error.unwrap_or(false);
34        let exit_status = if !is_error {
35            ExitStatus::Success
36        } else if json.get("error").is_some() {
37            ExitStatus::Error
38        } else {
39            ExitStatus::Failure
40        };
41        let debug_info = format!("{result:?}");
42
43        CheckResult {
44            _dir: dir,
45            json,
46            project_dir,
47            exit_status,
48            debug_info,
49        }
50    }
51
52    fn list_checks(&self, dir: TempDir) -> ListChecksResult {
53        let project_dir = dir.path().canonicalize().unwrap_or(dir.path().into());
54        let client = McpTestClient::connect(&project_dir);
55        let checks = client
56            .list_tools()
57            .iter()
58            .map(|t| {
59                t.name
60                    .strip_prefix("check_")
61                    .expect("tool name starts with check_")
62                    .replace('_', "-")
63            })
64            .collect();
65        ListChecksResult { _dir: dir, checks }
66    }
67}
68
69/// An MCP client connected to a running Scute MCP server.
70///
71/// Wraps rmcp's client with its own tokio runtime so callers don't need async.
72/// Use for protocol-level tests that need direct access beyond the `Scute` harness.
73pub struct McpTestClient {
74    service: RunningService<RoleClient, RootsProvider>,
75    rt: tokio::runtime::Runtime,
76}
77
78impl McpTestClient {
79    pub fn connect(project_root: &std::path::Path) -> Self {
80        let rt = tokio::runtime::Builder::new_current_thread()
81            .enable_all()
82            .build()
83            .unwrap();
84
85        let root: Root = serde_json::from_value(serde_json::json!({
86            "uri": format!("file://{}", project_root.display())
87        }))
88        .expect("valid root");
89        let service = rt
90            .block_on(async {
91                let transport = TokioChildProcess::new({
92                    let mut cmd = Command::new(target_bin("scute"));
93                    cmd.arg("mcp");
94                    cmd.current_dir(project_root);
95                    cmd
96                })
97                .expect("failed to spawn scute mcp");
98
99                RootsProvider(vec![root]).serve(transport).await
100            })
101            .expect("failed to connect to scute mcp");
102
103        Self { service, rt }
104    }
105
106    pub fn call_tool(&self, name: &str, args: &serde_json::Value) -> CallToolResult {
107        self.rt
108            .block_on(
109                self.service.call_tool(
110                    CallToolRequestParams::new(name.to_string())
111                        .with_arguments(args.as_object().unwrap().clone()),
112                ),
113            )
114            .expect("call_tool failed")
115    }
116
117    pub fn list_tools(&self) -> Vec<Tool> {
118        self.rt
119            .block_on(self.service.list_all_tools())
120            .expect("list_all_tools failed")
121    }
122}
123
124/// A [`ClientHandler`] that provides project roots to the server.
125struct RootsProvider(Vec<Root>);
126
127impl ClientHandler for RootsProvider {
128    fn get_info(&self) -> ClientInfo {
129        InitializeRequestParams::new(
130            ClientCapabilities::builder().enable_roots().build(),
131            Implementation::new("scute-test", env!("CARGO_PKG_VERSION")),
132        )
133    }
134
135    fn list_roots(
136        &self,
137        _: RequestContext<RoleClient>,
138    ) -> impl Future<Output = Result<ListRootsResult, ErrorData>> + Send + '_ {
139        let result: ListRootsResult =
140            serde_json::from_value(serde_json::json!({ "roots": self.0 })).expect("valid roots");
141        std::future::ready(Ok(result))
142    }
143}
144
145fn build_tool_args(check_name: &str, args: &[&str]) -> serde_json::Value {
146    match check_name {
147        "commit-message" => {
148            let message = args.first().copied().unwrap_or("");
149            serde_json::json!({ "message": message })
150        }
151        "code-complexity" => positional_paths_args("paths", args),
152        "code-similarity" => source_files_args(args),
153        "dependency-freshness" => match args.first() {
154            Some(path) => serde_json::json!({ "path": path }),
155            None => serde_json::json!({}),
156        },
157        _ => serde_json::json!({}),
158    }
159}
160
161fn positional_paths_args(key: &str, args: &[&str]) -> serde_json::Value {
162    match args.first() {
163        Some(_) => serde_json::json!({ key: args }),
164        None => serde_json::json!({}),
165    }
166}
167
168fn source_files_args(args: &[&str]) -> serde_json::Value {
169    let mut json = serde_json::Map::new();
170    let mut files = Vec::new();
171    let mut i = 0;
172    while i < args.len() {
173        if args[i] == "--source-dir"
174            && let Some(val) = args.get(i + 1)
175        {
176            json.insert("source_dir".into(), serde_json::json!(val));
177            i += 2;
178            continue;
179        }
180        files.push(args[i]);
181        i += 1;
182    }
183    if !files.is_empty() {
184        json.insert("files".into(), serde_json::json!(files));
185    }
186    serde_json::Value::Object(json)
187}