use anyhow::{Context, Result};
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config};
use hf_hub::{Repo, RepoType, api::sync::Api};
use std::sync::Arc;
use tokenizers::Tokenizer;
use crate::config::EmbeddingModel;
#[must_use]
pub fn embedding_document(
title: impl std::fmt::Display,
content: impl std::fmt::Display,
) -> String {
format!("{title} {content}")
}
const MINILM_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
#[allow(dead_code)]
const MINILM_DIM: usize = 384;
const MAX_SEQ_LEN: usize = 256;
const HF_DOWNLOAD_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(180);
const FALLBACK_MODEL_SUBDIR: &str =
".cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main";
pub(crate) const NOMIC_OLLAMA_MODEL: &str = "nomic-embed-text";
const NOMIC_MODEL_FAMILY_NEEDLE: &str = "nomic-embed";
pub(crate) const HF_CONFIG_FILE: &str = "config.json";
pub(crate) const HF_TOKENIZER_FILE: &str = "tokenizer.json";
pub(crate) const HF_WEIGHTS_FILE: &str = "model.safetensors";
#[allow(dead_code)]
const NOMIC_DIM: usize = 768;
const NOMIC_PREFIX_DOCUMENT: &str = "search_document: ";
const NOMIC_PREFIX_QUERY: &str = "search_query: ";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbedRole {
Document,
Query,
}
impl EmbedRole {
#[must_use]
pub fn nomic_prefix(self) -> &'static str {
match self {
Self::Document => NOMIC_PREFIX_DOCUMENT,
Self::Query => NOMIC_PREFIX_QUERY,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EmbedStatus {
Indexed,
Skipped(String),
Failed(String),
}
impl EmbedStatus {
#[must_use]
pub fn as_str(&self) -> &str {
match self {
Self::Indexed => "indexed",
Self::Skipped(_) => "skipped",
Self::Failed(_) => "failed",
}
}
#[must_use]
pub fn is_degraded(&self) -> bool {
!matches!(self, Self::Indexed)
}
#[must_use]
pub fn reason(&self) -> &str {
match self {
Self::Indexed => "",
Self::Skipped(r) | Self::Failed(r) => r.as_str(),
}
}
}
impl std::fmt::Display for EmbedStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Indexed => write!(f, "indexed"),
Self::Skipped(r) => write!(f, "skipped: {r}"),
Self::Failed(r) => write!(f, "failed: {r}"),
}
}
}
pub const EMBED_MAX_BYTES: usize = 64 * 1024;
#[must_use]
pub fn oversize_embed_reason(byte_len: usize) -> Option<String> {
(byte_len > EMBED_MAX_BYTES)
.then(|| format!("content {byte_len} bytes exceeds embed cap {EMBED_MAX_BYTES} bytes"))
}
pub trait Embed: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
self.embed(text)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
fn is_degraded(&self) -> bool {
false
}
}
#[derive(Clone)]
pub enum Embedder {
Local {
model: Arc<BertModel>,
tokenizer: Arc<Tokenizer>,
device: Device,
},
Ollama {
client: Arc<crate::llm::OllamaClient>,
model_name: String,
dim: usize,
degraded: Arc<std::sync::atomic::AtomicBool>,
},
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CosineComparison {
Comparable(f32),
DimensionMismatch {
query_dim: usize,
stored_dim: usize,
},
}
impl Embedder {
#[allow(dead_code)]
pub fn new() -> Result<Self> {
Self::new_local()
}
pub fn new_local() -> Result<Self> {
let device = Device::Cpu;
let (config_path, tokenizer_path, weights_path) = if Self::remote_fetch_disabled() {
Self::load_from_fallback()?
} else {
match Self::download_within(HF_DOWNLOAD_TIMEOUT, Self::download_via_hf_hub) {
Ok(paths) => paths,
Err(e) => {
eprintln!("ai-memory: hf-hub download failed ({e}), trying fallback dir");
Self::load_from_fallback()?
}
}
};
let config_data =
std::fs::read_to_string(&config_path).context("failed to read config.json")?;
let config: Config =
serde_json::from_str(&config_data).context("failed to parse config.json")?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
let truncation = tokenizers::TruncationParams {
max_length: MAX_SEQ_LEN,
..Default::default()
};
tokenizer
.with_truncation(Some(truncation))
.map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
tokenizer.with_padding(None);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
.context("failed to load model weights")?
};
let model = BertModel::load(vb, &config).context("failed to build BertModel")?;
Ok(Self::Local {
model: Arc::new(model),
tokenizer: Arc::new(tokenizer),
device,
})
}
pub fn new_ollama(client: Arc<crate::llm::OllamaClient>) -> Self {
Self::new_remote(client, NOMIC_OLLAMA_MODEL.to_string(), NOMIC_DIM)
}
#[must_use]
pub fn new_remote(
client: Arc<crate::llm::OllamaClient>,
model_name: String,
dim: usize,
) -> Self {
Self::Ollama {
client,
model_name,
dim,
degraded: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
pub fn from_resolved(
resolved: &crate::config::ResolvedEmbeddings,
tier_model: Option<crate::config::EmbeddingModel>,
) -> Result<Option<Self>> {
let Some(tier_model) = tier_model else {
return Ok(None);
};
if crate::config::is_api_embed_backend(&resolved.backend) {
let Some(dim) = resolved.embedding_dim else {
anyhow::bail!(
"embedding model {:?} (backend {:?}) has no known vector dim — \
pick a model from the known-dims table (override with the \
{} env var) or set the `[embeddings].dim` escape hatch in \
config.toml (#1598)",
resolved.model,
resolved.backend,
crate::config::ENV_EMBED_MODEL,
);
};
let api_key = resolved.api_key().unwrap_or_default();
let client = crate::llm::OllamaClient::new_openai_compatible(
&resolved.url,
&resolved.model,
api_key,
)
.context("failed to build OpenAI-compatible embed client (#1598)")?
.with_embed_dimensions(resolved.requested_dim);
return Ok(Some(Self::new_remote(
Arc::new(client),
resolved.model.clone(),
dim as usize,
)));
}
match tier_model {
crate::config::EmbeddingModel::MiniLmL6V2 => {
Self::for_model(tier_model, None).map(Some)
}
crate::config::EmbeddingModel::NomicEmbedV15 => {
let client =
crate::llm::OllamaClient::new_with_url(&resolved.url, NOMIC_OLLAMA_MODEL)
.context("failed to build Ollama embed client")?;
Self::for_model(tier_model, Some(Arc::new(client))).map(Some)
}
}
}
pub fn for_model(
model: EmbeddingModel,
ollama_client: Option<Arc<crate::llm::OllamaClient>>,
) -> Result<Self> {
match model {
EmbeddingModel::MiniLmL6V2 => Self::new_local(),
EmbeddingModel::NomicEmbedV15 => {
let client = ollama_client.ok_or_else(|| {
anyhow::anyhow!("nomic-embed-text-v1.5 requires Ollama (smart tier or above)")
})?;
if let Err(e) = client.ensure_embed_model(NOMIC_OLLAMA_MODEL) {
eprintln!("ai-memory: warning: failed to pull nomic model: {e}");
}
Ok(Self::new_ollama(client))
}
}
}
#[allow(dead_code)]
pub fn dim(&self) -> usize {
match self {
Self::Local { .. } => MINILM_DIM,
Self::Ollama { dim, .. } => *dim,
}
}
#[must_use]
pub fn model_description(&self) -> String {
match self {
Self::Local { .. } => "all-MiniLM-L6-v2 (384-dim, local)".to_string(),
Self::Ollama {
model_name, dim, ..
} => format!("{model_name} ({dim}-dim, remote)"),
}
}
#[must_use]
pub fn is_degraded(&self) -> bool {
match self {
Self::Local { .. } => false,
Self::Ollama { degraded, .. } => degraded.load(std::sync::atomic::Ordering::Relaxed),
}
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.embed_with_role(text, EmbedRole::Document)
}
pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
self.embed_with_role(text, EmbedRole::Query)
}
pub fn embed_with_role(&self, text: &str, role: EmbedRole) -> Result<Vec<f32>> {
match self {
Self::Local {
model,
tokenizer,
device,
} => {
Self::embed_local(model, tokenizer, device, text)
}
Self::Ollama {
client,
model_name,
degraded,
..
} => {
let result = if Self::model_requires_nomic_prefix(model_name) {
let prefixed = format!("{}{}", role.nomic_prefix(), text);
client.embed_text(&prefixed, model_name)
} else {
client.embed_text(text, model_name)
};
degraded.store(result.is_err(), std::sync::atomic::Ordering::Relaxed);
result
}
}
}
fn model_requires_nomic_prefix(model_name: &str) -> bool {
model_name
.to_ascii_lowercase()
.contains(NOMIC_MODEL_FAMILY_NEEDLE)
}
pub fn embed_with_status(&self, text: &str) -> (Option<Vec<f32>>, EmbedStatus) {
if text.is_empty() {
return (None, EmbedStatus::Skipped("empty content".to_string()));
}
if let Some(reason) = oversize_embed_reason(text.len()) {
return (None, EmbedStatus::Skipped(reason));
}
match self.embed(text) {
Ok(v) if v.is_empty() => (
None,
EmbedStatus::Failed("embedder returned empty vector".to_string()),
),
Ok(v) => (Some(v), EmbedStatus::Indexed),
Err(e) => {
let reason = format!("{e:#}");
tracing::warn!(target: "embeddings.degrade", reason = %reason, "embed_with_status: embedder failed");
(None, EmbedStatus::Failed(reason))
}
}
}
fn embed_local(
model: &BertModel,
tokenizer: &Tokenizer,
device: &Device,
text: &str,
) -> Result<Vec<f32>> {
let encoding = tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("tokenisation failed: {e}"))?;
let input_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let token_type_ids = encoding.get_type_ids();
let seq_len = input_ids.len();
let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
let attention_mask_tensor = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
let hidden = model
.forward(&input_ids, &token_type_ids, Some(&attention_mask_tensor))
.context("model forward pass failed")?;
let mask = attention_mask_tensor
.unsqueeze(2)?
.to_dtype(candle_core::DType::F32)?
.broadcast_as(hidden.shape())?;
let masked = hidden.mul(&mask)?;
let summed = masked.sum(1)?;
let count = mask.sum(1)?.clamp(1e-9, f64::MAX)?;
let pooled = summed.div(&count)?;
let norm = pooled
.sqr()?
.sum_keepdim(1)?
.sqrt()?
.clamp(1e-12, f64::MAX)?;
let normalised = pooled.broadcast_div(&norm)?;
let embedding: Vec<f32> = normalised.squeeze(0)?.to_vec1()?;
Ok(embedding)
}
#[allow(dead_code)]
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
match self {
Self::Local {
model,
tokenizer,
device,
} => Self::embed_local_batch(model, tokenizer, device, texts),
Self::Ollama {
client,
model_name,
degraded,
..
} => {
let result = if Self::model_requires_nomic_prefix(model_name) {
let prefixed: Vec<String> = texts
.iter()
.map(|t| format!("{}{}", EmbedRole::Document.nomic_prefix(), t))
.collect();
let refs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
client.embed_texts(&refs, model_name)
} else {
client.embed_texts(texts, model_name)
};
degraded.store(result.is_err(), std::sync::atomic::Ordering::Relaxed);
result
}
}
}
fn embed_local_batch(
model: &BertModel,
tokenizer: &Tokenizer,
device: &Device,
texts: &[&str],
) -> Result<Vec<Vec<f32>>> {
let inputs: Vec<&str> = texts.to_vec();
let encodings = tokenizer
.encode_batch(inputs, true)
.map_err(|e| anyhow::anyhow!("tokenisation batch failed: {e}"))?;
let max_len = encodings
.iter()
.map(tokenizers::Encoding::len)
.max()
.unwrap_or(0);
if max_len == 0 {
return Ok(texts.iter().map(|_| Vec::new()).collect());
}
let batch_size = encodings.len();
let mut input_ids_flat = Vec::with_capacity(batch_size * max_len);
let mut attention_mask_flat = Vec::with_capacity(batch_size * max_len);
let mut token_type_ids_flat = Vec::with_capacity(batch_size * max_len);
for enc in &encodings {
let ids = enc.get_ids();
let mask = enc.get_attention_mask();
let tt = enc.get_type_ids();
let len = ids.len();
input_ids_flat.extend_from_slice(ids);
attention_mask_flat.extend_from_slice(mask);
token_type_ids_flat.extend_from_slice(tt);
for _ in len..max_len {
input_ids_flat.push(0);
attention_mask_flat.push(0);
token_type_ids_flat.push(0);
}
}
let input_ids =
Tensor::new(input_ids_flat.as_slice(), device)?.reshape((batch_size, max_len))?;
let attention_mask_tensor =
Tensor::new(attention_mask_flat.as_slice(), device)?.reshape((batch_size, max_len))?;
let token_type_ids =
Tensor::new(token_type_ids_flat.as_slice(), device)?.reshape((batch_size, max_len))?;
let hidden = model
.forward(&input_ids, &token_type_ids, Some(&attention_mask_tensor))
.context("model forward pass (batched) failed")?;
let mask = attention_mask_tensor
.unsqueeze(2)?
.to_dtype(candle_core::DType::F32)?
.broadcast_as(hidden.shape())?;
let masked = hidden.mul(&mask)?;
let summed = masked.sum(1)?;
let count = mask.sum(1)?.clamp(1e-9, f64::MAX)?;
let pooled = summed.div(&count)?;
let norm = pooled
.sqr()?
.sum_keepdim(1)?
.sqrt()?
.clamp(1e-12, f64::MAX)?;
let normalised = pooled.broadcast_div(&norm)?;
let mut out: Vec<Vec<f32>> = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let row: Vec<f32> = normalised.get(i)?.to_vec1()?;
out.push(row);
}
Ok(out)
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot: f32 = 0.0;
let mut sq_a: f32 = 0.0;
let mut sq_b: f32 = 0.0;
for (&x, &y) in a.iter().zip(b.iter()) {
dot += x * y;
sq_a += x * x;
sq_b += y * y;
}
let denom = sq_a.sqrt() * sq_b.sqrt();
if denom < 1e-12 {
return 0.0;
}
let score = dot / denom;
if score.is_finite() { score } else { 0.0 }
}
#[must_use]
pub fn cosine_similarity_checked(query: &[f32], stored: &[f32]) -> CosineComparison {
if query.len() != stored.len() {
return CosineComparison::DimensionMismatch {
query_dim: query.len(),
stored_dim: stored.len(),
};
}
CosineComparison::Comparable(Self::cosine_similarity(query, stored))
}
#[must_use]
pub fn fuse(primary: &[f32], secondary: &[f32], primary_weight: f32) -> Vec<f32> {
if primary.len() != secondary.len() {
return primary.to_vec();
}
let w = primary_weight.clamp(0.0, 1.0);
let one_minus_w = 1.0 - w;
primary
.iter()
.zip(secondary.iter())
.map(|(p, s)| w * p + one_minus_w * s)
.collect()
}
fn download_within<F>(
budget: std::time::Duration,
f: F,
) -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
where
F: FnOnce() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
+ Send
+ 'static,
{
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let _ = tx.send(f());
});
match rx.recv_timeout(budget) {
Ok(result) => result,
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => anyhow::bail!(
"hf-hub model download exceeded {}s budget",
budget.as_secs()
),
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
anyhow::bail!("hf-hub model download thread terminated without a result")
}
}
}
fn download_via_hf_hub() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
{
let api = Api::new().context("failed to initialise HuggingFace Hub API")?;
let repo = api.repo(Repo::new(MINILM_MODEL_ID.to_string(), RepoType::Model));
let config_path = repo
.get(HF_CONFIG_FILE)
.context("failed to download config.json")?;
let tokenizer_path = repo
.get(HF_TOKENIZER_FILE)
.context("failed to download tokenizer.json")?;
let weights_path = repo
.get(HF_WEIGHTS_FILE)
.context("failed to download model.safetensors")?;
Ok((config_path, tokenizer_path, weights_path))
}
fn remote_fetch_disabled() -> bool {
let truthy = |name: &str| {
std::env::var(name)
.map(|v| matches!(v.trim(), "1" | "true" | "TRUE" | "yes" | "on"))
.unwrap_or(false)
};
truthy("AI_MEMORY_EMBED_OFFLINE") || truthy("HF_HUB_OFFLINE")
}
fn load_from_fallback() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
{
let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
let dir = std::path::PathBuf::from(home).join(FALLBACK_MODEL_SUBDIR);
let dir = dir.as_path();
let config = dir.join(HF_CONFIG_FILE);
let tokenizer = dir.join(HF_TOKENIZER_FILE);
let weights = dir.join(HF_WEIGHTS_FILE);
if config.exists() && tokenizer.exists() && weights.exists() {
Ok((config, tokenizer, weights))
} else {
anyhow::bail!(
"model files not found in fallback dir: {}. Download them manually from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
dir.display()
)
}
}
}
impl Embed for Embedder {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
Self::embed(self, text)
}
fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
Self::embed_query(self, text)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
Self::embed_batch(self, texts)
}
fn is_degraded(&self) -> bool {
Self::is_degraded(self)
}
}
#[allow(dead_code)]
pub const EMBEDDING_DIM: usize = MINILM_DIM;
pub const EMBEDDING_HEADER_LE_F32: u8 = 0x01;
pub const EMBEDDING_HEADER_BE_F32: u8 = 0x02;
#[derive(Debug)]
pub enum EmbeddingFormatError {
UnknownHeader(u8),
BigEndianUnsupported,
MalformedLength(usize),
}
impl std::fmt::Display for EmbeddingFormatError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnknownHeader(b) => write!(f, "unknown embedding header byte: 0x{b:02x}"),
Self::BigEndianUnsupported => write!(
f,
"big-endian f32 embeddings (header 0x02) are not supported until v0.7"
),
Self::MalformedLength(n) => {
write!(f, "embedding payload length {n} is not a multiple of 4")
}
}
}
}
impl std::error::Error for EmbeddingFormatError {}
#[must_use]
pub fn encode_embedding_blob(embedding: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(1 + embedding.len() * 4);
out.push(EMBEDDING_HEADER_LE_F32);
for f in embedding {
out.extend_from_slice(&f.to_le_bytes());
}
out
}
pub fn decode_embedding_blob(bytes: &[u8]) -> Result<Vec<f32>, EmbeddingFormatError> {
if bytes.is_empty() {
return Ok(Vec::new());
}
if bytes.len() % 4 == 1 {
let header = bytes[0];
return match header {
EMBEDDING_HEADER_LE_F32 => {
let payload = &bytes[1..];
Ok(payload
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect())
}
EMBEDDING_HEADER_BE_F32 => Err(EmbeddingFormatError::BigEndianUnsupported),
other => Err(EmbeddingFormatError::UnknownHeader(other)),
};
}
if bytes.len() % 4 == 0 {
return Ok(bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect());
}
Err(EmbeddingFormatError::MalformedLength(bytes.len()))
}
#[must_use]
pub fn decoded_dim(bytes: &[u8]) -> usize {
if bytes.is_empty() {
return 0;
}
if bytes.len() % 4 == 1 {
return (bytes.len() - 1) / 4;
}
bytes.len() / 4
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_similarity_identical() {
let v = vec![1.0, 0.0, 0.0];
let sim = Embedder::cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn embed_role_maps_to_nomic_prefix() {
assert_eq!(EmbedRole::Document.nomic_prefix(), NOMIC_PREFIX_DOCUMENT);
assert_eq!(EmbedRole::Query.nomic_prefix(), NOMIC_PREFIX_QUERY);
assert_ne!(
EmbedRole::Document.nomic_prefix(),
EmbedRole::Query.nomic_prefix()
);
assert!(NOMIC_PREFIX_DOCUMENT.ends_with(' '));
assert!(NOMIC_PREFIX_QUERY.ends_with(' '));
}
#[test]
fn nomic_prefix_gating_is_model_scoped() {
assert!(Embedder::model_requires_nomic_prefix(NOMIC_OLLAMA_MODEL));
assert!(Embedder::model_requires_nomic_prefix(&format!(
"{NOMIC_OLLAMA_MODEL}:v1.5"
)));
let other_embed_models = ["mxbai-embed-large", "all-minilm"];
for model in other_embed_models {
assert!(!Embedder::model_requires_nomic_prefix(model));
}
}
fn offline_openai_compatible_client() -> Arc<crate::llm::OllamaClient> {
Arc::new(
crate::llm::OllamaClient::new_openai_compatible(
"http://127.0.0.1:1",
"test-embed-model",
"",
)
.expect("client builds without network"),
)
}
#[test]
fn new_remote_carries_dynamic_dim_and_truthful_description_1598() {
let embedder = Embedder::new_remote(
offline_openai_compatible_client(),
"google/gemini-embedding-2".to_string(),
3072,
);
assert_eq!(embedder.dim(), 3072);
assert_eq!(
embedder.model_description(),
"google/gemini-embedding-2 (3072-dim, remote)"
);
assert!(!embedder.is_degraded());
}
#[test]
fn new_ollama_preserves_nomic_defaults_1598() {
let embedder = Embedder::new_ollama(offline_openai_compatible_client());
assert_eq!(embedder.dim(), NOMIC_DIM);
let desc = embedder.model_description();
assert!(desc.contains(NOMIC_OLLAMA_MODEL), "desc: {desc}");
assert!(desc.contains("768"), "desc: {desc}");
assert!(!embedder.is_degraded());
}
#[test]
fn remote_embed_failure_latches_degraded_flag_1598() {
let embedder = Embedder::new_remote(
offline_openai_compatible_client(),
"test-embed-model".to_string(),
8,
);
assert!(!embedder.is_degraded());
let err = embedder.embed("hello");
assert!(err.is_err(), "embed against a closed port must error");
assert!(embedder.is_degraded());
}
#[test]
fn local_embedder_is_never_degraded_via_trait_default_1598() {
let mock = crate::embeddings::test_support::MockEmbedder::new_ollama();
let as_trait: &dyn Embed = &mock;
assert!(!as_trait.is_degraded());
}
#[test]
fn from_resolved_keyword_tier_yields_none_1598() {
let resolved = crate::config::ResolvedEmbeddings::from_parts(
"openrouter".to_string(),
"https://openrouter.ai/api/v1".to_string(),
"google/gemini-embedding-2".to_string(),
Some(3072),
None,
);
let built = Embedder::from_resolved(&resolved, None).expect("keyword tier is Ok(None)");
assert!(built.is_none());
}
#[test]
fn from_resolved_api_backend_unknown_dim_bails_with_escape_hatch_1598() {
let resolved = crate::config::ResolvedEmbeddings::from_parts(
"openrouter".to_string(),
"https://openrouter.ai/api/v1".to_string(),
"some/unknown-embed-model".to_string(),
None,
None,
);
let result = Embedder::from_resolved(
&resolved,
Some(crate::config::EmbeddingModel::NomicEmbedV15),
);
let Err(err) = result else {
panic!("unknown dim on an API backend must fail closed");
};
let msg = format!("{err:#}");
assert!(msg.contains("dim"), "error must name the dim gap: {msg}");
assert!(
msg.contains("[embeddings].dim"),
"error must name the config escape hatch: {msg}"
);
assert!(
msg.contains(crate::config::ENV_EMBED_MODEL),
"error must name the model env var: {msg}"
);
}
#[test]
fn from_resolved_api_backend_builds_remote_embedder_1598() {
let resolved = crate::config::ResolvedEmbeddings::from_parts(
"openrouter".to_string(),
"https://openrouter.ai/api/v1".to_string(),
"google/gemini-embedding-2".to_string(),
Some(3072),
None,
);
let built = Embedder::from_resolved(
&resolved,
Some(crate::config::EmbeddingModel::NomicEmbedV15),
)
.expect("API-backend construction succeeds")
.expect("tier gates embeddings on");
assert!(matches!(built, Embedder::Ollama { .. }));
assert_eq!(built.dim(), 3072);
assert_eq!(
built.model_description(),
"google/gemini-embedding-2 (3072-dim, remote)"
);
}
#[test]
fn nomic_prefix_gating_covers_hf_id_and_case_forms_1598() {
assert!(Embedder::model_requires_nomic_prefix(
"nomic-ai/nomic-embed-text-v1.5"
));
assert!(Embedder::model_requires_nomic_prefix(
"nomic-embed-text-v1.5"
));
assert!(Embedder::model_requires_nomic_prefix(
"Nomic-AI/Nomic-Embed-Text-v1.5"
));
assert!(!Embedder::model_requires_nomic_prefix(
"google/gemini-embedding-2"
));
assert!(!Embedder::model_requires_nomic_prefix(
"ibm-granite/granite-embedding-125m-english"
));
}
#[test]
fn cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn cosine_similarity_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 2.0, 3.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn cosine_similarity_dimension_mismatch() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0]; let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn cosine_similarity_checked_comparable_matches_plain_cosine() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![2.0, 1.0, 0.5];
let plain = Embedder::cosine_similarity(&a, &b);
match Embedder::cosine_similarity_checked(&a, &b) {
CosineComparison::Comparable(c) => assert!((c - plain).abs() < 1e-6),
CosineComparison::DimensionMismatch { .. } => {
panic!("equal-length vectors must compare as Comparable")
}
}
}
#[test]
fn cosine_similarity_checked_flags_dimension_mismatch() {
let query = vec![0.0_f32; 5];
let stored = vec![0.0_f32; 3];
match Embedder::cosine_similarity_checked(&query, &stored) {
CosineComparison::DimensionMismatch {
query_dim,
stored_dim,
} => {
assert_eq!(query_dim, 5);
assert_eq!(stored_dim, 3);
}
CosineComparison::Comparable(_) => {
panic!("differing-length vectors must report DimensionMismatch")
}
}
}
#[test]
fn encode_embedding_blob_prefixes_le_header() {
let v = vec![1.0_f32, 2.0_f32];
let blob = encode_embedding_blob(&v);
assert_eq!(blob.len(), 1 + 8);
assert_eq!(blob[0], EMBEDDING_HEADER_LE_F32);
}
#[test]
fn decode_embedding_blob_round_trip_v17() {
let v = vec![1.5_f32, -0.25, 0.0];
let blob = encode_embedding_blob(&v);
let back = decode_embedding_blob(&blob).expect("round-trips");
assert_eq!(back, v);
}
#[test]
fn decode_embedding_blob_legacy_unheaded_le_f32() {
let v = vec![1.0_f32, 2.0, 3.0];
let raw: Vec<u8> = v.iter().flat_map(|f| f.to_le_bytes()).collect();
let back = decode_embedding_blob(&raw).expect("legacy decodes");
assert_eq!(back, v);
}
#[test]
fn decode_embedding_blob_rejects_be_header() {
let mut blob = vec![EMBEDDING_HEADER_BE_F32];
blob.extend_from_slice(&1.0_f32.to_be_bytes());
let err = decode_embedding_blob(&blob).expect_err("BE rejected");
assert!(matches!(err, EmbeddingFormatError::BigEndianUnsupported));
}
#[test]
fn decode_embedding_blob_rejects_unknown_header() {
let mut blob = vec![0xff_u8];
blob.extend_from_slice(&1.0_f32.to_le_bytes());
let err = decode_embedding_blob(&blob).expect_err("unknown header rejected");
assert!(matches!(err, EmbeddingFormatError::UnknownHeader(0xff)));
}
#[test]
fn decode_embedding_blob_rejects_malformed_length() {
let blob = vec![0u8; 6];
let err = decode_embedding_blob(&blob).expect_err("malformed length rejected");
assert!(matches!(err, EmbeddingFormatError::MalformedLength(6)));
}
#[test]
fn decoded_dim_handles_all_three_paths() {
assert_eq!(decoded_dim(&[]), 0);
let raw: Vec<u8> = vec![0u8; 16];
assert_eq!(decoded_dim(&raw), 4);
let mut headed = vec![EMBEDDING_HEADER_LE_F32];
headed.extend_from_slice(&[0u8; 12]);
assert_eq!(decoded_dim(&headed), 3);
}
#[test]
fn fuse_weighted_sum() {
let p = vec![1.0, 0.0, 0.0];
let s = vec![0.0, 1.0, 0.0];
let f = Embedder::fuse(&p, &s, 0.7);
assert!((f[0] - 0.7).abs() < 1e-6);
assert!((f[1] - 0.3).abs() < 1e-6);
assert!((f[2] - 0.0).abs() < 1e-6);
}
#[test]
fn fuse_primary_weight_clamped() {
let p = vec![1.0, 1.0];
let s = vec![0.0, 0.0];
let f = Embedder::fuse(&p, &s, 2.0);
assert!((f[0] - 1.0).abs() < 1e-6);
assert!((f[1] - 1.0).abs() < 1e-6);
let f = Embedder::fuse(&p, &s, -0.5);
assert!((f[0] - 0.0).abs() < 1e-6);
assert!((f[1] - 0.0).abs() < 1e-6);
}
#[test]
fn fuse_dimension_mismatch_returns_primary() {
let p = vec![1.0, 2.0, 3.0];
let s = vec![4.0, 5.0]; let f = Embedder::fuse(&p, &s, 0.7);
assert_eq!(f, p);
}
#[test]
fn fuse_cosine_pulls_toward_context() {
let q = vec![1.0_f32, 0.0];
let ctx = vec![0.0_f32, 1.0];
let fused = Embedder::fuse(&q, &ctx, 0.7);
let sim_q = Embedder::cosine_similarity(&fused, &q);
let sim_ctx = Embedder::cosine_similarity(&fused, &ctx);
assert!(sim_q > sim_ctx);
assert!(sim_q > 0.9); assert!(sim_ctx > 0.3); }
#[test]
fn test_fuse_with_weight_one_returns_primary() {
let primary = vec![0.6_f32, -0.8, 0.0]; let secondary = vec![0.0_f32, 0.0, 1.0];
let fused = Embedder::fuse(&primary, &secondary, 1.0);
assert_eq!(fused.len(), primary.len());
for (i, (f, p)) in fused.iter().zip(primary.iter()).enumerate() {
assert!(
(f - p).abs() < 1e-6,
"fuse weight=1 idx {i}: fused {} != primary {}",
f,
p
);
}
let sim = Embedder::cosine_similarity(&fused, &primary);
assert!(
(sim - 1.0).abs() < 1e-6,
"cos(fuse(p,s,1.0), p) must be 1.0"
);
}
#[test]
fn embed_status_as_str_each_variant() {
assert_eq!(EmbedStatus::Indexed.as_str(), "indexed");
assert_eq!(
EmbedStatus::Skipped("too big".to_string()).as_str(),
"skipped"
);
assert_eq!(
EmbedStatus::Failed("ollama down".to_string()).as_str(),
"failed"
);
}
#[test]
fn oversize_embed_reason_boundary_1595() {
assert_eq!(oversize_embed_reason(0), None);
assert_eq!(
oversize_embed_reason(EMBED_MAX_BYTES),
None,
"cap itself is allowed"
);
let reason = oversize_embed_reason(EMBED_MAX_BYTES + 1).expect("over-cap must skip");
assert!(
reason.contains(&(EMBED_MAX_BYTES + 1).to_string())
&& reason.contains(&EMBED_MAX_BYTES.to_string()),
"reason must name size + cap, got: {reason}"
);
}
#[test]
fn embed_status_is_degraded_only_for_non_indexed() {
assert!(!EmbedStatus::Indexed.is_degraded());
assert!(EmbedStatus::Skipped("x".to_string()).is_degraded());
assert!(EmbedStatus::Failed("x".to_string()).is_degraded());
}
#[test]
fn embed_status_reason_helper() {
assert_eq!(EmbedStatus::Indexed.reason(), "");
assert_eq!(EmbedStatus::Skipped("r1".to_string()).reason(), "r1");
assert_eq!(EmbedStatus::Failed("r2".to_string()).reason(), "r2");
}
#[test]
fn embed_status_display_includes_reason() {
assert_eq!(format!("{}", EmbedStatus::Indexed), "indexed");
assert_eq!(
format!("{}", EmbedStatus::Skipped("oversize".to_string())),
"skipped: oversize"
);
assert_eq!(
format!("{}", EmbedStatus::Failed("timeout".to_string())),
"failed: timeout"
);
}
#[test]
fn embedding_format_error_display_each_variant() {
let unk = EmbeddingFormatError::UnknownHeader(0xab);
assert!(unk.to_string().contains("0xab"));
let be = EmbeddingFormatError::BigEndianUnsupported;
assert!(be.to_string().contains("big-endian"));
let ml = EmbeddingFormatError::MalformedLength(7);
assert!(ml.to_string().contains("7"));
}
#[test]
fn embedding_format_error_is_std_error() {
let e: Box<dyn std::error::Error> = Box::new(EmbeddingFormatError::BigEndianUnsupported);
assert!(e.source().is_none());
}
#[test]
fn decode_embedding_blob_empty_returns_empty_vec() {
let v = decode_embedding_blob(&[]).expect("empty decodes to empty");
assert!(v.is_empty());
}
#[test]
fn test_fuse_is_l2_normalized() {
let primary = vec![3.0_f32, 0.0, 0.0]; let secondary = vec![0.0_f32, 4.0, 0.0]; let fused = Embedder::fuse(&primary, &secondary, 0.5);
let norm = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 2.5).abs() < 1e-5,
"fuse currently returns un-normalized vec; norm should be 2.5, got {norm}"
);
let normalized: Vec<f32> = fused.iter().map(|x| x / norm).collect();
let renorm = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(renorm - 1.0).abs() < 1e-5,
"renormalized fused must have unit norm, got {renorm}"
);
let sim = Embedder::cosine_similarity(&fused, &normalized);
assert!(
(sim - 1.0).abs() < 1e-5,
"cos(raw_fuse, normalize(raw_fuse)) must be 1.0, got {sim}"
);
}
}
#[cfg(test)]
#[allow(
clippy::unused_self,
clippy::unnecessary_wraps,
clippy::needless_pass_by_value,
clippy::wildcard_imports
)]
pub mod test_support {
use super::*;
pub enum MockEmbedder {
Local,
Ollama,
}
impl MockEmbedder {
pub fn new_local() -> Result<Self> {
Ok(Self::Local)
}
pub fn new_ollama() -> Self {
Self::Ollama
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let dim = match self {
Self::Local => MINILM_DIM,
Self::Ollama => NOMIC_DIM,
};
let hash = text.bytes().fold(0u32, |acc, b| {
acc.wrapping_mul(31).wrapping_add(u32::from(b))
});
let base = ((hash % 1000) as f32) / 1000.0;
let embedding: Vec<f32> = (0..dim)
.map(|i| base + ((i as f32) * 0.0001).sin().abs())
.collect();
Ok(embedding)
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
pub fn dim(&self) -> usize {
match self {
Self::Local => MINILM_DIM,
Self::Ollama => NOMIC_DIM,
}
}
pub fn model_description(&self) -> &str {
match self {
Self::Local => "mock-all-MiniLM-L6-v2 (384-dim, local)",
Self::Ollama => "mock-nomic-embed-text-v1.5 (768-dim, Ollama)",
}
}
}
impl Embed for MockEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
Self::embed(self, text)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
Self::embed_batch(self, texts)
}
}
pub struct FailingEmbedder;
impl Embed for FailingEmbedder {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Err(anyhow::anyhow!("test: synthetic embed failure"))
}
fn embed_batch(&self, _texts: &[&str]) -> Result<Vec<Vec<f32>>> {
Err(anyhow::anyhow!("test: synthetic embed_batch failure"))
}
}
}
#[cfg(test)]
mod mock_tests {
use super::test_support::*;
use super::*;
#[test]
fn mock_local_new() {
let embedder = MockEmbedder::new_local();
assert!(embedder.is_ok());
}
#[test]
fn mock_ollama_new() {
let embedder = MockEmbedder::new_ollama();
match embedder {
MockEmbedder::Ollama => {}
_ => panic!("expected Ollama variant"),
}
}
#[test]
fn mock_local_dim() {
let embedder = MockEmbedder::new_local().unwrap();
assert_eq!(embedder.dim(), MINILM_DIM);
}
#[test]
fn mock_ollama_dim() {
let embedder = MockEmbedder::new_ollama();
assert_eq!(embedder.dim(), NOMIC_DIM);
}
#[test]
fn mock_embed_local_deterministic() {
let embedder = MockEmbedder::new_local().unwrap();
let e1 = embedder.embed("test").unwrap();
let e2 = embedder.embed("test").unwrap();
assert_eq!(e1, e2);
}
#[test]
fn mock_embed_local_dimension() {
let embedder = MockEmbedder::new_local().unwrap();
let embedding = embedder.embed("hello world").unwrap();
assert_eq!(embedding.len(), MINILM_DIM);
}
#[test]
fn mock_embed_ollama_dimension() {
let embedder = MockEmbedder::new_ollama();
let embedding = embedder.embed("hello world").unwrap();
assert_eq!(embedding.len(), NOMIC_DIM);
}
#[test]
fn mock_embed_batch_local() {
let embedder = MockEmbedder::new_local().unwrap();
let texts = vec!["text1", "text2", "text3"];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
for emb in embeddings {
assert_eq!(emb.len(), MINILM_DIM);
}
}
#[test]
fn mock_embed_batch_ollama() {
let embedder = MockEmbedder::new_ollama();
let texts = vec!["text1", "text2"];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 2);
for emb in embeddings {
assert_eq!(emb.len(), NOMIC_DIM);
}
}
#[test]
fn mock_local_model_description() {
let embedder = MockEmbedder::new_local().unwrap();
let desc = embedder.model_description();
assert!(desc.contains("MiniLM"));
assert!(desc.contains("384"));
}
#[test]
fn mock_ollama_model_description() {
let embedder = MockEmbedder::new_ollama();
let desc = embedder.model_description();
assert!(desc.contains("nomic"));
assert!(desc.contains("768"));
}
#[test]
fn mock_embed_different_texts_different_vectors() {
let embedder = MockEmbedder::new_local().unwrap();
let e1 = embedder.embed("text one").unwrap();
let e2 = embedder.embed("text two").unwrap();
assert_ne!(e1[0], e2[0]);
}
}
#[test]
fn cache_evicts_least_recently_used() {
let v1 = vec![1.0, 2.0, 3.0];
let v2 = vec![4.0, 5.0, 6.0];
let sim = Embedder::cosine_similarity(&v1, &v2);
let expected = 32.0 / (14.0_f32.sqrt() * 77.0_f32.sqrt());
assert!((sim - expected).abs() < 1e-5);
}
#[cfg(test)]
mod w12h_extra_tests {
use super::*;
#[test]
fn for_model_nomic_without_ollama_client_errors() {
let res = Embedder::for_model(EmbeddingModel::NomicEmbedV15, None);
match res {
Err(e) => {
let err = e.to_string();
assert!(
err.contains("Ollama") || err.contains("nomic"),
"expected ollama error msg, got: {err}"
);
}
Ok(_) => panic!("expected NomicEmbedV15 without client to error"),
}
}
#[test]
fn cosine_similarity_both_zero_returns_zero() {
let a = vec![0.0_f32; 3];
let b = vec![0.0_f32; 3];
let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn cosine_similarity_negative_values() {
let a = vec![1.0_f32, 2.0, 3.0];
let b = vec![-1.0_f32, -2.0, -3.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_empty_vectors() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn fuse_zero_weight_returns_pure_secondary() {
let p = vec![1.0_f32, 0.0];
let s = vec![0.0_f32, 1.0];
let f = Embedder::fuse(&p, &s, 0.0);
assert!((f[0] - 0.0).abs() < 1e-6);
assert!((f[1] - 1.0).abs() < 1e-6);
}
#[test]
fn fuse_empty_vectors_returns_empty() {
let p: Vec<f32> = vec![];
let s: Vec<f32> = vec![];
let f = Embedder::fuse(&p, &s, 0.5);
assert!(f.is_empty());
}
#[test]
fn embedding_dim_constant_pinned() {
assert_eq!(EMBEDDING_DIM, MINILM_DIM);
assert_eq!(MINILM_DIM, 384);
assert_eq!(NOMIC_DIM, 768);
}
#[test]
fn fuse_dimension_mismatch_secondary_longer() {
let p = vec![1.0_f32, 2.0];
let s = vec![3.0_f32, 4.0, 5.0]; let f = Embedder::fuse(&p, &s, 0.5);
assert_eq!(f, p);
}
#[test]
fn cosine_similarity_dimension_mismatch_inverse() {
let a = vec![1.0_f32, 0.0];
let b = vec![1.0_f32, 0.0, 0.0];
let sim = Embedder::cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn pr9i_for_model_minilm_dispatches_to_new_local() {
let res = Embedder::for_model(EmbeddingModel::MiniLmL6V2, None);
match res {
Ok(e) => {
assert_eq!(e.dim(), 384);
let desc = e.model_description();
assert!(desc.contains("MiniLM"));
}
Err(e) => {
let msg = e.to_string();
assert!(
msg.contains("model")
|| msg.contains("config")
|| msg.contains("tokenizer")
|| msg.contains("fallback")
|| msg.contains("HuggingFace"),
"unexpected new_local error: {msg}"
);
}
}
}
#[test]
fn pr9i_embedder_new_alias_is_new_local() {
let res = Embedder::new();
match res {
Ok(e) => {
assert_eq!(e.dim(), 384);
}
Err(e) => {
let msg = e.to_string();
assert!(!msg.is_empty());
}
}
}
}
#[test]
fn embedder_returns_unreachable_when_model_path_missing() {
let result = Embedder::load_from_fallback();
match result {
Ok(_) => {
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") || err_msg.contains("fallback"),
"error should mention missing model files: {err_msg}"
);
}
}
}
#[test]
fn load_from_fallback_succeeds_when_files_present() {
use std::sync::Mutex;
static LOCK: Mutex<()> = Mutex::new(());
let _guard = LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let tmp = std::env::temp_dir().join(format!("ai-memory-w12h-fallback-{}", std::process::id()));
let model_dir = tmp.join(
".cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main",
);
std::fs::create_dir_all(&model_dir).expect("mk model dir");
for name in ["config.json", "tokenizer.json", "model.safetensors"] {
std::fs::write(model_dir.join(name), b"{}").expect("write placeholder");
}
let prev = std::env::var("HOME").ok();
unsafe {
std::env::set_var("HOME", &tmp);
}
let result = Embedder::load_from_fallback();
unsafe {
match prev {
Some(p) => std::env::set_var("HOME", p),
None => std::env::remove_var("HOME"),
}
}
let _ = std::fs::remove_dir_all(&tmp);
let (cfg, tok, w) = result.expect("placeholder files satisfy load_from_fallback");
assert!(cfg.ends_with("config.json"));
assert!(tok.ends_with("tokenizer.json"));
assert!(w.ends_with("model.safetensors"));
}
#[test]
fn offline_env_skips_network_and_errors_fast_on_empty_cache() {
use std::sync::Mutex;
static LOCK: Mutex<()> = Mutex::new(());
let _guard = LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let tmp = std::env::temp_dir().join(format!(
"ai-memory-1501-offline-{}-{}",
std::process::id(),
uuid::Uuid::new_v4()
));
std::fs::create_dir_all(&tmp).expect("mk empty home");
let prev_home = std::env::var("HOME").ok();
let prev_off = std::env::var("AI_MEMORY_EMBED_OFFLINE").ok();
unsafe {
std::env::set_var("HOME", &tmp);
std::env::set_var("AI_MEMORY_EMBED_OFFLINE", "1");
}
assert!(
Embedder::remote_fetch_disabled(),
"offline knob must be honored"
);
let result = Embedder::new_local();
unsafe {
match prev_home {
Some(p) => std::env::set_var("HOME", p),
None => std::env::remove_var("HOME"),
}
match prev_off {
Some(v) => std::env::set_var("AI_MEMORY_EMBED_OFFLINE", v),
None => std::env::remove_var("AI_MEMORY_EMBED_OFFLINE"),
}
}
let _ = std::fs::remove_dir_all(&tmp);
let msg = match result {
Ok(_) => panic!("empty cache + offline must error (degrades to keyword)"),
Err(e) => e.to_string(),
};
assert!(
msg.contains("not found") || msg.contains("fallback"),
"offline empty-cache error should point at the fallback dir: {msg}"
);
}
#[cfg(test)]
#[allow(clippy::too_many_lines)]
mod c5_ollama_variant_tests {
use super::*;
use crate::llm::OllamaClient;
use serde_json::json;
use std::sync::Arc;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
async fn ollama_with_embed_response(embedding_dim: usize) -> (Arc<OllamaClient>, MockServer) {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let vec_of_floats: Vec<f32> = (0..embedding_dim).map(|i| (i as f32) * 0.001).collect();
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"embeddings": [vec_of_floats],
})))
.mount(&server)
.await;
let uri = server.uri();
let client = tokio::task::spawn_blocking(move || {
OllamaClient::new_with_url(&uri, "test-model").expect("ollama client builds")
})
.await
.expect("spawn blocking completes");
(Arc::new(client), server)
}
#[tokio::test(flavor = "multi_thread")]
async fn embedder_new_ollama_constructs_with_expected_model_name() {
let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
let embedder = Embedder::new_ollama(client);
assert!(matches!(embedder, Embedder::Ollama { .. }));
}
#[tokio::test(flavor = "multi_thread")]
async fn embedder_for_model_nomic_with_client_succeeds() {
let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
let embedder = tokio::task::spawn_blocking(move || {
Embedder::for_model(EmbeddingModel::NomicEmbedV15, Some(client))
.expect("for_model NomicEmbedV15 with ollama client")
})
.await
.unwrap();
assert!(matches!(embedder, Embedder::Ollama { .. }));
assert_eq!(embedder.dim(), NOMIC_DIM); let desc = embedder.model_description();
assert!(desc.contains("nomic")); assert!(desc.contains("768"));
}
#[tokio::test(flavor = "multi_thread")]
async fn embedder_ollama_embed_returns_vector_from_wiremock() {
let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
let embedder = Embedder::new_ollama(client);
let v = tokio::task::spawn_blocking(move || embedder.embed("hello"))
.await
.unwrap()
.expect("embed_text via wiremock");
assert_eq!(v.len(), NOMIC_DIM);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_with_status_skipped_on_empty_content() {
let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
let embedder = Embedder::new_ollama(client);
let (vec_opt, status) = embedder.embed_with_status("");
assert!(vec_opt.is_none());
assert!(matches!(status, EmbedStatus::Skipped(_)));
assert_eq!(status.as_str(), "skipped");
assert!(status.reason().contains("empty"));
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_with_status_skipped_on_oversized_content() {
let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
let embedder = Embedder::new_ollama(client);
let big = "a".repeat(EMBED_MAX_BYTES + 1);
let (vec_opt, status) = embedder.embed_with_status(&big);
assert!(vec_opt.is_none());
match status {
EmbedStatus::Skipped(r) => {
assert!(r.contains("exceeds embed cap"), "got: {r}");
}
other => panic!("expected Skipped, got: {other:?}"),
}
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_with_status_indexed_on_happy_path() {
let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
let embedder = Embedder::new_ollama(client);
let (vec_opt, status) =
tokio::task::spawn_blocking(move || embedder.embed_with_status("hello world"))
.await
.unwrap();
assert!(vec_opt.is_some());
assert_eq!(status, EmbedStatus::Indexed);
assert!(!status.is_degraded());
assert_eq!(vec_opt.unwrap().len(), NOMIC_DIM);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_with_status_failed_when_embedder_errors() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(500).set_body_string("server error"))
.mount(&server)
.await;
let uri = server.uri();
let embedder = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
Embedder::new_ollama(Arc::new(client))
})
.await
.unwrap();
let (vec_opt, status) =
tokio::task::spawn_blocking(move || embedder.embed_with_status("hello"))
.await
.unwrap();
assert!(vec_opt.is_none());
match status {
EmbedStatus::Failed(reason) => {
assert!(!reason.is_empty());
}
other => panic!("expected Failed(_), got {other:?}"),
}
}
#[test]
fn perf_5_embed_batch_empty_input_returns_empty_vec() {
use super::test_support::MockEmbedder;
let mock = MockEmbedder::new_local().expect("mock local");
let result = mock.embed_batch(&[]).expect("empty batch ok");
assert!(
result.is_empty(),
"PERF-5: empty input must yield empty output (got {} rows)",
result.len(),
);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_batch_via_inherent_impl_returns_one_vec_per_input() {
let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
let embedder = Embedder::new_ollama(client);
let vecs =
tokio::task::spawn_blocking(move || embedder.embed_batch(&["one", "two", "three"]))
.await
.unwrap()
.expect("batch embed succeeds");
assert_eq!(vecs.len(), 3);
for v in &vecs {
assert_eq!(v.len(), NOMIC_DIM);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_trait_for_embedder_delegates_to_inherent_impl() {
let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
let embedder = Embedder::new_ollama(client);
let embedder_box: Box<dyn Embed> = Box::new(embedder);
let single = tokio::task::spawn_blocking({
let e = embedder_box;
move || {
let single = e.embed("alpha").expect("single embed");
let batch = e.embed_batch(&["beta", "gamma"]).expect("batch embed");
(single, batch)
}
})
.await
.unwrap();
let (single, batch) = single;
assert_eq!(single.len(), NOMIC_DIM);
assert_eq!(batch.len(), 2);
for v in &batch {
assert_eq!(v.len(), NOMIC_DIM);
}
}
#[test]
fn embed_trait_default_batch_default_impl_runs_for_external_impls() {
struct ConstEmbedder;
impl Embed for ConstEmbedder {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![1.0_f32, 2.0_f32, 3.0_f32])
}
}
let e = ConstEmbedder;
let batch = e.embed_batch(&["a", "b"]).expect("default batch path");
assert_eq!(batch.len(), 2);
assert_eq!(batch[0], vec![1.0_f32, 2.0_f32, 3.0_f32]);
assert_eq!(batch[1], vec![1.0_f32, 2.0_f32, 3.0_f32]);
}
#[test]
fn download_within_times_out_on_stalled_closure() {
let start = std::time::Instant::now();
let res = Embedder::download_within(std::time::Duration::from_millis(50), || {
std::thread::sleep(std::time::Duration::from_secs(30));
Ok((
std::path::PathBuf::new(),
std::path::PathBuf::new(),
std::path::PathBuf::new(),
))
});
let elapsed = start.elapsed();
assert!(res.is_err(), "stalled download must error, not hang");
assert!(
res.unwrap_err().to_string().contains("budget"),
"error should explain the timeout budget"
);
assert!(
elapsed < std::time::Duration::from_secs(5),
"watchdog must return promptly after the budget, not wait for the closure: {elapsed:?}"
);
}
#[test]
fn download_within_passes_through_fast_result() {
let res = Embedder::download_within(std::time::Duration::from_secs(5), || {
Ok((
std::path::PathBuf::from("config.json"),
std::path::PathBuf::from("tokenizer.json"),
std::path::PathBuf::from("model.safetensors"),
))
})
.expect("fast closure must pass through");
assert_eq!(res.0, std::path::PathBuf::from("config.json"));
assert_eq!(res.2, std::path::PathBuf::from("model.safetensors"));
}
}