use anyhow::Result;
use async_trait::async_trait;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::embeddings::backend::EmbeddingBackend;
use crate::embeddings::pool::MemoryPool;
pub struct StaticHashBackend {
dimension: usize,
memory_pool: Arc<MemoryPool>,
}
impl StaticHashBackend {
pub(in crate::embeddings) fn new(dimension: usize, memory_pool: Arc<MemoryPool>) -> Self {
Self {
dimension,
memory_pool,
}
}
fn text_hash(text: &str) -> u64 {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
}
#[async_trait]
impl EmbeddingBackend for StaticHashBackend {
fn embedding_dimension(&self) -> usize {
self.dimension
}
fn is_bert_based(&self) -> bool {
false
}
async fn process_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
for text in &texts {
let mut embedding = self.memory_pool.get_or_allocate();
let hash = Self::text_hash(text);
embedding.clear();
embedding.reserve(self.dimension);
for i in 0..self.dimension {
let value = ((hash.wrapping_add(i as u64)) as f32 / u64::MAX as f32) * 2.0 - 1.0;
embedding.push(value);
}
let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
results.push(embedding);
}
Ok(results)
}
}