use anyhow::{Context, Result, bail};
use std::path::Path;
use crate::domain::MemoryRecord;
const DEFAULT_DIM: usize = 384;
pub struct EmbeddingIndex {
entries: Vec<(String, Vec<f32>)>,
dim: usize,
}
impl EmbeddingIndex {
pub fn new(dim: usize) -> Self {
Self {
entries: Vec::new(),
dim,
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn load(path: &Path) -> Result<Self> {
let data = std::fs::read(path)
.with_context(|| format!("open embedding index: {}", path.display()))?;
if data.len() < 8 {
bail!("embedding index too small");
}
let entry_count = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
let dim = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
if dim == 0 {
bail!("embedding index has zero dimension");
}
let mut entries = Vec::with_capacity(entry_count);
let mut offset = 8;
for _ in 0..entry_count {
if offset + 4 > data.len() {
break;
}
let id_len = u32::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]) as usize;
offset += 4;
if offset + id_len > data.len() {
break;
}
let record_id = String::from_utf8_lossy(&data[offset..offset + id_len]).to_string();
offset += id_len;
let vec_bytes = dim * 4;
if offset + vec_bytes > data.len() {
break;
}
let mut embedding = vec![0f32; dim];
for i in 0..dim {
let b = offset + i * 4;
embedding[i] = f32::from_le_bytes([data[b], data[b + 1], data[b + 2], data[b + 3]]);
}
offset += vec_bytes;
entries.push((record_id, embedding));
}
Ok(Self { entries, dim })
}
pub fn save(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut buf = Vec::new();
buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
buf.extend_from_slice(&(self.dim as u32).to_le_bytes());
for (id, emb) in &self.entries {
let id_bytes = id.as_bytes();
buf.extend_from_slice(&(id_bytes.len() as u32).to_le_bytes());
buf.extend_from_slice(id_bytes);
for &val in emb {
buf.extend_from_slice(&val.to_le_bytes());
}
}
std::fs::write(path, &buf)
.with_context(|| format!("write embedding index: {}", path.display()))?;
Ok(())
}
pub fn add(&mut self, record_id: &str, embedding: Vec<f32>) {
if embedding.len() == self.dim {
self.entries.retain(|(id, _)| id != record_id);
self.entries.push((record_id.to_string(), embedding));
}
}
pub fn search(&self, query_embedding: &[f32], limit: usize) -> Vec<(String, f32)> {
if query_embedding.len() != self.dim || self.entries.is_empty() {
return Vec::new();
}
let mut scores: Vec<(String, f32)> = self
.entries
.iter()
.map(|(id, emb)| (id.clone(), cosine_similarity(query_embedding, emb)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(limit);
scores
}
pub fn build_from_records_with_model(
records: &[(String, &MemoryRecord)],
model: &fastembed::TextEmbedding,
) -> Result<Self> {
let texts: Vec<String> = records
.iter()
.map(|(_, r)| format!("{}: {}. {}", r.memory_type, r.title, r.summary))
.collect();
let embeddings = model.embed(texts, None)?;
let dim = embeddings.first().map(|e| e.len()).unwrap_or(DEFAULT_DIM);
let mut index = Self::new(dim);
for (i, (id, _)) in records.iter().enumerate() {
if let Some(emb) = embeddings.get(i) {
index.entries.push((id.clone(), emb.clone()));
}
}
Ok(index)
}
pub fn embed_query(model: &fastembed::TextEmbedding, query: &str) -> Result<Vec<f32>> {
let results = model.embed(vec![query.to_string()], None)?;
results.into_iter().next().context("no embedding returned")
}
}
pub fn try_append_record(
config: &crate::config::EmbeddingConfig,
record_id: &str,
record: &crate::domain::MemoryRecord,
) {
if !config.enabled || !config.auto_index {
return;
}
let index_path = config.resolved_index_path();
if !index_path.exists() {
return;
}
let mut index = match EmbeddingIndex::load(&index_path) {
Ok(idx) => idx,
Err(_) => return,
};
let Some(model) = cached_model_for(config.model_id.as_deref()) else {
return;
};
let text = format!(
"{}: {}. {}",
record.memory_type, record.title, record.summary
);
let emb = match model.embed(vec![text], None) {
Ok(mut v) if !v.is_empty() => v.remove(0),
_ => return,
};
index.add(record_id, emb);
let _ = index.save(&index_path);
}
pub fn cached_model() -> Option<&'static fastembed::TextEmbedding> {
cached_model_for(None)
}
pub fn cached_model_for(model_id: Option<&str>) -> Option<&'static fastembed::TextEmbedding> {
use std::sync::OnceLock;
static MODEL: OnceLock<Option<fastembed::TextEmbedding>> = OnceLock::new();
MODEL
.get_or_init(|| {
let variant = resolve_model_variant(model_id);
fastembed::TextEmbedding::try_new(
fastembed::InitOptions::new(variant).with_show_download_progress(false),
)
.ok()
})
.as_ref()
}
pub fn resolve_model_variant(model_id: Option<&str>) -> fastembed::EmbeddingModel {
match model_id {
Some("bge-small-zh-v1.5" | "bge-small-zh") => fastembed::EmbeddingModel::BGESmallZHV15,
Some("bge-large-zh-v1.5" | "bge-large-zh") => fastembed::EmbeddingModel::BGELargeZHV15,
Some("all-MiniLM-L6-v2" | "minilm") => fastembed::EmbeddingModel::AllMiniLML6V2,
Some("nomic-embed-text-v1.5" | "nomic") => fastembed::EmbeddingModel::NomicEmbedTextV15,
Some("multilingual-e5-small" | "e5-small") => {
fastembed::EmbeddingModel::MultilingualE5Small
}
Some("multilingual-e5-large" | "e5-large") => {
fastembed::EmbeddingModel::MultilingualE5Large
}
Some("bge-small-en-v1.5" | "bge-small-en") => fastembed::EmbeddingModel::BGESmallENV15,
_ => fastembed::EmbeddingModel::BGESmallZHV15,
}
}
pub fn model_dimensions(model_id: Option<&str>) -> usize {
match model_id {
Some("bge-small-zh-v1.5" | "bge-small-zh") => 512,
Some("bge-large-zh-v1.5" | "bge-large-zh") => 1024,
Some("all-MiniLM-L6-v2" | "minilm") => 384,
Some("nomic-embed-text-v1.5" | "nomic") => 768,
Some("multilingual-e5-small" | "e5-small") => 384,
Some("multilingual-e5-large" | "e5-large") => 1024,
Some("bge-small-en-v1.5" | "bge-small-en") => 384,
_ => 512,
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < 1e-10 { 0.0 } else { dot / denom }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_similarity_identical_vectors() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-5);
}
#[test]
fn cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-5);
}
#[test]
fn index_save_load_roundtrip() {
let mut index = EmbeddingIndex::new(3);
index.add("rec-1", vec![0.1, 0.2, 0.3]);
index.add("rec-2", vec![0.4, 0.5, 0.6]);
let dir = std::env::temp_dir().join("spool-emb-test");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test-index.bin");
index.save(&path).unwrap();
let loaded = EmbeddingIndex::load(&path).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded.dim(), 3);
let results = loaded.search(&[0.1, 0.2, 0.3], 2);
assert_eq!(results[0].0, "rec-1");
assert!(results[0].1 > 0.99);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn search_returns_most_similar() {
let mut index = EmbeddingIndex::new(3);
index.add("database", vec![0.9, 0.1, 0.0]);
index.add("frontend", vec![0.0, 0.1, 0.9]);
index.add("db-related", vec![0.8, 0.2, 0.1]);
let results = index.search(&[1.0, 0.0, 0.0], 3);
assert_eq!(results[0].0, "database");
assert_eq!(results[1].0, "db-related");
}
}