mcpkill 0.1.0

Universal MCP proxy — semantic cache + chunking to kill token waste
Documentation
use anyhow::{Context, Result};
use rayon::prelude::*;
use rusqlite::{params, Connection};
use std::sync::{Arc, Mutex};

use crate::similarity::cosine;

/// A single cached chunk with its embedding.
pub struct Chunk {
    pub text: String,
    pub embedding: Vec<f32>,
}

/// Aggregate statistics returned by [`Cache::stats`].
#[derive(serde::Serialize)]
pub struct CacheStats {
    pub entries: i64,
    pub total_chunks: i64,
    pub cache_hits: i64,
    pub tokens_original: i64,
    pub tokens_returned: i64,
    pub db_bytes: u64,
}

impl CacheStats {
    /// Total tokens saved = (original - returned) × all calls (first + hits).
    pub fn tokens_saved(&self) -> i64 {
        // Each entry: 1 initial call + hit_count repeat calls.
        // Savings per entry = (original - returned) × (1 + hit_count)
        // We can't compute this precisely without per-row data, so we
        // approximate using the stored aggregates.
        self.tokens_original - self.tokens_returned
            + self.cache_hits * (self.tokens_original.max(1) / self.entries.max(1))
    }

    pub fn hit_rate(&self) -> f64 {
        let total = self.entries + self.cache_hits;
        if total == 0 {
            return 0.0;
        }
        self.cache_hits as f64 / total as f64 * 100.0
    }
}

/// SQLite-backed semantic cache.
///
/// Schema
/// ──────
/// queries : one row per unique MCP call — stores token counts + usage stats
/// chunks  : N rows per query — each chunk has its own embedding for re-ranking
pub struct Cache {
    conn: Arc<Mutex<Connection>>,
}

/// Current schema version. Bump this whenever the schema changes.
const SCHEMA_VERSION: u32 = 1;

impl Cache {
    pub fn new(path: &str) -> Result<Self> {
        let conn = Connection::open(path)
            .with_context(|| format!("Cannot open cache database at {path}"))?;

        conn.execute_batch(
            "PRAGMA journal_mode = WAL;
             PRAGMA synchronous  = NORMAL;
             PRAGMA foreign_keys = ON;",
        )?;

        migrate(&conn)?;

        Ok(Self {
            conn: Arc::new(Mutex::new(conn)),
        })
    }

    // ── Read ─────────────────────────────────────────────────────────────────

    /// Find the most similar cached query above `threshold`.
    /// Uses rayon to parallelise cosine similarity across all candidates.
    pub fn search(&self, query_emb: &[f32], threshold: f32) -> Result<Option<(i64, Vec<Chunk>)>> {
        let conn = self.conn.lock().unwrap();

        let candidates: Vec<(i64, Vec<f32>)> = {
            let mut stmt = conn.prepare("SELECT id, query_embedding FROM queries")?;
            let rows: Vec<(i64, Vec<u8>)> = stmt
                .query_map([], |row| {
                    Ok((row.get::<_, i64>(0)?, row.get::<_, Vec<u8>>(1)?))
                })?
                .filter_map(|r| r.ok())
                .collect();
            // stmt is dropped here; bytes_to_f32 operates on owned data.
            rows.into_iter()
                .map(|(id, blob)| (id, bytes_to_f32(&blob)))
                .collect()
        };

        // Parallel cosine similarity via rayon.
        let best: Option<(i64, f32)> = candidates
            .par_iter()
            .map(|(id, emb)| (*id, cosine(query_emb, emb)))
            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));

        let Some((query_id, score)) = best else {
            return Ok(None);
        };

        if score < threshold {
            return Ok(None);
        }

        let chunks = self.load_chunks(&conn, query_id)?;
        Ok(Some((query_id, chunks)))
    }

    // ── Write ────────────────────────────────────────────────────────────────

    /// Persist a new query with its chunks and token counts.
    pub fn store(
        &self,
        query_text: &str,
        query_emb: &[f32],
        chunks: &[(String, Vec<f32>)],
        original_tokens: usize,
        returned_tokens: usize,
    ) -> Result<()> {
        let conn = self.conn.lock().unwrap();
        let now = unix_now();

        conn.execute(
            "INSERT INTO queries
             (query_text, query_embedding, original_tokens, returned_tokens, created_at, last_used_at)
             VALUES (?1, ?2, ?3, ?4, ?5, ?5)",
            params![
                query_text,
                f32_to_bytes(query_emb),
                original_tokens as i64,
                returned_tokens as i64,
                now
            ],
        )?;

        let query_id = conn.last_insert_rowid();
        let mut stmt = conn.prepare(
            "INSERT INTO chunks (query_id, text, embedding, position) VALUES (?1, ?2, ?3, ?4)",
        )?;

        for (pos, (text, emb)) in chunks.iter().enumerate() {
            stmt.execute(params![query_id, text, f32_to_bytes(emb), pos as i64])?;
        }

        Ok(())
    }

    /// Increment hit counter and refresh last_used_at for a cache hit.
    pub fn record_hit(&self, query_id: i64) -> Result<()> {
        let conn = self.conn.lock().unwrap();
        conn.execute(
            "UPDATE queries SET hit_count = hit_count + 1, last_used_at = ?1 WHERE id = ?2",
            params![unix_now(), query_id],
        )?;
        Ok(())
    }

    // ── Maintenance ──────────────────────────────────────────────────────────

    /// Remove entries not used within `days` days. Returns count removed.
    pub fn evict_expired(&self, days: u64) -> Result<usize> {
        let conn = self.conn.lock().unwrap();
        let cutoff = unix_now() - (days * 86_400) as i64;
        let n = conn.execute(
            "DELETE FROM queries WHERE last_used_at < ?1",
            params![cutoff],
        )?;
        Ok(n)
    }

    /// Remove entries not used in `days` days (user-facing alias).
    pub fn clear_older_than(&self, days: u64) -> Result<usize> {
        self.evict_expired(days)
    }

    /// Wipe the entire cache.
    pub fn clear_all(&self) -> Result<usize> {
        let conn = self.conn.lock().unwrap();
        let n = conn.execute("DELETE FROM queries", [])?;
        Ok(n)
    }

    /// Evict LRU entries until DB file is below `max_bytes`.
    pub fn evict_lru_if_needed(&self, path: &str, max_bytes: u64) -> Result<()> {
        let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
        if size <= max_bytes {
            return Ok(());
        }

        let conn = self.conn.lock().unwrap();
        // Delete the least-recently-used 20% of entries.
        let total: i64 = conn.query_row("SELECT COUNT(*) FROM queries", [], |r| r.get(0))?;
        let to_delete = (total as f64 * 0.20).ceil() as i64;

        conn.execute(
            "DELETE FROM queries WHERE id IN (
                SELECT id FROM queries ORDER BY last_used_at ASC LIMIT ?1
            )",
            params![to_delete],
        )?;

        Ok(())
    }

    // ── Stats ────────────────────────────────────────────────────────────────

    pub fn stats(&self, path: &str) -> Result<CacheStats> {
        let conn = self.conn.lock().unwrap();

        let (entries, total_chunks, cache_hits, tokens_original, tokens_returned) = conn
            .query_row(
                "SELECT
                    (SELECT COUNT(*) FROM queries),
                    (SELECT COUNT(*) FROM chunks),
                    COALESCE((SELECT SUM(hit_count)       FROM queries), 0),
                    COALESCE((SELECT SUM(original_tokens) FROM queries), 0),
                    COALESCE((SELECT SUM(returned_tokens) FROM queries), 0)",
                [],
                |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?, r.get(4)?)),
            )?;

        let db_bytes = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);

        Ok(CacheStats {
            entries,
            total_chunks,
            cache_hits,
            tokens_original,
            tokens_returned,
            db_bytes,
        })
    }

    // ── Private ───────────────────────────────────────────────────────────────

    fn load_chunks(&self, conn: &Connection, query_id: i64) -> Result<Vec<Chunk>> {
        let mut stmt = conn
            .prepare("SELECT text, embedding FROM chunks WHERE query_id = ?1 ORDER BY position")?;

        let chunks = stmt
            .query_map(params![query_id], |row| {
                Ok((row.get::<_, String>(0)?, row.get::<_, Vec<u8>>(1)?))
            })?
            .filter_map(|r| r.ok())
            .map(|(text, blob)| Chunk {
                text,
                embedding: bytes_to_f32(&blob),
            })
            .collect();

        Ok(chunks)
    }
}

