use anyhow::Result;
use serde_json::Value;
use crate::cache::{Cache, Chunk};
use crate::chunker::chunk;
use crate::embedder::Embedder;
use crate::similarity::cosine;
use crate::token;
pub struct Filter {
cache: Cache,
embedder: Embedder,
max_chunks: usize,
threshold: f32,
dry_run: bool,
verbose: bool,
db_path: String,
max_db_bytes: u64,
}
impl Filter {
pub fn new(
db_path: &str,
max_chunks: usize,
threshold: f32,
dry_run: bool,
verbose: bool,
max_db_mb: u64,
) -> Result<Self> {
Ok(Self {
cache: Cache::new(db_path)?,
embedder: Embedder::new()?,
max_chunks,
threshold,
dry_run,
verbose,
db_path: db_path.to_string(),
max_db_bytes: max_db_mb * 1024 * 1024,
})
}
pub fn process(&self, response: &Value, query: &str) -> Result<Value> {
let Some(content) = response
.get("result")
.and_then(|r| r.get("content"))
.and_then(|c| c.as_array())
else {
return Ok(response.clone());
};
let (text_parts, other_parts): (Vec<_>, Vec<_>) = content
.iter()
.partition(|p| p.get("type").and_then(|t| t.as_str()) == Some("text"));
if text_parts.is_empty() {
return Ok(response.clone());
}
let full_text: String = text_parts
.iter()
.filter_map(|p| p.get("text").and_then(|t| t.as_str()))
.collect::<Vec<_>>()
.join("\n\n");
let original_tokens = token::estimate(&full_text);
let query_emb = self.embedder.embed(query)?;
let top_texts = match self.cache.search(&query_emb, self.threshold)? {
Some((query_id, cached)) => {
self.log(format!(
"CACHE HIT [{} chunks] original={original_tokens}t",
cached.len()
));
let _ = self.cache.record_hit(query_id);
self.top_ranked(&cached, &query_emb)
}
None => {
self.log(format!(
"CACHE MISS [{original_tokens}t] — chunking {} bytes",
full_text.len()
));
let chunks = chunk(&full_text);
let texts: Vec<&str> = chunks.iter().map(String::as_str).collect();
let embeddings = self.embedder.embed_batch(&texts)?;
let chunk_data: Vec<(String, Vec<f32>)> =
chunks.into_iter().zip(embeddings).collect();
let top = self.top_from_data(&chunk_data, &query_emb);
let returned_tokens = token::estimate(&top.join(""));
self.cache.store(
query,
&query_emb,
&chunk_data,
original_tokens,
returned_tokens,
)?;
let _ = self
.cache
.evict_lru_if_needed(&self.db_path, self.max_db_bytes);
top
}
};
self.log(format!(
"→ returning {}/{} chunks (~{}t)",
top_texts.len(),
self.max_chunks,
token::estimate(&top_texts.join(""))
));
if self.dry_run {
self.log("DRY-RUN — returning original response unchanged");
return Ok(response.clone());
}
Ok(self.rebuild_response(response, &top_texts, other_parts))
}
fn top_ranked(&self, chunks: &[Chunk], query_emb: &[f32]) -> Vec<String> {
let pairs = chunks
.iter()
.map(|c| (c.text.as_str(), c.embedding.as_slice()));
self.top_k(pairs, query_emb)
}
fn top_from_data(&self, chunks: &[(String, Vec<f32>)], query_emb: &[f32]) -> Vec<String> {
let pairs = chunks.iter().map(|(t, e)| (t.as_str(), e.as_slice()));
self.top_k(pairs, query_emb)
}
fn top_k<'a>(
&self,
items: impl Iterator<Item = (&'a str, &'a [f32])>,
query_emb: &[f32],
) -> Vec<String> {
let mut scored: Vec<(f32, &str)> = items
.map(|(text, emb)| (cosine(query_emb, emb), text))
.collect();
scored.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.take(self.max_chunks)
.map(|(_, t)| t.to_string())
.collect()
}
fn rebuild_response(
&self,
original: &Value,
text_chunks: &[String],
other_parts: Vec<&Value>,
) -> Value {
let mut content: Vec<Value> = other_parts.iter().map(|p| (*p).clone()).collect();
content.push(serde_json::json!({
"type": "text",
"text": text_chunks.join("\n\n---\n\n")
}));
let mut response = original.clone();
response["result"]["content"] = Value::Array(content);
response
}
fn log(&self, msg: impl std::fmt::Display) {
if self.verbose || self.dry_run {
eprintln!("[mcpkill] {msg}");
}
}
}