Skip to main content

project_map_cli_rust/mcp/
server.rs

1use std::path::Path;
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use rust_mcp_sdk::{McpServer as SdkMcpServer, TransportOptions, StdioTransport, ToMcpServerHandler};
7use rust_mcp_sdk::mcp_server::{server_runtime, McpServerOptions, ServerHandler};
8use rust_mcp_sdk::schema::{
9    CallToolRequestParams, CallToolResult, InitializeResult,
10    PaginatedRequestParams, ListToolsResult, ServerCapabilities, ServerCapabilitiesTools,
11    Implementation, ProtocolVersion, RpcError,
12};
13use rust_mcp_sdk::schema::schema_utils::CallToolError;
14use rust_mcp_sdk::macros::{mcp_tool, JsonSchema};
15use tracing_subscriber::fmt;
16
17use crate::error::Result;
18use crate::core::query_engine::QueryEngine;
19use crate::core::orchestrator::Orchestrator;
20use crate::core::toon::ToonFormatter;
21use crate::core::utils::render_tree;
22
23// --- Tool Definitions ---
24
25#[mcp_tool(
26    name = "pm_status",
27    description = "Returns current workspace context and available commands."
28)]
29#[derive(JsonSchema, Deserialize, Serialize)]
30pub struct PmStatusTool {}
31
32#[mcp_tool(
33    name = "pm_query",
34    description = "Search for symbols or get file context."
35)]
36#[derive(JsonSchema, Deserialize, Serialize)]
37pub struct PmQueryTool {
38    /// Search query for symbols
39    pub query: Option<String>,
40    /// File path to get outline
41    pub path: Option<String>,
42    /// Search for filenames across the tree
43    pub filename: Option<String>,
44}
45
46#[mcp_tool(
47    name = "pm_check_blast_radius",
48    description = "Identifies all components and files that depend on or import a specific symbol."
49)]
50#[derive(JsonSchema, Deserialize, Serialize)]
51pub struct PmCheckBlastRadiusTool {
52    /// File path where the symbol is defined
53    pub path: String,
54    /// Symbol name to check
55    pub symbol: String,
56}
57
58#[mcp_tool(
59    name = "pm_plan",
60    description = "Analyze the architectural impact (fan-out) of a symbol before starting a refactor."
61)]
62#[derive(JsonSchema, Deserialize, Serialize)]
63pub struct PmPlanTool {
64    /// Symbol name to analyze
65    pub symbol: String,
66}
67
68#[mcp_tool(
69    name = "pm_semantic_search",
70    description = "Search for logic using natural language keywords (e.g., 'auth', 'database')."
71)]
72#[derive(JsonSchema, Deserialize, Serialize)]
73pub struct PmSemanticSearchTool {
74    /// Natural language query
75    pub query: String,
76}
77
78#[mcp_tool(
79    name = "pm_fetch_symbol",
80    description = "Extract raw source code for a specific class or function."
81)]
82#[derive(JsonSchema, Deserialize, Serialize)]
83pub struct PmFetchSymbolTool {
84    /// File path
85    pub path: String,
86    /// Symbol name
87    pub symbol: String,
88}
89
90#[mcp_tool(
91    name = "pm_init",
92    description = "Refresh the map index after significant code changes to maintain discovery accuracy."
93)]
94#[derive(JsonSchema, Deserialize, Serialize)]
95pub struct PmInitTool {}
96
97// --- Server Implementation ---
98
99pub struct McpServer {
100    engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
101}
102
103impl McpServer {
104    pub fn new() -> Self {
105        let engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
106        Self {
107            engine: Arc::new(std::sync::RwLock::new(engine)),
108        }
109    }
110
111    pub async fn run(&self) -> Result<()> {
112        let _ = fmt()
113            .with_writer(std::io::stderr)
114            .try_init();
115
116        let server_info = InitializeResult {
117            protocol_version: ProtocolVersion::V2025_11_25.into(),
118            capabilities: ServerCapabilities {
119                tools: Some(ServerCapabilitiesTools { list_changed: None }),
120                ..Default::default()
121            },
122            server_info: Implementation {
123                name: "project-map-cli-rust".to_string(),
124                version: env!("CARGO_PKG_VERSION").to_string(),
125                title: Some("Project Map CLI".to_string()),
126                description: None,
127                icons: vec![],
128                website_url: None,
129            },
130            instructions: None,
131            meta: None,
132        };
133
134        let transport = StdioTransport::new(TransportOptions::default())
135            .map_err(|e| crate::error::AppError::Generic(format!("Transport error: {}", e)))?;
136        let handler = self.clone_for_handler();
137        
138        let options = McpServerOptions {
139            server_details: server_info,
140            transport,
141            handler: handler.to_mcp_server_handler(),
142            task_store: None,
143            client_task_store: None,
144            message_observer: None,
145        };
146
147        let server = server_runtime::create_server(options);
148        server.start().await.map_err(|e| crate::error::AppError::Generic(format!("Server error: {}", e)))?;
149
150        Ok(())
151    }
152
153    fn clone_for_handler(&self) -> McpServerHandler {
154        McpServerHandler {
155            engine: Arc::clone(&self.engine),
156        }
157    }
158}
159
160pub struct McpServerHandler {
161    engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
162}
163
164#[async_trait]
165impl ServerHandler for McpServerHandler {
166    async fn handle_list_tools_request(
167        &self,
168        _params: Option<PaginatedRequestParams>,
169        _runtime: Arc<dyn SdkMcpServer>,
170    ) -> std::result::Result<ListToolsResult, RpcError> {
171        Ok(ListToolsResult {
172            tools: vec![
173                PmStatusTool::tool(),
174                PmQueryTool::tool(),
175                PmCheckBlastRadiusTool::tool(),
176                PmPlanTool::tool(),
177                PmSemanticSearchTool::tool(),
178                PmFetchSymbolTool::tool(),
179                PmInitTool::tool(),
180            ],
181            next_cursor: None,
182            meta: None,
183        })
184    }
185
186    async fn handle_call_tool_request(
187        &self,
188        params: CallToolRequestParams,
189        _runtime: Arc<dyn SdkMcpServer>,
190    ) -> std::result::Result<CallToolResult, CallToolError> {
191        let arguments = serde_json::Value::Object(params.arguments.unwrap_or_default());
192        let text = match params.name.as_str() {
193            "pm_status" => {
194                let (is_ready, tree, active_features) = if let Some(ref engine) = *self.engine.read().unwrap() {
195                    let paths = engine.get_all_file_paths();
196                    let tree = render_tree(&paths, 3);
197                    
198                    let mut active = Vec::new();
199                    if let Ok(entries) = std::fs::read_dir("projects/active") {
200                        for entry in entries.flatten() {
201                            if entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) {
202                                active.push(entry.file_name().to_string_lossy().into_owned());
203                            }
204                        }
205                    }
206                    (true, Some(tree), active)
207                } else {
208                    (false, None, Vec::new())
209                };
210
211                ToonFormatter::format_status(is_ready, Some(".project-map/latest/.project-map.json"), tree.as_deref(), &active_features)
212            }
213            "pm_query" => {
214                let args: PmQueryTool = serde_json::from_value(arguments)
215                    .map_err(|e| CallToolError(Box::new(e)))?;
216                
217                if let Some(ref engine) = *self.engine.read().unwrap() {
218                    if let Some(q) = args.query {
219                        let matches = engine.find_symbols(&q);
220                        ToonFormatter::format_symbols(&q, &matches)
221                    } else if let Some(p) = args.path {
222                        let symbols = engine.get_file_outline(&p);
223                        ToonFormatter::format_file_context(&p, &symbols)
224                    } else if let Some(f) = args.filename {
225                        let matches = engine.find_files(&f);
226                        ToonFormatter::format_file_matches(&f, &matches)
227                    } else {
228                        "Error: Provide query, path, or filename".to_string()
229                    }
230                } else {
231                    "Error: Index not loaded".to_string()
232                }
233            }
234            "pm_check_blast_radius" => {
235                let args: PmCheckBlastRadiusTool = serde_json::from_value(arguments)
236                    .map_err(|e| CallToolError(Box::new(e)))?;
237                
238                if let Some(ref engine) = *self.engine.read().unwrap() {
239                    let results = engine.check_blast_radius(&args.path, &args.symbol);
240                    ToonFormatter::format_blast_radius(&args.path, &args.symbol, &results)
241                } else {
242                    "Error: Index not loaded".to_string()
243                }
244            }
245            "pm_plan" => {
246                let args: PmPlanTool = serde_json::from_value(arguments)
247                    .map_err(|e| CallToolError(Box::new(e)))?;
248                
249                if let Some(ref engine) = *self.engine.read().unwrap() {
250                    let impact = engine.analyze_impact(&args.symbol);
251                    ToonFormatter::format_impact_analysis(&args.symbol, &impact)
252                } else {
253                    "Error: Index not loaded".to_string()
254                }
255            }
256            "pm_semantic_search" => {
257                let args: PmSemanticSearchTool = serde_json::from_value(arguments)
258                    .map_err(|e| CallToolError(Box::new(e)))?;
259                
260                if let Some(ref engine) = *self.engine.read().unwrap() {
261                    let matches = engine.find_symbols(&args.query);
262                    ToonFormatter::format_symbols(&args.query, &matches)
263                } else {
264                    "Error: Index not loaded".to_string()
265                }
266            }
267            "pm_fetch_symbol" => {
268                let args: PmFetchSymbolTool = serde_json::from_value(arguments)
269                    .map_err(|e| CallToolError(Box::new(e)))?;
270                
271                if let Some(ref engine) = *self.engine.read().unwrap() {
272                    if let Some(node) = engine.find_symbol_in_path(&args.path, &args.symbol) {
273                        if let Ok(content) = std::fs::read_to_string(&node.path) {
274                            let bytes = content.as_bytes();
275                            if node.start_byte < bytes.len() && node.end_byte <= bytes.len() {
276                                let content_str = String::from_utf8_lossy(&bytes[node.start_byte..node.end_byte]);
277                                ToonFormatter::format_fetch_result(&args.path, &args.symbol, Some(&content_str))
278                            } else {
279                                "Error: Byte range out of bounds".to_string()
280                            }
281                        } else {
282                            "Error: Could not read file".to_string()
283                        }
284                    } else {
285                        ToonFormatter::format_fetch_result(&args.path, &args.symbol, None)
286                    }
287                } else {
288                    "Error: Index not loaded".to_string()
289                }
290            }
291            "pm_init" => {
292                let mut orch = Orchestrator::new();
293                let _ = orch.scaffold_if_empty(Path::new("."));
294                if orch.build_index(Path::new(".")).is_ok() && orch.save_index_versioned(Path::new(".project-map")).is_ok() {
295                    let new_engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
296                    *self.engine.write().unwrap() = new_engine;
297                    "Index refreshed successfully.".to_string()
298                } else {
299                    "Failed to refresh index.".to_string()
300                }
301            }
302
303            _ => return Err(CallToolError(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Unknown tool")))),
304        };
305
306        Ok(CallToolResult::text_content(vec![text.into()]))
307    }
308}