Skip to main content

project_map_cli_rust/mcp/
server.rs

1use std::path::Path;
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use rust_mcp_sdk::{McpServer as SdkMcpServer, TransportOptions, StdioTransport, ToMcpServerHandler};
7use rust_mcp_sdk::mcp_server::{server_runtime, McpServerOptions, ServerHandler};
8use rust_mcp_sdk::schema::{
9    CallToolRequestParams, CallToolResult, InitializeResult,
10    PaginatedRequestParams, ListToolsResult, ServerCapabilities, ServerCapabilitiesTools,
11    Implementation, ProtocolVersion, RpcError,
12};
13use rust_mcp_sdk::schema::schema_utils::CallToolError;
14use rust_mcp_sdk::macros::{mcp_tool, JsonSchema};
15use tracing_subscriber::fmt;
16
17use crate::error::Result;
18use crate::core::query_engine::QueryEngine;
19use crate::core::orchestrator::Orchestrator;
20use crate::core::toon::ToonFormatter;
21
22// --- Tool Definitions ---
23
24#[mcp_tool(
25    name = "pm_status",
26    description = "Returns current workspace context and available commands."
27)]
28#[derive(JsonSchema, Deserialize, Serialize)]
29pub struct PmStatusTool {}
30
31#[mcp_tool(
32    name = "pm_query",
33    description = "Search for symbols or get file context."
34)]
35#[derive(JsonSchema, Deserialize, Serialize)]
36pub struct PmQueryTool {
37    /// Search query for symbols
38    pub query: Option<String>,
39    /// File path to get outline
40    pub path: Option<String>,
41}
42
43#[mcp_tool(
44    name = "pm_check_blast_radius",
45    description = "Identifies all components and files that depend on or import a specific symbol."
46)]
47#[derive(JsonSchema, Deserialize, Serialize)]
48pub struct PmCheckBlastRadiusTool {
49    /// File path where the symbol is defined
50    pub path: String,
51    /// Symbol name to check
52    pub symbol: String,
53}
54
55#[mcp_tool(
56    name = "pm_plan",
57    description = "Analyze the architectural impact (fan-out) of a symbol before starting a refactor."
58)]
59#[derive(JsonSchema, Deserialize, Serialize)]
60pub struct PmPlanTool {
61    /// Symbol name to analyze
62    pub symbol: String,
63}
64
65#[mcp_tool(
66    name = "pm_semantic_search",
67    description = "Search for logic using natural language keywords (e.g., 'auth', 'database')."
68)]
69#[derive(JsonSchema, Deserialize, Serialize)]
70pub struct PmSemanticSearchTool {
71    /// Natural language query
72    pub query: String,
73}
74
75#[mcp_tool(
76    name = "pm_fetch_symbol",
77    description = "Extract raw source code for a specific class or function."
78)]
79#[derive(JsonSchema, Deserialize, Serialize)]
80pub struct PmFetchSymbolTool {
81    /// File path
82    pub path: String,
83    /// Symbol name
84    pub symbol: String,
85}
86
87#[mcp_tool(
88    name = "pm_init",
89    description = "Refresh the map index after significant code changes to maintain discovery accuracy."
90)]
91#[derive(JsonSchema, Deserialize, Serialize)]
92pub struct PmInitTool {}
93
94// --- Server Implementation ---
95
96pub struct McpServer {
97    engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
98}
99
100impl McpServer {
101    pub fn new() -> Self {
102        let engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
103        Self {
104            engine: Arc::new(std::sync::RwLock::new(engine)),
105        }
106    }
107
108    pub async fn run(&self) -> Result<()> {
109        let _ = fmt()
110            .with_writer(std::io::stderr)
111            .try_init();
112
113        let server_info = InitializeResult {
114            protocol_version: ProtocolVersion::V2025_11_25.into(),
115            capabilities: ServerCapabilities {
116                tools: Some(ServerCapabilitiesTools { list_changed: None }),
117                ..Default::default()
118            },
119            server_info: Implementation {
120                name: "project-map-cli-rust".to_string(),
121                version: env!("CARGO_PKG_VERSION").to_string(),
122                title: Some("Project Map CLI".to_string()),
123                description: None,
124                icons: vec![],
125                website_url: None,
126            },
127            instructions: None,
128            meta: None,
129        };
130
131        let transport = StdioTransport::new(TransportOptions::default())
132            .map_err(|e| crate::error::AppError::Generic(format!("Transport error: {}", e)))?;
133        let handler = self.clone_for_handler();
134        
135        let options = McpServerOptions {
136            server_details: server_info,
137            transport,
138            handler: handler.to_mcp_server_handler(),
139            task_store: None,
140            client_task_store: None,
141            message_observer: None,
142        };
143
144        let server = server_runtime::create_server(options);
145        server.start().await.map_err(|e| crate::error::AppError::Generic(format!("Server error: {}", e)))?;
146
147        Ok(())
148    }
149
150    fn clone_for_handler(&self) -> McpServerHandler {
151        McpServerHandler {
152            engine: Arc::clone(&self.engine),
153        }
154    }
155}
156
157pub struct McpServerHandler {
158    engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
159}
160
161#[async_trait]
162impl ServerHandler for McpServerHandler {
163    async fn handle_list_tools_request(
164        &self,
165        _params: Option<PaginatedRequestParams>,
166        _runtime: Arc<dyn SdkMcpServer>,
167    ) -> std::result::Result<ListToolsResult, RpcError> {
168        Ok(ListToolsResult {
169            tools: vec![
170                PmStatusTool::tool(),
171                PmQueryTool::tool(),
172                PmCheckBlastRadiusTool::tool(),
173                PmPlanTool::tool(),
174                PmSemanticSearchTool::tool(),
175                PmFetchSymbolTool::tool(),
176                PmInitTool::tool(),
177            ],
178            next_cursor: None,
179            meta: None,
180        })
181    }
182
183    async fn handle_call_tool_request(
184        &self,
185        params: CallToolRequestParams,
186        _runtime: Arc<dyn SdkMcpServer>,
187    ) -> std::result::Result<CallToolResult, CallToolError> {
188        let arguments = serde_json::Value::Object(params.arguments.unwrap_or_default());
189        let text = match params.name.as_str() {
190            "pm_status" => {
191                let is_ready = self.engine.read().unwrap().is_some();
192                ToonFormatter::format_status(is_ready, Some(".project-map/latest/.project-map.json"))
193            }
194            "pm_query" => {
195                let args: PmQueryTool = serde_json::from_value(arguments)
196                    .map_err(|e| CallToolError(Box::new(e)))?;
197                
198                if let Some(ref engine) = *self.engine.read().unwrap() {
199                    if let Some(q) = args.query {
200                        let matches = engine.find_symbols(&q);
201                        ToonFormatter::format_symbols(&q, &matches)
202                    } else if let Some(p) = args.path {
203                        let symbols = engine.get_file_outline(&p);
204                        ToonFormatter::format_file_context(&p, &symbols)
205                    } else {
206                        "Error: Provide query or path".to_string()
207                    }
208                } else {
209                    "Error: Index not loaded".to_string()
210                }
211            }
212            "pm_check_blast_radius" => {
213                let args: PmCheckBlastRadiusTool = serde_json::from_value(arguments)
214                    .map_err(|e| CallToolError(Box::new(e)))?;
215                
216                if let Some(ref engine) = *self.engine.read().unwrap() {
217                    let results = engine.check_blast_radius(&args.path, &args.symbol);
218                    ToonFormatter::format_blast_radius(&args.path, &args.symbol, &results)
219                } else {
220                    "Error: Index not loaded".to_string()
221                }
222            }
223            "pm_plan" => {
224                let args: PmPlanTool = serde_json::from_value(arguments)
225                    .map_err(|e| CallToolError(Box::new(e)))?;
226                
227                if let Some(ref engine) = *self.engine.read().unwrap() {
228                    let impact = engine.analyze_impact(&args.symbol);
229                    ToonFormatter::format_impact_analysis(&args.symbol, &impact)
230                } else {
231                    "Error: Index not loaded".to_string()
232                }
233            }
234            "pm_semantic_search" => {
235                let args: PmSemanticSearchTool = serde_json::from_value(arguments)
236                    .map_err(|e| CallToolError(Box::new(e)))?;
237                
238                if let Some(ref engine) = *self.engine.read().unwrap() {
239                    let matches = engine.find_symbols(&args.query);
240                    ToonFormatter::format_symbols(&args.query, &matches)
241                } else {
242                    "Error: Index not loaded".to_string()
243                }
244            }
245            "pm_fetch_symbol" => {
246                let args: PmFetchSymbolTool = serde_json::from_value(arguments)
247                    .map_err(|e| CallToolError(Box::new(e)))?;
248                
249                if let Some(ref engine) = *self.engine.read().unwrap() {
250                    if let Some(node) = engine.find_symbol_in_path(&args.path, &args.symbol) {
251                        if let Ok(content) = std::fs::read_to_string(&node.path) {
252                            let bytes = content.as_bytes();
253                            if node.start_byte < bytes.len() && node.end_byte <= bytes.len() {
254                                let content_str = String::from_utf8_lossy(&bytes[node.start_byte..node.end_byte]);
255                                ToonFormatter::format_fetch_result(&args.path, &args.symbol, Some(&content_str))
256                            } else {
257                                "Error: Byte range out of bounds".to_string()
258                            }
259                        } else {
260                            "Error: Could not read file".to_string()
261                        }
262                    } else {
263                        ToonFormatter::format_fetch_result(&args.path, &args.symbol, None)
264                    }
265                } else {
266                    "Error: Index not loaded".to_string()
267                }
268            }
269            "pm_init" => {
270                let mut orch = Orchestrator::new();
271                if orch.build_index(Path::new(".")).is_ok() && orch.save_index_versioned(Path::new(".project-map")).is_ok() {
272                    let new_engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
273                    *self.engine.write().unwrap() = new_engine;
274                    "Index refreshed successfully.".to_string()
275                } else {
276                    "Failed to refresh index.".to_string()
277                }
278            }
279
280            _ => return Err(CallToolError(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Unknown tool")))),
281        };
282
283        Ok(CallToolResult::text_content(vec![text.into()]))
284    }
285}