Skip to main content

project_map_cli_rust/mcp/
server.rs

1use std::path::Path;
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use rust_mcp_sdk::{McpServer as SdkMcpServer, TransportOptions, StdioTransport};
7use rust_mcp_sdk::mcp_server::{server_runtime, ServerHandler};
8use rust_mcp_sdk::schema::{
9    CallToolRequest, CallToolResult, InitializeResult,
10    ListToolsRequest, ListToolsResult, ServerCapabilities, ServerCapabilitiesTools,
11    Implementation, ProtocolVersion, RpcError,
12};
13use rust_mcp_sdk::schema::schema_utils::CallToolError;
14use rust_mcp_sdk::macros::{mcp_tool, JsonSchema};
15
16use crate::error::Result;
17use crate::core::query_engine::QueryEngine;
18use crate::core::orchestrator::Orchestrator;
19
20// --- Tool Definitions ---
21
22#[mcp_tool(
23    name = "pm_status",
24    description = "Returns current workspace context and available commands."
25)]
26#[derive(JsonSchema, Deserialize, Serialize)]
27pub struct PmStatusTool {}
28
29#[mcp_tool(
30    name = "pm_query",
31    description = "Search for symbols or get file context."
32)]
33#[derive(JsonSchema, Deserialize, Serialize)]
34pub struct PmQueryTool {
35    /// Search query for symbols
36    pub query: Option<String>,
37    /// File path to get outline
38    pub path: Option<String>,
39}
40
41#[mcp_tool(
42    name = "pm_check_blast_radius",
43    description = "Identifies all components and files that depend on or import a specific symbol."
44)]
45#[derive(JsonSchema, Deserialize, Serialize)]
46pub struct PmCheckBlastRadiusTool {
47    /// File path where the symbol is defined
48    pub path: String,
49    /// Symbol name to check
50    pub symbol: String,
51}
52
53#[mcp_tool(
54    name = "pm_plan",
55    description = "Analyze the architectural impact (fan-out) of a symbol before starting a refactor."
56)]
57#[derive(JsonSchema, Deserialize, Serialize)]
58pub struct PmPlanTool {
59    /// Symbol name to analyze
60    pub symbol: String,
61}
62
63#[mcp_tool(
64    name = "pm_semantic_search",
65    description = "Search for logic using natural language keywords (e.g., 'auth', 'database')."
66)]
67#[derive(JsonSchema, Deserialize, Serialize)]
68pub struct PmSemanticSearchTool {
69    /// Natural language query
70    pub query: String,
71}
72
73#[mcp_tool(
74    name = "pm_fetch_symbol",
75    description = "Extract raw source code for a specific class or function."
76)]
77#[derive(JsonSchema, Deserialize, Serialize)]
78pub struct PmFetchSymbolTool {
79    /// File path
80    pub path: String,
81    /// Symbol name
82    pub symbol: String,
83}
84
85#[mcp_tool(
86    name = "pm_init",
87    description = "Refresh the map index after significant code changes to maintain discovery accuracy."
88)]
89#[derive(JsonSchema, Deserialize, Serialize)]
90pub struct PmInitTool {}
91
92// --- Server Implementation ---
93
94pub struct McpServer {
95    engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
96}
97
98impl McpServer {
99    pub fn new() -> Self {
100        let engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
101        Self {
102            engine: Arc::new(std::sync::RwLock::new(engine)),
103        }
104    }
105
106    pub async fn run(&self) -> Result<()> {
107        let server_info = InitializeResult {
108            protocol_version: ProtocolVersion::V2024_11_05.to_string(),
109            capabilities: ServerCapabilities {
110                tools: Some(ServerCapabilitiesTools { list_changed: None }),
111                ..Default::default()
112            },
113            server_info: Implementation {
114                name: "project-map-cli-rust".to_string(),
115                version: "0.1.2".to_string(),
116                title: Some("Project Map CLI".to_string()),
117            },
118            instructions: None,
119            meta: None,
120        };
121
122        let transport = StdioTransport::new(TransportOptions::default())
123            .map_err(|e| crate::error::AppError::Generic(format!("Transport error: {}", e)))?;
124        let handler = self.clone_for_handler();
125        
126        let server = server_runtime::create_server(server_info, transport, handler);
127        server.start().await.map_err(|e| crate::error::AppError::Generic(format!("Server error: {}", e)))?;
128
129        Ok(())
130    }
131
132    fn clone_for_handler(&self) -> McpServerHandler {
133        McpServerHandler {
134            engine: Arc::clone(&self.engine),
135        }
136    }
137}
138
139pub struct McpServerHandler {
140    engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
141}
142
143#[async_trait]
144impl ServerHandler for McpServerHandler {
145    async fn handle_list_tools_request(
146        &self,
147        _request: ListToolsRequest,
148        _runtime: &dyn SdkMcpServer,
149    ) -> std::result::Result<ListToolsResult, RpcError> {
150        Ok(ListToolsResult {
151            tools: vec![
152                PmStatusTool::tool(),
153                PmQueryTool::tool(),
154                PmCheckBlastRadiusTool::tool(),
155                PmPlanTool::tool(),
156                PmSemanticSearchTool::tool(),
157                PmFetchSymbolTool::tool(),
158                PmInitTool::tool(),
159            ],
160            next_cursor: None,
161            meta: None,
162        })
163    }
164
165    async fn handle_call_tool_request(
166        &self,
167        request: CallToolRequest,
168        _runtime: &dyn SdkMcpServer,
169    ) -> std::result::Result<CallToolResult, CallToolError> {
170        let arguments = serde_json::Value::Object(request.params.arguments.unwrap_or_default());
171        let text = match request.params.name.as_str() {
172            "pm_status" => {
173                if self.engine.read().unwrap().is_some() {
174                    "Status: System healthy. Index is present.".to_string()
175                } else {
176                    "Status: Index missing. Run project-map build.".to_string()
177                }
178            }
179            "pm_query" => {
180                let args: PmQueryTool = serde_json::from_value(arguments)
181                    .map_err(|e| CallToolError(Box::new(e)))?;
182                
183                if let Some(ref engine) = *self.engine.read().unwrap() {
184                    if let Some(q) = args.query {
185                        let matches = engine.find_symbols(&q);
186                        format!("Matches: {}", matches.len())
187                    } else if let Some(p) = args.path {
188                        let symbols = engine.get_file_outline(&p);
189                        format!("Symbols in {}: {}", p, symbols.len())
190                    } else {
191                        "Error: Provide query or path".to_string()
192                    }
193                } else {
194                    "Error: Index not loaded".to_string()
195                }
196            }
197            "pm_check_blast_radius" => {
198                let args: PmCheckBlastRadiusTool = serde_json::from_value(arguments)
199                    .map_err(|e| CallToolError(Box::new(e)))?;
200                
201                if let Some(ref engine) = *self.engine.read().unwrap() {
202                    let results = engine.check_blast_radius(&args.path, &args.symbol);
203
204                    if results.is_empty() {
205                        "No dependent components found.".to_string()
206                    } else {
207                        let mut unique_files = std::collections::HashSet::new();
208                        for r in &results { unique_files.insert(&r.path); }
209                        format!("Blast Radius for {}:\n- Total Impacted Nodes: {}\n- Unique Files: {}\n(Top 5: {})", 
210                            args.symbol, results.len(), unique_files.len(),
211                            results.iter().take(5).map(|r| r.name.as_str()).collect::<Vec<_>>().join(", "))
212                    }
213                } else {
214                    "Error: Index not loaded".to_string()
215                }
216            }
217            "pm_plan" => {
218                let args: PmPlanTool = serde_json::from_value(arguments)
219                    .map_err(|e| CallToolError(Box::new(e)))?;
220                
221                if let Some(ref engine) = *self.engine.read().unwrap() {
222                    let impact = engine.analyze_impact(&args.symbol);
223                    let blast = engine.check_blast_radius("", &args.symbol);
224
225                    let mut unique_blast = std::collections::HashSet::new();
226                    for r in &blast { unique_blast.insert(&r.path); }
227
228                    format!("Architectural Plan for {}:\n- Fan-out (Dependencies): {} nodes\n- Fan-in (Dependents): {} nodes across {} files.", 
229                        args.symbol, impact.len(), blast.len(), unique_blast.len())
230                } else {
231                    "Error: Index not loaded".to_string()
232                }
233            }
234            "pm_semantic_search" => {
235                let args: PmSemanticSearchTool = serde_json::from_value(arguments)
236                    .map_err(|e| CallToolError(Box::new(e)))?;
237                
238                if let Some(ref engine) = *self.engine.read().unwrap() {
239                    let matches = engine.find_symbols(&args.query);
240                    let mut result = format!("Semantic Search Results ({}):", matches.len());
241                    for m in matches.iter().take(15) {
242                        result.push_str(&format!("\n- {}: {}", m.path, m.name));
243                    }
244                    result
245                } else {
246                    "Error: Index not loaded".to_string()
247                }
248            }
249            "pm_fetch_symbol" => {
250                let args: PmFetchSymbolTool = serde_json::from_value(arguments)
251                    .map_err(|e| CallToolError(Box::new(e)))?;
252                
253                if let Some(ref engine) = *self.engine.read().unwrap() {
254                    if let Some(node) = engine.find_symbol_in_path(&args.path, &args.symbol) {
255                        if let Ok(content) = std::fs::read_to_string(&node.path) {
256                            let bytes = content.as_bytes();
257                            if node.start_byte < bytes.len() && node.end_byte <= bytes.len() {
258                                String::from_utf8_lossy(&bytes[node.start_byte..node.end_byte]).to_string()
259                            } else {
260                                "Error: Byte range out of bounds".to_string()
261                            }
262                        } else {
263                            "Error: Could not read file".to_string()
264                        }
265                    } else {
266                        "Error: Symbol not found".to_string()
267                    }
268                } else {
269                    "Error: Index not loaded".to_string()
270                }
271            }
272            "pm_init" => {
273                let mut orch = Orchestrator::new();
274                if orch.build_index(Path::new(".")).is_ok() && orch.save_index_versioned(Path::new(".project-map")).is_ok() {
275                    let new_engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
276                    *self.engine.write().unwrap() = new_engine;
277                    "Index refreshed successfully.".to_string()
278                } else {
279                    "Failed to refresh index.".to_string()
280                }
281            }
282
283            _ => return Err(CallToolError(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Unknown tool")))),
284        };
285
286        Ok(CallToolResult::text_content(vec![text.into()]))
287    }
288}