use anyhow::Result;
use std::collections::HashMap;
use crate::storage::Message;
pub type Vector = Vec<f32>;
pub struct Embeddings {
pub vocab: HashMap<String, usize>,
pub vectors: Vec<Vector>,
pub dim: usize,
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
.filter(|w| !w.is_empty())
.collect()
}
fn build_vocab(messages: &[Message], min_freq: usize) -> HashMap<String, usize> {
let mut freq: HashMap<String, usize> = HashMap::new();
for msg in messages {
for tok in tokenize(&msg.content) {
*freq.entry(tok).or_default() += 1;
}
}
let mut vocab = HashMap::new();
for (word, count) in freq {
if count >= min_freq {
let idx = vocab.len();
vocab.insert(word, idx);
}
}
vocab
}
fn build_cooccurrence(
messages: &[Message],
vocab: &HashMap<String, usize>,
window: usize,
) -> Vec<Vec<f32>> {
let n = vocab.len();
let mut matrix = vec![vec![0f32; n]; n];
for msg in messages {
let tokens: Vec<usize> = tokenize(&msg.content)
.iter()
.filter_map(|t| vocab.get(t).copied())
.collect();
for (i, ¢er) in tokens.iter().enumerate() {
let start = i.saturating_sub(window);
let end = (i + window + 1).min(tokens.len());
for j in start..end {
if j == i { continue; }
let ctx = tokens[j];
let dist = (i as f32 - j as f32).abs();
matrix[center][ctx] += 1.0 / dist;
}
}
}
matrix
}
fn pmi(matrix: &[Vec<f32>]) -> Vec<Vec<f32>> {
let n = matrix.len();
let total: f32 = matrix.iter().flat_map(|r| r.iter()).sum();
let row_sums: Vec<f32> = matrix.iter().map(|r| r.iter().sum()).collect();
let col_sums: Vec<f32> = (0..n)
.map(|j| matrix.iter().map(|r| r[j]).sum())
.collect();
let mut result = vec![vec![0f32; n]; n];
for i in 0..n {
for j in 0..n {
if matrix[i][j] > 0.0 {
let p_ij = matrix[i][j] / total;
let p_i = row_sums[i] / total;
let p_j = col_sums[j] / total;
result[i][j] = (p_ij / (p_i * p_j)).ln().max(0.0);
}
}
}
result
}
fn truncated_svd(matrix: &[Vec<f32>], dim: usize) -> Vec<Vector> {
let n = matrix.len();
let actual_dim = dim.min(n);
let mut vecs: Vec<Vector> = (0..n).map(|i| matrix[i].clone()).collect();
let mut basis: Vec<Vector> = Vec::new();
for _ in 0..actual_dim {
let mut v: Vector = (0..n).map(|i| (i as f32 * 1.1 + 0.3).sin()).collect();
normalize(&mut v);
for _ in 0..30 {
let mut av = vec![0f32; n];
for i in 0..n {
for j in 0..n {
av[i] += vecs[i][j] * v[j];
}
}
for b in &basis {
let dot: f32 = av.iter().zip(b.iter()).map(|(a, b)| a * b).sum();
for k in 0..n { av[k] -= dot * b[k]; }
}
normalize(&mut av);
v = av;
}
basis.push(v.clone());
for i in 0..n {
let dot: f32 = vecs[i].iter().zip(v.iter()).map(|(a, b)| a * b).sum();
for j in 0..n { vecs[i][j] -= dot * v[j]; }
}
}
(0..n)
.map(|i| {
basis.iter()
.map(|b| matrix[i].iter().zip(b.iter()).map(|(a, c)| a * c).sum())
.collect()
})
.collect()
}
fn normalize(v: &mut Vector) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for x in v.iter_mut() { *x /= norm; }
}
}
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na < 1e-8 || nb < 1e-8 { 0.0 } else { dot / (na * nb) }
}
impl Embeddings {
pub fn train(messages: &[Message], dim: usize, window: usize) -> Self {
eprintln!("building vocab...");
let vocab = build_vocab(messages, 2);
eprintln!(" {} tokens", vocab.len());
eprintln!("cooccurrence matrix...");
let cooc = build_cooccurrence(messages, &vocab, window);
eprintln!("PMI...");
let ppmi = pmi(&cooc);
eprintln!("SVD dim={}...", dim);
let mut vectors = truncated_svd(&ppmi, dim);
for v in &mut vectors { normalize(v); }
Self { vocab, vectors, dim }
}
pub fn embed_text(&self, text: &str) -> Vector {
let tokens: Vec<usize> = tokenize(text)
.iter()
.filter_map(|t| self.vocab.get(t).copied())
.collect();
if tokens.is_empty() {
return vec![0f32; self.dim];
}
let mut result = vec![0f32; self.dim];
for idx in &tokens {
for (i, val) in self.vectors[*idx].iter().enumerate() {
result[i] += val;
}
}
let n = tokens.len() as f32;
for x in &mut result { *x /= n; }
normalize(&mut result);
result
}
pub fn centroid(&self, texts: &[&str]) -> Vector {
let vecs: Vec<Vector> = texts.iter().map(|t| self.embed_text(t)).collect();
if vecs.is_empty() { return vec![0f32; self.dim]; }
let mut result = vec![0f32; self.dim];
for v in &vecs {
for (i, x) in v.iter().enumerate() { result[i] += x; }
}
let n = vecs.len() as f32;
for x in &mut result { *x /= n; }
normalize(&mut result);
result
}
pub fn save(&self, path: &str) -> Result<()> {
let vocab_pairs: Vec<(String, usize)> = self.vocab.iter()
.map(|(k, v)| (k.clone(), *v))
.collect();
let data = bincode::serialize(&(self.dim, &vocab_pairs, &self.vectors))?;
std::fs::write(path, data)?;
Ok(())
}
pub fn load(path: &str) -> Result<Self> {
let data = std::fs::read(path)?;
let (dim, pairs, vectors): (usize, Vec<(String, usize)>, Vec<Vector>) =
bincode::deserialize(&data)?;
let vocab = pairs.into_iter().collect();
Ok(Self { vocab, vectors, dim })
}
}