interpolize 1.0.0

a rust program that scrapes discord, learns how your friends talk, and generates new messages in their collective voice. yes, this is what we're doing with our lives.
use crate::config::Config;
use crate::embed::{Embeddings, Vector, cosine};
use crate::storage::{Message, thread};

pub struct Interpolator<'a> {
    config: &'a Config,
    embeddings: &'a Embeddings,
    messages: &'a [Message],
    msg_vectors: Vec<Vector>,
    channel_centroids: Vec<Vector>,
}

impl<'a> Interpolator<'a> {
    pub fn new(config: &'a Config, embeddings: &'a Embeddings, messages: &'a [Message]) -> Self {
        let msg_vectors: Vec<Vector> = messages
            .iter()
            .map(|m| embeddings.embed_text(&m.content))
            .collect();

        let weights = config.normalized_weights();
        let channel_centroids: Vec<Vector> = config.channels
            .iter()
            .map(|ch| {
                let texts: Vec<&str> = messages
                    .iter()
                    .filter(|m| m.channel_id == ch.id)
                    .map(|m| m.content.as_str())
                    .collect();
                embeddings.centroid(&texts)
            })
            .collect();

        Self { config, embeddings, messages, msg_vectors, channel_centroids }
    }

    pub fn style_vector(&self) -> Vector {
        let weights = self.config.normalized_weights();
        let dim = self.embeddings.dim;
        let mut result = vec![0f32; dim];

        for (centroid, &w) in self.channel_centroids.iter().zip(weights.iter()) {
            for (i, x) in centroid.iter().enumerate() {
                result[i] += w * x;
            }
        }

        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
        if norm > 1e-8 {
            for x in result.iter_mut() { *x /= norm; }
        }
        result
    }

    pub fn knn<'b>(&self, query: &Vector, pool: &'b [Message], k: usize) -> Vec<&'b Message> {
        let mut scored: Vec<(usize, f32)> = pool
            .iter()
            .enumerate()
            .map(|(i, _)| (i, cosine(query, &self.msg_vectors[
                self.messages.iter().position(|m| m.id == pool[i].id).unwrap_or(0)
            ])))
            .collect();

        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scored.truncate(k);
        scored.iter().map(|(i, _)| &pool[*i]).collect()
    }

    pub fn context_messages(&self, query: &str, k: usize) -> Vec<Message> {
        let qv = self.embeddings.embed_text(query);
        let depth = self.config.retrieval.thread_depth;

        let mut scored: Vec<(usize, f32)> = self.messages
            .iter()
            .enumerate()
            .map(|(i, _)| (i, cosine(&qv, &self.msg_vectors[i])))
            .collect();

        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scored.truncate(k);

        let mut result = vec![];
        for (i, _) in scored {
            let msg = &self.messages[i];
            let chain = thread(&msg.id, self.messages, depth);
            for m in chain {
                if !result.iter().any(|r: &Message| r.id == m.id) {
                    result.push(m);
                }
            }
        }
        result
    }

    pub fn style_messages(&self, k: usize) -> Vec<&Message> {
        let sv = self.style_vector();
        let mut scored: Vec<(usize, f32)> = self.msg_vectors
            .iter()
            .enumerate()
            .map(|(i, v)| (i, cosine(&sv, v)))
            .collect();

        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scored.truncate(k);
        scored.iter().map(|(i, _)| &self.messages[*i]).collect()
    }

    pub fn build_seed(&self, query: &str) -> (Vec<String>, Vector) {
        let ctx_k = self.config.retrieval.context_k;
        let style_k = self.config.retrieval.style_k;

        let ctx_msgs = self.context_messages(query, ctx_k);
        let style_msgs = self.style_messages(style_k);
        let sv = self.style_vector();

        let mut seed_text = String::new();
        for m in &ctx_msgs {
            seed_text.push_str(&m.content);
            seed_text.push(' ');
        }
        seed_text.push_str(query);

        let tokens = seed_text
            .split_whitespace()
            .map(|w| w.to_string())
            .collect();

        (tokens, sv)
    }
}