interpolize 1.0.0

a rust program that scrapes discord, learns how your friends talk, and generates new messages in their collective voice. yes, this is what we're doing with our lives.
use anyhow::Result;
use std::collections::HashMap;
use crate::storage::Message;

pub type Vector = Vec<f32>;

pub struct Embeddings {
    pub vocab: HashMap<String, usize>,
    pub vectors: Vec<Vector>,
    pub dim: usize,
}

fn tokenize(text: &str) -> Vec<String> {
    text.to_lowercase()
        .split_whitespace()
        .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
        .filter(|w| !w.is_empty())
        .collect()
}

fn build_vocab(messages: &[Message], min_freq: usize) -> HashMap<String, usize> {
    let mut freq: HashMap<String, usize> = HashMap::new();
    for msg in messages {
        for tok in tokenize(&msg.content) {
            *freq.entry(tok).or_default() += 1;
        }
    }
    let mut vocab = HashMap::new();
    for (word, count) in freq {
        if count >= min_freq {
            let idx = vocab.len();
            vocab.insert(word, idx);
        }
    }
    vocab
}

fn build_cooccurrence(
    messages: &[Message],
    vocab: &HashMap<String, usize>,
    window: usize,
) -> Vec<Vec<f32>> {
    let n = vocab.len();
    let mut matrix = vec![vec![0f32; n]; n];

    for msg in messages {
        let tokens: Vec<usize> = tokenize(&msg.content)
            .iter()
            .filter_map(|t| vocab.get(t).copied())
            .collect();

        for (i, &center) in tokens.iter().enumerate() {
            let start = i.saturating_sub(window);
            let end = (i + window + 1).min(tokens.len());
            for j in start..end {
                if j == i { continue; }
                let ctx = tokens[j];
                let dist = (i as f32 - j as f32).abs();
                matrix[center][ctx] += 1.0 / dist;
            }
        }
    }

    matrix
}

fn pmi(matrix: &[Vec<f32>]) -> Vec<Vec<f32>> {
    let n = matrix.len();
    let total: f32 = matrix.iter().flat_map(|r| r.iter()).sum();
    let row_sums: Vec<f32> = matrix.iter().map(|r| r.iter().sum()).collect();
    let col_sums: Vec<f32> = (0..n)
        .map(|j| matrix.iter().map(|r| r[j]).sum())
        .collect();

    let mut result = vec![vec![0f32; n]; n];
    for i in 0..n {
        for j in 0..n {
            if matrix[i][j] > 0.0 {
                let p_ij = matrix[i][j] / total;
                let p_i = row_sums[i] / total;
                let p_j = col_sums[j] / total;
                result[i][j] = (p_ij / (p_i * p_j)).ln().max(0.0);
            }
        }
    }
    result
}

fn truncated_svd(matrix: &[Vec<f32>], dim: usize) -> Vec<Vector> {
    let n = matrix.len();
    let actual_dim = dim.min(n);
    let mut vecs: Vec<Vector> = (0..n).map(|i| matrix[i].clone()).collect();

    let mut basis: Vec<Vector> = Vec::new();

    for _ in 0..actual_dim {
        let mut v: Vector = (0..n).map(|i| (i as f32 * 1.1 + 0.3).sin()).collect();
        normalize(&mut v);

        for _ in 0..30 {
            let mut av = vec![0f32; n];
            for i in 0..n {
                for j in 0..n {
                    av[i] += vecs[i][j] * v[j];
                }
            }
            for b in &basis {
                let dot: f32 = av.iter().zip(b.iter()).map(|(a, b)| a * b).sum();
                for k in 0..n { av[k] -= dot * b[k]; }
            }
            normalize(&mut av);
            v = av;
        }

        basis.push(v.clone());

        for i in 0..n {
            let dot: f32 = vecs[i].iter().zip(v.iter()).map(|(a, b)| a * b).sum();
            for j in 0..n { vecs[i][j] -= dot * v[j]; }
        }
    }

    (0..n)
        .map(|i| {
            basis.iter()
                .map(|b| matrix[i].iter().zip(b.iter()).map(|(a, c)| a * c).sum())
                .collect()
        })
        .collect()
}

fn normalize(v: &mut Vector) {
    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm > 1e-8 {
        for x in v.iter_mut() { *x /= norm; }
    }
}

pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    if na < 1e-8 || nb < 1e-8 { 0.0 } else { dot / (na * nb) }
}

impl Embeddings {
    pub fn train(messages: &[Message], dim: usize, window: usize) -> Self {
        eprintln!("building vocab...");
        let vocab = build_vocab(messages, 2);
        eprintln!("  {} tokens", vocab.len());

        eprintln!("cooccurrence matrix...");
        let cooc = build_cooccurrence(messages, &vocab, window);

        eprintln!("PMI...");
        let ppmi = pmi(&cooc);

        eprintln!("SVD dim={}...", dim);
        let mut vectors = truncated_svd(&ppmi, dim);
        for v in &mut vectors { normalize(v); }

        Self { vocab, vectors, dim }
    }

    pub fn embed_text(&self, text: &str) -> Vector {
        let tokens: Vec<usize> = tokenize(text)
            .iter()
            .filter_map(|t| self.vocab.get(t).copied())
            .collect();

        if tokens.is_empty() {
            return vec![0f32; self.dim];
        }

        let mut result = vec![0f32; self.dim];
        for idx in &tokens {
            for (i, val) in self.vectors[*idx].iter().enumerate() {
                result[i] += val;
            }
        }
        let n = tokens.len() as f32;
        for x in &mut result { *x /= n; }
        normalize(&mut result);
        result
    }

    pub fn centroid(&self, texts: &[&str]) -> Vector {
        let vecs: Vec<Vector> = texts.iter().map(|t| self.embed_text(t)).collect();
        if vecs.is_empty() { return vec![0f32; self.dim]; }

        let mut result = vec![0f32; self.dim];
        for v in &vecs {
            for (i, x) in v.iter().enumerate() { result[i] += x; }
        }
        let n = vecs.len() as f32;
        for x in &mut result { *x /= n; }
        normalize(&mut result);
        result
    }

    pub fn save(&self, path: &str) -> Result<()> {
        let vocab_pairs: Vec<(String, usize)> = self.vocab.iter()
            .map(|(k, v)| (k.clone(), *v))
            .collect();
        let data = bincode::serialize(&(self.dim, &vocab_pairs, &self.vectors))?;
        std::fs::write(path, data)?;
        Ok(())
    }

    pub fn load(path: &str) -> Result<Self> {
        let data = std::fs::read(path)?;
        let (dim, pairs, vectors): (usize, Vec<(String, usize)>, Vec<Vector>) =
            bincode::deserialize(&data)?;
        let vocab = pairs.into_iter().collect();
        Ok(Self { vocab, vectors, dim })
    }
}