use std::sync::OnceLock;
use regex::Regex;
use rmcp::ErrorData as McpError;
use super::ServerState;
use super::helpers::{json_result, kind_to_str, parse_kind};
use super::tokens;
use super::types_compress::{CompressParams, CompressResponse, ExpandParams, ExpandResponse};
use crate::query;
const EXPAND_BODY_CAP: usize = 128 * 1024;
const TOKENS_NOTE_REAL: &str =
"tokens counted with the o200k (gpt-4o) tokenizer; offline runs fall back to a word estimate";
const TOKENS_NOTE_HEURISTIC: &str =
"estimate (bytes/4); build with --features documents for real token counts";
fn tokens_note() -> String {
if tokens::TOKENS_ARE_COUNTED {
TOKENS_NOTE_REAL.to_string()
} else {
TOKENS_NOTE_HEURISTIC.to_string()
}
}
static RE_SPACES: OnceLock<Regex> = OnceLock::new();
static RE_BLANK_LINES: OnceLock<Regex> = OnceLock::new();
static RE_FILLERS: OnceLock<Regex> = OnceLock::new();
fn spaces_re() -> &'static Regex {
RE_SPACES.get_or_init(|| Regex::new(r"[ \t]{2,}").expect("compile RE_SPACES"))
}
fn blank_lines_re() -> &'static Regex {
RE_BLANK_LINES.get_or_init(|| Regex::new(r"\n{3,}").expect("compile RE_BLANK_LINES"))
}
fn fillers_re() -> &'static Regex {
RE_FILLERS.get_or_init(|| {
Regex::new(
r"(?i)\b(it is worth noting that|it should be noted that|it is important to note that|please note that|as you can see|as mentioned (?:above|earlier|before|previously)|in other words|to be honest|needless to say|for what it's worth|at the end of the day|as a matter of fact|the fact of the matter is|all things considered)\b[,.]?[ ]?"
)
.expect("compile RE_FILLERS")
})
}
fn lexical_pass(text: &str) -> String {
let text = spaces_re().replace_all(text, " ");
let text = blank_lines_re().replace_all(&text, "\n\n");
let text = fillers_re().replace_all(&text, "");
let mut seen: ahash::AHashSet<String> = ahash::AHashSet::new();
let mut out_paras: Vec<&str> = Vec::new();
for para in text.split("\n\n") {
let key = para.trim().to_string();
if key.is_empty() || seen.insert(key) {
out_paras.push(para);
}
}
out_paras.join("\n\n")
}
pub(super) async fn run_expand(
state: &ServerState,
params: ExpandParams,
) -> Result<rmcp::model::CallToolResult, McpError> {
let kind_filter = params
.kind
.as_deref()
.map(parse_kind)
.transpose()
.map_err(|e| {
McpError::invalid_params(format!("expand: invalid kind {:?}: {e}", params.kind), None)
})?;
let l1 = {
let store = state.store.read().await;
query::file_outline(&store, ¶ms.path).map_err(|e| {
McpError::invalid_params(format!("expand: file_outline({}): {e}", params.path), None)
})?
};
let candidates: Vec<&crate::extract::Symbol> = l1
.symbols
.iter()
.filter(|s| s.name == params.name)
.filter(|s| kind_filter.is_none_or(|k| s.kind == k))
.collect();
let symbol = match candidates.len() {
0 => {
let all_names: Vec<String> = l1
.symbols
.iter()
.filter(|s| kind_filter.is_none_or(|k| s.kind == k))
.map(|s| format!("[{}] {}", kind_to_str(s.kind), s.name))
.collect();
return Err(McpError::invalid_params(
format!(
"expand: symbol {:?} not found in {} (available: {})",
params.name,
params.path,
all_names.join(", ")
),
None,
));
}
1 => candidates[0],
_ => {
let matches: Vec<String> = candidates
.iter()
.map(|s| format!("[{}] {}", kind_to_str(s.kind), s.name))
.collect();
return Err(McpError::invalid_params(
format!(
"expand: {:?} matches {} symbols in {}; supply `kind` to disambiguate: {}",
params.name,
candidates.len(),
params.path,
matches.join(", ")
),
None,
));
}
};
let abs = state.root.join(params.path.to_path_buf());
let file_bytes = std::fs::read(&abs).map_err(|e| {
McpError::invalid_params(format!("expand: read {}: {e}", params.path), None)
})?;
let start = symbol.start_byte as usize;
let end = (symbol.end_byte as usize).min(file_bytes.len());
let raw = file_bytes.get(start..end).unwrap_or(&[]);
let end_row = {
let slice_for_count = file_bytes.get(..end).unwrap_or(&[]);
slice_for_count.iter().filter(|&&b| b == b'\n').count() as u32
};
let full_bytes = raw.len();
let (body_bytes, truncated) = if full_bytes > EXPAND_BODY_CAP {
(&raw[..EXPAND_BODY_CAP], true)
} else {
(raw, false)
};
let body = String::from_utf8_lossy(body_bytes).into_owned();
let response = ExpandResponse {
path: params.path.to_string(),
name: symbol.name.clone(),
kind: kind_to_str(symbol.kind).to_string(),
start_row: symbol.start_row + 1,
end_row: end_row + 1,
body,
bytes: full_bytes,
truncated,
};
json_result(&response)
}
pub(super) async fn run_compress(
state: &ServerState,
params: CompressParams,
) -> Result<rmcp::model::CallToolResult, McpError> {
match (¶ms.text, ¶ms.path) {
(Some(_), Some(_)) => {
return Err(McpError::invalid_params(
"supply exactly one of `text` or `path`, not both",
None,
));
}
(None, None) => {
return Err(McpError::invalid_params(
"supply exactly one of `text` or `path`",
None,
));
}
_ => {}
}
if let Some(path) = ¶ms.path {
run_structural(state, path).await
} else {
let text = params.text.as_deref().unwrap_or("");
run_prose(text, ¶ms)
}
}
async fn run_structural(
state: &ServerState,
path: &crate::path::RelPath,
) -> Result<rmcp::model::CallToolResult, McpError> {
let store = state.store.read().await;
let l1 = query::file_outline(&store, path).map_err(|e| {
McpError::invalid_params(format!("compress: file_outline({path}): {e}"), None)
})?;
let original_bytes = l1.size_bytes as usize;
let abs = state.root.join(path.to_path_buf());
let original_tokens = match std::fs::read(&abs) {
Ok(source) => tokens::count_tokens(&String::from_utf8_lossy(&source)),
Err(_) => (l1.size_bytes) / 4,
};
let mut lines: Vec<String> = Vec::new();
if !l1.imports.is_empty() {
lines.push("// imports".to_string());
for imp in &l1.imports {
lines.push(imp.raw.trim().to_string());
}
lines.push(String::new());
}
if !l1.symbols.is_empty() {
lines.push("// symbols".to_string());
for sym in &l1.symbols {
let kind = kind_to_str(sym.kind);
if let Some(sig) = &sym.signature {
lines.push(format!("// [{kind}] {}", sym.name));
lines.push(sig.trim().to_string());
} else {
lines.push(format!("// [{kind}] {}", sym.name));
}
}
}
let output = lines.join("\n");
let compressed_bytes = output.len();
let compressed_tokens = tokens::count_tokens(&output);
let ratio = if original_bytes == 0 {
1.0_f32
} else {
compressed_bytes as f32 / original_bytes as f32
};
let response = CompressResponse {
original_bytes,
original_tokens,
compressed_bytes,
compressed_tokens,
tokens_reduced: original_tokens.saturating_sub(compressed_tokens),
tokens_counted: tokens::TOKENS_ARE_COUNTED,
ratio,
strategy: "structural".to_string(),
output,
tokens_note: tokens_note(),
};
json_result(&response)
}
fn run_prose(
text: &str,
_params: &CompressParams,
) -> Result<rmcp::model::CallToolResult, McpError> {
let original_bytes = text.len();
let original_tokens = tokens::count_tokens(text);
let output = lexical_pass(text);
let compressed_bytes = output.len();
let compressed_tokens = tokens::count_tokens(&output);
let ratio = if original_bytes == 0 {
1.0_f32
} else {
compressed_bytes as f32 / original_bytes as f32
};
let response = CompressResponse {
original_bytes,
original_tokens,
compressed_bytes,
compressed_tokens,
tokens_reduced: original_tokens.saturating_sub(compressed_tokens),
tokens_counted: tokens::TOKENS_ARE_COUNTED,
ratio,
strategy: "lexical".to_string(),
output,
tokens_note: tokens_note(),
};
json_result(&response)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lexical_pass_collapses_whitespace() {
let input = "hello world\n\n\n\nextra blank lines";
let out = lexical_pass(input);
assert!(!out.contains(" "), "triple space should be collapsed");
assert!(
!out.contains("\n\n\n"),
"triple newline should be collapsed"
);
}
#[test]
fn lexical_pass_strips_fillers() {
let input = "It is worth noting that this is important. The code runs fast.";
let out = lexical_pass(input);
assert!(
!out.to_lowercase().contains("it is worth noting that"),
"filler phrase should be removed: {out:?}"
);
assert!(
out.contains("The code runs fast"),
"non-filler content must survive: {out:?}"
);
}
#[test]
fn lexical_pass_deduplicates_paragraphs() {
let repeated = "Hello world.\n\nHello world.\n\nDifferent paragraph.";
let out = lexical_pass(repeated);
let count = out.matches("Hello world.").count();
assert_eq!(
count, 1,
"duplicate paragraph must appear only once: {out:?}"
);
assert!(
out.contains("Different paragraph"),
"unique paragraph must survive: {out:?}"
);
}
#[test]
fn compress_response_reports_consistent_reduced_tokens() {
let original = "word ".repeat(40); let compressed = "word"; let original_tokens = tokens::count_tokens(&original);
let compressed_tokens = tokens::count_tokens(compressed);
let response = CompressResponse {
original_bytes: original.len(),
original_tokens,
compressed_bytes: compressed.len(),
compressed_tokens,
tokens_reduced: original_tokens.saturating_sub(compressed_tokens),
tokens_counted: tokens::TOKENS_ARE_COUNTED,
ratio: compressed.len() as f32 / original.len() as f32,
strategy: "lexical".to_string(),
output: compressed.to_string(),
tokens_note: tokens_note(),
};
assert_eq!(
response.tokens_reduced,
response.original_tokens - response.compressed_tokens,
"tokens_reduced must equal original - compressed"
);
assert!(response.original_tokens >= response.compressed_tokens);
}
#[cfg(not(feature = "documents"))]
#[test]
fn heuristic_count_is_bytes_over_four() {
let text = "a".repeat(400);
assert_eq!(tokens::count_tokens(&text), 100);
assert!(tokens_note().contains("bytes/4"));
}
#[cfg(feature = "documents")]
#[test]
fn real_count_note_names_tokenizer() {
assert!(tokens_note().contains("tokenizer"));
}
}