mcpkill 0.1.0

Universal MCP proxy — semantic cache + chunking to kill token waste
Documentation
use anyhow::Result;
use serde_json::Value;

use crate::cache::{Cache, Chunk};
use crate::chunker::chunk;
use crate::embedder::Embedder;
use crate::similarity::cosine;
use crate::token;

/// Orchestrates the full pipeline:
///   1. Extract text content from a MCP `tools/call` response
///   2. Embed the original query string
///   3. Cache lookup (cosine similarity)
///   4. Hit  → re-rank cached chunks, record hit, return top-K
///   5. Miss → smart-chunk + embed + store, return top-K
///
/// When `dry_run` is enabled the original response is returned unchanged;
/// filtering decisions are only printed to stderr.
pub struct Filter {
    cache: Cache,
    embedder: Embedder,
    max_chunks: usize,
    threshold: f32,
    dry_run: bool,
    verbose: bool,
    db_path: String,
    max_db_bytes: u64,
}

impl Filter {
    pub fn new(
        db_path: &str,
        max_chunks: usize,
        threshold: f32,
        dry_run: bool,
        verbose: bool,
        max_db_mb: u64,
    ) -> Result<Self> {
        Ok(Self {
            cache: Cache::new(db_path)?,
            embedder: Embedder::new()?,
            max_chunks,
            threshold,
            dry_run,
            verbose,
            db_path: db_path.to_string(),
            max_db_bytes: max_db_mb * 1024 * 1024,
        })
    }

    /// Process one `tools/call` JSON-RPC response.
    /// Returns the (possibly filtered) response.
    pub fn process(&self, response: &Value, query: &str) -> Result<Value> {
        let Some(content) = response
            .get("result")
            .and_then(|r| r.get("content"))
            .and_then(|c| c.as_array())
        else {
            return Ok(response.clone());
        };

        // Split text parts from everything else (images, resources, …).
        let (text_parts, other_parts): (Vec<_>, Vec<_>) = content
            .iter()
            .partition(|p| p.get("type").and_then(|t| t.as_str()) == Some("text"));

        if text_parts.is_empty() {
            return Ok(response.clone());
        }

        let full_text: String = text_parts
            .iter()
            .filter_map(|p| p.get("text").and_then(|t| t.as_str()))
            .collect::<Vec<_>>()
            .join("\n\n");

        let original_tokens = token::estimate(&full_text);
        let query_emb = self.embedder.embed(query)?;

        let top_texts = match self.cache.search(&query_emb, self.threshold)? {
            Some((query_id, cached)) => {
                self.log(format!(
                    "CACHE HIT  [{} chunks] original={original_tokens}t",
                    cached.len()
                ));
                let _ = self.cache.record_hit(query_id);
                self.top_ranked(&cached, &query_emb)
            }
            None => {
                self.log(format!(
                    "CACHE MISS [{original_tokens}t] — chunking {} bytes",
                    full_text.len()
                ));
                let chunks = chunk(&full_text);
                let texts: Vec<&str> = chunks.iter().map(String::as_str).collect();
                let embeddings = self.embedder.embed_batch(&texts)?;

                let chunk_data: Vec<(String, Vec<f32>)> =
                    chunks.into_iter().zip(embeddings).collect();

                let top = self.top_from_data(&chunk_data, &query_emb);
                let returned_tokens = token::estimate(&top.join(""));

                self.cache.store(
                    query,
                    &query_emb,
                    &chunk_data,
                    original_tokens,
                    returned_tokens,
                )?;

                let _ = self
                    .cache
                    .evict_lru_if_needed(&self.db_path, self.max_db_bytes);

                top
            }
        };

        self.log(format!(
            "→ returning {}/{} chunks  (~{}t)",
            top_texts.len(),
            self.max_chunks,
            token::estimate(&top_texts.join(""))
        ));

        if self.dry_run {
            self.log("DRY-RUN — returning original response unchanged");
            return Ok(response.clone());
        }

        Ok(self.rebuild_response(response, &top_texts, other_parts))
    }

    // ── Private ───────────────────────────────────────────────────────────────

    fn top_ranked(&self, chunks: &[Chunk], query_emb: &[f32]) -> Vec<String> {
        let pairs = chunks
            .iter()
            .map(|c| (c.text.as_str(), c.embedding.as_slice()));
        self.top_k(pairs, query_emb)
    }

    fn top_from_data(&self, chunks: &[(String, Vec<f32>)], query_emb: &[f32]) -> Vec<String> {
        let pairs = chunks.iter().map(|(t, e)| (t.as_str(), e.as_slice()));
        self.top_k(pairs, query_emb)
    }

    /// Score all items by cosine similarity, sort descending, return top-K texts.
    fn top_k<'a>(
        &self,
        items: impl Iterator<Item = (&'a str, &'a [f32])>,
        query_emb: &[f32],
    ) -> Vec<String> {
        let mut scored: Vec<(f32, &str)> = items
            .map(|(text, emb)| (cosine(query_emb, emb), text))
            .collect();
        scored.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
        scored
            .into_iter()
            .take(self.max_chunks)
            .map(|(_, t)| t.to_string())
            .collect()
    }

    fn rebuild_response(
        &self,
        original: &Value,
        text_chunks: &[String],
        other_parts: Vec<&Value>,
    ) -> Value {
        let mut content: Vec<Value> = other_parts.iter().map(|p| (*p).clone()).collect();
        content.push(serde_json::json!({
            "type": "text",
            "text": text_chunks.join("\n\n---\n\n")
        }));
        let mut response = original.clone();
        response["result"]["content"] = Value::Array(content);
        response
    }

    fn log(&self, msg: impl std::fmt::Display) {
        if self.verbose || self.dry_run {
            eprintln!("[mcpkill] {msg}");
        }
    }
}