use anyhow::{Context, Result};
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config};
use hf_hub::{Repo, RepoType, api::sync::Api};
use std::sync::{Arc, Mutex};
use tokenizers::Tokenizer;
use crate::config::EmbeddingModel;
const MINILM_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
#[allow(dead_code)]
const MINILM_DIM: usize = 384;
const MAX_SEQ_LEN: usize = 256;
const FALLBACK_MODEL_SUBDIR: &str =
".cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main";
const NOMIC_OLLAMA_MODEL: &str = "nomic-embed-text";
#[allow(dead_code)]
const NOMIC_DIM: usize = 768;
#[derive(Clone)]
pub enum Embedder {
Local {
model: Arc<Mutex<BertModel>>,
tokenizer: Arc<Tokenizer>,
device: Device,
},
Ollama {
client: Arc<crate::llm::OllamaClient>,
model_name: String,
},
}
impl Embedder {
#[allow(dead_code)]
pub fn new() -> Result<Self> {
Self::new_local()
}
pub fn new_local() -> Result<Self> {
let device = Device::Cpu;
let (config_path, tokenizer_path, weights_path) = match Self::download_via_hf_hub() {
Ok(paths) => paths,
Err(e) => {
eprintln!("ai-memory: hf-hub download failed ({e}), trying fallback dir");
Self::load_from_fallback()?
}
};
let config_data =
std::fs::read_to_string(&config_path).context("failed to read config.json")?;
let config: Config =
serde_json::from_str(&config_data).context("failed to parse config.json")?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
let truncation = tokenizers::TruncationParams {
max_length: MAX_SEQ_LEN,
..Default::default()
};
tokenizer
.with_truncation(Some(truncation))
.map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
tokenizer.with_padding(None);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
.context("failed to load model weights")?
};
let model = BertModel::load(vb, &config).context("failed to build BertModel")?;
Ok(Self::Local {
model: Arc::new(Mutex::new(model)),
tokenizer: Arc::new(tokenizer),
device,
})
}
pub fn new_ollama(client: Arc<crate::llm::OllamaClient>) -> Self {
Self::Ollama {
client,
model_name: NOMIC_OLLAMA_MODEL.to_string(),
}
}
pub fn for_model(
model: EmbeddingModel,
ollama_client: Option<Arc<crate::llm::OllamaClient>>,
) -> Result<Self> {
match model {
EmbeddingModel::MiniLmL6V2 => Self::new_local(),
EmbeddingModel::NomicEmbedV15 => {
let client = ollama_client.ok_or_else(|| {
anyhow::anyhow!("nomic-embed-text-v1.5 requires Ollama (smart tier or above)")
})?;
if let Err(e) = client.ensure_embed_model(NOMIC_OLLAMA_MODEL) {
eprintln!("ai-memory: warning: failed to pull nomic model: {e}");
}
Ok(Self::new_ollama(client))
}
}
}
#[allow(dead_code)]
pub fn dim(&self) -> usize {
match self {
Self::Local { .. } => MINILM_DIM,
Self::Ollama { .. } => NOMIC_DIM,
}
}
pub fn model_description(&self) -> &str {
match self {
Self::Local { .. } => "all-MiniLM-L6-v2 (384-dim, local)",
Self::Ollama { .. } => "nomic-embed-text-v1.5 (768-dim, Ollama)",
}
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
match self {
Self::Local {
model,
tokenizer,
device,
} => {
let model_guard = model
.lock()
.map_err(|e| anyhow::anyhow!("model lock poisoned: {e}"))?;
Self::embed_local(&model_guard, tokenizer, device, text)
}
Self::Ollama { client, model_name } => client.embed_text(text, model_name),
}
}
fn embed_local(
model: &BertModel,
tokenizer: &Tokenizer,
device: &Device,
text: &str,
) -> Result<Vec<f32>> {
let encoding = tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("tokenisation failed: {e}"))?;
let input_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let token_type_ids = encoding.get_type_ids();
let seq_len = input_ids.len();
let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
let attention_mask_tensor = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
let hidden = model
.forward(&input_ids, &token_type_ids, Some(&attention_mask_tensor))
.context("model forward pass failed")?;
let mask = attention_mask_tensor
.unsqueeze(2)?
.to_dtype(candle_core::DType::F32)?
.broadcast_as(hidden.shape())?;
let masked = hidden.mul(&mask)?;
let summed = masked.sum(1)?;
let count = mask.sum(1)?.clamp(1e-9, f64::MAX)?;
let pooled = summed.div(&count)?;
let norm = pooled
.sqr()?
.sum_keepdim(1)?
.sqrt()?
.clamp(1e-12, f64::MAX)?;
let normalised = pooled.broadcast_div(&norm)?;
let embedding: Vec<f32> = normalised.squeeze(0)?.to_vec1()?;
Ok(embedding)
}
#[allow(dead_code)]
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let denom = norm_a * norm_b;
if denom < 1e-12 { 0.0 } else { dot / denom }
}
#[must_use]
pub fn fuse(primary: &[f32], secondary: &[f32], primary_weight: f32) -> Vec<f32> {
if primary.len() != secondary.len() {
return primary.to_vec();
}
let w = primary_weight.clamp(0.0, 1.0);
let one_minus_w = 1.0 - w;
primary
.iter()
.zip(secondary.iter())
.map(|(p, s)| w * p + one_minus_w * s)
.collect()
}
fn download_via_hf_hub() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
{
let api = Api::new().context("failed to initialise HuggingFace Hub API")?;
let repo = api.repo(Repo::new(MINILM_MODEL_ID.to_string(), RepoType::Model));
let config_path = repo
.get("config.json")
.context("failed to download config.json")?;
let tokenizer_path = repo
.get("tokenizer.json")
.context("failed to download tokenizer.json")?;
let weights_path = repo
.get("model.safetensors")
.context("failed to download model.safetensors")?;
Ok((config_path, tokenizer_path, weights_path))
}
fn load_from_fallback() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
{
let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
let dir = std::path::PathBuf::from(home).join(FALLBACK_MODEL_SUBDIR);
let dir = dir.as_path();
let config = dir.join("config.json");
let tokenizer = dir.join("tokenizer.json");
let weights = dir.join("model.safetensors");
if config.exists() && tokenizer.exists() && weights.exists() {
Ok((config, tokenizer, weights))
} else {
anyhow::bail!(
"model files not found in fallback dir: {}. Download them manually from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
dir.display()
)
}
}
}
#[allow(dead_code)]
pub const EMBEDDING_DIM: usize = MINILM_DIM;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_similarity_identical() {
let v = vec![1.0, 0.0, 0.0];
let sim = Embedder::cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn cosine_similarity_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 2.0, 3.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn cosine_similarity_dimension_mismatch() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0]; let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn fuse_weighted_sum() {
let p = vec![1.0, 0.0, 0.0];
let s = vec![0.0, 1.0, 0.0];
let f = Embedder::fuse(&p, &s, 0.7);
assert!((f[0] - 0.7).abs() < 1e-6);
assert!((f[1] - 0.3).abs() < 1e-6);
assert!((f[2] - 0.0).abs() < 1e-6);
}
#[test]
fn fuse_primary_weight_clamped() {
let p = vec![1.0, 1.0];
let s = vec![0.0, 0.0];
let f = Embedder::fuse(&p, &s, 2.0);
assert!((f[0] - 1.0).abs() < 1e-6);
assert!((f[1] - 1.0).abs() < 1e-6);
let f = Embedder::fuse(&p, &s, -0.5);
assert!((f[0] - 0.0).abs() < 1e-6);
assert!((f[1] - 0.0).abs() < 1e-6);
}
#[test]
fn fuse_dimension_mismatch_returns_primary() {
let p = vec![1.0, 2.0, 3.0];
let s = vec![4.0, 5.0]; let f = Embedder::fuse(&p, &s, 0.7);
assert_eq!(f, p);
}
#[test]
fn fuse_cosine_pulls_toward_context() {
let q = vec![1.0_f32, 0.0];
let ctx = vec![0.0_f32, 1.0];
let fused = Embedder::fuse(&q, &ctx, 0.7);
let sim_q = Embedder::cosine_similarity(&fused, &q);
let sim_ctx = Embedder::cosine_similarity(&fused, &ctx);
assert!(sim_q > sim_ctx);
assert!(sim_q > 0.9); assert!(sim_ctx > 0.3); }
#[test]
fn test_fuse_with_weight_one_returns_primary() {
let primary = vec![0.6_f32, -0.8, 0.0]; let secondary = vec![0.0_f32, 0.0, 1.0];
let fused = Embedder::fuse(&primary, &secondary, 1.0);
assert_eq!(fused.len(), primary.len());
for (i, (f, p)) in fused.iter().zip(primary.iter()).enumerate() {
assert!(
(f - p).abs() < 1e-6,
"fuse weight=1 idx {i}: fused {} != primary {}",
f,
p
);
}
let sim = Embedder::cosine_similarity(&fused, &primary);
assert!(
(sim - 1.0).abs() < 1e-6,
"cos(fuse(p,s,1.0), p) must be 1.0"
);
}
#[test]
fn test_fuse_is_l2_normalized() {
let primary = vec![3.0_f32, 0.0, 0.0]; let secondary = vec![0.0_f32, 4.0, 0.0]; let fused = Embedder::fuse(&primary, &secondary, 0.5);
let norm = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 2.5).abs() < 1e-5,
"fuse currently returns un-normalized vec; norm should be 2.5, got {norm}"
);
let normalized: Vec<f32> = fused.iter().map(|x| x / norm).collect();
let renorm = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(renorm - 1.0).abs() < 1e-5,
"renormalized fused must have unit norm, got {renorm}"
);
let sim = Embedder::cosine_similarity(&fused, &normalized);
assert!(
(sim - 1.0).abs() < 1e-5,
"cos(raw_fuse, normalize(raw_fuse)) must be 1.0, got {sim}"
);
}
}
#[cfg(test)]
#[allow(
clippy::unused_self,
clippy::unnecessary_wraps,
clippy::needless_pass_by_value,
clippy::wildcard_imports
)]
pub mod test_support {
use super::*;
pub enum MockEmbedder {
Local,
Ollama,
}
impl MockEmbedder {
pub fn new_local() -> Result<Self> {
Ok(Self::Local)
}
pub fn new_ollama() -> Self {
Self::Ollama
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let dim = match self {
Self::Local => MINILM_DIM,
Self::Ollama => NOMIC_DIM,
};
let hash = text.bytes().fold(0u32, |acc, b| {
acc.wrapping_mul(31).wrapping_add(u32::from(b))
});
let base = ((hash % 1000) as f32) / 1000.0;
let embedding: Vec<f32> = (0..dim)
.map(|i| base + ((i as f32) * 0.0001).sin().abs())
.collect();
Ok(embedding)
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
pub fn dim(&self) -> usize {
match self {
Self::Local => MINILM_DIM,
Self::Ollama => NOMIC_DIM,
}
}
pub fn model_description(&self) -> &str {
match self {
Self::Local => "mock-all-MiniLM-L6-v2 (384-dim, local)",
Self::Ollama => "mock-nomic-embed-text-v1.5 (768-dim, Ollama)",
}
}
}
}
#[cfg(test)]
mod mock_tests {
use super::test_support::*;
use super::*;
#[test]
fn mock_local_new() {
let embedder = MockEmbedder::new_local();
assert!(embedder.is_ok());
}
#[test]
fn mock_ollama_new() {
let embedder = MockEmbedder::new_ollama();
match embedder {
MockEmbedder::Ollama => {}
_ => panic!("expected Ollama variant"),
}
}
#[test]
fn mock_local_dim() {
let embedder = MockEmbedder::new_local().unwrap();
assert_eq!(embedder.dim(), MINILM_DIM);
}
#[test]
fn mock_ollama_dim() {
let embedder = MockEmbedder::new_ollama();
assert_eq!(embedder.dim(), NOMIC_DIM);
}
#[test]
fn mock_embed_local_deterministic() {
let embedder = MockEmbedder::new_local().unwrap();
let e1 = embedder.embed("test").unwrap();
let e2 = embedder.embed("test").unwrap();
assert_eq!(e1, e2);
}
#[test]
fn mock_embed_local_dimension() {
let embedder = MockEmbedder::new_local().unwrap();
let embedding = embedder.embed("hello world").unwrap();
assert_eq!(embedding.len(), MINILM_DIM);
}
#[test]
fn mock_embed_ollama_dimension() {
let embedder = MockEmbedder::new_ollama();
let embedding = embedder.embed("hello world").unwrap();
assert_eq!(embedding.len(), NOMIC_DIM);
}
#[test]
fn mock_embed_batch_local() {
let embedder = MockEmbedder::new_local().unwrap();
let texts = vec!["text1", "text2", "text3"];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
for emb in embeddings {
assert_eq!(emb.len(), MINILM_DIM);
}
}
#[test]
fn mock_embed_batch_ollama() {
let embedder = MockEmbedder::new_ollama();
let texts = vec!["text1", "text2"];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 2);
for emb in embeddings {
assert_eq!(emb.len(), NOMIC_DIM);
}
}
#[test]
fn mock_local_model_description() {
let embedder = MockEmbedder::new_local().unwrap();
let desc = embedder.model_description();
assert!(desc.contains("MiniLM"));
assert!(desc.contains("384"));
}
#[test]
fn mock_ollama_model_description() {
let embedder = MockEmbedder::new_ollama();
let desc = embedder.model_description();
assert!(desc.contains("nomic"));
assert!(desc.contains("768"));
}
#[test]
fn mock_embed_different_texts_different_vectors() {
let embedder = MockEmbedder::new_local().unwrap();
let e1 = embedder.embed("text one").unwrap();
let e2 = embedder.embed("text two").unwrap();
assert_ne!(e1[0], e2[0]);
}
}
#[test]
fn cache_evicts_least_recently_used() {
let v1 = vec![1.0, 2.0, 3.0];
let v2 = vec![4.0, 5.0, 6.0];
let sim = Embedder::cosine_similarity(&v1, &v2);
let expected = 32.0 / (14.0_f32.sqrt() * 77.0_f32.sqrt());
assert!((sim - expected).abs() < 1e-5);
}
#[cfg(test)]
mod w12h_extra_tests {
use super::*;
#[test]
fn for_model_nomic_without_ollama_client_errors() {
let res = Embedder::for_model(EmbeddingModel::NomicEmbedV15, None);
match res {
Err(e) => {
let err = e.to_string();
assert!(
err.contains("Ollama") || err.contains("nomic"),
"expected ollama error msg, got: {err}"
);
}
Ok(_) => panic!("expected NomicEmbedV15 without client to error"),
}
}
#[test]
fn cosine_similarity_both_zero_returns_zero() {
let a = vec![0.0_f32; 3];
let b = vec![0.0_f32; 3];
let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn cosine_similarity_negative_values() {
let a = vec![1.0_f32, 2.0, 3.0];
let b = vec![-1.0_f32, -2.0, -3.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_empty_vectors() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn fuse_zero_weight_returns_pure_secondary() {
let p = vec![1.0_f32, 0.0];
let s = vec![0.0_f32, 1.0];
let f = Embedder::fuse(&p, &s, 0.0);
assert!((f[0] - 0.0).abs() < 1e-6);
assert!((f[1] - 1.0).abs() < 1e-6);
}
#[test]
fn fuse_empty_vectors_returns_empty() {
let p: Vec<f32> = vec![];
let s: Vec<f32> = vec![];
let f = Embedder::fuse(&p, &s, 0.5);
assert!(f.is_empty());
}
#[test]
fn embedding_dim_constant_pinned() {
assert_eq!(EMBEDDING_DIM, MINILM_DIM);
assert_eq!(MINILM_DIM, 384);
assert_eq!(NOMIC_DIM, 768);
}
#[test]
fn fuse_dimension_mismatch_secondary_longer() {
let p = vec![1.0_f32, 2.0];
let s = vec![3.0_f32, 4.0, 5.0]; let f = Embedder::fuse(&p, &s, 0.5);
assert_eq!(f, p);
}
#[test]
fn cosine_similarity_dimension_mismatch_inverse() {
let a = vec![1.0_f32, 0.0];
let b = vec![1.0_f32, 0.0, 0.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
}
#[test]
fn embedder_returns_unreachable_when_model_path_missing() {
let result = Embedder::load_from_fallback();
match result {
Ok(_) => {
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") || err_msg.contains("fallback"),
"error should mention missing model files: {err_msg}"
);
}
}
}
#[test]
fn load_from_fallback_succeeds_when_files_present() {
use std::sync::Mutex;
static LOCK: Mutex<()> = Mutex::new(());
let _guard = LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let tmp = std::env::temp_dir().join(format!("ai-memory-w12h-fallback-{}", std::process::id()));
let model_dir = tmp.join(
".cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main",
);
std::fs::create_dir_all(&model_dir).expect("mk model dir");
for name in ["config.json", "tokenizer.json", "model.safetensors"] {
std::fs::write(model_dir.join(name), b"{}").expect("write placeholder");
}
let prev = std::env::var("HOME").ok();
unsafe {
std::env::set_var("HOME", &tmp);
}
let result = Embedder::load_from_fallback();
unsafe {
match prev {
Some(p) => std::env::set_var("HOME", p),
None => std::env::remove_var("HOME"),
}
}
let _ = std::fs::remove_dir_all(&tmp);
let (cfg, tok, w) = result.expect("placeholder files satisfy load_from_fallback");
assert!(cfg.ends_with("config.json"));
assert!(tok.ends_with("tokenizer.json"));
assert!(w.ends_with("model.safetensors"));
}