project_map_cli_rust/mcp/
server.rs1use std::path::Path;
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use rust_mcp_sdk::{McpServer as SdkMcpServer, TransportOptions, StdioTransport, ToMcpServerHandler};
7use rust_mcp_sdk::mcp_server::{server_runtime, McpServerOptions, ServerHandler};
8use rust_mcp_sdk::schema::{
9 CallToolRequestParams, CallToolResult, InitializeResult,
10 PaginatedRequestParams, ListToolsResult, ServerCapabilities, ServerCapabilitiesTools,
11 Implementation, ProtocolVersion, RpcError,
12};
13use rust_mcp_sdk::schema::schema_utils::CallToolError;
14use rust_mcp_sdk::macros::{mcp_tool, JsonSchema};
15use tracing_subscriber::fmt;
16
17use crate::error::Result;
18use crate::core::query_engine::QueryEngine;
19use crate::core::orchestrator::Orchestrator;
20
21#[mcp_tool(
24 name = "pm_status",
25 description = "Returns current workspace context and available commands."
26)]
27#[derive(JsonSchema, Deserialize, Serialize)]
28pub struct PmStatusTool {}
29
30#[mcp_tool(
31 name = "pm_query",
32 description = "Search for symbols or get file context."
33)]
34#[derive(JsonSchema, Deserialize, Serialize)]
35pub struct PmQueryTool {
36 pub query: Option<String>,
38 pub path: Option<String>,
40}
41
42#[mcp_tool(
43 name = "pm_check_blast_radius",
44 description = "Identifies all components and files that depend on or import a specific symbol."
45)]
46#[derive(JsonSchema, Deserialize, Serialize)]
47pub struct PmCheckBlastRadiusTool {
48 pub path: String,
50 pub symbol: String,
52}
53
54#[mcp_tool(
55 name = "pm_plan",
56 description = "Analyze the architectural impact (fan-out) of a symbol before starting a refactor."
57)]
58#[derive(JsonSchema, Deserialize, Serialize)]
59pub struct PmPlanTool {
60 pub symbol: String,
62}
63
64#[mcp_tool(
65 name = "pm_semantic_search",
66 description = "Search for logic using natural language keywords (e.g., 'auth', 'database')."
67)]
68#[derive(JsonSchema, Deserialize, Serialize)]
69pub struct PmSemanticSearchTool {
70 pub query: String,
72}
73
74#[mcp_tool(
75 name = "pm_fetch_symbol",
76 description = "Extract raw source code for a specific class or function."
77)]
78#[derive(JsonSchema, Deserialize, Serialize)]
79pub struct PmFetchSymbolTool {
80 pub path: String,
82 pub symbol: String,
84}
85
86#[mcp_tool(
87 name = "pm_init",
88 description = "Refresh the map index after significant code changes to maintain discovery accuracy."
89)]
90#[derive(JsonSchema, Deserialize, Serialize)]
91pub struct PmInitTool {}
92
93pub struct McpServer {
96 engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
97}
98
99impl McpServer {
100 pub fn new() -> Self {
101 let engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
102 Self {
103 engine: Arc::new(std::sync::RwLock::new(engine)),
104 }
105 }
106
107 pub async fn run(&self) -> Result<()> {
108 let _ = fmt()
109 .with_writer(std::io::stderr)
110 .try_init();
111
112 let server_info = InitializeResult {
113 protocol_version: ProtocolVersion::V2024_11_05.to_string(),
114 capabilities: ServerCapabilities {
115 tools: Some(ServerCapabilitiesTools { list_changed: None }),
116 ..Default::default()
117 },
118 server_info: Implementation {
119 name: "project-map-cli-rust".to_string(),
120 version: env!("CARGO_PKG_VERSION").to_string(),
121 title: Some("Project Map CLI".to_string()),
122 description: None,
123 icons: vec![],
124 website_url: None,
125 },
126 instructions: None,
127 meta: None,
128 };
129
130 let transport = StdioTransport::new(TransportOptions::default())
131 .map_err(|e| crate::error::AppError::Generic(format!("Transport error: {}", e)))?;
132 let handler = self.clone_for_handler();
133
134 let options = McpServerOptions {
135 server_details: server_info,
136 transport,
137 handler: handler.to_mcp_server_handler(),
138 task_store: None,
139 client_task_store: None,
140 message_observer: None,
141 };
142
143 let server = server_runtime::create_server(options);
144 server.start().await.map_err(|e| crate::error::AppError::Generic(format!("Server error: {}", e)))?;
145
146 Ok(())
147 }
148
149 fn clone_for_handler(&self) -> McpServerHandler {
150 McpServerHandler {
151 engine: Arc::clone(&self.engine),
152 }
153 }
154}
155
156pub struct McpServerHandler {
157 engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
158}
159
160#[async_trait]
161impl ServerHandler for McpServerHandler {
162 async fn handle_list_tools_request(
163 &self,
164 _params: Option<PaginatedRequestParams>,
165 _runtime: Arc<dyn SdkMcpServer>,
166 ) -> std::result::Result<ListToolsResult, RpcError> {
167 Ok(ListToolsResult {
168 tools: vec![
169 PmStatusTool::tool(),
170 PmQueryTool::tool(),
171 PmCheckBlastRadiusTool::tool(),
172 PmPlanTool::tool(),
173 PmSemanticSearchTool::tool(),
174 PmFetchSymbolTool::tool(),
175 PmInitTool::tool(),
176 ],
177 next_cursor: None,
178 meta: None,
179 })
180 }
181
182 async fn handle_call_tool_request(
183 &self,
184 params: CallToolRequestParams,
185 _runtime: Arc<dyn SdkMcpServer>,
186 ) -> std::result::Result<CallToolResult, CallToolError> {
187 let arguments = serde_json::Value::Object(params.arguments.unwrap_or_default());
188 let text = match params.name.as_str() {
189 "pm_status" => {
190 if self.engine.read().unwrap().is_some() {
191 "Status: System healthy. Index is present.".to_string()
192 } else {
193 "Status: Index missing. Run project-map build.".to_string()
194 }
195 }
196 "pm_query" => {
197 let args: PmQueryTool = serde_json::from_value(arguments)
198 .map_err(|e| CallToolError(Box::new(e)))?;
199
200 if let Some(ref engine) = *self.engine.read().unwrap() {
201 if let Some(q) = args.query {
202 let matches = engine.find_symbols(&q);
203 format!("Matches: {}", matches.len())
204 } else if let Some(p) = args.path {
205 let symbols = engine.get_file_outline(&p);
206 format!("Symbols in {}: {}", p, symbols.len())
207 } else {
208 "Error: Provide query or path".to_string()
209 }
210 } else {
211 "Error: Index not loaded".to_string()
212 }
213 }
214 "pm_check_blast_radius" => {
215 let args: PmCheckBlastRadiusTool = serde_json::from_value(arguments)
216 .map_err(|e| CallToolError(Box::new(e)))?;
217
218 if let Some(ref engine) = *self.engine.read().unwrap() {
219 let results = engine.check_blast_radius(&args.path, &args.symbol);
220
221 if results.is_empty() {
222 "No dependent components found.".to_string()
223 } else {
224 let mut unique_files = std::collections::HashSet::new();
225 for r in &results { unique_files.insert(&r.path); }
226 format!("Blast Radius for {}:\n- Total Impacted Nodes: {}\n- Unique Files: {}\n(Top 5: {})",
227 args.symbol, results.len(), unique_files.len(),
228 results.iter().take(5).map(|r| r.name.as_str()).collect::<Vec<_>>().join(", "))
229 }
230 } else {
231 "Error: Index not loaded".to_string()
232 }
233 }
234 "pm_plan" => {
235 let args: PmPlanTool = serde_json::from_value(arguments)
236 .map_err(|e| CallToolError(Box::new(e)))?;
237
238 if let Some(ref engine) = *self.engine.read().unwrap() {
239 let impact = engine.analyze_impact(&args.symbol);
240 let blast = engine.check_blast_radius("", &args.symbol);
241
242 let mut unique_blast = std::collections::HashSet::new();
243 for r in &blast { unique_blast.insert(&r.path); }
244
245 format!("Architectural Plan for {}:\n- Fan-out (Dependencies): {} nodes\n- Fan-in (Dependents): {} nodes across {} files.",
246 args.symbol, impact.len(), blast.len(), unique_blast.len())
247 } else {
248 "Error: Index not loaded".to_string()
249 }
250 }
251 "pm_semantic_search" => {
252 let args: PmSemanticSearchTool = serde_json::from_value(arguments)
253 .map_err(|e| CallToolError(Box::new(e)))?;
254
255 if let Some(ref engine) = *self.engine.read().unwrap() {
256 let matches = engine.find_symbols(&args.query);
257 let mut result = format!("Semantic Search Results ({}):", matches.len());
258 for m in matches.iter().take(15) {
259 result.push_str(&format!("\n- {}: {}", m.path, m.name));
260 }
261 result
262 } else {
263 "Error: Index not loaded".to_string()
264 }
265 }
266 "pm_fetch_symbol" => {
267 let args: PmFetchSymbolTool = serde_json::from_value(arguments)
268 .map_err(|e| CallToolError(Box::new(e)))?;
269
270 if let Some(ref engine) = *self.engine.read().unwrap() {
271 if let Some(node) = engine.find_symbol_in_path(&args.path, &args.symbol) {
272 if let Ok(content) = std::fs::read_to_string(&node.path) {
273 let bytes = content.as_bytes();
274 if node.start_byte < bytes.len() && node.end_byte <= bytes.len() {
275 String::from_utf8_lossy(&bytes[node.start_byte..node.end_byte]).to_string()
276 } else {
277 "Error: Byte range out of bounds".to_string()
278 }
279 } else {
280 "Error: Could not read file".to_string()
281 }
282 } else {
283 "Error: Symbol not found".to_string()
284 }
285 } else {
286 "Error: Index not loaded".to_string()
287 }
288 }
289 "pm_init" => {
290 let mut orch = Orchestrator::new();
291 if orch.build_index(Path::new(".")).is_ok() && orch.save_index_versioned(Path::new(".project-map")).is_ok() {
292 let new_engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
293 *self.engine.write().unwrap() = new_engine;
294 "Index refreshed successfully.".to_string()
295 } else {
296 "Failed to refresh index.".to_string()
297 }
298 }
299
300 _ => return Err(CallToolError(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Unknown tool")))),
301 };
302
303 Ok(CallToolResult::text_content(vec![text.into()]))
304 }
305}