use crate::config::Config;
use crate::embed::{Embeddings, Vector, cosine};
use crate::storage::{Message, thread};
pub struct Interpolator<'a> {
config: &'a Config,
embeddings: &'a Embeddings,
messages: &'a [Message],
msg_vectors: Vec<Vector>,
channel_centroids: Vec<Vector>,
}
impl<'a> Interpolator<'a> {
pub fn new(config: &'a Config, embeddings: &'a Embeddings, messages: &'a [Message]) -> Self {
let msg_vectors: Vec<Vector> = messages
.iter()
.map(|m| embeddings.embed_text(&m.content))
.collect();
let weights = config.normalized_weights();
let channel_centroids: Vec<Vector> = config.channels
.iter()
.map(|ch| {
let texts: Vec<&str> = messages
.iter()
.filter(|m| m.channel_id == ch.id)
.map(|m| m.content.as_str())
.collect();
embeddings.centroid(&texts)
})
.collect();
Self { config, embeddings, messages, msg_vectors, channel_centroids }
}
pub fn style_vector(&self) -> Vector {
let weights = self.config.normalized_weights();
let dim = self.embeddings.dim;
let mut result = vec![0f32; dim];
for (centroid, &w) in self.channel_centroids.iter().zip(weights.iter()) {
for (i, x) in centroid.iter().enumerate() {
result[i] += w * x;
}
}
let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for x in result.iter_mut() { *x /= norm; }
}
result
}
pub fn knn<'b>(&self, query: &Vector, pool: &'b [Message], k: usize) -> Vec<&'b Message> {
let mut scored: Vec<(usize, f32)> = pool
.iter()
.enumerate()
.map(|(i, _)| (i, cosine(query, &self.msg_vectors[
self.messages.iter().position(|m| m.id == pool[i].id).unwrap_or(0)
])))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored.truncate(k);
scored.iter().map(|(i, _)| &pool[*i]).collect()
}
pub fn context_messages(&self, query: &str, k: usize) -> Vec<Message> {
let qv = self.embeddings.embed_text(query);
let depth = self.config.retrieval.thread_depth;
let mut scored: Vec<(usize, f32)> = self.messages
.iter()
.enumerate()
.map(|(i, _)| (i, cosine(&qv, &self.msg_vectors[i])))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored.truncate(k);
let mut result = vec![];
for (i, _) in scored {
let msg = &self.messages[i];
let chain = thread(&msg.id, self.messages, depth);
for m in chain {
if !result.iter().any(|r: &Message| r.id == m.id) {
result.push(m);
}
}
}
result
}
pub fn style_messages(&self, k: usize) -> Vec<&Message> {
let sv = self.style_vector();
let mut scored: Vec<(usize, f32)> = self.msg_vectors
.iter()
.enumerate()
.map(|(i, v)| (i, cosine(&sv, v)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored.truncate(k);
scored.iter().map(|(i, _)| &self.messages[*i]).collect()
}
pub fn build_seed(&self, query: &str) -> (Vec<String>, Vector) {
let ctx_k = self.config.retrieval.context_k;
let style_k = self.config.retrieval.style_k;
let ctx_msgs = self.context_messages(query, ctx_k);
let style_msgs = self.style_messages(style_k);
let sv = self.style_vector();
let mut seed_text = String::new();
for m in &ctx_msgs {
seed_text.push_str(&m.content);
seed_text.push(' ');
}
seed_text.push_str(query);
let tokens = seed_text
.split_whitespace()
.map(|w| w.to_string())
.collect();
(tokens, sv)
}
}