use moka::sync::Cache;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum EmbeddingError {
ModelNotAvailable(String),
TextTooLong { max_length: usize, actual: usize },
DimensionMismatch { expected: usize, actual: usize },
ProviderError(String),
CacheError(String),
}
impl std::fmt::Display for EmbeddingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ModelNotAvailable(model) => write!(f, "Embedding model not available: {}", model),
Self::TextTooLong { max_length, actual } => {
write!(f, "Text too long: {} > {} max", actual, max_length)
}
Self::DimensionMismatch { expected, actual } => {
write!(
f,
"Dimension mismatch: expected {}, got {}",
expected, actual
)
}
Self::ProviderError(msg) => write!(f, "Provider error: {}", msg),
Self::CacheError(msg) => write!(f, "Cache error: {}", msg),
}
}
}
impl std::error::Error for EmbeddingError {}
pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
pub trait EmbeddingProvider: Send + Sync {
fn model_name(&self) -> &str;
fn dimension(&self) -> usize;
fn max_length(&self) -> usize;
fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
fn normalize(&self, embedding: &mut [f32]) {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in embedding.iter_mut() {
*x /= norm;
}
}
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
pub model: String,
pub model_path: Option<String>,
pub dimension: usize,
pub max_length: usize,
pub normalize: bool,
pub batch_size: usize,
pub cache_size: usize,
pub cache_ttl_secs: u64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model: "all-MiniLM-L6-v2".to_string(),
model_path: None,
dimension: 384, max_length: 512,
normalize: true,
batch_size: 32,
cache_size: 10_000,
cache_ttl_secs: 3600, }
}
}
impl EmbeddingConfig {
pub fn sentence_transformer(model: &str) -> Self {
let dimension = match model {
"all-MiniLM-L6-v2" => 384,
"all-MiniLM-L12-v2" => 384,
"all-mpnet-base-v2" => 768,
"paraphrase-MiniLM-L6-v2" => 384,
"multi-qa-MiniLM-L6-cos-v1" => 384,
_ => 384, };
Self {
model: model.to_string(),
dimension,
..Default::default()
}
}
pub fn openai(model: &str) -> Self {
let dimension = match model {
"text-embedding-ada-002" => 1536,
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
_ => 1536,
};
Self {
model: model.to_string(),
dimension,
max_length: 8192,
..Default::default()
}
}
}
pub struct MockEmbeddingProvider {
config: EmbeddingConfig,
use_hash: bool,
}
impl MockEmbeddingProvider {
pub fn new(dimension: usize) -> Self {
Self {
config: EmbeddingConfig {
model: "mock".to_string(),
dimension,
..Default::default()
},
use_hash: true,
}
}
pub fn with_config(config: EmbeddingConfig) -> Self {
Self {
config,
use_hash: true,
}
}
fn hash_embed(&self, text: &str) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut embedding = Vec::with_capacity(self.config.dimension);
for i in 0..self.config.dimension {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
i.hash(&mut hasher);
let hash = hasher.finish();
let value = ((hash as f64) / (u64::MAX as f64) * 2.0 - 1.0) as f32;
embedding.push(value);
}
embedding
}
}
impl EmbeddingProvider for MockEmbeddingProvider {
fn model_name(&self) -> &str {
&self.config.model
}
fn dimension(&self) -> usize {
self.config.dimension
}
fn max_length(&self) -> usize {
self.config.max_length
}
fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
if text.len() > self.config.max_length {
return Err(EmbeddingError::TextTooLong {
max_length: self.config.max_length,
actual: text.len(),
});
}
let mut embedding = if self.use_hash {
self.hash_embed(text)
} else {
vec![0.0; self.config.dimension]
};
if self.config.normalize {
self.normalize(&mut embedding);
}
Ok(embedding)
}
}
pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
inner: P,
cache: Cache<u64, Vec<f32>>,
stats: Arc<CacheStats>,
}
#[derive(Debug, Default)]
pub struct CacheStats {
pub hits: std::sync::atomic::AtomicUsize,
pub misses: std::sync::atomic::AtomicUsize,
pub size: std::sync::atomic::AtomicUsize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
}
impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
pub fn new(inner: P, cache_size: usize) -> Self {
Self {
inner,
cache: Cache::new(cache_size as u64),
stats: Arc::new(CacheStats::default()),
}
}
pub fn with_ttl(inner: P, cache_size: usize, ttl_secs: u64) -> Self {
let cache = Cache::builder()
.max_capacity(cache_size as u64)
.time_to_live(std::time::Duration::from_secs(ttl_secs))
.build();
Self {
inner,
cache,
stats: Arc::new(CacheStats::default()),
}
}
pub fn stats(&self) -> &Arc<CacheStats> {
&self.stats
}
fn text_hash(text: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
}
impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
fn model_name(&self) -> &str {
self.inner.model_name()
}
fn dimension(&self) -> usize {
self.inner.dimension()
}
fn max_length(&self) -> usize {
self.inner.max_length()
}
fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
let hash = Self::text_hash(text);
if let Some(cached) = self.cache.get(&hash) {
self.stats
.hits
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
return Ok(cached);
}
self.stats
.misses
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let embedding = self.inner.embed(text)?;
self.cache.insert(hash, embedding.clone());
self.stats
.size
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(embedding)
}
fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
let mut uncached: Vec<(usize, &str)> = Vec::new();
for (i, text) in texts.iter().enumerate() {
let hash = Self::text_hash(text);
if let Some(cached) = self.cache.get(&hash) {
self.stats
.hits
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
results.push((i, cached));
} else {
self.stats
.misses
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
uncached.push((i, *text));
}
}
if !uncached.is_empty() {
let uncached_texts: Vec<&str> = uncached.iter().map(|(_, t)| *t).collect();
let embeddings = self.inner.embed_batch(&uncached_texts)?;
for ((i, text), embedding) in uncached.iter().zip(embeddings.into_iter()) {
let hash = Self::text_hash(text);
self.cache.insert(hash, embedding.clone());
self.stats
.size
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
results.push((*i, embedding));
}
}
results.sort_by_key(|(i, _)| *i);
Ok(results.into_iter().map(|(_, e)| e).collect())
}
}
#[derive(Debug)]
pub struct LocalOnnxProvider {
config: EmbeddingConfig,
#[allow(dead_code)]
model_loaded: bool,
}
impl LocalOnnxProvider {
pub fn new(config: EmbeddingConfig) -> EmbeddingResult<Self> {
Ok(Self {
config,
model_loaded: false,
})
}
pub fn load_pretrained(model_name: &str) -> EmbeddingResult<Self> {
let config = EmbeddingConfig::sentence_transformer(model_name);
Self::new(config)
}
}
impl EmbeddingProvider for LocalOnnxProvider {
fn model_name(&self) -> &str {
&self.config.model
}
fn dimension(&self) -> usize {
self.config.dimension
}
fn max_length(&self) -> usize {
self.config.max_length
}
fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
let mock = MockEmbeddingProvider::with_config(self.config.clone());
mock.embed(text)
}
}
pub struct EmbeddingVectorIndex<V, P>
where
V: crate::context_query::VectorIndex,
P: EmbeddingProvider,
{
index: Arc<V>,
provider: Arc<P>,
}
impl<V, P> EmbeddingVectorIndex<V, P>
where
V: crate::context_query::VectorIndex,
P: EmbeddingProvider,
{
pub fn new(index: Arc<V>, provider: Arc<P>) -> Self {
Self { index, provider }
}
pub fn search_text(
&self,
collection: &str,
text: &str,
k: usize,
min_score: Option<f32>,
) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
let embedding = self.provider.embed(text).map_err(|e| e.to_string())?;
self.index
.search_by_embedding(collection, &embedding, k, min_score)
}
pub fn search_embedding(
&self,
collection: &str,
embedding: &[f32],
k: usize,
min_score: Option<f32>,
) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
if embedding.len() != self.provider.dimension() {
return Err(format!(
"Embedding dimension mismatch: expected {}, got {}",
self.provider.dimension(),
embedding.len()
));
}
self.index
.search_by_embedding(collection, embedding, k, min_score)
}
pub fn provider(&self) -> &Arc<P> {
&self.provider
}
pub fn index(&self) -> &Arc<V> {
&self.index
}
}
impl<V, P> crate::context_query::VectorIndex for EmbeddingVectorIndex<V, P>
where
V: crate::context_query::VectorIndex,
P: EmbeddingProvider,
{
fn search_by_embedding(
&self,
collection: &str,
embedding: &[f32],
k: usize,
min_score: Option<f32>,
) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
self.search_embedding(collection, embedding, k, min_score)
}
fn search_by_text(
&self,
collection: &str,
text: &str,
k: usize,
min_score: Option<f32>,
) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
self.search_text(collection, text, k, min_score)
}
fn stats(&self, collection: &str) -> Option<crate::context_query::VectorIndexStats> {
self.index.stats(collection)
}
}
pub fn create_mock_provider(
dimension: usize,
cache_size: usize,
) -> CachedEmbeddingProvider<MockEmbeddingProvider> {
let mock = MockEmbeddingProvider::new(dimension);
CachedEmbeddingProvider::new(mock, cache_size)
}
pub fn create_embedding_index<V: crate::context_query::VectorIndex>(
index: Arc<V>,
dimension: usize,
) -> EmbeddingVectorIndex<V, CachedEmbeddingProvider<MockEmbeddingProvider>> {
let provider = Arc::new(create_mock_provider(dimension, 10_000));
EmbeddingVectorIndex::new(index, provider)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_embedding_deterministic() {
let provider = MockEmbeddingProvider::new(384);
let emb1 = provider.embed("hello world").unwrap();
let emb2 = provider.embed("hello world").unwrap();
assert_eq!(emb1, emb2);
assert_eq!(emb1.len(), 384);
}
#[test]
fn test_mock_embedding_different_texts() {
let provider = MockEmbeddingProvider::new(384);
let emb1 = provider.embed("hello").unwrap();
let emb2 = provider.embed("world").unwrap();
assert_ne!(emb1, emb2);
}
#[test]
fn test_cached_provider() {
let mock = MockEmbeddingProvider::new(128);
let cached = CachedEmbeddingProvider::new(mock, 100);
let _ = cached.embed("test text").unwrap();
assert_eq!(
cached
.stats()
.hits
.load(std::sync::atomic::Ordering::Relaxed),
0
);
assert_eq!(
cached
.stats()
.misses
.load(std::sync::atomic::Ordering::Relaxed),
1
);
let _ = cached.embed("test text").unwrap();
assert_eq!(
cached
.stats()
.hits
.load(std::sync::atomic::Ordering::Relaxed),
1
);
assert_eq!(
cached
.stats()
.misses
.load(std::sync::atomic::Ordering::Relaxed),
1
);
assert!(cached.stats().hit_rate() > 0.4);
}
#[test]
fn test_batch_embedding() {
let mock = MockEmbeddingProvider::new(128);
let cached = CachedEmbeddingProvider::new(mock, 100);
let texts = vec!["hello", "world", "test"];
let embeddings = cached.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
for emb in &embeddings {
assert_eq!(emb.len(), 128);
}
}
#[test]
fn test_normalization() {
let provider = MockEmbeddingProvider::new(3);
let emb = provider.embed("test").unwrap();
let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn test_text_too_long() {
let config = EmbeddingConfig {
max_length: 10,
..Default::default()
};
let provider = MockEmbeddingProvider::with_config(config);
let result = provider.embed("this is a very long text that exceeds the limit");
assert!(matches!(result, Err(EmbeddingError::TextTooLong { .. })));
}
}