project_map_cli_rust/mcp/
server.rs1use std::path::Path;
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6use rust_mcp_sdk::{McpServer as SdkMcpServer, TransportOptions, StdioTransport, ToMcpServerHandler};
7use rust_mcp_sdk::mcp_server::{server_runtime, McpServerOptions, ServerHandler};
8use rust_mcp_sdk::schema::{
9 CallToolRequestParams, CallToolResult, InitializeResult,
10 PaginatedRequestParams, ListToolsResult, ServerCapabilities, ServerCapabilitiesTools,
11 Implementation, ProtocolVersion, RpcError,
12};
13use rust_mcp_sdk::schema::schema_utils::CallToolError;
14use rust_mcp_sdk::macros::{mcp_tool, JsonSchema};
15use tracing_subscriber::fmt;
16
17use crate::error::Result;
18use crate::core::query_engine::QueryEngine;
19use crate::core::orchestrator::Orchestrator;
20use crate::core::toon::ToonFormatter;
21use crate::core::utils::render_tree;
22
23#[mcp_tool(
26 name = "pm_status",
27 description = "Returns current workspace context and available commands."
28)]
29#[derive(JsonSchema, Deserialize, Serialize)]
30pub struct PmStatusTool {}
31
32#[mcp_tool(
33 name = "pm_query",
34 description = "Search for symbols or get file context."
35)]
36#[derive(JsonSchema, Deserialize, Serialize)]
37pub struct PmQueryTool {
38 pub query: Option<String>,
40 pub path: Option<String>,
42 pub filename: Option<String>,
44}
45
46#[mcp_tool(
47 name = "pm_check_blast_radius",
48 description = "Identifies all components and files that depend on or import a specific symbol."
49)]
50#[derive(JsonSchema, Deserialize, Serialize)]
51pub struct PmCheckBlastRadiusTool {
52 pub path: String,
54 pub symbol: String,
56}
57
58#[mcp_tool(
59 name = "pm_plan",
60 description = "Analyze the architectural impact (fan-out) of a symbol before starting a refactor."
61)]
62#[derive(JsonSchema, Deserialize, Serialize)]
63pub struct PmPlanTool {
64 pub symbol: String,
66}
67
68#[mcp_tool(
69 name = "pm_semantic_search",
70 description = "Search for logic using natural language keywords (e.g., 'auth', 'database')."
71)]
72#[derive(JsonSchema, Deserialize, Serialize)]
73pub struct PmSemanticSearchTool {
74 pub query: String,
76}
77
78#[mcp_tool(
79 name = "pm_fetch_symbol",
80 description = "Extract raw source code for a specific class or function."
81)]
82#[derive(JsonSchema, Deserialize, Serialize)]
83pub struct PmFetchSymbolTool {
84 pub path: String,
86 pub symbol: String,
88}
89
90#[mcp_tool(
91 name = "pm_init",
92 description = "Refresh the map index after significant code changes to maintain discovery accuracy."
93)]
94#[derive(JsonSchema, Deserialize, Serialize)]
95pub struct PmInitTool {}
96
97pub struct McpServer {
100 engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
101}
102
103impl McpServer {
104 pub fn new() -> Self {
105 let engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
106 Self {
107 engine: Arc::new(std::sync::RwLock::new(engine)),
108 }
109 }
110
111 pub async fn run(&self) -> Result<()> {
112 let _ = fmt()
113 .with_writer(std::io::stderr)
114 .try_init();
115
116 let server_info = InitializeResult {
117 protocol_version: ProtocolVersion::V2025_11_25.into(),
118 capabilities: ServerCapabilities {
119 tools: Some(ServerCapabilitiesTools { list_changed: None }),
120 ..Default::default()
121 },
122 server_info: Implementation {
123 name: "project-map-cli-rust".to_string(),
124 version: env!("CARGO_PKG_VERSION").to_string(),
125 title: Some("Project Map CLI".to_string()),
126 description: None,
127 icons: vec![],
128 website_url: None,
129 },
130 instructions: None,
131 meta: None,
132 };
133
134 let transport = StdioTransport::new(TransportOptions::default())
135 .map_err(|e| crate::error::AppError::Generic(format!("Transport error: {}", e)))?;
136 let handler = self.clone_for_handler();
137
138 let options = McpServerOptions {
139 server_details: server_info,
140 transport,
141 handler: handler.to_mcp_server_handler(),
142 task_store: None,
143 client_task_store: None,
144 message_observer: None,
145 };
146
147 let server = server_runtime::create_server(options);
148 server.start().await.map_err(|e| crate::error::AppError::Generic(format!("Server error: {}", e)))?;
149
150 Ok(())
151 }
152
153 fn clone_for_handler(&self) -> McpServerHandler {
154 McpServerHandler {
155 engine: Arc::clone(&self.engine),
156 }
157 }
158}
159
160pub struct McpServerHandler {
161 engine: Arc<std::sync::RwLock<Option<QueryEngine>>>,
162}
163
164#[async_trait]
165impl ServerHandler for McpServerHandler {
166 async fn handle_list_tools_request(
167 &self,
168 _params: Option<PaginatedRequestParams>,
169 _runtime: Arc<dyn SdkMcpServer>,
170 ) -> std::result::Result<ListToolsResult, RpcError> {
171 Ok(ListToolsResult {
172 tools: vec![
173 PmStatusTool::tool(),
174 PmQueryTool::tool(),
175 PmCheckBlastRadiusTool::tool(),
176 PmPlanTool::tool(),
177 PmSemanticSearchTool::tool(),
178 PmFetchSymbolTool::tool(),
179 PmInitTool::tool(),
180 ],
181 next_cursor: None,
182 meta: None,
183 })
184 }
185
186 async fn handle_call_tool_request(
187 &self,
188 params: CallToolRequestParams,
189 _runtime: Arc<dyn SdkMcpServer>,
190 ) -> std::result::Result<CallToolResult, CallToolError> {
191 let arguments = serde_json::Value::Object(params.arguments.unwrap_or_default());
192 let text = match params.name.as_str() {
193 "pm_status" => {
194 let (is_ready, tree, active_features) = if let Some(ref engine) = *self.engine.read().unwrap() {
195 let paths = engine.get_all_file_paths();
196 let tree = render_tree(&paths, 3);
197
198 let mut active = Vec::new();
199 if let Ok(entries) = std::fs::read_dir("projects/active") {
200 for entry in entries.flatten() {
201 if entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) {
202 active.push(entry.file_name().to_string_lossy().into_owned());
203 }
204 }
205 }
206 (true, Some(tree), active)
207 } else {
208 (false, None, Vec::new())
209 };
210
211 ToonFormatter::format_status(is_ready, Some(".project-map/latest/.project-map.json"), tree.as_deref(), &active_features)
212 }
213 "pm_query" => {
214 let args: PmQueryTool = serde_json::from_value(arguments)
215 .map_err(|e| CallToolError(Box::new(e)))?;
216
217 if let Some(ref engine) = *self.engine.read().unwrap() {
218 if let Some(q) = args.query {
219 let matches = engine.find_symbols(&q);
220 ToonFormatter::format_symbols(&q, &matches)
221 } else if let Some(p) = args.path {
222 let symbols = engine.get_file_outline(&p);
223 ToonFormatter::format_file_context(&p, &symbols)
224 } else if let Some(f) = args.filename {
225 let matches = engine.find_files(&f);
226 ToonFormatter::format_file_matches(&f, &matches)
227 } else {
228 "Error: Provide query, path, or filename".to_string()
229 }
230 } else {
231 "Error: Index not loaded".to_string()
232 }
233 }
234 "pm_check_blast_radius" => {
235 let args: PmCheckBlastRadiusTool = serde_json::from_value(arguments)
236 .map_err(|e| CallToolError(Box::new(e)))?;
237
238 if let Some(ref engine) = *self.engine.read().unwrap() {
239 let results = engine.check_blast_radius(&args.path, &args.symbol);
240 ToonFormatter::format_blast_radius(&args.path, &args.symbol, &results)
241 } else {
242 "Error: Index not loaded".to_string()
243 }
244 }
245 "pm_plan" => {
246 let args: PmPlanTool = serde_json::from_value(arguments)
247 .map_err(|e| CallToolError(Box::new(e)))?;
248
249 if let Some(ref engine) = *self.engine.read().unwrap() {
250 let impact = engine.analyze_impact(&args.symbol);
251 ToonFormatter::format_impact_analysis(&args.symbol, &impact)
252 } else {
253 "Error: Index not loaded".to_string()
254 }
255 }
256 "pm_semantic_search" => {
257 let args: PmSemanticSearchTool = serde_json::from_value(arguments)
258 .map_err(|e| CallToolError(Box::new(e)))?;
259
260 if let Some(ref engine) = *self.engine.read().unwrap() {
261 let matches = engine.find_symbols(&args.query);
262 ToonFormatter::format_symbols(&args.query, &matches)
263 } else {
264 "Error: Index not loaded".to_string()
265 }
266 }
267 "pm_fetch_symbol" => {
268 let args: PmFetchSymbolTool = serde_json::from_value(arguments)
269 .map_err(|e| CallToolError(Box::new(e)))?;
270
271 if let Some(ref engine) = *self.engine.read().unwrap() {
272 if let Some(node) = engine.find_symbol_in_path(&args.path, &args.symbol) {
273 if let Ok(content) = std::fs::read_to_string(&node.path) {
274 let bytes = content.as_bytes();
275 if node.start_byte < bytes.len() && node.end_byte <= bytes.len() {
276 let content_str = String::from_utf8_lossy(&bytes[node.start_byte..node.end_byte]);
277 ToonFormatter::format_fetch_result(&args.path, &args.symbol, Some(&content_str))
278 } else {
279 "Error: Byte range out of bounds".to_string()
280 }
281 } else {
282 "Error: Could not read file".to_string()
283 }
284 } else {
285 ToonFormatter::format_fetch_result(&args.path, &args.symbol, None)
286 }
287 } else {
288 "Error: Index not loaded".to_string()
289 }
290 }
291 "pm_init" => {
292 let mut orch = Orchestrator::new();
293 let _ = orch.scaffold_if_empty(Path::new("."));
294 if orch.build_index(Path::new(".")).is_ok() && orch.save_index_versioned(Path::new(".project-map")).is_ok() {
295 let new_engine = QueryEngine::load(Path::new(".project-map/latest/.project-map.json")).ok();
296 *self.engine.write().unwrap() = new_engine;
297 "Index refreshed successfully.".to_string()
298 } else {
299 "Failed to refresh index.".to_string()
300 }
301 }
302
303 _ => return Err(CallToolError(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Unknown tool")))),
304 };
305
306 Ok(CallToolResult::text_content(vec![text.into()]))
307 }
308}