use crate::types::embedding::EmbeddingProvider;
use crate::{MemvidError, Result};
use ndarray::Array;
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Tensor;
use std::collections::hash_map::DefaultHasher;
use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokenizers::tokenizer::{Tokenizer, TruncationParams};
use tokenizers::{
PaddingDirection, PaddingParams, PaddingStrategy, TruncationDirection, TruncationStrategy,
};
#[cfg(target_os = "macos")]
mod stderr_suppress {
use std::fs::File;
use std::io;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
pub struct StderrSuppressor {
original_stderr: RawFd,
dev_null: File,
}
impl StderrSuppressor {
pub fn new() -> io::Result<Self> {
let dev_null = File::open("/dev/null")?;
let original_stderr = unsafe { libc::dup(2) };
if original_stderr == -1 {
return Err(io::Error::last_os_error());
}
let result = unsafe { libc::dup2(dev_null.as_raw_fd(), 2) };
if result == -1 {
unsafe { libc::close(original_stderr) };
return Err(io::Error::last_os_error());
}
Ok(Self {
original_stderr,
dev_null,
})
}
}
impl Drop for StderrSuppressor {
fn drop(&mut self) {
unsafe {
libc::dup2(self.original_stderr, 2);
libc::close(self.original_stderr);
}
let _ = &self.dev_null;
}
}
}
#[cfg(not(target_os = "macos"))]
mod stderr_suppress {
pub struct StderrSuppressor;
impl StderrSuppressor {
pub fn new() -> std::io::Result<Self> {
Ok(Self)
}
}
}
use once_cell::sync::Lazy;
static ORT_INIT: Lazy<()> = Lazy::new(|| {
let _stderr_guard = stderr_suppress::StderrSuppressor::new().ok();
let _ = Session::builder();
tracing::debug!("ONNX Runtime global environment initialized");
});
fn ensure_ort_init() {
Lazy::force(&ORT_INIT);
}
const MAX_SEQUENCE_LENGTH: usize = 512;
pub const MODEL_UNLOAD_TIMEOUT: Duration = Duration::from_secs(300);
const DEFAULT_CACHE_CAPACITY: usize = 1000;
#[derive(Debug, Clone)]
pub struct TextEmbedModelInfo {
pub name: &'static str,
pub model_url: &'static str,
pub tokenizer_url: &'static str,
pub dims: u32,
pub max_tokens: usize,
pub is_default: bool,
}
pub static TEXT_EMBED_MODELS: &[TextEmbedModelInfo] = &[
TextEmbedModelInfo {
name: "bge-small-en-v1.5",
model_url: "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx",
tokenizer_url: "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/tokenizer.json",
dims: 384,
max_tokens: 512,
is_default: true,
},
TextEmbedModelInfo {
name: "bge-base-en-v1.5",
model_url: "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx",
tokenizer_url: "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json",
dims: 768,
max_tokens: 512,
is_default: false,
},
TextEmbedModelInfo {
name: "nomic-embed-text-v1.5",
model_url: "https://huggingface.co/nomic-ai/nomic-embed-text-v1.5/resolve/main/onnx/model.onnx",
tokenizer_url: "https://huggingface.co/nomic-ai/nomic-embed-text-v1.5/resolve/main/tokenizer.json",
dims: 768,
max_tokens: 512,
is_default: false,
},
TextEmbedModelInfo {
name: "gte-large",
model_url: "https://huggingface.co/thenlper/gte-large/resolve/main/onnx/model.onnx",
tokenizer_url: "https://huggingface.co/thenlper/gte-large/resolve/main/tokenizer.json",
dims: 1024,
max_tokens: 512,
is_default: false,
},
];
#[must_use]
pub fn get_text_model_info(name: &str) -> &'static TextEmbedModelInfo {
TEXT_EMBED_MODELS
.iter()
.find(|m| m.name == name)
.unwrap_or_else(|| default_text_model_info())
}
#[must_use]
pub fn default_text_model_info() -> &'static TextEmbedModelInfo {
TEXT_EMBED_MODELS
.iter()
.find(|m| m.is_default)
.expect("No default text embedding model configured")
}
#[derive(Debug, Clone)]
pub struct TextEmbedConfig {
pub model_name: String,
pub models_dir: PathBuf,
pub offline: bool,
pub enable_cache: bool,
pub cache_capacity: usize,
}
impl Default for TextEmbedConfig {
fn default() -> Self {
let models_dir = dirs_next::cache_dir()
.map(|p| p.join("memvid").join("text-models"))
.unwrap_or_else(|| {
PathBuf::from(".memvid-cache/text-models")
});
Self {
model_name: default_text_model_info().name.to_string(),
models_dir,
offline: true, enable_cache: true, cache_capacity: DEFAULT_CACHE_CAPACITY,
}
}
}
impl TextEmbedConfig {
#[must_use]
pub fn bge_small() -> Self {
Self {
model_name: "bge-small-en-v1.5".to_string(),
..Default::default()
}
}
#[must_use]
pub fn bge_base() -> Self {
Self {
model_name: "bge-base-en-v1.5".to_string(),
..Default::default()
}
}
#[must_use]
pub fn nomic() -> Self {
Self {
model_name: "nomic-embed-text-v1.5".to_string(),
..Default::default()
}
}
#[must_use]
pub fn gte_large() -> Self {
Self {
model_name: "gte-large".to_string(),
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub size: usize,
pub capacity: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
struct EmbeddingCache {
cache: HashMap<u64, Vec<f32>>,
lru_queue: VecDeque<u64>,
capacity: usize,
hits: usize,
misses: usize,
}
impl EmbeddingCache {
fn new(capacity: usize) -> Self {
Self {
cache: HashMap::with_capacity(capacity),
lru_queue: VecDeque::with_capacity(capacity),
capacity,
hits: 0,
misses: 0,
}
}
fn get(&mut self, key: u64) -> Option<Vec<f32>> {
if let Some(embedding) = self.cache.get(&key) {
self.lru_queue.retain(|&k| k != key);
self.lru_queue.push_front(key);
self.hits += 1;
Some(embedding.clone())
} else {
self.misses += 1;
None
}
}
fn insert(&mut self, key: u64, value: Vec<f32>) {
if self.cache.contains_key(&key) {
self.cache.insert(key, value);
self.lru_queue.retain(|&k| k != key);
self.lru_queue.push_front(key);
return;
}
if self.cache.len() >= self.capacity {
if let Some(oldest_key) = self.lru_queue.pop_back() {
self.cache.remove(&oldest_key);
}
}
self.cache.insert(key, value);
self.lru_queue.push_front(key);
}
fn clear(&mut self) {
self.cache.clear();
self.lru_queue.clear();
self.hits = 0;
self.misses = 0;
}
fn stats(&self) -> CacheStats {
CacheStats {
hits: self.hits,
misses: self.misses,
size: self.cache.len(),
capacity: self.capacity,
}
}
}
pub struct LocalTextEmbedder {
config: TextEmbedConfig,
model_info: &'static TextEmbedModelInfo,
session: Mutex<Option<Session>>,
tokenizer: Mutex<Option<Tokenizer>>,
last_used: Mutex<Instant>,
cache: Mutex<Option<EmbeddingCache>>,
}
impl LocalTextEmbedder {
pub fn new(config: TextEmbedConfig) -> Result<Self> {
let model_info = get_text_model_info(&config.model_name);
let cache = if config.enable_cache {
Some(EmbeddingCache::new(config.cache_capacity))
} else {
None
};
Ok(Self {
config,
model_info,
session: Mutex::new(None),
tokenizer: Mutex::new(None),
last_used: Mutex::new(Instant::now()),
cache: Mutex::new(cache),
})
}
#[must_use]
pub fn model_info(&self) -> &'static TextEmbedModelInfo {
self.model_info
}
fn ensure_model_file(&self) -> Result<PathBuf> {
let filename = format!("{}.onnx", self.model_info.name);
let path = self.config.models_dir.join(&filename);
if path.exists() {
return Ok(path);
}
Err(MemvidError::EmbeddingFailed {
reason: format!(
"Text embedding model not found at {}. Please download manually:\n\
mkdir -p {}\n\
curl -L '{}' -o '{}'",
path.display(),
self.config.models_dir.display(),
self.model_info.model_url,
path.display()
)
.into(),
})
}
fn ensure_tokenizer_file(&self) -> Result<PathBuf> {
let filename = format!("{}_tokenizer.json", self.model_info.name);
let path = self.config.models_dir.join(&filename);
if path.exists() {
return Ok(path);
}
Err(MemvidError::EmbeddingFailed {
reason: format!(
"Tokenizer not found at {}. Please download manually:\n\
curl -L '{}' -o '{}'",
path.display(),
self.model_info.tokenizer_url,
path.display()
)
.into(),
})
}
fn load_session(&self) -> Result<()> {
ensure_ort_init();
let mut session_guard = self
.session
.lock()
.map_err(|_| MemvidError::Lock("Failed to lock text embed session".into()))?;
if session_guard.is_some() {
return Ok(());
}
let model_path = self.ensure_model_file()?;
tracing::debug!(path = %model_path.display(), "Loading text embedding model");
let _stderr_guard = stderr_suppress::StderrSuppressor::new().ok();
let session = Session::builder()
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to create session builder: {}", e).into(),
})?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to set optimization level: {}", e).into(),
})?
.with_intra_threads(4)
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to set intra threads: {}", e).into(),
})?
.commit_from_file(&model_path)
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to load text embedding model: {}", e).into(),
})?;
*session_guard = Some(session);
tracing::info!(model = %self.model_info.name, "Text embedding model loaded");
Ok(())
}
fn load_tokenizer(&self) -> Result<()> {
let mut tokenizer_guard = self
.tokenizer
.lock()
.map_err(|_| MemvidError::Lock("Failed to lock tokenizer".into()))?;
if tokenizer_guard.is_some() {
return Ok(());
}
let tokenizer_path = self.ensure_tokenizer_file()?;
tracing::debug!(path = %tokenizer_path.display(), "Loading tokenizer");
let mut tokenizer =
Tokenizer::from_file(&tokenizer_path).map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to load tokenizer: {}", e).into(),
})?;
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(MAX_SEQUENCE_LENGTH),
direction: PaddingDirection::Right,
pad_to_multiple_of: None,
pad_id: 0,
pad_type_id: 0,
pad_token: "[PAD]".to_string(),
}));
tokenizer
.with_truncation(Some(TruncationParams {
max_length: MAX_SEQUENCE_LENGTH,
strategy: TruncationStrategy::LongestFirst,
stride: 0,
direction: TruncationDirection::Right,
}))
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to apply truncation config: {}", e).into(),
})?;
*tokenizer_guard = Some(tokenizer);
tracing::info!(model = %self.model_info.name, "Tokenizer loaded");
Ok(())
}
fn cache_key(text: &str) -> u64 {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
pub fn encode_text(&self, text: &str) -> Result<Vec<f32>> {
if let Ok(mut cache_guard) = self.cache.lock() {
if let Some(ref mut cache) = *cache_guard {
let key = Self::cache_key(text);
if let Some(embedding) = cache.get(key) {
tracing::debug!(text_len = text.len(), "Cache hit");
return Ok(embedding);
}
tracing::debug!(text_len = text.len(), "Cache miss");
}
}
let _stderr_guard = stderr_suppress::StderrSuppressor::new().ok();
self.load_session()?;
self.load_tokenizer()?;
let encoding = {
let tokenizer_guard = self
.tokenizer
.lock()
.map_err(|_| MemvidError::Lock("Failed to lock tokenizer".into()))?;
let tokenizer =
tokenizer_guard
.as_ref()
.ok_or_else(|| MemvidError::EmbeddingFailed {
reason: "Tokenizer not loaded".into(),
})?;
tokenizer
.encode(text, true)
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Text tokenization failed: {}", e).into(),
})?
};
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|id| *id as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|id| *id as i64)
.collect();
let token_type_ids: Vec<i64> = encoding
.get_type_ids()
.iter()
.map(|id| *id as i64)
.collect();
let max_length = input_ids.len();
let input_ids_array = Array::from_shape_vec((1, max_length), input_ids).map_err(|e| {
MemvidError::EmbeddingFailed {
reason: format!("Failed to create input_ids array: {}", e).into(),
}
})?;
let attention_mask_array =
Array::from_shape_vec((1, max_length), attention_mask).map_err(|e| {
MemvidError::EmbeddingFailed {
reason: format!("Failed to create attention_mask array: {}", e).into(),
}
})?;
let token_type_ids_array =
Array::from_shape_vec((1, max_length), token_type_ids).map_err(|e| {
MemvidError::EmbeddingFailed {
reason: format!("Failed to create token_type_ids array: {}", e).into(),
}
})?;
if let Ok(mut last) = self.last_used.lock() {
*last = Instant::now();
}
let mut session_guard = self
.session
.lock()
.map_err(|_| MemvidError::Lock("Failed to lock session".into()))?;
let session = session_guard
.as_mut()
.ok_or_else(|| MemvidError::EmbeddingFailed {
reason: "Session not loaded".into(),
})?;
let input_names: Vec<String> = session.inputs.iter().map(|i| i.name.clone()).collect();
let output_name = session
.outputs
.first()
.map(|o| o.name.clone())
.unwrap_or_else(|| "last_hidden_state".to_string());
let input_ids_tensor =
Tensor::from_array(input_ids_array).map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to create input_ids tensor: {}", e).into(),
})?;
let attention_mask_tensor =
Tensor::from_array(attention_mask_array).map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to create attention_mask tensor: {}", e).into(),
})?;
let token_type_ids_tensor =
Tensor::from_array(token_type_ids_array).map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to create token_type_ids tensor: {}", e).into(),
})?;
let _stderr_guard = stderr_suppress::StderrSuppressor::new().ok();
let outputs = if input_names.len() >= 3 {
session
.run(ort::inputs![
input_names[0].clone() => input_ids_tensor,
input_names[1].clone() => attention_mask_tensor,
input_names[2].clone() => token_type_ids_tensor
])
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Text inference failed: {}", e).into(),
})?
} else if input_names.len() >= 2 {
session
.run(ort::inputs![
input_names[0].clone() => input_ids_tensor,
input_names[1].clone() => attention_mask_tensor
])
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Text inference failed: {}", e).into(),
})?
} else {
let name = input_names
.first()
.cloned()
.unwrap_or_else(|| "input_ids".to_string());
session
.run(ort::inputs![name => input_ids_tensor])
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Text inference failed: {}", e).into(),
})?
};
let output = outputs
.get(&output_name)
.ok_or_else(|| MemvidError::EmbeddingFailed {
reason: format!("No output '{}' from model", output_name).into(),
})?;
let (_shape, data) =
output
.try_extract_tensor::<f32>()
.map_err(|e| MemvidError::EmbeddingFailed {
reason: format!("Failed to extract embeddings: {}", e).into(),
})?;
let embedding_dim = self.model_info.dims as usize;
let embedding: Vec<f32> = data.iter().take(embedding_dim).copied().collect();
if embedding.iter().any(|v| !v.is_finite()) {
return Err(MemvidError::EmbeddingFailed {
reason: "Text embedding contains non-finite values".into(),
});
}
let normalized = l2_normalize(&embedding);
tracing::debug!(
text_len = text.len(),
dims = normalized.len(),
"Generated text embedding"
);
if let Ok(mut cache_guard) = self.cache.lock() {
if let Some(ref mut cache) = *cache_guard {
let key = Self::cache_key(text);
cache.insert(key, normalized.clone());
}
}
Ok(normalized)
}
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
embeddings.push(self.encode_text(text)?);
}
Ok(embeddings)
}
pub fn cache_stats(&self) -> Option<CacheStats> {
if let Ok(cache_guard) = self.cache.lock() {
cache_guard.as_ref().map(|cache| cache.stats())
} else {
None
}
}
pub fn clear_cache(&self) -> Result<()> {
if let Ok(mut cache_guard) = self.cache.lock() {
if let Some(ref mut cache) = *cache_guard {
cache.clear();
tracing::debug!("Embedding cache cleared");
}
}
Ok(())
}
pub fn is_loaded(&self) -> bool {
self.session.lock().map(|g| g.is_some()).unwrap_or(false)
}
pub fn maybe_unload(&self) -> Result<()> {
let last_used = self
.last_used
.lock()
.map_err(|_| MemvidError::Lock("Failed to check last_used".into()))?;
if last_used.elapsed() > MODEL_UNLOAD_TIMEOUT {
tracing::debug!(model = %self.model_info.name, "Model idle, unloading");
if let Ok(mut guard) = self.session.lock() {
*guard = None;
}
if let Ok(mut guard) = self.tokenizer.lock() {
*guard = None;
}
}
Ok(())
}
pub fn unload(&self) -> Result<()> {
if let Ok(mut guard) = self.session.lock() {
*guard = None;
}
if let Ok(mut guard) = self.tokenizer.lock() {
*guard = None;
}
tracing::debug!(model = %self.model_info.name, "Text embedding model unloaded");
Ok(())
}
}
impl EmbeddingProvider for LocalTextEmbedder {
fn kind(&self) -> &str {
"local"
}
fn model(&self) -> &str {
self.model_info.name
}
fn dimension(&self) -> usize {
self.model_info.dims as usize
}
fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
self.encode_text(text)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.encode_batch(texts)
}
fn is_ready(&self) -> bool {
true
}
fn init(&mut self) -> Result<()> {
Ok(())
}
}
fn l2_normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm.is_finite() && norm > 1e-10 {
v.iter().map(|x| x / norm).collect()
} else {
vec![0.0; v.len()]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_registry() {
assert_eq!(TEXT_EMBED_MODELS.len(), 4);
let default_model = default_text_model_info();
assert_eq!(default_model.name, "bge-small-en-v1.5");
assert_eq!(default_model.dims, 384);
assert!(default_model.is_default);
}
#[test]
fn test_get_model_info() {
let bge_small = get_text_model_info("bge-small-en-v1.5");
assert_eq!(bge_small.dims, 384);
let bge_base = get_text_model_info("bge-base-en-v1.5");
assert_eq!(bge_base.dims, 768);
let nomic = get_text_model_info("nomic-embed-text-v1.5");
assert_eq!(nomic.dims, 768);
let gte = get_text_model_info("gte-large");
assert_eq!(gte.dims, 1024);
let unknown = get_text_model_info("unknown-model");
assert_eq!(unknown.name, "bge-small-en-v1.5");
}
#[test]
fn test_config_defaults() {
let config = TextEmbedConfig::default();
assert_eq!(config.model_name, "bge-small-en-v1.5");
assert!(config.offline);
let bge_small = TextEmbedConfig::bge_small();
assert_eq!(bge_small.model_name, "bge-small-en-v1.5");
let bge_base = TextEmbedConfig::bge_base();
assert_eq!(bge_base.model_name, "bge-base-en-v1.5");
let nomic = TextEmbedConfig::nomic();
assert_eq!(nomic.model_name, "nomic-embed-text-v1.5");
let gte = TextEmbedConfig::gte_large();
assert_eq!(gte.model_name, "gte-large");
}
#[test]
fn test_l2_normalize() {
let v = vec![3.0, 4.0];
let normalized = l2_normalize(&v);
assert_eq!(normalized.len(), 2);
assert!((normalized[0] - 0.6).abs() < 1e-6);
assert!((normalized[1] - 0.8).abs() < 1e-6);
let zero = vec![0.0, 0.0];
let normalized_zero = l2_normalize(&zero);
assert_eq!(normalized_zero, vec![0.0, 0.0]);
}
#[test]
fn test_embed_provider_trait() {
let config = TextEmbedConfig::default();
let embedder = LocalTextEmbedder::new(config).unwrap();
assert_eq!(embedder.kind(), "local");
assert_eq!(embedder.model(), "bge-small-en-v1.5");
assert_eq!(embedder.dimension(), 384);
assert!(embedder.is_ready());
}
#[test]
fn test_cache_enabled_by_default() {
let config = TextEmbedConfig::default();
assert!(config.enable_cache);
assert_eq!(config.cache_capacity, 1000);
let embedder = LocalTextEmbedder::new(config).unwrap();
assert!(embedder.cache_stats().is_some());
}
#[test]
fn test_cache_can_be_disabled() {
let config = TextEmbedConfig {
enable_cache: false,
..Default::default()
};
let embedder = LocalTextEmbedder::new(config).unwrap();
assert!(embedder.cache_stats().is_none());
}
#[test]
fn test_cache_basic_operations() {
let mut cache = EmbeddingCache::new(10);
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.size, 0);
cache.insert(1, vec![1.0, 2.0, 3.0]);
assert_eq!(cache.stats().size, 1);
let result = cache.get(1);
assert!(result.is_some());
assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]);
assert_eq!(cache.stats().hits, 1);
assert_eq!(cache.stats().misses, 0);
let result = cache.get(999);
assert!(result.is_none());
assert_eq!(cache.stats().hits, 1);
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_cache_lru_eviction() {
let mut cache = EmbeddingCache::new(3);
cache.insert(1, vec![1.0]);
cache.insert(2, vec![2.0]);
cache.insert(3, vec![3.0]);
assert_eq!(cache.stats().size, 3);
let _ = cache.get(1);
cache.insert(4, vec![4.0]);
assert_eq!(cache.stats().size, 3);
assert!(cache.get(1).is_some());
assert!(cache.get(3).is_some());
assert!(cache.get(2).is_none());
assert!(cache.get(4).is_some());
}
#[test]
fn test_cache_clear() {
let mut cache = EmbeddingCache::new(10);
cache.insert(1, vec![1.0]);
cache.insert(2, vec![2.0]);
let _ = cache.get(1); let _ = cache.get(999);
assert_eq!(cache.stats().size, 2);
assert_eq!(cache.stats().hits, 1);
assert_eq!(cache.stats().misses, 1);
cache.clear();
assert_eq!(cache.stats().size, 0);
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
}
#[test]
fn test_cache_stats_hit_rate() {
let stats = CacheStats {
hits: 7,
misses: 3,
size: 5,
capacity: 10,
};
assert_eq!(stats.hit_rate(), 0.7);
let stats_zero = CacheStats {
hits: 0,
misses: 0,
size: 0,
capacity: 10,
};
assert_eq!(stats_zero.hit_rate(), 0.0);
}
#[test]
fn test_cache_key_consistency() {
let key1 = LocalTextEmbedder::cache_key("hello world");
let key2 = LocalTextEmbedder::cache_key("hello world");
assert_eq!(key1, key2);
let key3 = LocalTextEmbedder::cache_key("goodbye world");
assert_ne!(key1, key3);
}
#[test]
#[ignore] fn test_cache_integration() {
let config = TextEmbedConfig {
enable_cache: true,
cache_capacity: 100,
..Default::default()
};
let embedder = LocalTextEmbedder::new(config).unwrap();
let text = "test embedding";
let _ = embedder.encode_text(text).unwrap();
let stats1 = embedder.cache_stats().unwrap();
assert_eq!(stats1.misses, 1);
assert_eq!(stats1.hits, 0);
assert_eq!(stats1.size, 1);
let _ = embedder.encode_text(text).unwrap();
let stats2 = embedder.cache_stats().unwrap();
assert_eq!(stats2.misses, 1); assert_eq!(stats2.hits, 1); assert_eq!(stats2.size, 1);
embedder.clear_cache().unwrap();
let stats3 = embedder.cache_stats().unwrap();
assert_eq!(stats3.size, 0);
assert_eq!(stats3.hits, 0);
assert_eq!(stats3.misses, 0);
}
}