pub const TOPIC_CONTINUITY_THRESHOLD: f64 = 0.65;
const LOOKBACK_TURNS: usize = 3;
#[derive(Debug, Clone)]
pub struct TopicSegment {
pub tag: String,
pub is_new_topic: bool,
pub continuity_score: f64,
}
pub fn detect_topic_shift(
current_tag: Option<&str>,
next_tag_number: u32,
new_message_embedding: &[f32],
recent_embeddings: &[Vec<f32>],
) -> TopicSegment {
if recent_embeddings.is_empty() || current_tag.is_none() {
return TopicSegment {
tag: format!("topic-{next_tag_number}"),
is_new_topic: true,
continuity_score: 0.0,
};
}
let max_similarity = recent_embeddings
.iter()
.take(LOOKBACK_TURNS)
.map(|emb| cosine_similarity(new_message_embedding, emb))
.fold(0.0_f64, f64::max);
if max_similarity >= TOPIC_CONTINUITY_THRESHOLD {
TopicSegment {
tag: current_tag.unwrap_or("topic-1").to_string(),
is_new_topic: false,
continuity_score: max_similarity,
}
} else {
TopicSegment {
tag: format!("topic-{next_tag_number}"),
is_new_topic: true,
continuity_score: max_similarity,
}
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let (mut dot, mut norm_a, mut norm_b) = (0.0f64, 0.0f64, 0.0f64);
for (x, y) in a.iter().zip(b.iter()) {
let (x, y) = (*x as f64, *y as f64);
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < 1e-10 { 0.0 } else { dot / denom }
}
pub fn partition_by_topic<'a, T>(
messages: &'a [T],
current_tag: &str,
get_tag: impl Fn(&T) -> Option<&str>,
) -> (Vec<&'a T>, Vec<&'a T>) {
let mut current = Vec::new();
let mut other = Vec::new();
for msg in messages {
if get_tag(msg).is_some_and(|t| t == current_tag) {
current.push(msg);
} else {
other.push(msg);
}
}
(current, other)
}
pub fn summarize_topic_block(
topic_tag: &str,
message_count: usize,
first_user_msg: &str,
) -> String {
let snippet: String = first_user_msg.chars().take(80).collect();
format!("[Earlier topic ({topic_tag}, {message_count} messages): \"{snippet}...\"]")
}
#[cfg(test)]
mod tests {
use super::*;
fn make_embedding(seed: f32) -> Vec<f32> {
(0..8).map(|i| (seed + i as f32 * 0.1).sin()).collect()
}
#[test]
fn first_message_starts_new_topic() {
let emb = make_embedding(1.0);
let result = detect_topic_shift(None, 1, &emb, &[]);
assert!(result.is_new_topic);
assert_eq!(result.tag, "topic-1");
}
#[test]
fn similar_message_continues_topic() {
let emb1 = make_embedding(1.0);
let emb2 = make_embedding(1.01); let result = detect_topic_shift(Some("topic-1"), 2, &emb2, &[emb1]);
assert!(!result.is_new_topic);
assert_eq!(result.tag, "topic-1");
assert!(result.continuity_score > TOPIC_CONTINUITY_THRESHOLD);
}
#[test]
fn dissimilar_message_starts_new_topic() {
let emb1 = make_embedding(1.0);
let emb2 = make_embedding(100.0); let result = detect_topic_shift(Some("topic-1"), 2, &emb2, &[emb1]);
assert!(result.is_new_topic);
assert_eq!(result.tag, "topic-2");
assert!(result.continuity_score < TOPIC_CONTINUITY_THRESHOLD);
}
#[test]
fn partition_separates_topics() {
let msgs = vec![
("hello", Some("topic-1")),
("world", Some("topic-1")),
("new thing", Some("topic-2")),
("back to hello", Some("topic-1")),
];
let (current, other) = partition_by_topic(&msgs, "topic-2", |m| m.1);
assert_eq!(current.len(), 1);
assert_eq!(other.len(), 3);
}
#[test]
fn summarize_topic_block_formats_correctly() {
let summary = summarize_topic_block(
"topic-1",
5,
"cleanup Duncan's workspace and remove redundant items",
);
assert!(summary.contains("topic-1"));
assert!(summary.contains("5 messages"));
assert!(summary.contains("cleanup"));
}
#[test]
fn cosine_similarity_identical_vectors() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
}