Skip to main content

hirn_engine/consolidation/
pattern.rs

1use super::*;
2
3// ═══════════════════════════════════════════════════════════════════════════
4// Pattern Detection
5// ═══════════════════════════════════════════════════════════════════════════
6
7/// A detected recurring pattern across segments.
8#[derive(Debug, Clone)]
9pub struct Pattern {
10    /// Entities involved in this pattern.
11    pub entities: Vec<String>,
12    /// Number of segments this pattern appears in.
13    pub frequency: usize,
14    /// Indices of segments containing this pattern.
15    pub segment_indices: Vec<usize>,
16    /// Diversity score: unique time spans covered.
17    pub diversity_score: f64,
18    /// Representative embedding (mean of segment embeddings where pattern appears).
19    pub representative_embedding: Option<Vec<f32>>,
20}
21
22/// A temporal pattern: recurring topic over time.
23#[derive(Debug, Clone)]
24pub struct TemporalPattern {
25    /// The dominant entity or topic.
26    pub topic: String,
27    /// Segment indices where this topic recurs.
28    pub occurrences: Vec<usize>,
29    /// Estimated period in seconds (if periodic), None if irregular.
30    pub period_seconds: Option<i64>,
31    /// First occurrence timestamp.
32    pub first_occurrence: Timestamp,
33    /// Last occurrence timestamp.
34    pub last_occurrence: Timestamp,
35}
36
37/// A recurring causal chain pattern.
38#[derive(Debug, Clone)]
39pub struct CausalPattern {
40    /// Entity names forming the causal chain.
41    pub chain: Vec<String>,
42    /// Number of times this chain was observed.
43    pub occurrences: usize,
44    /// Confidence based on consistency.
45    pub confidence: f32,
46}
47
48/// All detected patterns from a set of segments.
49#[derive(Debug, Clone)]
50pub struct DetectedPatterns {
51    pub entity_patterns: Vec<Pattern>,
52    pub temporal_patterns: Vec<TemporalPattern>,
53    pub causal_patterns: Vec<CausalPattern>,
54}
55
56/// Detect patterns across episode segments.
57pub async fn detect_patterns(
58    segments: &[EpisodeSegment],
59    config: &ConsolidationConfig,
60    db: &HirnDB,
61) -> DetectedPatterns {
62    let entity_patterns = detect_entity_patterns(segments, config);
63    let temporal_patterns = detect_temporal_patterns(segments, config);
64    let causal_patterns = detect_causal_patterns(segments, db).await;
65
66    DetectedPatterns {
67        entity_patterns,
68        temporal_patterns,
69        causal_patterns,
70    }
71}
72
73/// Detect entity frequency and co-occurrence patterns.
74pub(super) fn detect_entity_patterns(
75    segments: &[EpisodeSegment],
76    config: &ConsolidationConfig,
77) -> Vec<Pattern> {
78    // Count entity appearances across segments.
79    let mut entity_segments: HashMap<String, Vec<usize>> = HashMap::new();
80    for (seg_idx, seg) in segments.iter().enumerate() {
81        let mut seen_entities: HashSet<String> = HashSet::new();
82        for rec in &seg.records {
83            for ent in &rec.entities {
84                seen_entities.insert(ent.name.clone());
85            }
86        }
87        for entity in seen_entities {
88            entity_segments.entry(entity).or_default().push(seg_idx);
89        }
90    }
91
92    // Detect co-occurrence patterns (entities that consistently appear together).
93    let mut co_occurrence: HashMap<(String, String), Vec<usize>> = HashMap::new();
94    for (seg_idx, seg) in segments.iter().enumerate() {
95        let mut seg_entities: HashSet<String> = HashSet::new();
96        for rec in &seg.records {
97            for ent in &rec.entities {
98                seg_entities.insert(ent.name.clone());
99            }
100        }
101        let mut entities: Vec<String> = seg_entities.into_iter().collect();
102        entities.sort();
103        for i in 0..entities.len() {
104            for j in (i + 1)..entities.len() {
105                let pair = (entities[i].clone(), entities[j].clone());
106                co_occurrence.entry(pair).or_default().push(seg_idx);
107            }
108        }
109    }
110
111    let mut patterns = Vec::new();
112
113    // Single-entity patterns.
114    for (entity, seg_indices) in &entity_segments {
115        if seg_indices.len() >= config.min_pattern_frequency {
116            let diversity = compute_diversity(segments, seg_indices);
117            let embedding = compute_pattern_embedding(segments, seg_indices);
118            patterns.push(Pattern {
119                entities: vec![entity.clone()],
120                frequency: seg_indices.len(),
121                segment_indices: seg_indices.clone(),
122                diversity_score: diversity,
123                representative_embedding: embedding,
124            });
125        }
126    }
127
128    // Co-occurrence patterns.
129    for ((e1, e2), seg_indices) in &co_occurrence {
130        if seg_indices.len() >= config.min_pattern_frequency {
131            let diversity = compute_diversity(segments, seg_indices);
132            let embedding = compute_pattern_embedding(segments, seg_indices);
133            patterns.push(Pattern {
134                entities: vec![e1.clone(), e2.clone()],
135                frequency: seg_indices.len(),
136                segment_indices: seg_indices.clone(),
137                diversity_score: diversity,
138                representative_embedding: embedding,
139            });
140        }
141    }
142
143    // Sort by frequency × diversity (descending).
144    patterns.sort_by(|a, b| {
145        let score_a = a.frequency as f64 * a.diversity_score;
146        let score_b = b.frequency as f64 * b.diversity_score;
147        score_b
148            .partial_cmp(&score_a)
149            .unwrap_or(std::cmp::Ordering::Equal)
150    });
151
152    patterns
153}
154
155/// Compute diversity score for a pattern across segments.
156/// Measures the time span covered relative to total time span.
157fn compute_diversity(segments: &[EpisodeSegment], seg_indices: &[usize]) -> f64 {
158    if seg_indices.len() <= 1 || segments.is_empty() {
159        return 1.0;
160    }
161
162    // Safety: seg_indices.len() > 1 guarantees first()/last() succeed.
163    // Bounds-check indices against segments to guard against corruption.
164    let &first_idx = seg_indices.first().unwrap();
165    let &last_idx = seg_indices.last().unwrap();
166    if first_idx >= segments.len() || last_idx >= segments.len() {
167        return 1.0;
168    }
169
170    let first_seg_time = segments[first_idx].start_time.as_datetime();
171    let last_seg_time = segments[last_idx].end_time.as_datetime();
172
173    let total_first = segments.first().unwrap().start_time.as_datetime();
174    let total_last = segments.last().unwrap().end_time.as_datetime();
175
176    let pattern_span = last_seg_time
177        .signed_duration_since(first_seg_time)
178        .num_seconds() as f64;
179    let total_span = total_last
180        .signed_duration_since(total_first)
181        .num_seconds()
182        .max(1) as f64;
183
184    (pattern_span / total_span).clamp(0.0, 1.0)
185}
186
187/// Compute a representative embedding for a pattern.
188fn compute_pattern_embedding(
189    segments: &[EpisodeSegment],
190    seg_indices: &[usize],
191) -> Option<Vec<f32>> {
192    let embeddings: Vec<&Vec<f32>> = seg_indices
193        .iter()
194        .filter_map(|&idx| segments.get(idx))
195        .filter_map(|seg| seg.topic_embedding.as_ref())
196        .collect();
197
198    if embeddings.is_empty() {
199        return None;
200    }
201
202    let dims = embeddings[0].len();
203    let mut mean = vec![0.0f32; dims];
204    for emb in &embeddings {
205        for (i, v) in emb.iter().enumerate() {
206            mean[i] += v;
207        }
208    }
209    let n = embeddings.len() as f32;
210    for v in &mut mean {
211        *v /= n;
212    }
213    let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
214    if norm > 0.0 {
215        for v in &mut mean {
216            *v /= norm;
217        }
218    }
219    Some(mean)
220}
221
222/// Detect temporal patterns (recurring topics over time).
223fn detect_temporal_patterns(
224    segments: &[EpisodeSegment],
225    config: &ConsolidationConfig,
226) -> Vec<TemporalPattern> {
227    // Group segments by dominant entity.
228    let mut topic_occurrences: HashMap<String, Vec<usize>> = HashMap::new();
229    for (idx, seg) in segments.iter().enumerate() {
230        for entity in &seg.dominant_entities {
231            topic_occurrences
232                .entry(entity.clone())
233                .or_default()
234                .push(idx);
235        }
236    }
237
238    let mut patterns = Vec::new();
239
240    for (topic, occurrences) in &topic_occurrences {
241        if occurrences.len() < config.min_pattern_frequency {
242            continue;
243        }
244
245        let first = segments[occurrences[0]].start_time;
246        let Some(&last_idx) = occurrences.last() else {
247            continue;
248        };
249        let last = segments[last_idx].start_time;
250
251        // Estimate period if we have enough data points.
252        let period = if occurrences.len() >= 3 {
253            estimate_period(segments, occurrences)
254        } else {
255            None
256        };
257
258        patterns.push(TemporalPattern {
259            topic: topic.clone(),
260            occurrences: occurrences.clone(),
261            period_seconds: period,
262            first_occurrence: first,
263            last_occurrence: last,
264        });
265    }
266
267    patterns
268}
269
270/// Estimate period of recurring pattern using median inter-occurrence interval.
271fn estimate_period(segments: &[EpisodeSegment], occurrences: &[usize]) -> Option<i64> {
272    if occurrences.len() < 3 {
273        return None;
274    }
275
276    let mut intervals: Vec<i64> = Vec::new();
277    for i in 1..occurrences.len() {
278        let prev_time = segments[occurrences[i - 1]].start_time.as_datetime();
279        let curr_time = segments[occurrences[i]].start_time.as_datetime();
280        let interval = curr_time.signed_duration_since(prev_time).num_seconds();
281        intervals.push(interval);
282    }
283
284    // Compute median interval.
285    intervals.sort_unstable();
286    let median = intervals[intervals.len() / 2];
287
288    // Check if intervals are reasonably consistent (coefficient of variation < 0.5).
289    let mean = intervals.iter().sum::<i64>() as f64 / intervals.len() as f64;
290    if mean <= 0.0 {
291        return None;
292    }
293    let variance = intervals
294        .iter()
295        .map(|&x| {
296            let diff = x as f64 - mean;
297            diff * diff
298        })
299        .sum::<f64>()
300        / intervals.len() as f64;
301    let cv = variance.sqrt() / mean;
302
303    if cv < 0.5 {
304        Some(median)
305    } else {
306        None // Too irregular to be a pattern.
307    }
308}
309
310/// Detect recurring causal chain patterns using graph edges.
311async fn detect_causal_patterns(segments: &[EpisodeSegment], db: &HirnDB) -> Vec<CausalPattern> {
312    let store = db.graph_store();
313    let mut chain_counts: HashMap<Vec<String>, usize> = HashMap::new();
314
315    for seg in segments {
316        // For each record in the segment, follow causal edges to find chains.
317        for rec in &seg.records {
318            let causes_edges = store
319                .get_edges_of_type(rec.id, EdgeRelation::Causes)
320                .await
321                .unwrap_or_default();
322            if causes_edges.is_empty() {
323                continue;
324            }
325
326            // Build chain from this record following Causes edges.
327            let mut chain = vec![dominant_entity_name(rec)];
328            let mut current = rec.id;
329            let mut visited = HashSet::new();
330            visited.insert(current);
331
332            loop {
333                let edges = store
334                    .get_edges_of_type(current, EdgeRelation::Causes)
335                    .await
336                    .unwrap_or_default();
337                let next = edges.into_iter().find(|e| !visited.contains(&e.target));
338                match next {
339                    Some(edge) => {
340                        visited.insert(edge.target);
341                        // Try to find entity name from target record.
342                        let target_name = seg
343                            .records
344                            .iter()
345                            .find(|r| r.id == edge.target)
346                            .map_or_else(|| format!("{}", edge.target), dominant_entity_name);
347                        chain.push(target_name);
348                        current = edge.target;
349                    }
350                    None => break,
351                }
352            }
353
354            if chain.len() >= 2 {
355                *chain_counts.entry(chain).or_default() += 1;
356            }
357        }
358    }
359
360    chain_counts
361        .into_iter()
362        .filter(|(_, count)| *count >= 2)
363        .map(|(chain, occurrences)| CausalPattern {
364            confidence: (occurrences as f32 / 10.0).clamp(0.0, 1.0),
365            chain,
366            occurrences,
367        })
368        .collect()
369}
370
371fn dominant_entity_name(rec: &EpisodicRecord) -> String {
372    rec.entities.first().map_or_else(
373        || rec.content.chars().take(30).collect(),
374        |e| e.name.clone(),
375    )
376}