interpolize 5.0.1

a rust program that scrapes discord, learns how your friends talk, and generates new messages in their collective voice. yes, this is what we're doing with our lives.
use std::collections::HashMap;
use crate::embed::{Embeddings, Vector, cosine};

type Chain = HashMap<Vec<String>, HashMap<String, f32>>;

pub struct Markov {
    chains: Vec<Chain>,
    max_order: usize,
    unigrams: HashMap<String, f32>,
}

fn tokenize(text: &str) -> Vec<String> {
    text.split_whitespace()
        .map(|w| w.to_string())
        .filter(|w| !w.is_empty())
        .collect()
}

fn detokenize(tokens: &[String]) -> String {
    tokens.join(" ")
}

impl Markov {
    pub fn train(texts: &[&str], max_order: usize) -> Self {
        let mut chains: Vec<Chain> = (0..max_order).map(|_| HashMap::new()).collect();
        let mut unigrams: HashMap<String, f32> = HashMap::new();

        for text in texts {
            let tokens = tokenize(text);
            for tok in &tokens {
                *unigrams.entry(tok.clone()).or_default() += 1.0;
            }
            for order in 1..=max_order {
                for window in tokens.windows(order + 1) {
                    let ctx = window[..order].to_vec();
                    let next = window[order].clone();
                    *chains[order - 1]
                        .entry(ctx)
                        .or_default()
                        .entry(next)
                        .or_default() += 1.0;
                }
            }
        }

        Self { chains, max_order, unigrams }
    }

    fn candidates(&self, ctx: &[String]) -> Option<&HashMap<String, f32>> {
        for order in (0..self.max_order).rev() {
            if ctx.len() < order + 1 { continue; }
            let key = ctx[ctx.len() - order - 1..].to_vec();
            if let Some(nexts) = self.chains[order].get(&key) {
                return Some(nexts);
            }
        }
        None
    }

    pub fn generate(
        &self,
        seed: &[String],
        max_tokens: usize,
        style_vec: Option<&Vector>,
        embeddings: Option<&Embeddings>,
        temperature: f32,
    ) -> Vec<String> {
        let mut output = seed.to_vec();

        for _ in 0..max_tokens {
            let ctx = &output;
            let raw = match self.candidates(ctx) {
                Some(c) => c,
                None => &self.unigrams,
            };

            let mut scored: Vec<(String, f32)> = raw
                .iter()
                .map(|(tok, &freq)| {
                    let mut score = freq;
                    if let (Some(sv), Some(emb)) = (style_vec, embeddings) {
                        if let Some(&idx) = emb.vocab.get(tok) {
                            let sim = cosine(&emb.vectors[idx], sv);
                            score *= (1.0 + sim).max(0.01);
                        }
                    }
                    (tok.clone(), score)
                })
                .collect();

            let total: f32 = scored.iter().map(|(_, s)| s / temperature).sum();
            if total <= 0.0 { break; }

            let mut r = rand_f32() * total;
            let mut chosen = scored[0].0.clone();
            for (tok, score) in &scored {
                r -= score / temperature;
                if r <= 0.0 {
                    chosen = tok.clone();
                    break;
                }
            }

            let stop = [".", "!", "?"].contains(&chosen.as_str());
            output.push(chosen);
            if stop { break; }
        }

        output[seed.len()..].to_vec()
    }

    pub fn generate_from(
        &self,
        prompt: &str,
        context_msgs: &[&str],
        style_vec: Option<&Vector>,
        embeddings: Option<&Embeddings>,
        max_tokens: usize,
        temperature: f32,
    ) -> String {
        let mut seed: Vec<String> = context_msgs
            .iter()
            .flat_map(|m| tokenize(m))
            .collect();

        seed.extend(tokenize(prompt));

        let generated = self.generate(&seed, max_tokens, style_vec, embeddings, temperature);
        detokenize(&generated)
    }
}

fn rand_f32() -> f32 {
    use std::time::{SystemTime, UNIX_EPOCH};
    let nanos = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap()
        .subsec_nanos();
    (nanos % 100_000) as f32 / 100_000.0
}