project_rag/
mcp_server.rs

1use crate::client::RagClient;
2use crate::types::*;
3
4use anyhow::{Context, Result};
5use rmcp::{
6    ErrorData as McpError, Peer, RoleServer, ServerHandler, ServiceExt,
7    handler::server::{router::prompt::PromptRouter, tool::ToolRouter, wrapper::Parameters},
8    model::*,
9    prompt, prompt_handler, prompt_router,
10    service::RequestContext,
11    tool, tool_handler, tool_router,
12};
13use std::sync::Arc;
14use tokio_util::sync::CancellationToken;
15
16/// Guard that cancels a CancellationToken when dropped.
17/// This ensures that if the async handler's future is dropped (e.g., due to client disconnect),
18/// the cancellation token is triggered, allowing cooperative cancellation of long-running operations.
19struct CancelOnDropGuard {
20    token: CancellationToken,
21}
22
23impl CancelOnDropGuard {
24    fn new(token: CancellationToken) -> Self {
25        Self { token }
26    }
27}
28
29impl Drop for CancelOnDropGuard {
30    fn drop(&mut self) {
31        if !self.token.is_cancelled() {
32            tracing::info!("Tool handler dropped, triggering cancellation");
33            self.token.cancel();
34        }
35    }
36}
37
38#[derive(Clone)]
39pub struct RagMcpServer {
40    client: Arc<RagClient>,
41    tool_router: ToolRouter<Self>,
42    prompt_router: PromptRouter<Self>,
43}
44
45impl RagMcpServer {
46    /// Create a new RAG MCP server with default configuration
47    pub async fn new() -> Result<Self> {
48        let client = RagClient::new().await?;
49        Self::with_client(Arc::new(client))
50    }
51
52    /// Create a new RAG MCP server with an existing client
53    pub fn with_client(client: Arc<RagClient>) -> Result<Self> {
54        Ok(Self {
55            client,
56            tool_router: Self::tool_router(),
57            prompt_router: Self::prompt_router(),
58        })
59    }
60
61    /// Get the underlying client
62    pub fn client(&self) -> &RagClient {
63        &self.client
64    }
65
66    /// Create a new RAG MCP server with custom configuration
67    pub async fn with_config(config: crate::config::Config) -> Result<Self> {
68        let client = RagClient::with_config(config).await?;
69        Self::with_client(Arc::new(client))
70    }
71
72    /// Normalize a path to a canonical absolute form for consistent cache lookups
73    pub fn normalize_path(path: &str) -> Result<String> {
74        RagClient::normalize_path(path)
75    }
76
77    /// Index a codebase directory (convenience method for testing)
78    #[allow(clippy::too_many_arguments)]
79    pub async fn do_index(
80        &self,
81        path: String,
82        project: Option<String>,
83        include_patterns: Vec<String>,
84        exclude_patterns: Vec<String>,
85        max_file_size: usize,
86        peer: Option<Peer<RoleServer>>,
87        progress_token: Option<ProgressToken>,
88        cancel_token: Option<CancellationToken>,
89    ) -> Result<IndexResponse> {
90        let cancel_token = cancel_token.unwrap_or_default();
91        crate::client::indexing::do_index_smart(
92            &self.client,
93            path,
94            project,
95            include_patterns,
96            exclude_patterns,
97            max_file_size,
98            peer,
99            progress_token,
100            cancel_token,
101        )
102        .await
103    }
104}
105
106#[tool_router(router = tool_router)]
107impl RagMcpServer {
108    #[tool(
109        description = "Index a codebase directory, creating embeddings for semantic search. Automatically performs full indexing for new codebases or incremental updates for previously indexed codebases."
110    )]
111    async fn index_codebase(
112        &self,
113        meta: Meta,
114        peer: Peer<RoleServer>,
115        Parameters(req): Parameters<IndexRequest>,
116    ) -> Result<String, String> {
117        // Validate request inputs
118        req.validate()?;
119
120        // Get progress token if provided
121        let progress_token = meta.get_progress_token();
122
123        // Create a cancellation token for this indexing operation
124        // When this handler's future is dropped (e.g., client disconnects),
125        // the CancellationToken will be dropped and signal cancellation
126        let cancel_token = CancellationToken::new();
127        let cancel_token_for_index = cancel_token.clone();
128
129        // Use a guard to cancel on drop
130        let _cancel_guard = CancelOnDropGuard::new(cancel_token);
131
132        let response = crate::client::indexing::do_index_smart(
133            &self.client,
134            req.path,
135            req.project,
136            req.include_patterns,
137            req.exclude_patterns,
138            req.max_file_size,
139            Some(peer),
140            progress_token,
141            cancel_token_for_index,
142        )
143        .await
144        .map_err(|e| format!("{:#}", e))?; // Use alternate display to show full error chain
145
146        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
147    }
148
149    #[tool(description = "Query the indexed codebase using semantic search")]
150    async fn query_codebase(
151        &self,
152        Parameters(req): Parameters<QueryRequest>,
153    ) -> Result<String, String> {
154        // Validate request inputs
155        req.validate()?;
156
157        let response = self
158            .client
159            .query_codebase(req)
160            .await
161            .map_err(|e| format!("{:#}", e))?;
162
163        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
164    }
165
166    #[tool(description = "Get statistics about the indexed codebase")]
167    async fn get_statistics(
168        &self,
169        Parameters(_req): Parameters<StatisticsRequest>,
170    ) -> Result<String, String> {
171        let response = self
172            .client
173            .get_statistics()
174            .await
175            .map_err(|e| format!("{:#}", e))?;
176
177        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
178    }
179
180    #[tool(description = "Clear all indexed data from the vector database")]
181    async fn clear_index(
182        &self,
183        Parameters(_req): Parameters<ClearRequest>,
184    ) -> Result<String, String> {
185        let response = self
186            .client
187            .clear_index()
188            .await
189            .map_err(|e| format!("{:#}", e))?;
190
191        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
192    }
193
194    #[tool(description = "Advanced search with filters for file type, language, and path patterns")]
195    async fn search_by_filters(
196        &self,
197        Parameters(req): Parameters<AdvancedSearchRequest>,
198    ) -> Result<String, String> {
199        // Validate request inputs
200        req.validate()?;
201
202        let response = self
203            .client
204            .search_with_filters(req)
205            .await
206            .map_err(|e| format!("{:#}", e))?;
207
208        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
209    }
210
211    #[tool(description = "Search git commit history using semantic search with on-demand indexing")]
212    async fn search_git_history(
213        &self,
214        Parameters(req): Parameters<SearchGitHistoryRequest>,
215    ) -> Result<String, String> {
216        // Validate request inputs
217        req.validate()?;
218
219        let response = self
220            .client
221            .search_git_history(req)
222            .await
223            .map_err(|e| format!("{:#}", e))?;
224
225        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
226    }
227
228    #[tool(description = "Find the definition of a symbol at a given file location (line and column)")]
229    async fn find_definition(
230        &self,
231        Parameters(req): Parameters<FindDefinitionRequest>,
232    ) -> Result<String, String> {
233        // Validate request inputs
234        req.validate()?;
235
236        let response = self
237            .client
238            .find_definition(req)
239            .await
240            .map_err(|e| format!("{:#}", e))?;
241
242        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
243    }
244
245    #[tool(description = "Find all references to a symbol at a given file location")]
246    async fn find_references(
247        &self,
248        Parameters(req): Parameters<FindReferencesRequest>,
249    ) -> Result<String, String> {
250        // Validate request inputs
251        req.validate()?;
252
253        let response = self
254            .client
255            .find_references(req)
256            .await
257            .map_err(|e| format!("{:#}", e))?;
258
259        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
260    }
261
262    #[tool(description = "Get the call graph for a function at a given file location (callers and callees)")]
263    async fn get_call_graph(
264        &self,
265        Parameters(req): Parameters<GetCallGraphRequest>,
266    ) -> Result<String, String> {
267        // Validate request inputs
268        req.validate()?;
269
270        let response = self
271            .client
272            .get_call_graph(req)
273            .await
274            .map_err(|e| format!("{:#}", e))?;
275
276        serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
277    }
278}
279
280// Prompts for slash commands
281#[prompt_router]
282impl RagMcpServer {
283    #[prompt(
284        name = "index",
285        description = "Index a codebase directory to enable semantic search (automatically performs full or incremental based on existing index)"
286    )]
287    async fn index_prompt(
288        &self,
289        Parameters(args): Parameters<serde_json::Value>,
290    ) -> Result<GetPromptResult, McpError> {
291        let path = args.get("path").and_then(|v| v.as_str()).unwrap_or(".");
292
293        let messages = vec![PromptMessage::new_text(
294            PromptMessageRole::User,
295            format!(
296                "Please index the codebase at path: '{}'. This will automatically perform a full index if this is the first time, or an incremental update if the codebase has been indexed before.",
297                path
298            ),
299        )];
300
301        Ok(GetPromptResult {
302            description: Some(format!(
303                "Index codebase at {} (auto-detects full/incremental)",
304                path
305            )),
306            messages,
307        })
308    }
309
310    #[prompt(
311        name = "query",
312        description = "Search the indexed codebase using semantic search"
313    )]
314    async fn query_prompt(
315        &self,
316        Parameters(args): Parameters<serde_json::Value>,
317    ) -> Result<Vec<PromptMessage>, McpError> {
318        let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
319
320        Ok(vec![PromptMessage::new_text(
321            PromptMessageRole::User,
322            format!("Please search the codebase for: {}", query),
323        )])
324    }
325
326    #[prompt(
327        name = "stats",
328        description = "Get statistics about the indexed codebase"
329    )]
330    async fn stats_prompt(&self) -> Vec<PromptMessage> {
331        vec![PromptMessage::new_text(
332            PromptMessageRole::User,
333            "Please get statistics about the indexed codebase.",
334        )]
335    }
336
337    #[prompt(
338        name = "clear",
339        description = "Clear all indexed data from the vector database"
340    )]
341    async fn clear_prompt(&self) -> Vec<PromptMessage> {
342        vec![PromptMessage::new_text(
343            PromptMessageRole::User,
344            "Please clear all indexed data from the vector database.",
345        )]
346    }
347
348    #[prompt(
349        name = "search",
350        description = "Advanced search with filters (file type, language, path)"
351    )]
352    async fn search_prompt(
353        &self,
354        Parameters(args): Parameters<serde_json::Value>,
355    ) -> Result<Vec<PromptMessage>, McpError> {
356        let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
357
358        Ok(vec![PromptMessage::new_text(
359            PromptMessageRole::User,
360            format!("Please perform an advanced search for: {}", query),
361        )])
362    }
363
364    #[prompt(
365        name = "git-search",
366        description = "Search git commit history using semantic search (automatically indexes commits on-demand)"
367    )]
368    async fn git_search_prompt(
369        &self,
370        Parameters(args): Parameters<serde_json::Value>,
371    ) -> Result<Vec<PromptMessage>, McpError> {
372        let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
373        let path = args.get("path").and_then(|v| v.as_str()).unwrap_or(".");
374
375        Ok(vec![PromptMessage::new_text(
376            PromptMessageRole::User,
377            format!(
378                "Please search git commit history at path '{}' for: {}. This will automatically index commits as needed.",
379                path, query
380            ),
381        )])
382    }
383
384    #[prompt(
385        name = "definition",
386        description = "Find where a symbol is defined at a given file location"
387    )]
388    async fn definition_prompt(
389        &self,
390        Parameters(args): Parameters<serde_json::Value>,
391    ) -> Result<Vec<PromptMessage>, McpError> {
392        let file = args.get("file").and_then(|v| v.as_str()).unwrap_or("");
393        let line = args.get("line").and_then(|v| v.as_u64()).unwrap_or(1);
394        let column = args.get("column").and_then(|v| v.as_u64()).unwrap_or(0);
395
396        Ok(vec![PromptMessage::new_text(
397            PromptMessageRole::User,
398            format!(
399                "Please find the definition of the symbol at file '{}', line {}, column {}.",
400                file, line, column
401            ),
402        )])
403    }
404
405    #[prompt(
406        name = "references",
407        description = "Find all references to a symbol at a given file location"
408    )]
409    async fn references_prompt(
410        &self,
411        Parameters(args): Parameters<serde_json::Value>,
412    ) -> Result<Vec<PromptMessage>, McpError> {
413        let file = args.get("file").and_then(|v| v.as_str()).unwrap_or("");
414        let line = args.get("line").and_then(|v| v.as_u64()).unwrap_or(1);
415        let column = args.get("column").and_then(|v| v.as_u64()).unwrap_or(0);
416
417        Ok(vec![PromptMessage::new_text(
418            PromptMessageRole::User,
419            format!(
420                "Please find all references to the symbol at file '{}', line {}, column {}.",
421                file, line, column
422            ),
423        )])
424    }
425
426    #[prompt(
427        name = "callgraph",
428        description = "Get the call graph (callers and callees) for a function at a given location"
429    )]
430    async fn callgraph_prompt(
431        &self,
432        Parameters(args): Parameters<serde_json::Value>,
433    ) -> Result<Vec<PromptMessage>, McpError> {
434        let file = args.get("file").and_then(|v| v.as_str()).unwrap_or("");
435        let line = args.get("line").and_then(|v| v.as_u64()).unwrap_or(1);
436        let column = args.get("column").and_then(|v| v.as_u64()).unwrap_or(0);
437
438        Ok(vec![PromptMessage::new_text(
439            PromptMessageRole::User,
440            format!(
441                "Please get the call graph for the function at file '{}', line {}, column {}. Show what calls this function and what it calls.",
442                file, line, column
443            ),
444        )])
445    }
446}
447
448#[tool_handler(router = self.tool_router)]
449#[prompt_handler]
450impl ServerHandler for RagMcpServer {
451    fn get_info(&self) -> ServerInfo {
452        ServerInfo {
453            protocol_version: ProtocolVersion::default(),
454            capabilities: ServerCapabilities::builder()
455                .enable_tools()
456                .enable_prompts()
457                .build(),
458            server_info: Implementation {
459                name: "project".into(),
460                title: Some("Project RAG - Code Understanding with Semantic Search".into()),
461                version: env!("CARGO_PKG_VERSION").into(),
462                icons: None,
463                website_url: None,
464            },
465            instructions: Some(
466                "RAG-based codebase indexing and semantic search. \
467                Use index_codebase to create embeddings (automatically performs full or incremental indexing), \
468                query_codebase to search, and search_by_filters for advanced queries."
469                    .into(),
470            ),
471        }
472    }
473}
474
475impl RagMcpServer {
476    pub async fn serve_stdio() -> Result<()> {
477        tracing::info!("Starting RAG MCP server");
478
479        let server = Self::new().await.context("Failed to create MCP server")?;
480
481        let transport = rmcp::transport::io::stdio();
482
483        server.serve(transport).await?.waiting().await?;
484
485        Ok(())
486    }
487}
488
489#[cfg(test)]
490mod tests;