use std::sync::atomic::Ordering;
use rmcp::ErrorData as McpError;
use rmcp::model::CallToolResult;
use super::ServerState;
use super::helpers::{SEARCH_LIMIT_DEFAULT, SEARCH_LIMIT_MAX, json_result};
use super::types::{GrepHit, WorkspaceGrepParams, WorkspaceGrepResponse};
pub(super) fn run_workspace_grep(
state: &ServerState,
params: WorkspaceGrepParams,
) -> Result<CallToolResult, McpError> {
let limit = params
.limit
.unwrap_or(SEARCH_LIMIT_DEFAULT)
.min(SEARCH_LIMIT_MAX) as usize;
let scan_cap = limit.saturating_mul(8).max(2_000);
let generation = state.cache_generation.load(Ordering::Relaxed);
let skip: usize = match params.cursor.as_ref() {
Some(c) => {
let (offset, snapshot_id) = c.decode_in_memory()?;
if snapshot_id != generation {
return json_result(&WorkspaceGrepResponse {
pattern: params.pattern,
total_files_matched: 0,
total_matches: 0,
truncated: false,
hits: Vec::new(),
next_cursor: None,
cursor_invalidated: true,
});
}
offset as usize
}
None => 0,
};
let re = regex::Regex::new(¶ms.pattern)
.map_err(|e| McpError::invalid_params(format!("invalid regex: {e}"), None))?;
let path_finder = params
.path_contains
.as_deref()
.map(|n| memchr::memmem::Finder::new(n.as_bytes()));
let lang_filter = params.language.as_deref();
let cache = state.cache.load_full();
let mut hits: Vec<GrepHit> = Vec::with_capacity(limit.min(64));
let mut total_matches: u32 = 0;
let mut total_files_matched: usize = 0;
let mut truncated = false;
let mut files_visited: usize = 0;
let mut files_seen: usize = 0;
'files: for (path, entry) in &cache.by_path {
if files_visited >= scan_cap {
truncated = true;
break;
}
let path_ok = path_finder
.as_ref()
.is_none_or(|f| f.find(path.as_bytes()).is_some());
if !path_ok {
continue;
}
let lang_ok = lang_filter.is_none_or(|l| entry.language == l);
if !lang_ok {
continue;
}
if files_seen < skip {
files_seen += 1;
continue;
}
files_seen += 1;
files_visited += 1;
let abs = state.root.join(path.to_path_buf());
let source = match std::fs::read_to_string(&abs) {
Ok(s) => s,
Err(e) => {
tracing::debug!(path = %abs.display(), error = %e, "workspace_grep: skipping unreadable file");
continue;
}
};
let line_starts: Vec<usize> = std::iter::once(0)
.chain(memchr::memchr_iter(b'\n', source.as_bytes()).map(|pos| pos + 1))
.collect();
let mut file_had_match = false;
for mat in re.find_iter(&source) {
total_matches = total_matches.saturating_add(1);
file_had_match = true;
if hits.len() >= limit {
truncated = true;
break 'files;
}
let match_start = mat.start();
let line_idx = line_starts
.partition_point(|&ls| ls <= match_start)
.saturating_sub(1);
let line_start_byte = line_starts[line_idx];
let line_num = (line_idx as u32) + 1; let column = (match_start - line_start_byte) as u32;
let context_before = if params.include_context && line_idx > 0 {
Some(extract_line(&source, &line_starts, line_idx - 1))
} else {
None
};
let context_after = if params.include_context && line_idx + 1 < line_starts.len() {
Some(extract_line(&source, &line_starts, line_idx + 1))
} else {
None
};
hits.push(GrepHit {
path: path.clone(),
line_num,
column,
matched_text: mat.as_str().to_owned(),
context_before,
context_after,
});
}
if file_had_match {
total_files_matched += 1;
}
}
let next_cursor = if truncated {
Some(super::cursor::Cursor::encode_in_memory(
files_seen as u64,
generation,
))
} else {
None
};
json_result(&WorkspaceGrepResponse {
pattern: params.pattern,
total_files_matched,
total_matches,
truncated,
hits,
next_cursor,
cursor_invalidated: false,
})
}
fn extract_line(source: &str, line_starts: &[usize], line_idx: usize) -> String {
let start = line_starts[line_idx];
let end = line_starts
.get(line_idx + 1)
.copied()
.unwrap_or(source.len());
let raw = &source[start..end];
raw.trim_end_matches('\n').trim_end_matches('\r').to_owned()
}