project_map_cli_rust/mcp/
server.rs1use std::path::Path;
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use rust_mcp_sdk::{McpServer as SdkMcpServer, TransportOptions, StdioTransport, ToMcpServerHandler};
7use rust_mcp_sdk::mcp_server::{server_runtime, McpServerOptions, ServerHandler};
8use rust_mcp_sdk::schema::{
9 CallToolRequestParams, CallToolResult, InitializeResult,
10 PaginatedRequestParams, ListToolsResult, ServerCapabilities, ServerCapabilitiesTools,
11 Implementation, ProtocolVersion, RpcError,
12};
13use rust_mcp_sdk::schema::schema_utils::CallToolError;
14use rust_mcp_sdk::macros::{mcp_tool, JsonSchema};
15use tracing_subscriber::fmt;
16
17use crate::error::Result;
18use crate::core::query_engine::QueryEngine;
19use crate::core::orchestrator::Orchestrator;
20use crate::core::toon::ToonFormatter;
21
22#[mcp_tool(
25 name = "pm_status",
26 description = "Returns current workspace context and available commands."
27)]
28#[derive(JsonSchema, Deserialize, Serialize)]
29pub struct PmStatusTool {}
30
31#[mcp_tool(
32 name = "pm_query",
33 description = "Search for symbols or get file context."
34)]
35#[derive(JsonSchema, Deserialize, Serialize)]
36pub struct PmQueryTool {
37 pub query: Option<String>,
39 pub path: Option<String>,
41}
42
43#[mcp_tool(
44 name = "pm_check_blast_radius",
45 description = "Identifies all components and files that depend on or import a specific symbol."
46)]
47#[derive(JsonSchema, Deserialize, Serialize)]
48pub struct PmCheckBlastRadiusTool {
49 pub path: String,
51 pub symbol: String,
53}
54
55#[mcp_tool(
56 name = "pm_plan",
57 description = "Analyze the architectural impact (fan-out) of a symbol before starting a refactor."
58)]
59#[derive(JsonSchema, Deserialize, Serialize)]
60pub struct PmPlanTool {
61 pub symbol: String,
63}
64
65#[mcp_tool(
66 name = "pm_semantic_search",
67 description = "Search for logic using natural language keywords (e.g., 'auth', 'database')."
68)]
69#[derive(JsonSchema, Deserialize, Serialize)]
70pub struct PmSemanticSearchTool {
71 pub query: String,
73}
74
75#[mcp_tool(
76 name = "pm_fetch_symbol",
77 description = "Extract raw source code for a specific class or function."
78)]
79#[derive(JsonSchema, Deserialize, Serialize)]
80pub struct PmFetchSymbolTool {
81 pub path: String,
83 pub symbol: String,
85}
86
87#[mcp_tool(
88 name = "pm_init",
89 description = "Refresh the map index after significant code changes to maintain discovery accuracy."
90)]
91#[derive(JsonSchema, Deserialize, Serialize)]
92pub struct PmInitTool {}
93
94pub struct McpServer {
97 engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
98}
99
100impl McpServer {
101 pub fn new() -> Self {
102 let engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
103 Self {
104 engine: Arc::new(std::sync::RwLock::new(engine)),
105 }
106 }
107
108 pub async fn run(&self) -> Result<()> {
109 let _ = fmt()
110 .with_writer(std::io::stderr)
111 .try_init();
112
113 let server_info = InitializeResult {
114 protocol_version: ProtocolVersion::V2025_11_25.into(),
115 capabilities: ServerCapabilities {
116 tools: Some(ServerCapabilitiesTools { list_changed: None }),
117 ..Default::default()
118 },
119 server_info: Implementation {
120 name: "project-map-cli-rust".to_string(),
121 version: env!("CARGO_PKG_VERSION").to_string(),
122 title: Some("Project Map CLI".to_string()),
123 description: None,
124 icons: vec![],
125 website_url: None,
126 },
127 instructions: None,
128 meta: None,
129 };
130
131 let transport = StdioTransport::new(TransportOptions::default())
132 .map_err(|e| crate::error::AppError::Generic(format!("Transport error: {}", e)))?;
133 let handler = self.clone_for_handler();
134
135 let options = McpServerOptions {
136 server_details: server_info,
137 transport,
138 handler: handler.to_mcp_server_handler(),
139 task_store: None,
140 client_task_store: None,
141 message_observer: None,
142 };
143
144 let server = server_runtime::create_server(options);
145 server.start().await.map_err(|e| crate::error::AppError::Generic(format!("Server error: {}", e)))?;
146
147 Ok(())
148 }
149
150 fn clone_for_handler(&self) -> McpServerHandler {
151 McpServerHandler {
152 engine: Arc::clone(&self.engine),
153 }
154 }
155}
156
157pub struct McpServerHandler {
158 engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
159}
160
161#[async_trait]
162impl ServerHandler for McpServerHandler {
163 async fn handle_list_tools_request(
164 &self,
165 _params: Option<PaginatedRequestParams>,
166 _runtime: Arc<dyn SdkMcpServer>,
167 ) -> std::result::Result<ListToolsResult, RpcError> {
168 Ok(ListToolsResult {
169 tools: vec![
170 PmStatusTool::tool(),
171 PmQueryTool::tool(),
172 PmCheckBlastRadiusTool::tool(),
173 PmPlanTool::tool(),
174 PmSemanticSearchTool::tool(),
175 PmFetchSymbolTool::tool(),
176 PmInitTool::tool(),
177 ],
178 next_cursor: None,
179 meta: None,
180 })
181 }
182
183 async fn handle_call_tool_request(
184 &self,
185 params: CallToolRequestParams,
186 _runtime: Arc<dyn SdkMcpServer>,
187 ) -> std::result::Result<CallToolResult, CallToolError> {
188 let arguments = serde_json::Value::Object(params.arguments.unwrap_or_default());
189 let text = match params.name.as_str() {
190 "pm_status" => {
191 let is_ready = self.engine.read().unwrap().is_some();
192 ToonFormatter::format_status(is_ready, Some(".project-map/latest/.project-map.json"))
193 }
194 "pm_query" => {
195 let args: PmQueryTool = serde_json::from_value(arguments)
196 .map_err(|e| CallToolError(Box::new(e)))?;
197
198 if let Some(ref engine) = *self.engine.read().unwrap() {
199 if let Some(q) = args.query {
200 let matches = engine.find_symbols(&q);
201 ToonFormatter::format_symbols(&q, &matches)
202 } else if let Some(p) = args.path {
203 let symbols = engine.get_file_outline(&p);
204 ToonFormatter::format_file_context(&p, &symbols)
205 } else {
206 "Error: Provide query or path".to_string()
207 }
208 } else {
209 "Error: Index not loaded".to_string()
210 }
211 }
212 "pm_check_blast_radius" => {
213 let args: PmCheckBlastRadiusTool = serde_json::from_value(arguments)
214 .map_err(|e| CallToolError(Box::new(e)))?;
215
216 if let Some(ref engine) = *self.engine.read().unwrap() {
217 let results = engine.check_blast_radius(&args.path, &args.symbol);
218 ToonFormatter::format_blast_radius(&args.path, &args.symbol, &results)
219 } else {
220 "Error: Index not loaded".to_string()
221 }
222 }
223 "pm_plan" => {
224 let args: PmPlanTool = serde_json::from_value(arguments)
225 .map_err(|e| CallToolError(Box::new(e)))?;
226
227 if let Some(ref engine) = *self.engine.read().unwrap() {
228 let impact = engine.analyze_impact(&args.symbol);
229 ToonFormatter::format_impact_analysis(&args.symbol, &impact)
230 } else {
231 "Error: Index not loaded".to_string()
232 }
233 }
234 "pm_semantic_search" => {
235 let args: PmSemanticSearchTool = serde_json::from_value(arguments)
236 .map_err(|e| CallToolError(Box::new(e)))?;
237
238 if let Some(ref engine) = *self.engine.read().unwrap() {
239 let matches = engine.find_symbols(&args.query);
240 ToonFormatter::format_symbols(&args.query, &matches)
241 } else {
242 "Error: Index not loaded".to_string()
243 }
244 }
245 "pm_fetch_symbol" => {
246 let args: PmFetchSymbolTool = serde_json::from_value(arguments)
247 .map_err(|e| CallToolError(Box::new(e)))?;
248
249 if let Some(ref engine) = *self.engine.read().unwrap() {
250 if let Some(node) = engine.find_symbol_in_path(&args.path, &args.symbol) {
251 if let Ok(content) = std::fs::read_to_string(&node.path) {
252 let bytes = content.as_bytes();
253 if node.start_byte < bytes.len() && node.end_byte <= bytes.len() {
254 let content_str = String::from_utf8_lossy(&bytes[node.start_byte..node.end_byte]);
255 ToonFormatter::format_fetch_result(&args.path, &args.symbol, Some(&content_str))
256 } else {
257 "Error: Byte range out of bounds".to_string()
258 }
259 } else {
260 "Error: Could not read file".to_string()
261 }
262 } else {
263 ToonFormatter::format_fetch_result(&args.path, &args.symbol, None)
264 }
265 } else {
266 "Error: Index not loaded".to_string()
267 }
268 }
269 "pm_init" => {
270 let mut orch = Orchestrator::new();
271 if orch.build_index(Path::new(".")).is_ok() && orch.save_index_versioned(Path::new(".project-map")).is_ok() {
272 let new_engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
273 *self.engine.write().unwrap() = new_engine;
274 "Index refreshed successfully.".to_string()
275 } else {
276 "Failed to refresh index.".to_string()
277 }
278 }
279
280 _ => return Err(CallToolError(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Unknown tool")))),
281 };
282
283 Ok(CallToolResult::text_content(vec![text.into()]))
284 }
285}