use super::*;
use candle_core::DType;
use std::sync::atomic::{AtomicUsize, Ordering};
static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
struct MockProvider {
dims: usize,
call_count: AtomicUsize,
}
impl MockProvider {
fn new(dims: usize) -> Self {
Self {
dims,
call_count: AtomicUsize::new(0),
}
}
}
impl EmbeddingProvider for MockProvider {
fn dimensions(&self) -> usize {
self.dims
}
fn embed(&self, _text: &str) -> Result<Vec<f32>, CodememError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(vec![0.1; self.dims])
}
fn name(&self) -> &str {
"mock"
}
}
struct BatchTrackingMock {
dims: usize,
single_calls: AtomicUsize,
batch_calls: AtomicUsize,
}
impl BatchTrackingMock {
fn new(dims: usize) -> Self {
Self {
dims,
single_calls: AtomicUsize::new(0),
batch_calls: AtomicUsize::new(0),
}
}
}
impl EmbeddingProvider for BatchTrackingMock {
fn dimensions(&self) -> usize {
self.dims
}
fn embed(&self, _text: &str) -> Result<Vec<f32>, CodememError> {
self.single_calls.fetch_add(1, Ordering::SeqCst);
Ok(vec![0.1; self.dims])
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
self.batch_calls.fetch_add(1, Ordering::SeqCst);
Ok(texts.iter().map(|_| vec![0.1; self.dims]).collect())
}
fn name(&self) -> &str {
"batch-mock"
}
}
struct FailingMock;
impl EmbeddingProvider for FailingMock {
fn dimensions(&self) -> usize {
768
}
fn embed(&self, _text: &str) -> Result<Vec<f32>, CodememError> {
Err(CodememError::Embedding("mock failure".into()))
}
fn name(&self) -> &str {
"failing-mock"
}
}
#[test]
fn cached_provider_cache_hit() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
let v1 = provider.embed("hello").unwrap();
assert_eq!(v1.len(), 4);
let v2 = provider.embed("hello").unwrap();
assert_eq!(v1, v2);
let (size, cap) = provider.cache_stats();
assert_eq!(size, 1);
assert_eq!(cap, 100);
}
#[test]
fn cached_provider_cache_miss() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
provider.embed("hello").unwrap();
provider.embed("world").unwrap();
let (size, _) = provider.cache_stats();
assert_eq!(size, 2);
}
#[test]
fn cached_provider_batch_empty() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
let result = provider.embed_batch(&[]).unwrap();
assert!(result.is_empty());
}
#[test]
fn cached_provider_batch_single() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
let result = provider.embed_batch(&["hello"]).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 4);
let (size, _) = provider.cache_stats();
assert_eq!(size, 1);
}
#[test]
fn cached_provider_batch_mixed_cache() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
provider.embed("hello").unwrap();
let result = provider.embed_batch(&["hello", "world"]).unwrap();
assert_eq!(result.len(), 2);
let (size, _) = provider.cache_stats();
assert_eq!(size, 2);
}
#[test]
fn cached_provider_batch_all_cached() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
provider.embed("hello").unwrap();
provider.embed("world").unwrap();
let result = provider.embed_batch(&["hello", "world"]).unwrap();
assert_eq!(result.len(), 2);
let (size, _) = provider.cache_stats();
assert_eq!(size, 2);
}
#[test]
fn cached_provider_batch_delegates_to_inner_batch() {
let mock = BatchTrackingMock::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
let result = provider.embed_batch(&["a", "b", "c"]).unwrap();
assert_eq!(result.len(), 3);
let (size, _) = provider.cache_stats();
assert_eq!(size, 3);
}
#[test]
fn cached_provider_zero_capacity() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 0);
provider.embed("a").unwrap();
provider.embed("b").unwrap();
let (size, cap) = provider.cache_stats();
assert_eq!(cap, 1);
assert_eq!(size, 1);
}
#[test]
fn cached_provider_name_delegates() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
assert_eq!(provider.name(), "mock");
}
#[test]
fn cached_provider_dimensions_delegates() {
let mock = MockProvider::new(768);
let provider = CachedProvider::new(Box::new(mock), 100);
assert_eq!(provider.dimensions(), 768);
}
#[test]
fn cached_provider_inner_error_propagates() {
let provider = CachedProvider::new(Box::new(FailingMock), 100);
let result = provider.embed("test");
assert!(result.is_err());
let err = result.err().unwrap().to_string();
assert!(
err.contains("mock failure"),
"Should propagate inner error: {err}"
);
}
#[test]
fn cached_provider_inner_batch_error_propagates() {
let provider = CachedProvider::new(Box::new(FailingMock), 100);
let result = provider.embed_batch(&["test"]);
assert!(result.is_err());
}
#[test]
fn cached_provider_evicts_lru() {
let mock = MockProvider::new(2);
let provider = CachedProvider::new(Box::new(mock), 2);
provider.embed("a").unwrap();
provider.embed("b").unwrap();
provider.embed("c").unwrap();
let (size, cap) = provider.cache_stats();
assert_eq!(size, 2);
assert_eq!(cap, 2);
}
#[test]
fn cached_provider_cache_hit_avoids_inner_call() {
let mock = MockProvider::new(4);
use std::sync::Arc;
let call_count = Arc::new(AtomicUsize::new(0));
let count_clone = call_count.clone();
struct CountingMock {
dims: usize,
count: Arc<AtomicUsize>,
}
impl EmbeddingProvider for CountingMock {
fn dimensions(&self) -> usize {
self.dims
}
fn embed(&self, _text: &str) -> Result<Vec<f32>, CodememError> {
self.count.fetch_add(1, Ordering::SeqCst);
Ok(vec![0.1; self.dims])
}
fn name(&self) -> &str {
"counting"
}
}
let _ = mock; let provider = CachedProvider::new(
Box::new(CountingMock {
dims: 4,
count: count_clone,
}),
100,
);
provider.embed("hello").unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 1);
provider.embed("hello").unwrap();
assert_eq!(
call_count.load(Ordering::SeqCst),
1,
"Second call should be cached"
);
provider.embed("world").unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[test]
fn from_env_unknown_provider() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "nonexistent_provider_xyz") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
match result {
Err(e) => {
let err = e.to_string();
assert!(
err.contains("Unknown embedding provider"),
"Error should mention unknown provider: {err}"
);
}
Ok(_) => panic!("Expected error for unknown provider"),
}
}
#[test]
fn embedding_service_missing_model() {
match EmbeddingService::new(Path::new("/nonexistent/path"), 16, DType::F32) {
Err(e) => {
let err = e.to_string();
assert!(
err.contains("Model not found"),
"Error should mention missing model: {err}"
);
}
Ok(_) => panic!("Expected error for missing model"),
}
}
#[test]
fn default_model_dir_path() {
let dir = EmbeddingService::default_model_dir();
assert!(dir.to_string_lossy().contains(MODEL_NAME));
assert!(dir.to_string_lossy().contains(".codemem"));
}
#[test]
fn constants_are_sensible() {
assert_eq!(DEFAULT_REMOTE_DIMENSIONS, 768);
assert_eq!(CACHE_CAPACITY, 10_000);
assert_eq!(DEFAULT_BATCH_SIZE, 16);
assert_eq!(MODEL_NAME, "bge-base-en-v1.5");
assert_eq!(DEFAULT_HF_REPO, "BAAI/bge-base-en-v1.5");
}
#[test]
fn parse_dtype_f32() {
assert!(matches!(parse_dtype("f32").unwrap(), DType::F32));
assert!(matches!(parse_dtype("float32").unwrap(), DType::F32));
assert!(matches!(parse_dtype("").unwrap(), DType::F16));
}
#[test]
fn parse_dtype_f16() {
assert!(matches!(parse_dtype("f16").unwrap(), DType::F16));
assert!(matches!(parse_dtype("float16").unwrap(), DType::F16));
assert!(matches!(parse_dtype("half").unwrap(), DType::F16));
}
#[test]
fn parse_dtype_bf16() {
assert!(matches!(parse_dtype("bf16").unwrap(), DType::BF16));
assert!(matches!(parse_dtype("bfloat16").unwrap(), DType::BF16));
}
#[test]
fn parse_dtype_case_insensitive() {
assert!(matches!(parse_dtype("F16").unwrap(), DType::F16));
assert!(matches!(parse_dtype("F32").unwrap(), DType::F32));
assert!(matches!(parse_dtype("BF16").unwrap(), DType::BF16));
}
#[test]
fn parse_dtype_unknown() {
let err = parse_dtype("int8").unwrap_err().to_string();
assert!(err.contains("Unknown dtype"), "Error: {err}");
}
#[test]
fn resolve_model_id_full_repo() {
let (repo, dir) = resolve_model_id("BAAI/bge-base-en-v1.5").unwrap();
assert_eq!(repo, "BAAI/bge-base-en-v1.5");
assert_eq!(dir, "bge-base-en-v1.5");
}
#[test]
fn resolve_model_id_short_bge() {
let (repo, dir) = resolve_model_id("bge-small-en-v1.5").unwrap();
assert_eq!(repo, "BAAI/bge-small-en-v1.5");
assert_eq!(dir, "bge-small-en-v1.5");
}
#[test]
fn resolve_model_id_other_repo() {
let (repo, dir) = resolve_model_id("sentence-transformers/all-MiniLM-L6-v2").unwrap();
assert_eq!(repo, "sentence-transformers/all-MiniLM-L6-v2");
assert_eq!(dir, "all-MiniLM-L6-v2");
}
#[test]
fn resolve_model_id_bare_name_rejected() {
let err = resolve_model_id("my-custom-model").unwrap_err().to_string();
assert!(
err.contains("must be a full HuggingFace repo ID"),
"Error: {err}"
);
}
#[test]
fn model_dir_for_custom_model() {
let dir = EmbeddingService::model_dir_for("bge-small-en-v1.5");
assert!(dir.to_string_lossy().contains("bge-small-en-v1.5"));
assert!(dir.to_string_lossy().contains(".codemem"));
}
#[test]
fn from_env_batch_size_env_var() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "ollama") };
unsafe { std::env::set_var("CODEMEM_EMBED_BATCH_SIZE", "8") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
unsafe { std::env::remove_var("CODEMEM_EMBED_BATCH_SIZE") };
assert!(result.is_ok());
}
#[test]
fn from_env_candle_bare_name_rejected() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "candle") };
unsafe { std::env::set_var("CODEMEM_EMBED_MODEL", "my-custom-model") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
unsafe { std::env::remove_var("CODEMEM_EMBED_MODEL") };
match result {
Err(e) => {
let err = e.to_string();
assert!(
err.contains("must be a full HuggingFace repo ID"),
"Error should explain the format requirement: {err}"
);
}
Ok(_) => panic!("Expected error for bare model name without org/"),
}
}
#[test]
fn from_env_ollama_provider() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "ollama") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
let provider = result.expect("from_env should succeed for ollama");
assert_eq!(provider.name(), "ollama");
}
#[test]
fn from_env_openai_provider() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "openai") };
unsafe { std::env::set_var("OPENAI_API_KEY", "test-key-123") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
unsafe { std::env::remove_var("OPENAI_API_KEY") };
let provider = result.expect("from_env should succeed for openai");
assert_eq!(provider.name(), "openai");
}
#[test]
fn from_env_openai_missing_api_key() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "openai") };
unsafe { std::env::remove_var("CODEMEM_EMBED_API_KEY") };
unsafe { std::env::remove_var("OPENAI_API_KEY") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
assert!(result.is_err());
let err = result.err().unwrap().to_string();
assert!(
err.contains("API_KEY"),
"Should mention API key requirement: {err}"
);
}
#[test]
fn from_env_gemini_provider() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "gemini") };
unsafe { std::env::set_var("CODEMEM_EMBED_API_KEY", "test-gemini-key") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
unsafe { std::env::remove_var("CODEMEM_EMBED_API_KEY") };
let provider = result.expect("from_env should succeed for gemini");
assert_eq!(provider.name(), "gemini");
}
#[test]
fn from_env_google_alias() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "google") };
unsafe { std::env::set_var("CODEMEM_EMBED_API_KEY", "test-google-key") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
unsafe { std::env::remove_var("CODEMEM_EMBED_API_KEY") };
let provider = result.expect("'google' alias should create gemini provider");
assert_eq!(provider.name(), "gemini");
}
#[test]
fn from_env_gemini_missing_api_key() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "gemini") };
unsafe { std::env::remove_var("CODEMEM_EMBED_API_KEY") };
unsafe { std::env::remove_var("GEMINI_API_KEY") };
unsafe { std::env::remove_var("GOOGLE_API_KEY") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
assert!(result.is_err(), "gemini without API key should fail");
}
#[test]
fn from_env_with_config_ollama_url() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "ollama") };
unsafe { std::env::remove_var("CODEMEM_EMBED_URL") };
unsafe { std::env::remove_var("CODEMEM_EMBED_MODEL") };
let config = codemem_core::EmbeddingConfig {
provider: "ollama".to_string(),
url: "http://custom:11434".to_string(),
model: "custom-model".to_string(),
dimensions: 512,
cache_capacity: 5000,
..Default::default()
};
let result = from_env(Some(&config));
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
let provider = result.expect("from_env with config should succeed");
assert_eq!(provider.name(), "ollama");
assert_eq!(provider.dimensions(), 512);
}
#[test]
fn from_env_env_var_overrides_config() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "ollama") };
unsafe { std::env::set_var("CODEMEM_EMBED_DIMENSIONS", "256") };
let config = codemem_core::EmbeddingConfig {
provider: "candle".to_string(),
dimensions: 512,
..Default::default()
};
let result = from_env(Some(&config));
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
unsafe { std::env::remove_var("CODEMEM_EMBED_DIMENSIONS") };
let provider = result.expect("from_env should succeed");
assert_eq!(provider.name(), "ollama");
assert_eq!(provider.dimensions(), 256);
}
#[test]
fn from_env_openai_with_custom_api_key_env() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "openai") };
unsafe { std::env::set_var("CODEMEM_EMBED_API_KEY", "custom-key") };
unsafe { std::env::remove_var("OPENAI_API_KEY") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
unsafe { std::env::remove_var("CODEMEM_EMBED_API_KEY") };
let provider = result.expect("Should use CODEMEM_EMBED_API_KEY");
assert_eq!(provider.name(), "openai");
}
#[test]
fn from_env_empty_string_treated_as_candle() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
match result {
Ok(p) => assert_eq!(p.name(), "candle"),
Err(e) => {
let err = e.to_string();
assert!(
err.contains("Model not found") || err.contains("model"),
"Should be a candle model error, not unknown provider: {err}"
);
}
}
}
#[test]
fn from_env_case_insensitive() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("CODEMEM_EMBED_PROVIDER", "OLLAMA") };
let result = from_env(None);
unsafe { std::env::remove_var("CODEMEM_EMBED_PROVIDER") };
let provider = result.expect("Provider name should be case-insensitive");
assert_eq!(provider.name(), "ollama");
}
#[test]
fn cached_provider_concurrent_embed_no_panic() {
use std::sync::Arc;
use std::thread;
let mock = MockProvider::new(4);
let provider = Arc::new(CachedProvider::new(Box::new(mock), 100));
let handles: Vec<_> = (0..10)
.map(|i| {
let p = Arc::clone(&provider);
thread::spawn(move || {
let text = format!("text_{}", i);
let result = p.embed(&text);
assert!(result.is_ok(), "Thread {i} should not panic or error");
let embedding = result.unwrap();
assert_eq!(embedding.len(), 4);
})
})
.collect();
for handle in handles {
handle.join().expect("Thread should not panic");
}
let (size, _) = provider.cache_stats();
assert_eq!(size, 10);
}
#[test]
fn cached_provider_concurrent_embed_same_key() {
use std::sync::Arc;
use std::thread;
let call_count = Arc::new(AtomicUsize::new(0));
let count_clone = call_count.clone();
struct SlowCountingMock {
dims: usize,
count: Arc<AtomicUsize>,
}
impl EmbeddingProvider for SlowCountingMock {
fn dimensions(&self) -> usize {
self.dims
}
fn embed(&self, _text: &str) -> Result<Vec<f32>, CodememError> {
self.count.fetch_add(1, Ordering::SeqCst);
std::thread::sleep(std::time::Duration::from_millis(1));
Ok(vec![0.42; self.dims])
}
fn name(&self) -> &str {
"slow-counting"
}
}
let provider = Arc::new(CachedProvider::new(
Box::new(SlowCountingMock {
dims: 4,
count: count_clone,
}),
100,
));
let handles: Vec<_> = (0..10)
.map(|_| {
let p = Arc::clone(&provider);
thread::spawn(move || {
let result = p.embed("same_key");
assert!(result.is_ok());
let embedding = result.unwrap();
assert_eq!(embedding, vec![0.42; 4]);
})
})
.collect();
for handle in handles {
handle.join().expect("Thread should not panic");
}
let (size, _) = provider.cache_stats();
assert_eq!(size, 1);
}
#[test]
fn cached_provider_concurrent_embed_batch_no_corruption() {
use std::sync::Arc;
use std::thread;
struct DistinctMock {
dims: usize,
}
impl EmbeddingProvider for DistinctMock {
fn dimensions(&self) -> usize {
self.dims
}
fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
let val = text.as_bytes().first().copied().unwrap_or(0) as f32 / 255.0;
Ok(vec![val; self.dims])
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
texts.iter().map(|t| self.embed(t)).collect()
}
fn name(&self) -> &str {
"distinct"
}
}
let provider = Arc::new(CachedProvider::new(
Box::new(DistinctMock { dims: 4 }),
1000,
));
let handles: Vec<_> = (0..5)
.map(|i| {
let p = Arc::clone(&provider);
thread::spawn(move || {
let texts: Vec<String> = (0..3).map(|j| format!("t{}_{}", i, j)).collect();
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
let result = p.embed_batch(&text_refs);
assert!(result.is_ok(), "Thread {i} batch should succeed");
let embeddings = result.unwrap();
assert_eq!(embeddings.len(), 3, "Thread {i} should get 3 embeddings");
for emb in &embeddings {
assert_eq!(emb.len(), 4);
}
})
})
.collect();
for handle in handles {
handle.join().expect("Thread should not panic");
}
let (size, _) = provider.cache_stats();
assert_eq!(size, 15);
}
#[test]
fn cached_provider_concurrent_reads_and_writes() {
use std::sync::Arc;
use std::thread;
let mock = MockProvider::new(4);
let provider = Arc::new(CachedProvider::new(Box::new(mock), 100));
for i in 0..5 {
provider.embed(&format!("pre_{}", i)).unwrap();
}
let handles: Vec<_> = (0..10)
.map(|i| {
let p = Arc::clone(&provider);
thread::spawn(move || {
if i % 2 == 0 {
let result = p.embed(&format!("pre_{}", i % 5));
assert!(result.is_ok());
} else {
let result = p.embed(&format!("new_{}", i));
assert!(result.is_ok());
}
})
})
.collect();
for handle in handles {
handle.join().expect("Thread should not panic");
}
}
#[test]
fn cached_provider_embed_empty_string() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
let result = provider.embed("");
assert!(result.is_ok());
let embedding = result.unwrap();
assert_eq!(embedding.len(), 4);
let (size, _) = provider.cache_stats();
assert_eq!(size, 1);
let result2 = provider.embed("");
assert!(result2.is_ok());
assert_eq!(result2.unwrap(), embedding);
}
#[test]
fn cached_provider_embed_very_long_string() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
let long_text = "a".repeat(15_000);
let result = provider.embed(&long_text);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 4);
}
#[test]
fn cached_provider_batch_with_duplicates_avoids_redundant_calls() {
use std::sync::Arc;
let call_count = Arc::new(AtomicUsize::new(0));
let count_clone = call_count.clone();
struct CountingBatchMock {
dims: usize,
count: Arc<AtomicUsize>,
}
impl EmbeddingProvider for CountingBatchMock {
fn dimensions(&self) -> usize {
self.dims
}
fn embed(&self, _text: &str) -> Result<Vec<f32>, CodememError> {
self.count.fetch_add(1, Ordering::SeqCst);
Ok(vec![0.1; self.dims])
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
for _ in texts {
self.count.fetch_add(1, Ordering::SeqCst);
}
Ok(texts.iter().map(|_| vec![0.1; self.dims]).collect())
}
fn name(&self) -> &str {
"counting-batch"
}
}
let provider = CachedProvider::new(
Box::new(CountingBatchMock {
dims: 4,
count: count_clone,
}),
100,
);
provider.embed("hello").unwrap();
let after_first = call_count.load(Ordering::SeqCst);
assert_eq!(after_first, 1);
let result = provider.embed_batch(&["hello", "world", "hello"]).unwrap();
assert_eq!(result.len(), 3);
let after_batch = call_count.load(Ordering::SeqCst);
assert_eq!(
after_batch - after_first,
1,
"Only uncached text should hit inner provider"
);
}
#[test]
fn cached_provider_batch_zero_items() {
let mock = MockProvider::new(4);
let provider = CachedProvider::new(Box::new(mock), 100);
let result = provider.embed_batch(&[]).unwrap();
assert!(result.is_empty());
let (size, _) = provider.cache_stats();
assert_eq!(size, 0);
}