1use crate::client::RagClient;
2use crate::types::*;
3
4use anyhow::{Context, Result};
5use rmcp::{
6 ErrorData as McpError, Peer, RoleServer, ServerHandler, ServiceExt,
7 handler::server::{router::prompt::PromptRouter, tool::ToolRouter, wrapper::Parameters},
8 model::*,
9 prompt, prompt_handler, prompt_router,
10 service::RequestContext,
11 tool, tool_handler, tool_router,
12};
13use std::sync::Arc;
14use tokio_util::sync::CancellationToken;
15
16struct CancelOnDropGuard {
20 token: CancellationToken,
21}
22
23impl CancelOnDropGuard {
24 fn new(token: CancellationToken) -> Self {
25 Self { token }
26 }
27}
28
29impl Drop for CancelOnDropGuard {
30 fn drop(&mut self) {
31 if !self.token.is_cancelled() {
32 tracing::info!("Tool handler dropped, triggering cancellation");
33 self.token.cancel();
34 }
35 }
36}
37
38#[derive(Clone)]
39pub struct RagMcpServer {
40 client: Arc<RagClient>,
41 tool_router: ToolRouter<Self>,
42 prompt_router: PromptRouter<Self>,
43}
44
45impl RagMcpServer {
46 pub async fn new() -> Result<Self> {
48 let client = RagClient::new().await?;
49 Self::with_client(Arc::new(client))
50 }
51
52 pub fn with_client(client: Arc<RagClient>) -> Result<Self> {
54 Ok(Self {
55 client,
56 tool_router: Self::tool_router(),
57 prompt_router: Self::prompt_router(),
58 })
59 }
60
61 pub fn client(&self) -> &RagClient {
63 &self.client
64 }
65
66 pub async fn with_config(config: crate::config::Config) -> Result<Self> {
68 let client = RagClient::with_config(config).await?;
69 Self::with_client(Arc::new(client))
70 }
71
72 pub fn normalize_path(path: &str) -> Result<String> {
74 RagClient::normalize_path(path)
75 }
76
77 #[allow(clippy::too_many_arguments)]
79 pub async fn do_index(
80 &self,
81 path: String,
82 project: Option<String>,
83 include_patterns: Vec<String>,
84 exclude_patterns: Vec<String>,
85 max_file_size: usize,
86 peer: Option<Peer<RoleServer>>,
87 progress_token: Option<ProgressToken>,
88 cancel_token: Option<CancellationToken>,
89 ) -> Result<IndexResponse> {
90 let cancel_token = cancel_token.unwrap_or_default();
91 crate::client::indexing::do_index_smart(
92 &self.client,
93 path,
94 project,
95 include_patterns,
96 exclude_patterns,
97 max_file_size,
98 peer,
99 progress_token,
100 cancel_token,
101 )
102 .await
103 }
104}
105
106#[tool_router(router = tool_router)]
107impl RagMcpServer {
108 #[tool(
109 description = "Index a codebase directory, creating embeddings for semantic search. Automatically performs full indexing for new codebases or incremental updates for previously indexed codebases."
110 )]
111 async fn index_codebase(
112 &self,
113 meta: Meta,
114 peer: Peer<RoleServer>,
115 Parameters(req): Parameters<IndexRequest>,
116 ) -> Result<String, String> {
117 req.validate()?;
119
120 let progress_token = meta.get_progress_token();
122
123 let cancel_token = CancellationToken::new();
127 let cancel_token_for_index = cancel_token.clone();
128
129 let _cancel_guard = CancelOnDropGuard::new(cancel_token);
131
132 let response = crate::client::indexing::do_index_smart(
133 &self.client,
134 req.path,
135 req.project,
136 req.include_patterns,
137 req.exclude_patterns,
138 req.max_file_size,
139 Some(peer),
140 progress_token,
141 cancel_token_for_index,
142 )
143 .await
144 .map_err(|e| format!("{:#}", e))?; serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
147 }
148
149 #[tool(description = "Query the indexed codebase using semantic search")]
150 async fn query_codebase(
151 &self,
152 Parameters(req): Parameters<QueryRequest>,
153 ) -> Result<String, String> {
154 req.validate()?;
156
157 let response = self
158 .client
159 .query_codebase(req)
160 .await
161 .map_err(|e| format!("{:#}", e))?;
162
163 serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
164 }
165
166 #[tool(description = "Get statistics about the indexed codebase")]
167 async fn get_statistics(
168 &self,
169 Parameters(_req): Parameters<StatisticsRequest>,
170 ) -> Result<String, String> {
171 let response = self
172 .client
173 .get_statistics()
174 .await
175 .map_err(|e| format!("{:#}", e))?;
176
177 serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
178 }
179
180 #[tool(description = "Clear all indexed data from the vector database")]
181 async fn clear_index(
182 &self,
183 Parameters(_req): Parameters<ClearRequest>,
184 ) -> Result<String, String> {
185 let response = self
186 .client
187 .clear_index()
188 .await
189 .map_err(|e| format!("{:#}", e))?;
190
191 serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
192 }
193
194 #[tool(description = "Advanced search with filters for file type, language, and path patterns")]
195 async fn search_by_filters(
196 &self,
197 Parameters(req): Parameters<AdvancedSearchRequest>,
198 ) -> Result<String, String> {
199 req.validate()?;
201
202 let response = self
203 .client
204 .search_with_filters(req)
205 .await
206 .map_err(|e| format!("{:#}", e))?;
207
208 serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
209 }
210
211 #[tool(description = "Search git commit history using semantic search with on-demand indexing")]
212 async fn search_git_history(
213 &self,
214 Parameters(req): Parameters<SearchGitHistoryRequest>,
215 ) -> Result<String, String> {
216 req.validate()?;
218
219 let response = self
220 .client
221 .search_git_history(req)
222 .await
223 .map_err(|e| format!("{:#}", e))?;
224
225 serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
226 }
227
228 #[tool(description = "Find the definition of a symbol at a given file location (line and column)")]
229 async fn find_definition(
230 &self,
231 Parameters(req): Parameters<FindDefinitionRequest>,
232 ) -> Result<String, String> {
233 req.validate()?;
235
236 let response = self
237 .client
238 .find_definition(req)
239 .await
240 .map_err(|e| format!("{:#}", e))?;
241
242 serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
243 }
244
245 #[tool(description = "Find all references to a symbol at a given file location")]
246 async fn find_references(
247 &self,
248 Parameters(req): Parameters<FindReferencesRequest>,
249 ) -> Result<String, String> {
250 req.validate()?;
252
253 let response = self
254 .client
255 .find_references(req)
256 .await
257 .map_err(|e| format!("{:#}", e))?;
258
259 serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
260 }
261
262 #[tool(description = "Get the call graph for a function at a given file location (callers and callees)")]
263 async fn get_call_graph(
264 &self,
265 Parameters(req): Parameters<GetCallGraphRequest>,
266 ) -> Result<String, String> {
267 req.validate()?;
269
270 let response = self
271 .client
272 .get_call_graph(req)
273 .await
274 .map_err(|e| format!("{:#}", e))?;
275
276 serde_json::to_string_pretty(&response).map_err(|e| format!("Serialization failed: {}", e))
277 }
278}
279
280#[prompt_router]
282impl RagMcpServer {
283 #[prompt(
284 name = "index",
285 description = "Index a codebase directory to enable semantic search (automatically performs full or incremental based on existing index)"
286 )]
287 async fn index_prompt(
288 &self,
289 Parameters(args): Parameters<serde_json::Value>,
290 ) -> Result<GetPromptResult, McpError> {
291 let path = args.get("path").and_then(|v| v.as_str()).unwrap_or(".");
292
293 let messages = vec![PromptMessage::new_text(
294 PromptMessageRole::User,
295 format!(
296 "Please index the codebase at path: '{}'. This will automatically perform a full index if this is the first time, or an incremental update if the codebase has been indexed before.",
297 path
298 ),
299 )];
300
301 Ok(GetPromptResult {
302 description: Some(format!(
303 "Index codebase at {} (auto-detects full/incremental)",
304 path
305 )),
306 messages,
307 })
308 }
309
310 #[prompt(
311 name = "query",
312 description = "Search the indexed codebase using semantic search"
313 )]
314 async fn query_prompt(
315 &self,
316 Parameters(args): Parameters<serde_json::Value>,
317 ) -> Result<Vec<PromptMessage>, McpError> {
318 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
319
320 Ok(vec![PromptMessage::new_text(
321 PromptMessageRole::User,
322 format!("Please search the codebase for: {}", query),
323 )])
324 }
325
326 #[prompt(
327 name = "stats",
328 description = "Get statistics about the indexed codebase"
329 )]
330 async fn stats_prompt(&self) -> Vec<PromptMessage> {
331 vec![PromptMessage::new_text(
332 PromptMessageRole::User,
333 "Please get statistics about the indexed codebase.",
334 )]
335 }
336
337 #[prompt(
338 name = "clear",
339 description = "Clear all indexed data from the vector database"
340 )]
341 async fn clear_prompt(&self) -> Vec<PromptMessage> {
342 vec![PromptMessage::new_text(
343 PromptMessageRole::User,
344 "Please clear all indexed data from the vector database.",
345 )]
346 }
347
348 #[prompt(
349 name = "search",
350 description = "Advanced search with filters (file type, language, path)"
351 )]
352 async fn search_prompt(
353 &self,
354 Parameters(args): Parameters<serde_json::Value>,
355 ) -> Result<Vec<PromptMessage>, McpError> {
356 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
357
358 Ok(vec![PromptMessage::new_text(
359 PromptMessageRole::User,
360 format!("Please perform an advanced search for: {}", query),
361 )])
362 }
363
364 #[prompt(
365 name = "git-search",
366 description = "Search git commit history using semantic search (automatically indexes commits on-demand)"
367 )]
368 async fn git_search_prompt(
369 &self,
370 Parameters(args): Parameters<serde_json::Value>,
371 ) -> Result<Vec<PromptMessage>, McpError> {
372 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
373 let path = args.get("path").and_then(|v| v.as_str()).unwrap_or(".");
374
375 Ok(vec![PromptMessage::new_text(
376 PromptMessageRole::User,
377 format!(
378 "Please search git commit history at path '{}' for: {}. This will automatically index commits as needed.",
379 path, query
380 ),
381 )])
382 }
383
384 #[prompt(
385 name = "definition",
386 description = "Find where a symbol is defined at a given file location"
387 )]
388 async fn definition_prompt(
389 &self,
390 Parameters(args): Parameters<serde_json::Value>,
391 ) -> Result<Vec<PromptMessage>, McpError> {
392 let file = args.get("file").and_then(|v| v.as_str()).unwrap_or("");
393 let line = args.get("line").and_then(|v| v.as_u64()).unwrap_or(1);
394 let column = args.get("column").and_then(|v| v.as_u64()).unwrap_or(0);
395
396 Ok(vec![PromptMessage::new_text(
397 PromptMessageRole::User,
398 format!(
399 "Please find the definition of the symbol at file '{}', line {}, column {}.",
400 file, line, column
401 ),
402 )])
403 }
404
405 #[prompt(
406 name = "references",
407 description = "Find all references to a symbol at a given file location"
408 )]
409 async fn references_prompt(
410 &self,
411 Parameters(args): Parameters<serde_json::Value>,
412 ) -> Result<Vec<PromptMessage>, McpError> {
413 let file = args.get("file").and_then(|v| v.as_str()).unwrap_or("");
414 let line = args.get("line").and_then(|v| v.as_u64()).unwrap_or(1);
415 let column = args.get("column").and_then(|v| v.as_u64()).unwrap_or(0);
416
417 Ok(vec![PromptMessage::new_text(
418 PromptMessageRole::User,
419 format!(
420 "Please find all references to the symbol at file '{}', line {}, column {}.",
421 file, line, column
422 ),
423 )])
424 }
425
426 #[prompt(
427 name = "callgraph",
428 description = "Get the call graph (callers and callees) for a function at a given location"
429 )]
430 async fn callgraph_prompt(
431 &self,
432 Parameters(args): Parameters<serde_json::Value>,
433 ) -> Result<Vec<PromptMessage>, McpError> {
434 let file = args.get("file").and_then(|v| v.as_str()).unwrap_or("");
435 let line = args.get("line").and_then(|v| v.as_u64()).unwrap_or(1);
436 let column = args.get("column").and_then(|v| v.as_u64()).unwrap_or(0);
437
438 Ok(vec![PromptMessage::new_text(
439 PromptMessageRole::User,
440 format!(
441 "Please get the call graph for the function at file '{}', line {}, column {}. Show what calls this function and what it calls.",
442 file, line, column
443 ),
444 )])
445 }
446}
447
448#[tool_handler(router = self.tool_router)]
449#[prompt_handler]
450impl ServerHandler for RagMcpServer {
451 fn get_info(&self) -> ServerInfo {
452 ServerInfo {
453 protocol_version: ProtocolVersion::default(),
454 capabilities: ServerCapabilities::builder()
455 .enable_tools()
456 .enable_prompts()
457 .build(),
458 server_info: Implementation {
459 name: "project".into(),
460 title: Some("Project RAG - Code Understanding with Semantic Search".into()),
461 version: env!("CARGO_PKG_VERSION").into(),
462 icons: None,
463 website_url: None,
464 },
465 instructions: Some(
466 "RAG-based codebase indexing and semantic search. \
467 Use index_codebase to create embeddings (automatically performs full or incremental indexing), \
468 query_codebase to search, and search_by_filters for advanced queries."
469 .into(),
470 ),
471 }
472 }
473}
474
475impl RagMcpServer {
476 pub async fn serve_stdio() -> Result<()> {
477 tracing::info!("Starting RAG MCP server");
478
479 let server = Self::new().await.context("Failed to create MCP server")?;
480
481 let transport = rmcp::transport::io::stdio();
482
483 server.serve(transport).await?.waiting().await?;
484
485 Ok(())
486 }
487}
488
489#[cfg(test)]
490mod tests;