Skip to main content

hirn_engine/consolidation/
narrative.rs

1use super::*;
2
3// ═══════════════════════════════════════════════════════════════════════════
4// Narrative Thread Formation
5// ═══════════════════════════════════════════════════════════════════════════
6
7/// A narrative thread — a coherent "story" spanning multiple segments.
8#[derive(Debug, Clone)]
9pub struct NarrativeThread {
10    /// Auto-generated title from dominant entities/topics.
11    pub title: String,
12    /// Segments composing this thread.
13    pub segment_indices: Vec<usize>,
14    /// All record IDs in this thread (ordered by time).
15    pub record_ids: Vec<MemoryId>,
16    /// All content from records in this thread.
17    pub contents: Vec<String>,
18    /// All summaries from records in this thread.
19    pub summaries: Vec<String>,
20    /// Timeline: start to end.
21    pub start_time: Timestamp,
22    pub end_time: Timestamp,
23    /// Key entities participating in this thread.
24    pub entities: Vec<String>,
25    /// Sub-threads (if hierarchical splitting detected).
26    pub sub_threads: Vec<Self>,
27    /// Mean embedding for the thread.
28    pub embedding: Option<Vec<f32>>,
29}
30
31/// F-021 FIX: Precomputed condensed distance matrix for O(N²) pairwise similarity.
32/// Stores upper-triangular entries in row-major order:
33///   index(i,j) = i * n - i*(i+1)/2 + j - i - 1   for i < j
34struct CondensedMatrix {
35    data: Vec<f32>,
36    n: usize,
37}
38
39impl CondensedMatrix {
40    fn new(segments: &[EpisodeSegment]) -> Self {
41        let n = segments.len();
42        let size = n * (n - 1) / 2;
43        let mut data = Vec::with_capacity(size);
44
45        // Pre-compute entity sets once to avoid re-allocating HashSets.
46        let entity_sets: Vec<HashSet<&str>> = segments
47            .iter()
48            .map(|s| s.dominant_entities.iter().map(String::as_str).collect())
49            .collect();
50
51        for i in 0..n {
52            for j in (i + 1)..n {
53                let embedding_sim =
54                    match (&segments[i].topic_embedding, &segments[j].topic_embedding) {
55                        (Some(ea), Some(eb)) => {
56                            1.0 - lance_linalg::distance::cosine_distance(ea, eb)
57                        }
58                        _ => 0.0,
59                    };
60                let intersection = entity_sets[i].intersection(&entity_sets[j]).count();
61                let union = entity_sets[i].union(&entity_sets[j]).count();
62                let entity_sim = if union > 0 {
63                    intersection as f32 / union as f32
64                } else {
65                    0.0
66                };
67                data.push(embedding_sim * 0.6 + entity_sim * 0.4);
68            }
69        }
70
71        Self { data, n }
72    }
73
74    #[inline]
75    fn get(&self, i: usize, j: usize) -> f32 {
76        let (a, b) = if i < j { (i, j) } else { (j, i) };
77        self.data[a * self.n - a * (a + 1) / 2 + b - a - 1]
78    }
79}
80
81/// F-021 FIX: Union-Find with path compression and union by rank.
82struct UnionFind {
83    parent: Vec<usize>,
84    rank: Vec<usize>,
85}
86
87impl UnionFind {
88    fn new(n: usize) -> Self {
89        Self {
90            parent: (0..n).collect(),
91            rank: vec![0; n],
92        }
93    }
94
95    fn find(&mut self, x: usize) -> usize {
96        if self.parent[x] != x {
97            self.parent[x] = self.find(self.parent[x]);
98        }
99        self.parent[x]
100    }
101
102    fn union(&mut self, x: usize, y: usize) -> bool {
103        let rx = self.find(x);
104        let ry = self.find(y);
105        if rx == ry {
106            return false;
107        }
108        match self.rank[rx].cmp(&self.rank[ry]) {
109            std::cmp::Ordering::Less => self.parent[rx] = ry,
110            std::cmp::Ordering::Greater => self.parent[ry] = rx,
111            std::cmp::Ordering::Equal => {
112                self.parent[ry] = rx;
113                self.rank[rx] += 1;
114            }
115        }
116        true
117    }
118}
119
120/// Form narrative threads from segments using single-linkage clustering
121/// with a precomputed condensed distance matrix.
122///
123/// **F-021 FIX:** Replaced O(N⁴) hierarchical agglomerative clustering with
124/// O(N² log N) sorted-edge single-linkage via union-find. The pairwise
125/// similarity matrix is computed once (O(N²·D)), sorted once (O(N² log N)),
126/// then edges are greedily merged above the threshold.
127pub fn form_narrative_threads(
128    segments: &[EpisodeSegment],
129    _patterns: &DetectedPatterns,
130    config: &ConsolidationConfig,
131) -> Vec<NarrativeThread> {
132    if segments.is_empty() {
133        return Vec::new();
134    }
135    if segments.len() == 1 {
136        return vec![thread_from_segment_group(segments, &[0])];
137    }
138
139    let n = segments.len();
140    let matrix = CondensedMatrix::new(segments);
141
142    // Collect all (similarity, i, j) edges and sort descending.
143    let mut edges: Vec<(f32, usize, usize)> = Vec::with_capacity(n * (n - 1) / 2);
144    for i in 0..n {
145        for j in (i + 1)..n {
146            let sim = matrix.get(i, j);
147            if sim >= config.thread_similarity_threshold {
148                edges.push((sim, i, j));
149            }
150        }
151    }
152    edges.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
153
154    // Single-linkage merge via union-find.
155    let mut uf = UnionFind::new(n);
156    for &(_sim, i, j) in &edges {
157        uf.union(i, j);
158    }
159
160    // Collect clusters from union-find roots.
161    let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
162    for i in 0..n {
163        cluster_map.entry(uf.find(i)).or_default().push(i);
164    }
165
166    // Convert clusters to narrative threads.
167    let mut threads: Vec<NarrativeThread> = Vec::new();
168    for (_root, cluster) in &cluster_map {
169        let mut thread = thread_from_segment_group(segments, cluster);
170
171        // Sub-thread detection for large clusters.
172        if cluster.len() >= 4 {
173            let sub_threads = detect_sub_threads(segments, cluster, config, &matrix);
174            if sub_threads.len() > 1 {
175                thread.sub_threads = sub_threads;
176            }
177        }
178
179        threads.push(thread);
180    }
181
182    // Sort threads by start time.
183    threads.sort_by_key(|thread| thread.start_time);
184
185    threads
186}
187
188/// Create a narrative thread from a group of segment indices.
189fn thread_from_segment_group(segments: &[EpisodeSegment], indices: &[usize]) -> NarrativeThread {
190    let mut all_records: Vec<&EpisodicRecord> = Vec::new();
191    let mut all_entities: HashMap<String, usize> = HashMap::new();
192    let mut embeddings: Vec<&Vec<f32>> = Vec::new();
193
194    for &idx in indices {
195        let seg = &segments[idx];
196        for rec in &seg.records {
197            all_records.push(rec);
198            for ent in &rec.entities {
199                *all_entities.entry(ent.name.clone()).or_default() += 1;
200            }
201        }
202        if let Some(ref emb) = seg.topic_embedding {
203            embeddings.push(emb);
204        }
205    }
206
207    // Sort records by timestamp.
208    all_records.sort_by_key(|r| r.timestamp);
209
210    // Get top entities for title.
211    let mut entity_list: Vec<(String, usize)> = all_entities.into_iter().collect();
212    entity_list.sort_by_key(|item| std::cmp::Reverse(item.1));
213    let top_entities: Vec<String> = entity_list
214        .iter()
215        .take(3)
216        .map(|(name, _)| name.clone())
217        .collect();
218
219    let title = if top_entities.is_empty() {
220        "Unnamed Thread".to_string()
221    } else {
222        top_entities.join(", ")
223    };
224
225    let entities: Vec<String> = entity_list.into_iter().map(|(name, _)| name).collect();
226
227    // Mean embedding.
228    let embedding = if embeddings.is_empty() {
229        None
230    } else {
231        let dims = embeddings[0].len();
232        let mut mean = vec![0.0f32; dims];
233        for emb in &embeddings {
234            for (i, v) in emb.iter().enumerate() {
235                mean[i] += v;
236            }
237        }
238        let n = embeddings.len() as f32;
239        for v in &mut mean {
240            *v /= n;
241        }
242        let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
243        if norm > 0.0 {
244            for v in &mut mean {
245                *v /= norm;
246            }
247        }
248        Some(mean)
249    };
250
251    let (start_time, end_time) = match (all_records.first(), all_records.last()) {
252        (Some(first), Some(last)) => (first.timestamp, last.timestamp),
253        _ => (Timestamp::default(), Timestamp::default()),
254    };
255
256    NarrativeThread {
257        title,
258        segment_indices: indices.to_vec(),
259        record_ids: all_records.iter().map(|r| r.id).collect(),
260        contents: all_records.iter().map(|r| r.content.clone()).collect(),
261        summaries: all_records.iter().map(|r| r.summary.clone()).collect(),
262        start_time,
263        end_time,
264        entities,
265        sub_threads: Vec::new(),
266        embedding,
267    }
268}
269
270/// Detect sub-threads within a cluster by re-clustering with a tighter threshold.
271///
272/// **F-021 FIX:** Reuses the precomputed `CondensedMatrix` instead of
273/// recomputing O(C²·|A|·|B|) similarities. Union-find single-linkage
274/// on within-cluster pairs.
275fn detect_sub_threads(
276    segments: &[EpisodeSegment],
277    cluster: &[usize],
278    config: &ConsolidationConfig,
279    matrix: &CondensedMatrix,
280) -> Vec<NarrativeThread> {
281    let tighter_threshold = config.thread_similarity_threshold + 0.15;
282    let m = cluster.len();
283
284    // Collect within-cluster edges that exceed the tighter threshold.
285    let mut edges: Vec<(f32, usize, usize)> = Vec::new();
286    for ci in 0..m {
287        for cj in (ci + 1)..m {
288            let sim = matrix.get(cluster[ci], cluster[cj]);
289            if sim >= tighter_threshold {
290                edges.push((sim, ci, cj));
291            }
292        }
293    }
294    edges.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
295
296    let mut uf = UnionFind::new(m);
297    for &(_sim, i, j) in &edges {
298        uf.union(i, j);
299    }
300
301    let mut sub_map: HashMap<usize, Vec<usize>> = HashMap::new();
302    for ci in 0..m {
303        sub_map.entry(uf.find(ci)).or_default().push(cluster[ci]);
304    }
305
306    sub_map
307        .into_values()
308        .map(|c| thread_from_segment_group(segments, &c))
309        .collect()
310}