use std::collections::HashMap;
use crate::error::RagError;
pub trait Embedder: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>, RagError>;
fn embedding_dim(&self) -> usize;
}
pub struct IdentityEmbedder {
dim: usize,
}
impl IdentityEmbedder {
pub fn new(dim: usize) -> Result<Self, RagError> {
if dim == 0 {
return Err(RagError::DimensionMismatch {
expected: 1,
got: 0,
});
}
Ok(Self { dim })
}
fn hash_to_vec(&self, text: &str) -> Vec<f32> {
let text_bytes = text.as_bytes();
let mut fingerprint: u64 = 0xcbf2_9ce4_8422_2325; for &byte in text_bytes {
fingerprint ^= byte as u64;
fingerprint = fingerprint.wrapping_mul(0x0000_0100_0000_01B3);
}
fingerprint ^= text_bytes.len() as u64;
fingerprint = fingerprint.wrapping_mul(0x0000_0100_0000_01B3);
(0..self.dim)
.map(|d| {
let mut z =
fingerprint.wrapping_add((d as u64).wrapping_mul(0x9e37_79b9_7f4a_7c15));
z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
z ^= z >> 31;
let signed = z as i64;
let f = (signed as f64 / (i64::MAX as f64)).clamp(-1.0, 1.0) as f32;
if f == 0.0 {
((d + 1) as f32) * 1e-7
} else {
f
}
})
.collect()
}
}
impl Embedder for IdentityEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>, RagError> {
let mut v = self.hash_to_vec(text);
l2_normalize(&mut v);
Ok(v)
}
fn embedding_dim(&self) -> usize {
self.dim
}
}
pub struct TfIdfEmbedder {
vocab: HashMap<String, usize>,
idf: Vec<f32>,
dim: usize,
}
impl TfIdfEmbedder {
pub fn fit(documents: &[&str], max_features: usize) -> Self {
let max_features = max_features.max(1);
let n_docs = documents.len().max(1);
let mut df: HashMap<String, usize> = HashMap::new();
for doc in documents {
let tokens = tokenize(doc);
let unique: std::collections::HashSet<String> = tokens.into_iter().collect();
for tok in unique {
*df.entry(tok).or_insert(0) += 1;
}
}
let mut df_vec: Vec<(String, usize)> = df.into_iter().collect();
df_vec.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
df_vec.truncate(max_features);
let dim = df_vec.len();
let mut vocab = HashMap::with_capacity(dim);
let mut idf = vec![0.0f32; dim];
for (idx, (term, doc_freq)) in df_vec.into_iter().enumerate() {
vocab.insert(term, idx);
idf[idx] = ((1.0 + n_docs as f32) / (1.0 + doc_freq as f32)).ln() + 1.0;
}
Self { vocab, idf, dim }
}
pub fn embed_bow(&self, text: &str) -> Vec<f32> {
let tokens = tokenize(text);
let n_tokens = tokens.len().max(1) as f32;
let mut tf = vec![0.0f32; self.dim];
for tok in &tokens {
if let Some(&idx) = self.vocab.get(tok) {
tf[idx] += 1.0;
}
}
for v in tf.iter_mut() {
*v /= n_tokens;
}
tf
}
pub fn vocab_size(&self) -> usize {
self.dim
}
pub fn vocab(&self) -> &HashMap<String, usize> {
&self.vocab
}
}
impl Embedder for TfIdfEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>, RagError> {
if self.dim == 0 {
return Err(RagError::EmbeddingFailed(
"TfIdfEmbedder has an empty vocabulary".into(),
));
}
let mut tf = self.embed_bow(text);
for (i, v) in tf.iter_mut().enumerate() {
*v *= self.idf[i];
}
l2_normalize(&mut tf);
Ok(tf)
}
fn embedding_dim(&self) -> usize {
self.dim
}
}
pub fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
pub(crate) fn tokenize(text: &str) -> Vec<String> {
text.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
.filter(|s| !s.is_empty())
.map(|s| s.to_lowercase())
.collect()
}