hirn_engine/consolidation/
narrative.rs1use super::*;
2
3#[derive(Debug, Clone)]
9pub struct NarrativeThread {
10 pub title: String,
12 pub segment_indices: Vec<usize>,
14 pub record_ids: Vec<MemoryId>,
16 pub contents: Vec<String>,
18 pub summaries: Vec<String>,
20 pub start_time: Timestamp,
22 pub end_time: Timestamp,
23 pub entities: Vec<String>,
25 pub sub_threads: Vec<Self>,
27 pub embedding: Option<Vec<f32>>,
29}
30
31struct CondensedMatrix {
35 data: Vec<f32>,
36 n: usize,
37}
38
39impl CondensedMatrix {
40 fn new(segments: &[EpisodeSegment]) -> Self {
41 let n = segments.len();
42 let size = n * (n - 1) / 2;
43 let mut data = Vec::with_capacity(size);
44
45 let entity_sets: Vec<HashSet<&str>> = segments
47 .iter()
48 .map(|s| s.dominant_entities.iter().map(String::as_str).collect())
49 .collect();
50
51 for i in 0..n {
52 for j in (i + 1)..n {
53 let embedding_sim =
54 match (&segments[i].topic_embedding, &segments[j].topic_embedding) {
55 (Some(ea), Some(eb)) => {
56 1.0 - lance_linalg::distance::cosine_distance(ea, eb)
57 }
58 _ => 0.0,
59 };
60 let intersection = entity_sets[i].intersection(&entity_sets[j]).count();
61 let union = entity_sets[i].union(&entity_sets[j]).count();
62 let entity_sim = if union > 0 {
63 intersection as f32 / union as f32
64 } else {
65 0.0
66 };
67 data.push(embedding_sim * 0.6 + entity_sim * 0.4);
68 }
69 }
70
71 Self { data, n }
72 }
73
74 #[inline]
75 fn get(&self, i: usize, j: usize) -> f32 {
76 let (a, b) = if i < j { (i, j) } else { (j, i) };
77 self.data[a * self.n - a * (a + 1) / 2 + b - a - 1]
78 }
79}
80
81struct UnionFind {
83 parent: Vec<usize>,
84 rank: Vec<usize>,
85}
86
87impl UnionFind {
88 fn new(n: usize) -> Self {
89 Self {
90 parent: (0..n).collect(),
91 rank: vec![0; n],
92 }
93 }
94
95 fn find(&mut self, x: usize) -> usize {
96 if self.parent[x] != x {
97 self.parent[x] = self.find(self.parent[x]);
98 }
99 self.parent[x]
100 }
101
102 fn union(&mut self, x: usize, y: usize) -> bool {
103 let rx = self.find(x);
104 let ry = self.find(y);
105 if rx == ry {
106 return false;
107 }
108 match self.rank[rx].cmp(&self.rank[ry]) {
109 std::cmp::Ordering::Less => self.parent[rx] = ry,
110 std::cmp::Ordering::Greater => self.parent[ry] = rx,
111 std::cmp::Ordering::Equal => {
112 self.parent[ry] = rx;
113 self.rank[rx] += 1;
114 }
115 }
116 true
117 }
118}
119
120pub fn form_narrative_threads(
128 segments: &[EpisodeSegment],
129 _patterns: &DetectedPatterns,
130 config: &ConsolidationConfig,
131) -> Vec<NarrativeThread> {
132 if segments.is_empty() {
133 return Vec::new();
134 }
135 if segments.len() == 1 {
136 return vec![thread_from_segment_group(segments, &[0])];
137 }
138
139 let n = segments.len();
140 let matrix = CondensedMatrix::new(segments);
141
142 let mut edges: Vec<(f32, usize, usize)> = Vec::with_capacity(n * (n - 1) / 2);
144 for i in 0..n {
145 for j in (i + 1)..n {
146 let sim = matrix.get(i, j);
147 if sim >= config.thread_similarity_threshold {
148 edges.push((sim, i, j));
149 }
150 }
151 }
152 edges.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
153
154 let mut uf = UnionFind::new(n);
156 for &(_sim, i, j) in &edges {
157 uf.union(i, j);
158 }
159
160 let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
162 for i in 0..n {
163 cluster_map.entry(uf.find(i)).or_default().push(i);
164 }
165
166 let mut threads: Vec<NarrativeThread> = Vec::new();
168 for (_root, cluster) in &cluster_map {
169 let mut thread = thread_from_segment_group(segments, cluster);
170
171 if cluster.len() >= 4 {
173 let sub_threads = detect_sub_threads(segments, cluster, config, &matrix);
174 if sub_threads.len() > 1 {
175 thread.sub_threads = sub_threads;
176 }
177 }
178
179 threads.push(thread);
180 }
181
182 threads.sort_by_key(|thread| thread.start_time);
184
185 threads
186}
187
188fn thread_from_segment_group(segments: &[EpisodeSegment], indices: &[usize]) -> NarrativeThread {
190 let mut all_records: Vec<&EpisodicRecord> = Vec::new();
191 let mut all_entities: HashMap<String, usize> = HashMap::new();
192 let mut embeddings: Vec<&Vec<f32>> = Vec::new();
193
194 for &idx in indices {
195 let seg = &segments[idx];
196 for rec in &seg.records {
197 all_records.push(rec);
198 for ent in &rec.entities {
199 *all_entities.entry(ent.name.clone()).or_default() += 1;
200 }
201 }
202 if let Some(ref emb) = seg.topic_embedding {
203 embeddings.push(emb);
204 }
205 }
206
207 all_records.sort_by_key(|r| r.timestamp);
209
210 let mut entity_list: Vec<(String, usize)> = all_entities.into_iter().collect();
212 entity_list.sort_by_key(|item| std::cmp::Reverse(item.1));
213 let top_entities: Vec<String> = entity_list
214 .iter()
215 .take(3)
216 .map(|(name, _)| name.clone())
217 .collect();
218
219 let title = if top_entities.is_empty() {
220 "Unnamed Thread".to_string()
221 } else {
222 top_entities.join(", ")
223 };
224
225 let entities: Vec<String> = entity_list.into_iter().map(|(name, _)| name).collect();
226
227 let embedding = if embeddings.is_empty() {
229 None
230 } else {
231 let dims = embeddings[0].len();
232 let mut mean = vec![0.0f32; dims];
233 for emb in &embeddings {
234 for (i, v) in emb.iter().enumerate() {
235 mean[i] += v;
236 }
237 }
238 let n = embeddings.len() as f32;
239 for v in &mut mean {
240 *v /= n;
241 }
242 let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
243 if norm > 0.0 {
244 for v in &mut mean {
245 *v /= norm;
246 }
247 }
248 Some(mean)
249 };
250
251 let (start_time, end_time) = match (all_records.first(), all_records.last()) {
252 (Some(first), Some(last)) => (first.timestamp, last.timestamp),
253 _ => (Timestamp::default(), Timestamp::default()),
254 };
255
256 NarrativeThread {
257 title,
258 segment_indices: indices.to_vec(),
259 record_ids: all_records.iter().map(|r| r.id).collect(),
260 contents: all_records.iter().map(|r| r.content.clone()).collect(),
261 summaries: all_records.iter().map(|r| r.summary.clone()).collect(),
262 start_time,
263 end_time,
264 entities,
265 sub_threads: Vec::new(),
266 embedding,
267 }
268}
269
270fn detect_sub_threads(
276 segments: &[EpisodeSegment],
277 cluster: &[usize],
278 config: &ConsolidationConfig,
279 matrix: &CondensedMatrix,
280) -> Vec<NarrativeThread> {
281 let tighter_threshold = config.thread_similarity_threshold + 0.15;
282 let m = cluster.len();
283
284 let mut edges: Vec<(f32, usize, usize)> = Vec::new();
286 for ci in 0..m {
287 for cj in (ci + 1)..m {
288 let sim = matrix.get(cluster[ci], cluster[cj]);
289 if sim >= tighter_threshold {
290 edges.push((sim, ci, cj));
291 }
292 }
293 }
294 edges.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
295
296 let mut uf = UnionFind::new(m);
297 for &(_sim, i, j) in &edges {
298 uf.union(i, j);
299 }
300
301 let mut sub_map: HashMap<usize, Vec<usize>> = HashMap::new();
302 for ci in 0..m {
303 sub_map.entry(uf.find(ci)).or_default().push(cluster[ci]);
304 }
305
306 sub_map
307 .into_values()
308 .map(|c| thread_from_segment_group(segments, &c))
309 .collect()
310}