1pub const TOPIC_CONTINUITY_THRESHOLD: f64 = 0.65;
29
30const LOOKBACK_TURNS: usize = 3;
34
35#[derive(Debug, Clone)]
37pub struct TopicSegment {
38 pub tag: String,
40 pub is_new_topic: bool,
42 pub continuity_score: f64,
44}
45
46pub fn detect_topic_shift(
59 current_tag: Option<&str>,
60 next_tag_number: u32,
61 new_message_embedding: &[f32],
62 recent_embeddings: &[Vec<f32>],
63) -> TopicSegment {
64 if recent_embeddings.is_empty() || current_tag.is_none() {
65 return TopicSegment {
67 tag: format!("topic-{next_tag_number}"),
68 is_new_topic: true,
69 continuity_score: 0.0,
70 };
71 }
72
73 let max_similarity = recent_embeddings
77 .iter()
78 .take(LOOKBACK_TURNS)
79 .map(|emb| cosine_similarity(new_message_embedding, emb))
80 .fold(0.0_f64, f64::max);
81
82 if max_similarity >= TOPIC_CONTINUITY_THRESHOLD {
83 TopicSegment {
85 tag: current_tag.unwrap_or("topic-1").to_string(),
86 is_new_topic: false,
87 continuity_score: max_similarity,
88 }
89 } else {
90 TopicSegment {
92 tag: format!("topic-{next_tag_number}"),
93 is_new_topic: true,
94 continuity_score: max_similarity,
95 }
96 }
97}
98
99fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
101 if a.len() != b.len() || a.is_empty() {
102 return 0.0;
103 }
104 let (mut dot, mut norm_a, mut norm_b) = (0.0f64, 0.0f64, 0.0f64);
105 for (x, y) in a.iter().zip(b.iter()) {
106 let (x, y) = (*x as f64, *y as f64);
107 dot += x * y;
108 norm_a += x * x;
109 norm_b += y * y;
110 }
111 let denom = norm_a.sqrt() * norm_b.sqrt();
112 if denom < 1e-10 { 0.0 } else { dot / denom }
113}
114
115pub fn partition_by_topic<'a, T>(
121 messages: &'a [T],
122 current_tag: &str,
123 get_tag: impl Fn(&T) -> Option<&str>,
124) -> (Vec<&'a T>, Vec<&'a T>) {
125 let mut current = Vec::new();
126 let mut other = Vec::new();
127 for msg in messages {
128 if get_tag(msg).is_some_and(|t| t == current_tag) {
129 current.push(msg);
130 } else {
131 other.push(msg);
132 }
133 }
134 (current, other)
135}
136
137pub fn summarize_topic_block(
141 topic_tag: &str,
142 message_count: usize,
143 first_user_msg: &str,
144) -> String {
145 let snippet: String = first_user_msg.chars().take(80).collect();
146 format!("[Earlier topic ({topic_tag}, {message_count} messages): \"{snippet}...\"]")
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 fn make_embedding(seed: f32) -> Vec<f32> {
154 (0..8).map(|i| (seed + i as f32 * 0.1).sin()).collect()
156 }
157
158 #[test]
159 fn first_message_starts_new_topic() {
160 let emb = make_embedding(1.0);
161 let result = detect_topic_shift(None, 1, &emb, &[]);
162 assert!(result.is_new_topic);
163 assert_eq!(result.tag, "topic-1");
164 }
165
166 #[test]
167 fn similar_message_continues_topic() {
168 let emb1 = make_embedding(1.0);
169 let emb2 = make_embedding(1.01); let result = detect_topic_shift(Some("topic-1"), 2, &emb2, &[emb1]);
171 assert!(!result.is_new_topic);
172 assert_eq!(result.tag, "topic-1");
173 assert!(result.continuity_score > TOPIC_CONTINUITY_THRESHOLD);
174 }
175
176 #[test]
177 fn dissimilar_message_starts_new_topic() {
178 let emb1 = make_embedding(1.0);
179 let emb2 = make_embedding(100.0); let result = detect_topic_shift(Some("topic-1"), 2, &emb2, &[emb1]);
181 assert!(result.is_new_topic);
182 assert_eq!(result.tag, "topic-2");
183 assert!(result.continuity_score < TOPIC_CONTINUITY_THRESHOLD);
184 }
185
186 #[test]
187 fn partition_separates_topics() {
188 let msgs = vec![
189 ("hello", Some("topic-1")),
190 ("world", Some("topic-1")),
191 ("new thing", Some("topic-2")),
192 ("back to hello", Some("topic-1")),
193 ];
194 let (current, other) = partition_by_topic(&msgs, "topic-2", |m| m.1);
195 assert_eq!(current.len(), 1);
196 assert_eq!(other.len(), 3);
197 }
198
199 #[test]
200 fn summarize_topic_block_formats_correctly() {
201 let summary = summarize_topic_block(
202 "topic-1",
203 5,
204 "cleanup Duncan's workspace and remove redundant items",
205 );
206 assert!(summary.contains("topic-1"));
207 assert!(summary.contains("5 messages"));
208 assert!(summary.contains("cleanup"));
209 }
210
211 #[test]
212 fn cosine_similarity_identical_vectors() {
213 let v = vec![1.0, 2.0, 3.0];
214 let sim = cosine_similarity(&v, &v);
215 assert!((sim - 1.0).abs() < 1e-6);
216 }
217
218 #[test]
219 fn cosine_similarity_orthogonal_vectors() {
220 let a = vec![1.0, 0.0, 0.0];
221 let b = vec![0.0, 1.0, 0.0];
222 let sim = cosine_similarity(&a, &b);
223 assert!(sim.abs() < 1e-6);
224 }
225}