use anyhow::Result;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use crate::vectordb::provider::EmbeddingProvider;
pub const FAST_EMBEDDING_DIM: usize = 384;
pub struct FastEmbeddingProvider {
trigram_cache: HashMap<String, u64>,
}
impl FastEmbeddingProvider {
pub fn new() -> Self {
Self {
trigram_cache: HashMap::new(),
}
}
fn extract_ngrams(&self, text: &str, n: usize) -> Vec<String> {
let chars: Vec<char> = text.chars().collect();
if chars.len() < n {
return vec![text.to_string()];
}
let mut ngrams = Vec::with_capacity(chars.len() - n + 1);
for i in 0..=(chars.len() - n) {
let ngram: String = chars[i..(i + n)].iter().collect();
ngrams.push(ngram);
}
ngrams
}
fn hash_string(&mut self, s: &str) -> u64 {
if let Some(&hash) = self.trigram_cache.get(s) {
return hash;
}
let mut hasher = DefaultHasher::new();
s.hash(&mut hasher);
let hash = hasher.finish();
if s.len() == 3 {
self.trigram_cache.insert(s.to_string(), hash);
}
hash
}
}
impl EmbeddingProvider for FastEmbeddingProvider {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut provider = self.clone();
let text = text.to_lowercase();
let trigrams = provider.extract_ngrams(&text, 3);
let mut embedding = vec![0.0; FAST_EMBEDDING_DIM];
for (i, trigram) in trigrams.iter().enumerate() {
let hash = provider.hash_string(trigram);
let position_weight = 1.0 - (i as f32 / trigrams.len() as f32) * 0.5;
for j in 0..3 {
let index = ((hash >> (j * 16)) % FAST_EMBEDDING_DIM as u64) as usize;
embedding[index] += position_weight;
}
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut embedding {
*x /= norm;
}
}
Ok(embedding)
}
fn embedding_dimension(&self) -> usize {
FAST_EMBEDDING_DIM
}
fn name(&self) -> &'static str {
"Fast-Trigram"
}
fn description(&self) -> &'static str {
"Fast embedding using character trigrams with position weighting (less accurate but quicker than ONNX)"
}
}
impl Clone for FastEmbeddingProvider {
fn clone(&self) -> Self {
Self {
trigram_cache: self.trigram_cache.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vectordb::provider::tests::test_provider_basics;
#[test]
fn test_fast_provider() {
let provider = FastEmbeddingProvider::new();
test_provider_basics(&provider);
}
#[test]
fn test_deterministic_embeddings() {
let provider = FastEmbeddingProvider::new();
let text = "fn main() { println!(\"Hello, world!\"); }";
let embedding1 = provider.embed(text).unwrap();
let embedding2 = provider.embed(text).unwrap();
assert_eq!(embedding1, embedding2);
}
#[test]
fn test_similar_texts() {
let provider = FastEmbeddingProvider::new();
let text1 = "fn calculate_sum(a: i32, b: i32) -> i32 { a + b }";
let text2 = "fn calculate_sum(a: i32, b: i32) -> i32 { return a + b; }";
let text3 = "struct Point { x: i32, y: i32 }";
let embedding1 = provider.embed(text1).unwrap();
let embedding2 = provider.embed(text2).unwrap();
let embedding3 = provider.embed(text3).unwrap();
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
dot_product
}
let sim_1_2 = cosine_similarity(&embedding1, &embedding2);
let sim_1_3 = cosine_similarity(&embedding1, &embedding3);
assert!(sim_1_2 > 0.8, "Similar texts should have high similarity: {}", sim_1_2);
assert!(sim_1_3 < 0.8, "Different texts should have lower similarity: {}", sim_1_3);
}
}