Skip to main content

roboticus_agent/
topic.rs

1//! Topic segmentation for session conversation threads.
2//!
3//! Detects topic changes within a session so the context builder can
4//! prioritize messages from the current topic over stale threads. This
5//! prevents the model from confusing concurrent conversation topics
6//! (e.g. workspace cleanup vs. security monitoring vs. Go migration)
7//! in long-running sessions.
8//!
9//! ## Architecture
10//!
11//! - **Detection**: Compares the user's new message against the most recent
12//!   assistant response via cosine similarity on embeddings. A drop below
13//!   the `TOPIC_CONTINUITY_THRESHOLD` signals a topic shift.
14//! - **Tagging**: Each message gets a `topic_tag` (e.g. `"topic-3"`) stored
15//!   in the database. Sequential messages on the same topic share a tag.
16//! - **Context assembly**: `build_context_with_budget` uses topic tags to
17//!   ensure the current topic's messages are prioritized within the token
18//!   budget, with older topics getting a compressed summary instead of
19//!   full message history.
20
21/// Cosine similarity threshold below which a new message is considered
22/// a topic change from the previous conversation thread.
23///
24/// Typical values:
25/// - 0.85+: very similar (same topic, follow-up question)
26/// - 0.65-0.85: related but drifting (adjacent topic)
27/// - < 0.65: different topic entirely
28pub const TOPIC_CONTINUITY_THRESHOLD: f64 = 0.65;
29
30/// Maximum number of sequential turns to check for topic continuity.
31/// We compare against the most recent N user messages, not just the last one,
32/// to handle cases where the user interleaves topics quickly.
33const LOOKBACK_TURNS: usize = 3;
34
35/// Represents a detected topic boundary in a conversation.
36#[derive(Debug, Clone)]
37pub struct TopicSegment {
38    /// Opaque tag identifying this topic thread (e.g. "topic-1", "topic-2").
39    pub tag: String,
40    /// Whether this message starts a new topic (vs continuing the previous).
41    pub is_new_topic: bool,
42    /// Similarity score against the previous topic (0.0 if first message).
43    pub continuity_score: f64,
44}
45
46/// Detect whether a new user message is a continuation of the current topic
47/// or a shift to a new one.
48///
49/// Uses embedding cosine similarity against recent assistant responses.
50/// Returns the topic tag to assign to this message.
51///
52/// `current_tag`: the topic tag of the most recent message (e.g. "topic-2").
53///   Pass `None` for the first message in a session.
54/// `next_tag_number`: the next available topic number (e.g. 3 if "topic-2" is current).
55/// `new_message_embedding`: embedding of the user's new message.
56/// `recent_embeddings`: embeddings of the most recent `LOOKBACK_TURNS` messages
57///   (from newest to oldest).
58pub fn detect_topic_shift(
59    current_tag: Option<&str>,
60    next_tag_number: u32,
61    new_message_embedding: &[f32],
62    recent_embeddings: &[Vec<f32>],
63) -> TopicSegment {
64    if recent_embeddings.is_empty() || current_tag.is_none() {
65        // First message or no history — start topic-1.
66        return TopicSegment {
67            tag: format!("topic-{next_tag_number}"),
68            is_new_topic: true,
69            continuity_score: 0.0,
70        };
71    }
72
73    // Compare against the most recent messages and take the MAX similarity.
74    // This handles interleaved topics where the user might briefly address
75    // one topic before returning to another.
76    let max_similarity = recent_embeddings
77        .iter()
78        .take(LOOKBACK_TURNS)
79        .map(|emb| cosine_similarity(new_message_embedding, emb))
80        .fold(0.0_f64, f64::max);
81
82    if max_similarity >= TOPIC_CONTINUITY_THRESHOLD {
83        // Continuing the current topic.
84        TopicSegment {
85            tag: current_tag.unwrap_or("topic-1").to_string(),
86            is_new_topic: false,
87            continuity_score: max_similarity,
88        }
89    } else {
90        // Topic shift detected.
91        TopicSegment {
92            tag: format!("topic-{next_tag_number}"),
93            is_new_topic: true,
94            continuity_score: max_similarity,
95        }
96    }
97}
98
99/// Cosine similarity between two embedding vectors.
100fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
101    if a.len() != b.len() || a.is_empty() {
102        return 0.0;
103    }
104    let (mut dot, mut norm_a, mut norm_b) = (0.0f64, 0.0f64, 0.0f64);
105    for (x, y) in a.iter().zip(b.iter()) {
106        let (x, y) = (*x as f64, *y as f64);
107        dot += x * y;
108        norm_a += x * x;
109        norm_b += y * y;
110    }
111    let denom = norm_a.sqrt() * norm_b.sqrt();
112    if denom < 1e-10 { 0.0 } else { dot / denom }
113}
114
115/// Given a list of messages with topic tags, partition them into
116/// current-topic messages and off-topic messages.
117///
118/// Returns `(current_topic_messages, off_topic_messages)` where both
119/// preserve chronological order.
120pub fn partition_by_topic<'a, T>(
121    messages: &'a [T],
122    current_tag: &str,
123    get_tag: impl Fn(&T) -> Option<&str>,
124) -> (Vec<&'a T>, Vec<&'a T>) {
125    let mut current = Vec::new();
126    let mut other = Vec::new();
127    for msg in messages {
128        if get_tag(msg).is_some_and(|t| t == current_tag) {
129            current.push(msg);
130        } else {
131            other.push(msg);
132        }
133    }
134    (current, other)
135}
136
137/// Generate a compressed summary line for an off-topic message block.
138/// Used in context assembly to represent trimmed topics as a single system note
139/// rather than including all their messages verbatim.
140pub fn summarize_topic_block(
141    topic_tag: &str,
142    message_count: usize,
143    first_user_msg: &str,
144) -> String {
145    let snippet: String = first_user_msg.chars().take(80).collect();
146    format!("[Earlier topic ({topic_tag}, {message_count} messages): \"{snippet}...\"]")
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn make_embedding(seed: f32) -> Vec<f32> {
154        // Simple deterministic embedding for testing
155        (0..8).map(|i| (seed + i as f32 * 0.1).sin()).collect()
156    }
157
158    #[test]
159    fn first_message_starts_new_topic() {
160        let emb = make_embedding(1.0);
161        let result = detect_topic_shift(None, 1, &emb, &[]);
162        assert!(result.is_new_topic);
163        assert_eq!(result.tag, "topic-1");
164    }
165
166    #[test]
167    fn similar_message_continues_topic() {
168        let emb1 = make_embedding(1.0);
169        let emb2 = make_embedding(1.01); // very similar
170        let result = detect_topic_shift(Some("topic-1"), 2, &emb2, &[emb1]);
171        assert!(!result.is_new_topic);
172        assert_eq!(result.tag, "topic-1");
173        assert!(result.continuity_score > TOPIC_CONTINUITY_THRESHOLD);
174    }
175
176    #[test]
177    fn dissimilar_message_starts_new_topic() {
178        let emb1 = make_embedding(1.0);
179        let emb2 = make_embedding(100.0); // very different
180        let result = detect_topic_shift(Some("topic-1"), 2, &emb2, &[emb1]);
181        assert!(result.is_new_topic);
182        assert_eq!(result.tag, "topic-2");
183        assert!(result.continuity_score < TOPIC_CONTINUITY_THRESHOLD);
184    }
185
186    #[test]
187    fn partition_separates_topics() {
188        let msgs = vec![
189            ("hello", Some("topic-1")),
190            ("world", Some("topic-1")),
191            ("new thing", Some("topic-2")),
192            ("back to hello", Some("topic-1")),
193        ];
194        let (current, other) = partition_by_topic(&msgs, "topic-2", |m| m.1);
195        assert_eq!(current.len(), 1);
196        assert_eq!(other.len(), 3);
197    }
198
199    #[test]
200    fn summarize_topic_block_formats_correctly() {
201        let summary = summarize_topic_block(
202            "topic-1",
203            5,
204            "cleanup Duncan's workspace and remove redundant items",
205        );
206        assert!(summary.contains("topic-1"));
207        assert!(summary.contains("5 messages"));
208        assert!(summary.contains("cleanup"));
209    }
210
211    #[test]
212    fn cosine_similarity_identical_vectors() {
213        let v = vec![1.0, 2.0, 3.0];
214        let sim = cosine_similarity(&v, &v);
215        assert!((sim - 1.0).abs() < 1e-6);
216    }
217
218    #[test]
219    fn cosine_similarity_orthogonal_vectors() {
220        let a = vec![1.0, 0.0, 0.0];
221        let b = vec![0.0, 1.0, 0.0];
222        let sim = cosine_similarity(&a, &b);
223        assert!(sim.abs() < 1e-6);
224    }
225}