spool-memory 0.2.3

Local-first developer memory system — persistent, structured knowledge for AI coding tools
Documentation
//! Local embedding-based semantic retrieval over lifecycle records.
//!
//! Feature-gated behind `embedding`. Uses fastembed-rs (ONNX Runtime) to run
//! a local embedding model (default: BAAI/bge-small-zh-v1.5).
//!
//! The index is a simple `Vec<(RecordId, Vec<f32>)>` with brute-force cosine
//! similarity — sufficient for up to ~50K records (< 10ms search).
//!
//! Storage: `<config_dir>/.spool/embedding-index.bin`

use anyhow::{Context, Result, bail};
use std::path::Path;

use crate::domain::MemoryRecord;

const DEFAULT_DIM: usize = 384;

pub struct EmbeddingIndex {
    entries: Vec<(String, Vec<f32>)>,
    dim: usize,
}

impl EmbeddingIndex {
    pub fn new(dim: usize) -> Self {
        Self {
            entries: Vec::new(),
            dim,
        }
    }

    pub fn len(&self) -> usize {
        self.entries.len()
    }

    pub fn dim(&self) -> usize {
        self.dim
    }

    pub fn load(path: &Path) -> Result<Self> {
        let data = std::fs::read(path)
            .with_context(|| format!("open embedding index: {}", path.display()))?;
        if data.len() < 8 {
            bail!("embedding index too small");
        }
        let entry_count = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
        let dim = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
        if dim == 0 {
            bail!("embedding index has zero dimension");
        }

        let mut entries = Vec::with_capacity(entry_count);
        let mut offset = 8;
        for _ in 0..entry_count {
            if offset + 4 > data.len() {
                break;
            }
            let id_len = u32::from_le_bytes([
                data[offset],
                data[offset + 1],
                data[offset + 2],
                data[offset + 3],
            ]) as usize;
            offset += 4;
            if offset + id_len > data.len() {
                break;
            }
            let record_id = String::from_utf8_lossy(&data[offset..offset + id_len]).to_string();
            offset += id_len;
            let vec_bytes = dim * 4;
            if offset + vec_bytes > data.len() {
                break;
            }
            let mut embedding = vec![0f32; dim];
            for i in 0..dim {
                let b = offset + i * 4;
                embedding[i] = f32::from_le_bytes([data[b], data[b + 1], data[b + 2], data[b + 3]]);
            }
            offset += vec_bytes;
            entries.push((record_id, embedding));
        }
        Ok(Self { entries, dim })
    }

    pub fn save(&self, path: &Path) -> Result<()> {
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)?;
        }
        let mut buf = Vec::new();
        buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
        buf.extend_from_slice(&(self.dim as u32).to_le_bytes());
        for (id, emb) in &self.entries {
            let id_bytes = id.as_bytes();
            buf.extend_from_slice(&(id_bytes.len() as u32).to_le_bytes());
            buf.extend_from_slice(id_bytes);
            for &val in emb {
                buf.extend_from_slice(&val.to_le_bytes());
            }
        }
        std::fs::write(path, &buf)
            .with_context(|| format!("write embedding index: {}", path.display()))?;
        Ok(())
    }

    pub fn add(&mut self, record_id: &str, embedding: Vec<f32>) {
        if embedding.len() == self.dim {
            self.entries.retain(|(id, _)| id != record_id);
            self.entries.push((record_id.to_string(), embedding));
        }
    }

    pub fn search(&self, query_embedding: &[f32], limit: usize) -> Vec<(String, f32)> {
        if query_embedding.len() != self.dim || self.entries.is_empty() {
            return Vec::new();
        }
        let mut scores: Vec<(String, f32)> = self
            .entries
            .iter()
            .map(|(id, emb)| (id.clone(), cosine_similarity(query_embedding, emb)))
            .collect();
        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
        scores.truncate(limit);
        scores
    }

    pub fn build_from_records_with_model(
        records: &[(String, &MemoryRecord)],
        model: &fastembed::TextEmbedding,
    ) -> Result<Self> {
        let texts: Vec<String> = records
            .iter()
            .map(|(_, r)| format!("{}: {}. {}", r.memory_type, r.title, r.summary))
            .collect();
        let embeddings = model.embed(texts, None)?;
        let dim = embeddings.first().map(|e| e.len()).unwrap_or(DEFAULT_DIM);
        let mut index = Self::new(dim);
        for (i, (id, _)) in records.iter().enumerate() {
            if let Some(emb) = embeddings.get(i) {
                index.entries.push((id.clone(), emb.clone()));
            }
        }
        Ok(index)
    }

    pub fn embed_query(model: &fastembed::TextEmbedding, query: &str) -> Result<Vec<f32>> {
        let results = model.embed(vec![query.to_string()], None)?;
        results.into_iter().next().context("no embedding returned")
    }
}

/// Best-effort append of a single record to the embedding index.
/// Uses a process-global cached model instance (loaded once, reused across calls).
/// Returns silently on any failure (model not installed, index missing, etc.).
pub fn try_append_record(
    config: &crate::config::EmbeddingConfig,
    record_id: &str,
    record: &crate::domain::MemoryRecord,
) {
    if !config.enabled || !config.auto_index {
        return;
    }
    let index_path = config.resolved_index_path();
    if !index_path.exists() {
        return;
    }
    let mut index = match EmbeddingIndex::load(&index_path) {
        Ok(idx) => idx,
        Err(_) => return,
    };
    let Some(model) = cached_model_for(config.model_id.as_deref()) else {
        return;
    };
    let text = format!(
        "{}: {}. {}",
        record.memory_type, record.title, record.summary
    );
    let emb = match model.embed(vec![text], None) {
        Ok(mut v) if !v.is_empty() => v.remove(0),
        _ => return,
    };
    index.add(record_id, emb);
    let _ = index.save(&index_path);
}

