Skip to main content

project_map_cli_rust/mcp/
server.rs

1use std::path::Path;
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use rust_mcp_sdk::{McpServer as SdkMcpServer, TransportOptions, StdioTransport};
7use rust_mcp_sdk::mcp_server::{server_runtime, ServerHandler};
8use rust_mcp_sdk::schema::{
9    CallToolRequest, CallToolResult, InitializeResult,
10    ListToolsRequest, ListToolsResult, ServerCapabilities, ServerCapabilitiesTools,
11    Implementation, ProtocolVersion, RpcError,
12};
13use rust_mcp_sdk::schema::schema_utils::CallToolError;
14use rust_mcp_sdk::macros::{mcp_tool, JsonSchema};
15use tracing_subscriber::fmt;
16
17use crate::error::Result;
18use crate::core::query_engine::QueryEngine;
19use crate::core::orchestrator::Orchestrator;
20
21// --- Tool Definitions ---
22
23#[mcp_tool(
24    name = "pm_status",
25    description = "Returns current workspace context and available commands."
26)]
27#[derive(JsonSchema, Deserialize, Serialize)]
28pub struct PmStatusTool {}
29
30#[mcp_tool(
31    name = "pm_query",
32    description = "Search for symbols or get file context."
33)]
34#[derive(JsonSchema, Deserialize, Serialize)]
35pub struct PmQueryTool {
36    /// Search query for symbols
37    pub query: Option<String>,
38    /// File path to get outline
39    pub path: Option<String>,
40}
41
42#[mcp_tool(
43    name = "pm_check_blast_radius",
44    description = "Identifies all components and files that depend on or import a specific symbol."
45)]
46#[derive(JsonSchema, Deserialize, Serialize)]
47pub struct PmCheckBlastRadiusTool {
48    /// File path where the symbol is defined
49    pub path: String,
50    /// Symbol name to check
51    pub symbol: String,
52}
53
54#[mcp_tool(
55    name = "pm_plan",
56    description = "Analyze the architectural impact (fan-out) of a symbol before starting a refactor."
57)]
58#[derive(JsonSchema, Deserialize, Serialize)]
59pub struct PmPlanTool {
60    /// Symbol name to analyze
61    pub symbol: String,
62}
63
64#[mcp_tool(
65    name = "pm_semantic_search",
66    description = "Search for logic using natural language keywords (e.g., 'auth', 'database')."
67)]
68#[derive(JsonSchema, Deserialize, Serialize)]
69pub struct PmSemanticSearchTool {
70    /// Natural language query
71    pub query: String,
72}
73
74#[mcp_tool(
75    name = "pm_fetch_symbol",
76    description = "Extract raw source code for a specific class or function."
77)]
78#[derive(JsonSchema, Deserialize, Serialize)]
79pub struct PmFetchSymbolTool {
80    /// File path
81    pub path: String,
82    /// Symbol name
83    pub symbol: String,
84}
85
86#[mcp_tool(
87    name = "pm_init",
88    description = "Refresh the map index after significant code changes to maintain discovery accuracy."
89)]
90#[derive(JsonSchema, Deserialize, Serialize)]
91pub struct PmInitTool {}
92
93// --- Server Implementation ---
94
95pub struct McpServer {
96    engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
97}
98
99impl McpServer {
100    pub fn new() -> Self {
101        let engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
102        Self {
103            engine: Arc::new(std::sync::RwLock::new(engine)),
104        }
105    }
106
107    pub async fn run(&self) -> Result<()> {
108        let _ = fmt()
109            .with_writer(std::io::stderr)
110            .try_init();
111
112        let server_info = InitializeResult {
113            protocol_version: ProtocolVersion::V2024_11_05.to_string(),
114            capabilities: ServerCapabilities {
115                tools: Some(ServerCapabilitiesTools { list_changed: None }),
116                ..Default::default()
117            },
118            server_info: Implementation {
119                name: "project-map-cli-rust".to_string(),
120                version: env!("CARGO_PKG_VERSION").to_string(),
121                title: Some("Project Map CLI".to_string()),
122            },
123            instructions: None,
124            meta: None,
125        };
126
127        let transport = StdioTransport::new(TransportOptions::default())
128            .map_err(|e| crate::error::AppError::Generic(format!("Transport error: {}", e)))?;
129        let handler = self.clone_for_handler();
130        
131        let server = server_runtime::create_server(server_info, transport, handler);
132        server.start().await.map_err(|e| crate::error::AppError::Generic(format!("Server error: {}", e)))?;
133
134        Ok(())
135    }
136
137    fn clone_for_handler(&self) -> McpServerHandler {
138        McpServerHandler {
139            engine: Arc::clone(&self.engine),
140        }
141    }
142}
143
144pub struct McpServerHandler {
145    engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
146}
147
148#[async_trait]
149impl ServerHandler for McpServerHandler {
150    async fn handle_list_tools_request(
151        &self,
152        _request: ListToolsRequest,
153        _runtime: &dyn SdkMcpServer,
154    ) -> std::result::Result<ListToolsResult, RpcError> {
155        Ok(ListToolsResult {
156            tools: vec![
157                PmStatusTool::tool(),
158                PmQueryTool::tool(),
159                PmCheckBlastRadiusTool::tool(),
160                PmPlanTool::tool(),
161                PmSemanticSearchTool::tool(),
162                PmFetchSymbolTool::tool(),
163                PmInitTool::tool(),
164            ],
165            next_cursor: None,
166            meta: None,
167        })
168    }
169
170    async fn handle_call_tool_request(
171        &self,
172        request: CallToolRequest,
173        _runtime: &dyn SdkMcpServer,
174    ) -> std::result::Result<CallToolResult, CallToolError> {
175        let arguments = serde_json::Value::Object(request.params.arguments.unwrap_or_default());
176        let text = match request.params.name.as_str() {
177            "pm_status" => {
178                if self.engine.read().unwrap().is_some() {
179                    "Status: System healthy. Index is present.".to_string()
180                } else {
181                    "Status: Index missing. Run project-map build.".to_string()
182                }
183            }
184            "pm_query" => {
185                let args: PmQueryTool = serde_json::from_value(arguments)
186                    .map_err(|e| CallToolError(Box::new(e)))?;
187                
188                if let Some(ref engine) = *self.engine.read().unwrap() {
189                    if let Some(q) = args.query {
190                        let matches = engine.find_symbols(&q);
191                        format!("Matches: {}", matches.len())
192                    } else if let Some(p) = args.path {
193                        let symbols = engine.get_file_outline(&p);
194                        format!("Symbols in {}: {}", p, symbols.len())
195                    } else {
196                        "Error: Provide query or path".to_string()
197                    }
198                } else {
199                    "Error: Index not loaded".to_string()
200                }
201            }
202            "pm_check_blast_radius" => {
203                let args: PmCheckBlastRadiusTool = serde_json::from_value(arguments)
204                    .map_err(|e| CallToolError(Box::new(e)))?;
205                
206                if let Some(ref engine) = *self.engine.read().unwrap() {
207                    let results = engine.check_blast_radius(&args.path, &args.symbol);
208
209                    if results.is_empty() {
210                        "No dependent components found.".to_string()
211                    } else {
212                        let mut unique_files = std::collections::HashSet::new();
213                        for r in &results { unique_files.insert(&r.path); }
214                        format!("Blast Radius for {}:\n- Total Impacted Nodes: {}\n- Unique Files: {}\n(Top 5: {})", 
215                            args.symbol, results.len(), unique_files.len(),
216                            results.iter().take(5).map(|r| r.name.as_str()).collect::<Vec<_>>().join(", "))
217                    }
218                } else {
219                    "Error: Index not loaded".to_string()
220                }
221            }
222            "pm_plan" => {
223                let args: PmPlanTool = serde_json::from_value(arguments)
224                    .map_err(|e| CallToolError(Box::new(e)))?;
225                
226                if let Some(ref engine) = *self.engine.read().unwrap() {
227                    let impact = engine.analyze_impact(&args.symbol);
228                    let blast = engine.check_blast_radius("", &args.symbol);
229
230                    let mut unique_blast = std::collections::HashSet::new();
231                    for r in &blast { unique_blast.insert(&r.path); }
232
233                    format!("Architectural Plan for {}:\n- Fan-out (Dependencies): {} nodes\n- Fan-in (Dependents): {} nodes across {} files.", 
234                        args.symbol, impact.len(), blast.len(), unique_blast.len())
235                } else {
236                    "Error: Index not loaded".to_string()
237                }
238            }
239            "pm_semantic_search" => {
240                let args: PmSemanticSearchTool = serde_json::from_value(arguments)
241                    .map_err(|e| CallToolError(Box::new(e)))?;
242                
243                if let Some(ref engine) = *self.engine.read().unwrap() {
244                    let matches = engine.find_symbols(&args.query);
245                    let mut result = format!("Semantic Search Results ({}):", matches.len());
246                    for m in matches.iter().take(15) {
247                        result.push_str(&format!("\n- {}: {}", m.path, m.name));
248                    }
249                    result
250                } else {
251                    "Error: Index not loaded".to_string()
252                }
253            }
254            "pm_fetch_symbol" => {
255                let args: PmFetchSymbolTool = serde_json::from_value(arguments)
256                    .map_err(|e| CallToolError(Box::new(e)))?;
257                
258                if let Some(ref engine) = *self.engine.read().unwrap() {
259                    if let Some(node) = engine.find_symbol_in_path(&args.path, &args.symbol) {
260                        if let Ok(content) = std::fs::read_to_string(&node.path) {
261                            let bytes = content.as_bytes();
262                            if node.start_byte < bytes.len() && node.end_byte <= bytes.len() {
263                                String::from_utf8_lossy(&bytes[node.start_byte..node.end_byte]).to_string()
264                            } else {
265                                "Error: Byte range out of bounds".to_string()
266                            }
267                        } else {
268                            "Error: Could not read file".to_string()
269                        }
270                    } else {
271                        "Error: Symbol not found".to_string()
272                    }
273                } else {
274                    "Error: Index not loaded".to_string()
275                }
276            }
277            "pm_init" => {
278                let mut orch = Orchestrator::new();
279                if orch.build_index(Path::new(".")).is_ok() && orch.save_index_versioned(Path::new(".project-map")).is_ok() {
280                    let new_engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
281                    *self.engine.write().unwrap() = new_engine;
282                    "Index refreshed successfully.".to_string()
283                } else {
284                    "Failed to refresh index.".to_string()
285                }
286            }
287
288            _ => return Err(CallToolError(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Unknown tool")))),
289        };
290
291        Ok(CallToolResult::text_content(vec![text.into()]))
292    }
293}