// ── Migration ────────────────────────────────────────────────────────────────

/// Apply schema migrations using SQLite's `user_version` pragma as a version
/// counter. Each migration is applied exactly once in order; new migrations
/// are appended to the `MIGRATIONS` slice.
fn migrate(conn: &Connection) -> Result<()> {
    // Each entry: (target_version, SQL to apply).
    const MIGRATIONS: &[(u32, &str)] = &[(
        1,
        "CREATE TABLE IF NOT EXISTS queries (
             id               INTEGER PRIMARY KEY AUTOINCREMENT,
             query_text       TEXT    NOT NULL,
             query_embedding  BLOB    NOT NULL,
             original_tokens  INTEGER NOT NULL DEFAULT 0,
             returned_tokens  INTEGER NOT NULL DEFAULT 0,
             hit_count        INTEGER NOT NULL DEFAULT 0,
             created_at       INTEGER NOT NULL,
             last_used_at     INTEGER NOT NULL
         );
         CREATE TABLE IF NOT EXISTS chunks (
             id         INTEGER PRIMARY KEY AUTOINCREMENT,
             query_id   INTEGER NOT NULL REFERENCES queries(id) ON DELETE CASCADE,
             text       TEXT    NOT NULL,
             embedding  BLOB    NOT NULL,
             position   INTEGER NOT NULL
         );
         CREATE INDEX IF NOT EXISTS idx_chunks_query    ON chunks(query_id);
         CREATE INDEX IF NOT EXISTS idx_queries_last    ON queries(last_used_at);
         CREATE INDEX IF NOT EXISTS idx_queries_created ON queries(created_at);",
    )];

    let current: u32 = conn.query_row("PRAGMA user_version", [], |r| r.get(0))?;

    for &(version, sql) in MIGRATIONS {
        if current < version {
            conn.execute_batch(sql)?;
            conn.execute_batch(&format!("PRAGMA user_version = {version}"))?;
        }
    }

    debug_assert_eq!(
        {
            let v: u32 = conn.query_row("PRAGMA user_version", [], |r| r.get(0))?;
            v
        },
        SCHEMA_VERSION,
        "schema version mismatch after migration"
    );

    Ok(())
}

// ── Helpers ──────────────────────────────────────────────────────────────────

fn unix_now() -> i64 {
    std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .unwrap()
        .as_secs() as i64
}

fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
    v.iter().flat_map(|f| f.to_le_bytes()).collect()
}

fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
    bytes
        .chunks_exact(4)
        .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
        .collect()
}