hirn_engine/consolidation/
pattern.rs1use super::*;
2
3#[derive(Debug, Clone)]
9pub struct Pattern {
10 pub entities: Vec<String>,
12 pub frequency: usize,
14 pub segment_indices: Vec<usize>,
16 pub diversity_score: f64,
18 pub representative_embedding: Option<Vec<f32>>,
20}
21
22#[derive(Debug, Clone)]
24pub struct TemporalPattern {
25 pub topic: String,
27 pub occurrences: Vec<usize>,
29 pub period_seconds: Option<i64>,
31 pub first_occurrence: Timestamp,
33 pub last_occurrence: Timestamp,
35}
36
37#[derive(Debug, Clone)]
39pub struct CausalPattern {
40 pub chain: Vec<String>,
42 pub occurrences: usize,
44 pub confidence: f32,
46}
47
48#[derive(Debug, Clone)]
50pub struct DetectedPatterns {
51 pub entity_patterns: Vec<Pattern>,
52 pub temporal_patterns: Vec<TemporalPattern>,
53 pub causal_patterns: Vec<CausalPattern>,
54}
55
56pub async fn detect_patterns(
58 segments: &[EpisodeSegment],
59 config: &ConsolidationConfig,
60 db: &HirnDB,
61) -> DetectedPatterns {
62 let entity_patterns = detect_entity_patterns(segments, config);
63 let temporal_patterns = detect_temporal_patterns(segments, config);
64 let causal_patterns = detect_causal_patterns(segments, db).await;
65
66 DetectedPatterns {
67 entity_patterns,
68 temporal_patterns,
69 causal_patterns,
70 }
71}
72
73pub(super) fn detect_entity_patterns(
75 segments: &[EpisodeSegment],
76 config: &ConsolidationConfig,
77) -> Vec<Pattern> {
78 let mut entity_segments: HashMap<String, Vec<usize>> = HashMap::new();
80 for (seg_idx, seg) in segments.iter().enumerate() {
81 let mut seen_entities: HashSet<String> = HashSet::new();
82 for rec in &seg.records {
83 for ent in &rec.entities {
84 seen_entities.insert(ent.name.clone());
85 }
86 }
87 for entity in seen_entities {
88 entity_segments.entry(entity).or_default().push(seg_idx);
89 }
90 }
91
92 let mut co_occurrence: HashMap<(String, String), Vec<usize>> = HashMap::new();
94 for (seg_idx, seg) in segments.iter().enumerate() {
95 let mut seg_entities: HashSet<String> = HashSet::new();
96 for rec in &seg.records {
97 for ent in &rec.entities {
98 seg_entities.insert(ent.name.clone());
99 }
100 }
101 let mut entities: Vec<String> = seg_entities.into_iter().collect();
102 entities.sort();
103 for i in 0..entities.len() {
104 for j in (i + 1)..entities.len() {
105 let pair = (entities[i].clone(), entities[j].clone());
106 co_occurrence.entry(pair).or_default().push(seg_idx);
107 }
108 }
109 }
110
111 let mut patterns = Vec::new();
112
113 for (entity, seg_indices) in &entity_segments {
115 if seg_indices.len() >= config.min_pattern_frequency {
116 let diversity = compute_diversity(segments, seg_indices);
117 let embedding = compute_pattern_embedding(segments, seg_indices);
118 patterns.push(Pattern {
119 entities: vec![entity.clone()],
120 frequency: seg_indices.len(),
121 segment_indices: seg_indices.clone(),
122 diversity_score: diversity,
123 representative_embedding: embedding,
124 });
125 }
126 }
127
128 for ((e1, e2), seg_indices) in &co_occurrence {
130 if seg_indices.len() >= config.min_pattern_frequency {
131 let diversity = compute_diversity(segments, seg_indices);
132 let embedding = compute_pattern_embedding(segments, seg_indices);
133 patterns.push(Pattern {
134 entities: vec![e1.clone(), e2.clone()],
135 frequency: seg_indices.len(),
136 segment_indices: seg_indices.clone(),
137 diversity_score: diversity,
138 representative_embedding: embedding,
139 });
140 }
141 }
142
143 patterns.sort_by(|a, b| {
145 let score_a = a.frequency as f64 * a.diversity_score;
146 let score_b = b.frequency as f64 * b.diversity_score;
147 score_b
148 .partial_cmp(&score_a)
149 .unwrap_or(std::cmp::Ordering::Equal)
150 });
151
152 patterns
153}
154
155fn compute_diversity(segments: &[EpisodeSegment], seg_indices: &[usize]) -> f64 {
158 if seg_indices.len() <= 1 || segments.is_empty() {
159 return 1.0;
160 }
161
162 let &first_idx = seg_indices.first().unwrap();
165 let &last_idx = seg_indices.last().unwrap();
166 if first_idx >= segments.len() || last_idx >= segments.len() {
167 return 1.0;
168 }
169
170 let first_seg_time = segments[first_idx].start_time.as_datetime();
171 let last_seg_time = segments[last_idx].end_time.as_datetime();
172
173 let total_first = segments.first().unwrap().start_time.as_datetime();
174 let total_last = segments.last().unwrap().end_time.as_datetime();
175
176 let pattern_span = last_seg_time
177 .signed_duration_since(first_seg_time)
178 .num_seconds() as f64;
179 let total_span = total_last
180 .signed_duration_since(total_first)
181 .num_seconds()
182 .max(1) as f64;
183
184 (pattern_span / total_span).clamp(0.0, 1.0)
185}
186
187fn compute_pattern_embedding(
189 segments: &[EpisodeSegment],
190 seg_indices: &[usize],
191) -> Option<Vec<f32>> {
192 let embeddings: Vec<&Vec<f32>> = seg_indices
193 .iter()
194 .filter_map(|&idx| segments.get(idx))
195 .filter_map(|seg| seg.topic_embedding.as_ref())
196 .collect();
197
198 if embeddings.is_empty() {
199 return None;
200 }
201
202 let dims = embeddings[0].len();
203 let mut mean = vec![0.0f32; dims];
204 for emb in &embeddings {
205 for (i, v) in emb.iter().enumerate() {
206 mean[i] += v;
207 }
208 }
209 let n = embeddings.len() as f32;
210 for v in &mut mean {
211 *v /= n;
212 }
213 let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
214 if norm > 0.0 {
215 for v in &mut mean {
216 *v /= norm;
217 }
218 }
219 Some(mean)
220}
221
222fn detect_temporal_patterns(
224 segments: &[EpisodeSegment],
225 config: &ConsolidationConfig,
226) -> Vec<TemporalPattern> {
227 let mut topic_occurrences: HashMap<String, Vec<usize>> = HashMap::new();
229 for (idx, seg) in segments.iter().enumerate() {
230 for entity in &seg.dominant_entities {
231 topic_occurrences
232 .entry(entity.clone())
233 .or_default()
234 .push(idx);
235 }
236 }
237
238 let mut patterns = Vec::new();
239
240 for (topic, occurrences) in &topic_occurrences {
241 if occurrences.len() < config.min_pattern_frequency {
242 continue;
243 }
244
245 let first = segments[occurrences[0]].start_time;
246 let Some(&last_idx) = occurrences.last() else {
247 continue;
248 };
249 let last = segments[last_idx].start_time;
250
251 let period = if occurrences.len() >= 3 {
253 estimate_period(segments, occurrences)
254 } else {
255 None
256 };
257
258 patterns.push(TemporalPattern {
259 topic: topic.clone(),
260 occurrences: occurrences.clone(),
261 period_seconds: period,
262 first_occurrence: first,
263 last_occurrence: last,
264 });
265 }
266
267 patterns
268}
269
270fn estimate_period(segments: &[EpisodeSegment], occurrences: &[usize]) -> Option<i64> {
272 if occurrences.len() < 3 {
273 return None;
274 }
275
276 let mut intervals: Vec<i64> = Vec::new();
277 for i in 1..occurrences.len() {
278 let prev_time = segments[occurrences[i - 1]].start_time.as_datetime();
279 let curr_time = segments[occurrences[i]].start_time.as_datetime();
280 let interval = curr_time.signed_duration_since(prev_time).num_seconds();
281 intervals.push(interval);
282 }
283
284 intervals.sort_unstable();
286 let median = intervals[intervals.len() / 2];
287
288 let mean = intervals.iter().sum::<i64>() as f64 / intervals.len() as f64;
290 if mean <= 0.0 {
291 return None;
292 }
293 let variance = intervals
294 .iter()
295 .map(|&x| {
296 let diff = x as f64 - mean;
297 diff * diff
298 })
299 .sum::<f64>()
300 / intervals.len() as f64;
301 let cv = variance.sqrt() / mean;
302
303 if cv < 0.5 {
304 Some(median)
305 } else {
306 None }
308}
309
310async fn detect_causal_patterns(segments: &[EpisodeSegment], db: &HirnDB) -> Vec<CausalPattern> {
312 let store = db.graph_store();
313 let mut chain_counts: HashMap<Vec<String>, usize> = HashMap::new();
314
315 for seg in segments {
316 for rec in &seg.records {
318 let causes_edges = store
319 .get_edges_of_type(rec.id, EdgeRelation::Causes)
320 .await
321 .unwrap_or_default();
322 if causes_edges.is_empty() {
323 continue;
324 }
325
326 let mut chain = vec![dominant_entity_name(rec)];
328 let mut current = rec.id;
329 let mut visited = HashSet::new();
330 visited.insert(current);
331
332 loop {
333 let edges = store
334 .get_edges_of_type(current, EdgeRelation::Causes)
335 .await
336 .unwrap_or_default();
337 let next = edges.into_iter().find(|e| !visited.contains(&e.target));
338 match next {
339 Some(edge) => {
340 visited.insert(edge.target);
341 let target_name = seg
343 .records
344 .iter()
345 .find(|r| r.id == edge.target)
346 .map_or_else(|| format!("{}", edge.target), dominant_entity_name);
347 chain.push(target_name);
348 current = edge.target;
349 }
350 None => break,
351 }
352 }
353
354 if chain.len() >= 2 {
355 *chain_counts.entry(chain).or_default() += 1;
356 }
357 }
358 }
359
360 chain_counts
361 .into_iter()
362 .filter(|(_, count)| *count >= 2)
363 .map(|(chain, occurrences)| CausalPattern {
364 confidence: (occurrences as f32 / 10.0).clamp(0.0, 1.0),
365 chain,
366 occurrences,
367 })
368 .collect()
369}
370
371fn dominant_entity_name(rec: &EpisodicRecord) -> String {
372 rec.entities.first().map_or_else(
373 || rec.content.chars().take(30).collect(),
374 |e| e.name.clone(),
375 )
376}