/// Process-global cached embedding model. Loaded once on first use.
/// Returns None if model initialization fails (model not downloaded, etc.).
pub fn cached_model() -> Option<&'static fastembed::TextEmbedding> {
    cached_model_for(None)
}

/// Load (or return cached) embedding model for the given model_id.
/// Uses process-global OnceLock — the first model_id wins for the process lifetime.
pub fn cached_model_for(model_id: Option<&str>) -> Option<&'static fastembed::TextEmbedding> {
    use std::sync::OnceLock;
    static MODEL: OnceLock<Option<fastembed::TextEmbedding>> = OnceLock::new();
    MODEL
        .get_or_init(|| {
            let variant = resolve_model_variant(model_id);
            fastembed::TextEmbedding::try_new(
                fastembed::InitOptions::new(variant).with_show_download_progress(false),
            )
            .ok()
        })
        .as_ref()
}

/// Map user-facing model_id string to fastembed enum variant.
pub fn resolve_model_variant(model_id: Option<&str>) -> fastembed::EmbeddingModel {
    match model_id {
        Some("bge-small-zh-v1.5" | "bge-small-zh") => fastembed::EmbeddingModel::BGESmallZHV15,
        Some("bge-large-zh-v1.5" | "bge-large-zh") => fastembed::EmbeddingModel::BGELargeZHV15,
        Some("all-MiniLM-L6-v2" | "minilm") => fastembed::EmbeddingModel::AllMiniLML6V2,
        Some("nomic-embed-text-v1.5" | "nomic") => fastembed::EmbeddingModel::NomicEmbedTextV15,
        Some("multilingual-e5-small" | "e5-small") => {
            fastembed::EmbeddingModel::MultilingualE5Small
        }
        Some("multilingual-e5-large" | "e5-large") => {
            fastembed::EmbeddingModel::MultilingualE5Large
        }
        Some("bge-small-en-v1.5" | "bge-small-en") => fastembed::EmbeddingModel::BGESmallENV15,
        _ => fastembed::EmbeddingModel::BGESmallZHV15,
    }
}

/// Known model dimensions for display/validation purposes.
pub fn model_dimensions(model_id: Option<&str>) -> usize {
    match model_id {
        Some("bge-small-zh-v1.5" | "bge-small-zh") => 512,
        Some("bge-large-zh-v1.5" | "bge-large-zh") => 1024,
        Some("all-MiniLM-L6-v2" | "minilm") => 384,
        Some("nomic-embed-text-v1.5" | "nomic") => 768,
        Some("multilingual-e5-small" | "e5-small") => 384,
        Some("multilingual-e5-large" | "e5-large") => 1024,
        Some("bge-small-en-v1.5" | "bge-small-en") => 384,
        _ => 512,
    }
}

fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    let mut dot = 0.0f32;
    let mut norm_a = 0.0f32;
    let mut norm_b = 0.0f32;
    for i in 0..a.len() {
        dot += a[i] * b[i];
        norm_a += a[i] * a[i];
        norm_b += b[i] * b[i];
    }
    let denom = norm_a.sqrt() * norm_b.sqrt();
    if denom < 1e-10 { 0.0 } else { dot / denom }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn cosine_similarity_identical_vectors() {
        let v = vec![1.0, 2.0, 3.0];
        let sim = cosine_similarity(&v, &v);
        assert!((sim - 1.0).abs() < 1e-5);
    }

    #[test]
    fn cosine_similarity_orthogonal_vectors() {
        let a = vec![1.0, 0.0, 0.0];
        let b = vec![0.0, 1.0, 0.0];
        let sim = cosine_similarity(&a, &b);
        assert!(sim.abs() < 1e-5);
    }

    #[test]
    fn index_save_load_roundtrip() {
        let mut index = EmbeddingIndex::new(3);
        index.add("rec-1", vec![0.1, 0.2, 0.3]);
        index.add("rec-2", vec![0.4, 0.5, 0.6]);

        let dir = std::env::temp_dir().join("spool-emb-test");
        std::fs::create_dir_all(&dir).unwrap();
        let path = dir.join("test-index.bin");
        index.save(&path).unwrap();

        let loaded = EmbeddingIndex::load(&path).unwrap();
        assert_eq!(loaded.len(), 2);
        assert_eq!(loaded.dim(), 3);

        let results = loaded.search(&[0.1, 0.2, 0.3], 2);
        assert_eq!(results[0].0, "rec-1");
        assert!(results[0].1 > 0.99);

        std::fs::remove_dir_all(&dir).ok();
    }

    #[test]
    fn search_returns_most_similar() {
        let mut index = EmbeddingIndex::new(3);
        index.add("database", vec![0.9, 0.1, 0.0]);
        index.add("frontend", vec![0.0, 0.1, 0.9]);
        index.add("db-related", vec![0.8, 0.2, 0.1]);

        let results = index.search(&[1.0, 0.0, 0.0], 3);
        assert_eq!(results[0].0, "database");
        assert_eq!(results[1].0, "db-related");
    }
}