impl LocalEmbedder {
pub fn new() -> Self {
Self {
dimension: 256,
document_frequencies: RwLock::new(HashMap::new()),
doc_count: RwLock::new(0),
}
}
pub fn with_dimension(dimension: usize) -> Self {
Self {
dimension,
document_frequencies: RwLock::new(HashMap::new()),
doc_count: RwLock::new(0),
}
}
pub fn fit(&self, documents: &[String]) -> Result<(), String> {
let mut df = self
.document_frequencies
.write()
.map_err(|e| format!("Lock error: {e}"))?;
let mut count = self
.doc_count
.write()
.map_err(|e| format!("Lock error: {e}"))?;
df.clear();
*count = documents.len();
for doc in documents {
let tokens: std::collections::HashSet<String> =
self.tokenize(doc).into_iter().collect();
for token in tokens {
*df.entry(token).or_insert(0) += 1;
}
}
Ok(())
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>, String> {
let tokens = self.tokenize(text);
let mut embedding = vec![0.0f32; self.dimension];
let mut tf: HashMap<String, usize> = HashMap::new();
for token in &tokens {
*tf.entry(token.clone()).or_insert(0) += 1;
}
let doc_count = *self
.doc_count
.read()
.map_err(|e| format!("Lock error: {e}"))?;
let df = self
.document_frequencies
.read()
.map_err(|e| format!("Lock error: {e}"))?;
for (token, count) in &tf {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
token.hash(&mut hasher);
let hash = hasher.finish();
let idx = (hash as usize) % self.dimension;
let term_freq = (1.0 + *count as f32).ln();
let doc_freq = df.get(token).copied().unwrap_or(1) as f32;
let n = (doc_count.max(1)) as f32;
let inv_doc_freq = (n / doc_freq).ln();
let weight = term_freq * inv_doc_freq;
let sign = if (hash >> 32) & 1 == 0 { 1.0 } else { -1.0 };
embedding[idx] += sign * weight;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut embedding {
*x /= norm;
}
}
Ok(embedding)
}
fn tokenize(&self, text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|s| s.len() > 1)
.map(|s| s.to_string())
.collect()
}
pub fn dimension(&self) -> usize {
self.dimension
}
}
impl Default for LocalEmbedder {
fn default() -> Self {
Self::new()
}
}
unsafe impl Send for LocalEmbedder {}
unsafe impl Sync for LocalEmbedder {}