use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use anyhow::Result;
use clap::Args;
use rmcp::{
ErrorData as McpError, RoleServer, ServerHandler, ServiceExt,
handler::server::{router::prompt::PromptRouter, tool::ToolRouter, wrapper::Parameters},
model::{
CallToolResult, Content, GetPromptRequestParams, GetPromptResult, ListPromptsResult,
PaginatedRequestParams, PromptMessage, PromptMessageContent, PromptMessageRole,
ServerCapabilities, ServerInfo,
},
prompt, prompt_handler, prompt_router, schemars,
schemars::JsonSchema,
service::RequestContext,
tool, tool_handler, tool_router,
transport::stdio,
};
use serde::Deserialize;
use super::db::{SearchDb, SearchResult};
use super::snippet::SnippetExtractor;
use crate::index::format::{SymbolEntry, SymbolOutput};
use crate::mount::MountTable;
use crate::mount::handler::flush_dirty_mounts;
use crate::utils::format::{
EnrichedSearchResult, ExploreResult, OutputFormat, ReferenceWithSnippet, SymbolWithSnippet,
format_explore, format_references, format_search_results, format_symbols,
};
use crate::utils::manifest;
fn normalize_context_lines(value: Option<i32>) -> i32 {
match value {
Some(0) => 1, Some(n) => n, None => 10, }
}
#[derive(Debug, Deserialize, JsonSchema, Args)]
pub struct SearchParams {
pub query: String,
#[arg(short, long, value_delimiter = ',')]
pub scope: Option<Vec<String>>,
#[arg(short, long, value_delimiter = ',')]
pub kind: Option<Vec<String>>,
#[arg(short = 'f', long)]
pub path: Option<String>,
#[arg(short, long)]
pub project: Option<String>,
#[arg(short = 'v', long)]
pub visibility: Option<String>,
#[arg(short, long)]
pub limit: Option<u32>,
#[arg(short, long)]
pub offset: Option<u32>,
#[arg(long)]
pub context_lines: Option<i32>,
#[arg(long, default_value = "text")]
#[serde(default)]
pub format: OutputFormat,
}
#[derive(Debug, Deserialize, JsonSchema, Args)]
pub struct GetFileSymbolsParams {
pub file: String,
#[arg(short = 'v', long)]
pub visibility: Option<String>,
#[arg(long)]
pub context_lines: Option<i32>,
#[arg(short, long)]
pub limit: Option<u32>,
#[arg(short, long)]
pub offset: Option<u32>,
#[arg(long, default_value = "text")]
#[serde(default)]
pub format: OutputFormat,
}
#[derive(Debug, Deserialize, JsonSchema, Args)]
pub struct GetChildrenParams {
pub file: String,
pub parent: String,
#[arg(short = 'v', long)]
pub visibility: Option<String>,
#[arg(long)]
pub context_lines: Option<i32>,
#[arg(short, long)]
pub limit: Option<u32>,
#[arg(short, long)]
pub offset: Option<u32>,
#[arg(long, default_value = "text")]
#[serde(default)]
pub format: OutputFormat,
}
#[derive(Debug, Deserialize, JsonSchema, Args)]
pub struct GetCallersParams {
pub name: String,
#[arg(short = 'k', long = "ref-kind")]
pub reference_kind: Option<String>,
#[arg(short, long)]
pub project: Option<String>,
#[arg(short = 'v', long)]
pub visibility: Option<String>,
#[arg(short, long)]
pub limit: Option<u32>,
#[arg(short, long)]
pub offset: Option<u32>,
#[arg(long)]
pub context_lines: Option<i32>,
#[arg(long, default_value = "text")]
#[serde(default)]
pub format: OutputFormat,
}
#[derive(Debug, Deserialize, JsonSchema, Args)]
pub struct GetCalleesParams {
pub caller: String,
#[arg(short = 'k', long = "ref-kind")]
pub reference_kind: Option<String>,
#[arg(short, long)]
pub project: Option<String>,
#[arg(short = 'v', long)]
pub visibility: Option<String>,
#[arg(short, long)]
pub limit: Option<u32>,
#[arg(short, long)]
pub offset: Option<u32>,
#[arg(long)]
pub context_lines: Option<i32>,
#[arg(long, default_value = "text")]
#[serde(default)]
pub format: OutputFormat,
}
#[derive(Debug, Deserialize, JsonSchema, Args)]
pub struct ExploreParams {
pub path: Option<String>,
#[arg(short, long)]
pub project: Option<String>,
#[arg(short = 'v', long)]
pub visibility: Option<String>,
#[arg(short, long, default_value = "200")]
pub max_entries: u32,
#[arg(long, default_value = "text")]
#[serde(default)]
pub format: OutputFormat,
}
#[derive(Clone)]
pub struct CodeIndexServer {
db: Arc<Mutex<SearchDb>>,
mount_table: Arc<Mutex<MountTable>>,
snippet_extractor: SnippetExtractor,
tool_router: ToolRouter<Self>,
prompt_router: PromptRouter<Self>,
}
impl CodeIndexServer {
pub fn new(db: Arc<Mutex<SearchDb>>, mount_table: Arc<Mutex<MountTable>>) -> Self {
let workspace_root = mount_table
.lock()
.expect("mount table lock poisoned")
.workspace_root()
.to_path_buf();
Self {
db,
mount_table,
snippet_extractor: SnippetExtractor::new(workspace_root),
tool_router: Self::tool_router(),
prompt_router: Self::prompt_router(),
}
}
fn enrich_with_snippets(
&self,
symbols: Vec<SymbolEntry>,
context_lines: i32,
) -> Vec<SymbolWithSnippet> {
symbols
.into_iter()
.filter_map(|symbol| {
if !self
.snippet_extractor
.file_exists(&symbol.project, &symbol.file)
{
return None;
}
let snippet = self.snippet_extractor.extract_snippet(
&symbol.project,
&symbol.file,
symbol.line[0],
symbol.line[1],
context_lines,
);
Some(SymbolOutput::from_entry(&symbol, snippet))
})
.collect()
}
fn enrich_refs_with_snippets(
&self,
references: Vec<crate::index::format::ReferenceEntry>,
context_lines: i32,
) -> Vec<ReferenceWithSnippet> {
references
.into_iter()
.filter_map(|reference| {
if !self
.snippet_extractor
.file_exists(&reference.project, &reference.file)
{
return None;
}
let snippet = self.snippet_extractor.extract_snippet(
&reference.project,
&reference.file,
reference.line[0],
reference.line[1],
context_lines,
);
Some(ReferenceWithSnippet {
reference,
context: snippet,
})
})
.collect()
}
}
#[tool_router]
impl CodeIndexServer {
#[tool(
description = "Search symbols, files, and texts. FTS5 with BM25 ranking.\n\n\
**Query syntax:**\n\
- `foo bar` — match both (implicit AND)\n\
- `foo|bar` or `foo OR bar` — match either\n\
- `foo*` — prefix (matches fooBar, fooHandler)\n\
- `\"exact phrase\"` — literal match\n\
- `foo NOT test` — exclude term\n\n\
**Tip:** Use `|` to search multiple terms efficiently: `handler|middleware|context`\n\n\
**Params:** query (required), limit (default 10), snippet_lines (default 10)\n\n\
**Optional filters:** scope, kind, path, project, visibility"
)]
pub async fn search(
&self,
Parameters(params): Parameters<SearchParams>,
) -> Result<CallToolResult, McpError> {
let db = self
.db
.lock()
.map_err(|e| McpError::internal_error(format!("db lock poisoned: {e}"), None))?;
let scope = params.scope.unwrap_or_default();
let limit = params.limit.unwrap_or(10);
let offset = params.offset.unwrap_or(0);
let kind = params.kind.unwrap_or_default();
let results = db
.search(
¶ms.query,
&scope,
&kind,
params.path.as_deref(),
params.project.as_deref(),
params.visibility.as_deref(),
limit,
offset,
)
.map_err(|e| McpError::internal_error(format!("search failed: {e}"), None))?;
drop(db);
let context_lines = normalize_context_lines(params.context_lines);
let enriched: Vec<EnrichedSearchResult> = results
.into_iter()
.filter_map(|result| match result {
SearchResult::Symbol(symbol) => {
if !self
.snippet_extractor
.file_exists(&symbol.project, &symbol.file)
{
return None;
}
let snippet = self.snippet_extractor.extract_snippet(
&symbol.project,
&symbol.file,
symbol.line[0],
symbol.line[1],
context_lines,
);
Some(EnrichedSearchResult::Symbol(SymbolOutput::from_entry(
&symbol, snippet,
)))
}
SearchResult::File(file) => Some(EnrichedSearchResult::File(file)),
SearchResult::Text(text) => Some(EnrichedSearchResult::Text(text)),
})
.collect();
let output = format_search_results(&enriched, params.format)
.map_err(|e| McpError::internal_error(format!("serialization failed: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(output)]))
}
#[tool(
description = "Get all symbols in a file, ordered by line number. Returns code snippets by default."
)]
pub async fn get_file_symbols(
&self,
Parameters(params): Parameters<GetFileSymbolsParams>,
) -> Result<CallToolResult, McpError> {
let limit = params.limit.unwrap_or(100);
let offset = params.offset.unwrap_or(0);
let db = self
.db
.lock()
.map_err(|e| McpError::internal_error(format!("db lock poisoned: {e}"), None))?;
let results = db
.get_file_symbols(¶ms.file, params.visibility.as_deref(), limit, offset)
.map_err(|e| McpError::internal_error(format!("get_file_symbols failed: {e}"), None))?;
drop(db);
let context_lines = normalize_context_lines(params.context_lines);
let enriched = self.enrich_with_snippets(results, context_lines);
let output = format_symbols(&enriched, params.format)
.map_err(|e| McpError::internal_error(format!("serialization failed: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(output)]))
}
#[tool(
description = "Get direct children of a symbol (e.g. methods of a class). Returns code snippets by default."
)]
pub async fn get_children(
&self,
Parameters(params): Parameters<GetChildrenParams>,
) -> Result<CallToolResult, McpError> {
let limit = params.limit.unwrap_or(100);
let offset = params.offset.unwrap_or(0);
let db = self
.db
.lock()
.map_err(|e| McpError::internal_error(format!("db lock poisoned: {e}"), None))?;
let results = db
.get_children(
¶ms.file,
¶ms.parent,
params.visibility.as_deref(),
limit,
offset,
)
.map_err(|e| McpError::internal_error(format!("get_children failed: {e}"), None))?;
drop(db);
let context_lines = normalize_context_lines(params.context_lines);
let enriched = self.enrich_with_snippets(results, context_lines);
let output = format_symbols(&enriched, params.format)
.map_err(|e| McpError::internal_error(format!("serialization failed: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(output)]))
}
#[tool(
description = "Explore a project's file structure. Returns project metadata, subprojects (if any), and files grouped by directory. Use 'path' to scope to a subdirectory. Files are capped per directory if total exceeds max_entries (default: 200)."
)]
pub async fn explore(
&self,
Parameters(params): Parameters<ExploreParams>,
) -> Result<CallToolResult, McpError> {
let mut project_path = params.project.as_deref().unwrap_or("").to_string();
let mut path_filter = params.path.clone();
let db = self
.db
.lock()
.map_err(|e| McpError::internal_error(format!("db lock poisoned: {e}"), None))?;
if params.project.is_none()
&& let Some(ref path) = path_filter
{
let all_projects = db.list_projects().map_err(|e| {
McpError::internal_error(format!("list_projects failed: {e}"), None)
})?;
for proj in &all_projects {
if proj.is_empty() {
continue;
}
if path == proj {
project_path = proj.clone();
path_filter = None;
break;
} else if path.starts_with(&format!("{}/", proj)) {
project_path = proj.clone();
path_filter = Some(path[proj.len() + 1..].to_string());
break;
}
}
}
drop(db);
let mt = self.mount_table.lock().map_err(|e| {
McpError::internal_error(format!("mount table lock poisoned: {e}"), None)
})?;
let project_root = mt.project_root(&project_path).ok_or_else(|| {
McpError::invalid_params(format!("Project not found: '{}'", project_path), None)
})?;
let metadata = manifest::extract_metadata(&project_root);
drop(mt);
let db = self
.db
.lock()
.map_err(|e| McpError::internal_error(format!("db lock poisoned: {e}"), None))?;
let subprojects: Vec<String> = db
.list_projects()
.map_err(|e| McpError::internal_error(format!("list_projects failed: {e}"), None))?
.into_iter()
.filter(|p| {
if project_path.is_empty() {
!p.is_empty()
} else {
p.starts_with(&format!("{}/", project_path))
}
})
.collect();
let full_overview = db.explore_dir_overview(&project_path, None).map_err(|e| {
McpError::internal_error(format!("explore_dir_overview failed: {e}"), None)
})?;
let filtered_overview = if path_filter.is_some() {
db.explore_dir_overview(&project_path, path_filter.as_deref())
.map_err(|e| {
McpError::internal_error(format!("explore_dir_overview failed: {e}"), None)
})?
} else {
full_overview.clone()
};
let max_level = super::db::visibility_max_level(params.visibility.as_deref(), "public");
let max_entries = params.max_entries as usize;
let mut total_visible_files = 0usize;
let mut num_visible_groups = 0usize;
let mut total_map: BTreeMap<(String, Option<String>), usize> = BTreeMap::new();
for (dir, lang, min_vis, count) in &filtered_overview {
*total_map.entry((dir.clone(), lang.clone())).or_default() += *count;
let passes_filter = match max_level {
Some(max) => *min_vis <= max,
None => true,
};
if passes_filter && lang.is_some() {
total_visible_files += count;
num_visible_groups += 1;
}
}
let cap = if total_visible_files <= max_entries {
usize::MAX
} else {
(max_entries / num_visible_groups.max(1)).max(1)
};
let files = db
.explore_files_capped(
&project_path,
path_filter.as_deref(),
params.visibility.as_deref(),
cap,
)
.map_err(|e| {
McpError::internal_error(format!("explore_files_capped failed: {e}"), None)
})?;
drop(db);
let mut files_by_dir: BTreeMap<String, Vec<String>> = BTreeMap::new();
let mut fetched_counts: BTreeMap<(String, Option<String>), usize> = BTreeMap::new();
for (dir, filename, lang) in files {
files_by_dir.entry(dir.clone()).or_default().push(filename);
*fetched_counts.entry((dir, lang)).or_default() += 1;
}
let mut result_dirs: BTreeMap<String, Vec<String>> = BTreeMap::new();
let mut all_dirs: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
for (dir, _lang, _min_vis, _count) in &full_overview {
all_dirs.insert(dir.clone());
}
for dir in all_dirs {
let mut entries = files_by_dir.remove(&dir).unwrap_or_default();
let mut remainders: BTreeMap<String, usize> = BTreeMap::new(); for ((d, lang), total) in &total_map {
if d != &dir {
continue;
}
let fetched = fetched_counts
.get(&(d.clone(), lang.clone()))
.copied()
.unwrap_or(0);
let remaining = total.saturating_sub(fetched);
if remaining > 0 {
let lang_name = lang.as_deref().unwrap_or("other");
*remainders.entry(lang_name.to_string()).or_default() += remaining;
}
}
for (lang, count) in remainders {
entries.push(format!("+{} {} files", count, lang));
}
result_dirs.insert(dir, entries);
}
let result = ExploreResult {
project: if project_path.is_empty() {
None
} else {
Some(project_path.to_string())
},
metadata,
subprojects: if subprojects.is_empty() {
None
} else {
Some(subprojects)
},
directories: result_dirs,
};
let output = format_explore(&result, params.format)
.map_err(|e| McpError::internal_error(format!("serialization failed: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(output)]))
}
#[tool(
description = "Find all places that call or reference a symbol. Returns references sorted by file and line. Useful for finding callers of a function/method. Note: For struct/class fields and methods, use `get_children` instead."
)]
pub async fn get_callers(
&self,
Parameters(params): Parameters<GetCallersParams>,
) -> Result<CallToolResult, McpError> {
let limit = params.limit.unwrap_or(100);
let offset = params.offset.unwrap_or(0);
let db = self
.db
.lock()
.map_err(|e| McpError::internal_error(format!("db lock poisoned: {e}"), None))?;
let results = db
.get_callers(
¶ms.name,
params.reference_kind.as_deref(),
params.project.as_deref(),
params.visibility.as_deref(),
limit,
offset,
)
.map_err(|e| McpError::internal_error(format!("get_callers failed: {e}"), None))?;
drop(db);
let context_lines = normalize_context_lines(params.context_lines);
let enriched = self.enrich_refs_with_snippets(results, context_lines);
let output = format_references(&enriched, params.format)
.map_err(|e| McpError::internal_error(format!("serialization failed: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(output)]))
}
#[tool(
description = "Find all symbols that a given function/method calls or references. Returns references sorted by file and line. Useful for understanding dependencies and call chains."
)]
pub async fn get_callees(
&self,
Parameters(params): Parameters<GetCalleesParams>,
) -> Result<CallToolResult, McpError> {
let limit = params.limit.unwrap_or(100);
let offset = params.offset.unwrap_or(0);
let db = self
.db
.lock()
.map_err(|e| McpError::internal_error(format!("db lock poisoned: {e}"), None))?;
let results = db
.get_callees(
¶ms.caller,
params.reference_kind.as_deref(),
params.project.as_deref(),
params.visibility.as_deref(),
limit,
offset,
)
.map_err(|e| McpError::internal_error(format!("get_callees failed: {e}"), None))?;
drop(db);
let context_lines = normalize_context_lines(params.context_lines);
let enriched = self.enrich_refs_with_snippets(results, context_lines);
let output = format_references(&enriched, params.format)
.map_err(|e| McpError::internal_error(format!("serialization failed: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(output)]))
}
#[tool(
description = "Flush pending index changes to .codeindex/ files on disk. Call this when you need the index persisted (e.g., before git operations). Returns the number of projects flushed."
)]
pub async fn flush_index(&self) -> Result<CallToolResult, McpError> {
let flushed = flush_dirty_mounts(&self.mount_table, &self.db)
.map_err(|e| McpError::internal_error(format!("flush_index failed: {e}"), None))?;
let message = if flushed == 0 {
"No pending changes to flush.".to_string()
} else {
format!("Flushed {} project(s) to disk.", flushed)
};
Ok(CallToolResult::success(vec![Content::text(message)]))
}
}
#[prompt_router]
impl CodeIndexServer {
#[prompt(name = "explore-codebase")]
pub fn explore_codebase(&self) -> GetPromptResult {
GetPromptResult {
description: Some("Understand codebase structure and architecture".into()),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::text(
"Help me understand this codebase. Use `explore` for structure, `search` for entry points (main, cli, handler), `get_file_symbols` for key files, and `get_callees` to trace flows. Summarize the architecture.",
),
}],
}
}
#[prompt(name = "find-symbol")]
pub fn find_symbol(&self) -> GetPromptResult {
GetPromptResult {
description: Some("Find symbol definition and all callers".into()),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::text(
"Find where a symbol is defined and how it's used. Use `search` to locate it, `get_file_symbols` for context, and `get_callers` to find all references. Summarize its role in the codebase.",
),
}],
}
}
#[prompt(name = "trace-call-chain")]
pub fn trace_call_chain(&self) -> GetPromptResult {
GetPromptResult {
description: Some("Trace call chain from entry point".into()),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::text(
"Trace the execution flow from an entry point. Use `search` to find it, then `get_callees` recursively to follow the call chain. Map out the main execution path and key decision points.",
),
}],
}
}
#[prompt(name = "onboard")]
pub fn onboard(&self) -> GetPromptResult {
GetPromptResult {
description: Some("Quick onboarding for newcomers".into()),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::text(
"I'm new to this codebase. Use `explore` for overview, find main entry points, identify core types (class/struct/trait). Give me a summary: project purpose, key modules, and suggested files to read first.",
),
}],
}
}
}
#[tool_handler]
#[prompt_handler]
impl ServerHandler for CodeIndexServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
instructions: Some(
"Code index providing full-text search and structural navigation over indexed codebases.
**Tools:**
- `explore`: Project structure — metadata, subprojects, files grouped by directory.
- `search`: Unified FTS across symbols, files, and texts. BM25-ranked results.
- `get_file_symbols`: All symbols in a file, ordered by line number.
- `get_children`: Direct children of a symbol (e.g., methods of a class).
- `get_callers`: Find all places that call/reference a symbol.
- `get_callees`: Find all symbols that a function/method calls.
- `flush_index`: Persist pending changes to .codeindex/ files.
**Common parameters:**
- `limit` (default 100): Maximum results to return
- `offset` (default 0): Skip N results for pagination
- `context_lines` (recommended: 10): Lines of code context for type info and docs (0=metadata only, -1=all)
- `kind`: Filter by symbol kind (function, method, class, struct, interface, enum, constant, variable, property, module, import, impl)
- `project`: Filter by project path (relative from workspace root)
- `visibility`: Filter by max visibility (public < internal < private). Default: public"
.into(),
),
capabilities: ServerCapabilities::builder()
.enable_tools()
.enable_prompts()
.build(),
..Default::default()
}
}
}
pub fn extract_result_text(result: &CallToolResult) -> String {
use rmcp::model::RawContent;
result
.content
.iter()
.filter_map(|c| match &c.raw {
RawContent::Text(t) => Some(t.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
pub async fn start_server(
db: Arc<Mutex<SearchDb>>,
mount_table: Arc<Mutex<MountTable>>,
) -> Result<()> {
let server = CodeIndexServer::new(db, mount_table);
let service = server
.serve(stdio())
.await
.map_err(|e| anyhow::anyhow!("MCP serve error: {e}"))?;
service
.waiting()
.await
.map_err(|e| anyhow::anyhow!("MCP runtime error: {e}"))?;
Ok(())
}