Skip to main content

hirn_engine/consolidation/
pipeline.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use crate::graph_store::GraphStore;
5use tracing::Instrument;
6
7use super::*;
8
9// ═══════════════════════════════════════════════════════════════════════════
10// Consolidation Pipeline
11// ═══════════════════════════════════════════════════════════════════════════
12
13/// Result from running the consolidation pipeline.
14#[derive(Debug, Clone)]
15pub struct ConsolidationResult {
16    /// Number of episodic records processed.
17    pub records_processed: usize,
18    /// Number of segments created.
19    pub segments_created: usize,
20    /// Number of patterns detected.
21    pub patterns_detected: usize,
22    /// Number of causal edges discovered via temporal co-occurrence (Granger-like).
23    pub causal_edges_discovered: usize,
24    /// Number of narrative threads formed.
25    pub threads_formed: usize,
26    /// Number of communities detected.
27    pub communities_detected: usize,
28    /// Number of community summaries stored.
29    pub community_summaries_stored: usize,
30    /// Number of community-related edges created.
31    pub community_edges_created: usize,
32    /// Number of RAPTOR hierarchical summaries stored.
33    pub raptor_summaries_stored: usize,
34    /// Number of RAPTOR tree levels created.
35    pub raptor_levels_created: usize,
36    /// Number of RAPTOR provenance edges created.
37    pub raptor_edges_created: usize,
38    /// Number of semantic records created.
39    pub concepts_extracted: usize,
40    /// Number of `derived_from` edges created.
41    pub provenance_edges_created: usize,
42    /// Number of episodes archived (if `archive_after_consolidation` was true).
43    pub episodes_archived: usize,
44    /// Execution time in milliseconds.
45    pub execution_time_ms: f64,
46}
47
48impl ConsolidationResult {
49    /// Returns true when the consolidation run produced durable state changes
50    /// that can be credited against an outstanding interference backlog.
51    pub const fn made_progress(&self) -> bool {
52        self.causal_edges_discovered > 0
53            || self.community_summaries_stored > 0
54            || self.community_edges_created > 0
55            || self.raptor_summaries_stored > 0
56            || self.raptor_levels_created > 0
57            || self.raptor_edges_created > 0
58            || self.concepts_extracted > 0
59            || self.provenance_edges_created > 0
60            || self.episodes_archived > 0
61    }
62}
63
64/// Builder for the consolidation pipeline.
65pub struct ConsolidateBuilder<'a> {
66    db: &'a HirnDB,
67    config: ConsolidationConfig,
68    where_conditions: Vec<WhereFilter>,
69    llm: Option<Arc<dyn hirn_core::embed::LlmProvider>>,
70    /// Agent ID for Cedar policy enforcement.
71    agent_id: Option<String>,
72}
73
74/// A simple WHERE filter for consolidation.
75#[derive(Debug, Clone)]
76pub struct WhereFilter {
77    pub field: String,
78    pub op: FilterOp,
79    pub value: f64,
80}
81
82#[derive(Debug, Clone, Copy)]
83pub enum FilterOp {
84    Gt,
85    Lt,
86    Gte,
87    Lte,
88    Eq,
89}
90
91impl<'a> ConsolidateBuilder<'a> {
92    pub(crate) fn new(db: &'a HirnDB) -> Self {
93        Self {
94            db,
95            config: ConsolidationConfig::default(),
96            where_conditions: Vec::new(),
97            llm: None,
98            agent_id: None,
99        }
100    }
101
102    /// Set the topic similarity threshold.
103    #[must_use]
104    pub const fn topic_threshold(mut self, threshold: f32) -> Self {
105        self.config.topic_similarity_threshold = threshold;
106        self
107    }
108
109    /// Set the surprise threshold.
110    #[must_use]
111    pub const fn surprise_threshold(mut self, threshold: f32) -> Self {
112        self.config.surprise_threshold = threshold;
113        self
114    }
115
116    /// Set the temporal gap in seconds.
117    #[must_use]
118    pub const fn temporal_gap(mut self, seconds: i64) -> Self {
119        self.config.temporal_gap_seconds = seconds;
120        self
121    }
122
123    /// Set whether to archive source episodes after consolidation.
124    #[must_use]
125    pub const fn archive(mut self, archive: bool) -> Self {
126        self.config.archive_after_consolidation = archive;
127        self
128    }
129
130    /// Set the thread similarity threshold.
131    #[must_use]
132    pub const fn thread_threshold(mut self, threshold: f32) -> Self {
133        self.config.thread_similarity_threshold = threshold;
134        self
135    }
136
137    /// Set a full config.
138    #[must_use]
139    pub const fn config(mut self, config: ConsolidationConfig) -> Self {
140        self.config = config;
141        self
142    }
143
144    /// Add a WHERE condition to filter episodes.
145    #[must_use]
146    pub fn where_condition(mut self, field: &str, op: FilterOp, value: f64) -> Self {
147        self.where_conditions.push(WhereFilter {
148            field: field.to_string(),
149            op,
150            value,
151        });
152        self
153    }
154
155    /// Set an LLM provider for community summary generation.
156    #[must_use]
157    pub fn llm(mut self, llm: Arc<dyn hirn_core::embed::LlmProvider>) -> Self {
158        self.llm = Some(llm);
159        self
160    }
161
162    /// Enable RAPTOR hierarchical summarization.
163    #[must_use]
164    pub fn raptor(mut self, enabled: bool) -> Self {
165        self.config.raptor_enabled = enabled;
166        self
167    }
168
169    /// Set the agent ID for Cedar policy enforcement.
170    #[must_use]
171    pub fn agent_id(mut self, id: impl Into<String>) -> Self {
172        self.agent_id = Some(id.into());
173        self
174    }
175
176    /// Execute the consolidation pipeline.
177    pub async fn execute(self) -> HirnResult<ConsolidationResult> {
178        // Cedar policy enforcement.
179        let agent = self.agent_id.as_deref().unwrap_or("anonymous");
180        self.db
181            .enforce(
182                agent,
183                crate::policy::Action::Consolidate,
184                &self.db.config().default_realm,
185                "",
186            )
187            .await?;
188
189        execute_consolidation_pipeline(
190            self.db,
191            &self.config,
192            &self.where_conditions,
193            self.llm.as_ref(),
194        )
195        .await
196    }
197}
198
199/// Execute the full consolidation pipeline.
200pub async fn execute_consolidation_pipeline(
201    db: &HirnDB,
202    config: &ConsolidationConfig,
203    where_filters: &[WhereFilter],
204    llm: Option<&Arc<dyn hirn_core::embed::LlmProvider>>,
205) -> HirnResult<ConsolidationResult> {
206    // F-111: wrap entire pipeline with a total timeout so a series of slow LLM
207    // calls (e.g. RAPTOR 3 levels × 5 clusters = 15 calls at default 10 s each)
208    // cannot hold the consolidation lock indefinitely.
209    let result = tokio::time::timeout(
210        config.total_consolidation_timeout,
211        execute_consolidation_pipeline_inner(db, config, where_filters, llm)
212            .instrument(tracing::info_span!("consolidate")),
213    )
214    .await
215    .unwrap_or_else(|_| {
216        tracing::warn!(
217            timeout_secs = config.total_consolidation_timeout.as_secs(),
218            "consolidation pipeline exceeded total_consolidation_timeout; aborting pass"
219        );
220        Err(HirnError::Timeout(format!(
221            "consolidation exceeded {}s total_consolidation_timeout",
222            config.total_consolidation_timeout.as_secs()
223        )))
224    });
225
226    match &result {
227        Ok(result) => {
228            db.write_runtime().record_consolidation_success(result);
229        }
230        Err(_) => {
231            db.write_runtime().record_consolidation_failure();
232        }
233    }
234
235    result
236}
237
238async fn execute_consolidation_pipeline_inner(
239    db: &HirnDB,
240    config: &ConsolidationConfig,
241    where_filters: &[WhereFilter],
242    llm: Option<&Arc<dyn hirn_core::embed::LlmProvider>>,
243) -> HirnResult<ConsolidationResult> {
244    config.validate()?;
245    let start = Instant::now();
246
247    // Cursor-based incremental scan: only process episodes written after the
248    // last successful consolidation run.  This prevents reprocessing already-
249    // consolidated records on every pass and bounds the working-set size.
250    let cursor_ms = db.write_runtime().consolidation_cursor_ms();
251    let after_cursor = if cursor_ms > 0 {
252        Some(hirn_core::Timestamp::from_millis(cursor_ms))
253    } else {
254        None
255    };
256
257    // F-18 / F-95 FIX: Retrieve episodes in bounded batches to prevent OOM.
258    // Default batch size reduced from 10,000 → 1,000 (~4 MB working set).
259    let filter = crate::db::EpisodicFilter {
260        include_archived: false,
261        after: after_cursor,
262        limit: Some(config.consolidation_batch_size),
263        ..Default::default()
264    };
265    let mut episodes = db.list_episodes(&filter).await?;
266
267    // Apply WHERE filters.
268    if !where_filters.is_empty() {
269        episodes.retain(|ep| {
270            where_filters
271                .iter()
272                .all(|wf| episode_matches_filter(ep, wf))
273        });
274    }
275
276    if episodes.is_empty() {
277        // No new episodes since the cursor — but still run archive + provenance-repair
278        // for any existing semantic concepts whose source episodes haven't been archived
279        // yet or whose DerivedFrom edges were removed.
280        let (episodes_archived, provenance_edges_created) =
281            run_rerun_repair_pass(db, config).await;
282        return Ok(ConsolidationResult {
283            records_processed: 0,
284            segments_created: 0,
285            patterns_detected: 0,
286            causal_edges_discovered: 0,
287            threads_formed: 0,
288            communities_detected: 0,
289            community_summaries_stored: 0,
290            community_edges_created: 0,
291            raptor_summaries_stored: 0,
292            raptor_levels_created: 0,
293            raptor_edges_created: 0,
294            concepts_extracted: 0,
295            provenance_edges_created,
296            episodes_archived,
297            execution_time_ms: start.elapsed().as_secs_f64() * 1000.0,
298        });
299    }
300
301    // Sort episodes by timestamp.
302    episodes.sort_by_key(|e| e.timestamp);
303
304    let records_processed = episodes.len();
305
306    // 2. Segment.
307    let segments = segment_episodes(&episodes, config);
308    let segments_created = segments.len();
309
310    // 3. Detect patterns.
311    let patterns = detect_patterns(&segments, config, db).await;
312    let patterns_detected = patterns.entity_patterns.len()
313        + patterns.temporal_patterns.len()
314        + patterns.causal_patterns.len();
315
316    // 3.5. Causal discovery — discover new causal edges from temporal co-occurrence.
317    let causal_edges_discovered = discover_causal_edges(&episodes, db).await;
318
319    // 4. Form narrative threads.
320    let threads = form_narrative_threads(&segments, &patterns, config);
321    let threads_formed = threads.len();
322
323    // 4.5. Community detection on the persistent graph.
324    let community_config = CommunityConfig::default();
325    let community_result = detect_communities(db.graph_store(), &community_config).await?;
326    let communities_detected = if community_result.levels.is_empty() {
327        0
328    } else {
329        community_result.levels[0].len()
330    };
331
332    // 4.6. Generate community summaries (Stage 3.6) if LLM is available.
333    // F-058 FIX: Use incremental path when a previous community result is cached,
334    // skipping LLM summarization for unchanged communities.
335    let (community_summaries_stored, community_edges_created) = if let Some(llm) = llm {
336        let prev = db.take_cached_community_result();
337        let summary_result = if let Some(ref prev) = prev {
338            generate_community_summaries_incremental(
339                db,
340                llm,
341                prev,
342                &community_result,
343                50,
344                config.llm_timeout,
345            )
346            .await?
347        } else {
348            generate_community_summaries(db, llm, &community_result, 50, config.llm_timeout).await?
349        };
350        (
351            summary_result.summaries_stored,
352            summary_result.edges_created,
353        )
354    } else {
355        (0, 0)
356    };
357
358    // Cache the community result for incremental use in the next consolidation.
359    db.set_cached_community_result(community_result);
360
361    // 4.7. RAPTOR hierarchical summarization (R-008).
362    // Build a multi-level summary tree over semantic records for top-down retrieval.
363    let (raptor_summaries_stored, raptor_levels_created, raptor_edges_created) =
364        if config.raptor_enabled {
365            if let Some(llm) = llm {
366                let raptor_result = build_raptor_tree(db, llm, config).await?;
367                (
368                    raptor_result.summaries_stored,
369                    raptor_result.levels_created,
370                    raptor_result.edges_created,
371                )
372            } else {
373                (0, 0, 0)
374            }
375        } else {
376            (0, 0, 0)
377        };
378
379    // 5. Extract concepts (F-047: use LLM when available, heuristic fallback).
380    let concepts = extract_concepts(&threads, db, llm, config.llm_timeout).await;
381
382    // 6. Store concepts as semantic records via single batch append + provenance edges.
383    //    Provenance edges are only created after the batch write succeeds (transactional).
384    let agent = AgentId::well_known("consolidation");
385    let mut concepts_extracted = 0;
386    let mut provenance_edges_created = 0;
387
388    struct PendingConceptRecord {
389        record: SemanticRecord,
390        source_episode_ids: Vec<MemoryId>,
391    }
392
393    struct ResolvedConceptRecord {
394        semantic_id: MemoryId,
395        source_episode_ids: Vec<MemoryId>,
396    }
397
398    // 6a. Build all SemanticRecord objects while preserving rerun repair targets.
399    let mut pending_records: Vec<PendingConceptRecord> = Vec::new();
400    let mut resolved_records: Vec<ResolvedConceptRecord> = Vec::new();
401
402    for concept in &concepts {
403        // Reruns must continue provenance/archive repair even when the
404        // semantic concept already exists.
405        if let Ok(existing) = db.get_semantic_by_concept(&concept.concept_name).await {
406            let mut source_episode_ids = existing.source_episodes.clone();
407            source_episode_ids.extend(concept.source_episode_ids.iter().copied());
408            source_episode_ids.sort();
409            source_episode_ids.dedup();
410
411            resolved_records.push(ResolvedConceptRecord {
412                semantic_id: existing.id,
413                source_episode_ids,
414            });
415            continue;
416        }
417
418        let mut builder = SemanticRecord::builder()
419            .concept(&concept.concept_name)
420            .knowledge_type(concept.knowledge_type)
421            .description(&concept.description)
422            .confidence(concept.confidence)
423            .agent_id(agent.clone())
424            .origin(Origin::Consolidation);
425
426        if let Some(ref emb) = concept.embedding {
427            builder = builder.embedding(emb.clone());
428        }
429
430        for &source_id in &concept.source_episode_ids {
431            builder = builder.source_episode(source_id);
432        }
433
434        for &contra_id in &concept.contradiction_ids {
435            builder = builder.contradiction(contra_id);
436        }
437
438        let record = builder.build()?;
439        pending_records.push(PendingConceptRecord {
440            record,
441            source_episode_ids: concept.source_episode_ids.clone(),
442        });
443    }
444
445    // 6b. Single batch write — no partial summaries or orphaned edges on failure.
446    if !pending_records.is_empty() {
447        let records_to_store = pending_records
448            .iter()
449            .map(|pending| pending.record.clone())
450            .collect::<Vec<_>>();
451        let batch_results = db.batch_store_semantic(records_to_store).await;
452
453        for (result, pending) in batch_results.into_iter().zip(&pending_records) {
454            if let Ok(semantic_id) = result {
455                concepts_extracted += 1;
456                resolved_records.push(ResolvedConceptRecord {
457                    semantic_id,
458                    source_episode_ids: pending.source_episode_ids.clone(),
459                });
460            }
461        }
462    }
463
464    // 6c. Create or repair provenance edges for both newly written and
465    // previously existing consolidation concepts.
466    let mut consolidated_ids = HashSet::new();
467    for resolved in &resolved_records {
468        let mut existing_targets = match db
469            .cached_graph()
470            .get_edges_of_type(resolved.semantic_id, EdgeRelation::DerivedFrom)
471            .await
472        {
473            Ok(edges) => edges
474                .into_iter()
475                .filter_map(|edge| {
476                    if edge.source == resolved.semantic_id {
477                        Some(edge.target)
478                    } else if edge.target == resolved.semantic_id {
479                        Some(edge.source)
480                    } else {
481                        None
482                    }
483                })
484                .collect::<HashSet<_>>(),
485            Err(error) => {
486                tracing::warn!(
487                    semantic_id = %resolved.semantic_id,
488                    error = %error,
489                    "failed to inspect existing consolidation provenance edges"
490                );
491                HashSet::new()
492            }
493        };
494
495        for &source_id in &resolved.source_episode_ids {
496            consolidated_ids.insert(source_id);
497            if existing_targets.contains(&source_id) {
498                continue;
499            }
500
501            match db
502                .connect_with(
503                    resolved.semantic_id,
504                    source_id,
505                    EdgeRelation::DerivedFrom,
506                    1.0,
507                    Metadata::default(),
508                )
509                .await
510            {
511                Ok(_) => {
512                    provenance_edges_created += 1;
513                    existing_targets.insert(source_id);
514                }
515                Err(hirn_core::HirnError::AlreadyExists(error)) => {
516                    let repaired = match db
517                        .cached_graph()
518                        .get_edges_between(resolved.semantic_id, source_id)
519                        .await
520                    {
521                        Ok(edges) => edges.iter().any(|edge| {
522                            edge.relation == EdgeRelation::DerivedFrom
523                                && edge.source == resolved.semantic_id
524                                && edge.target == source_id
525                        }),
526                        Err(graph_error) => {
527                            tracing::warn!(
528                                semantic_id = %resolved.semantic_id,
529                                source_id = %source_id,
530                                error = %graph_error,
531                                "failed to verify consolidation provenance repair after duplicate edge write"
532                            );
533                            false
534                        }
535                    };
536
537                    if repaired {
538                        provenance_edges_created += 1;
539                        existing_targets.insert(source_id);
540                    } else {
541                        tracing::warn!(
542                            semantic_id = %resolved.semantic_id,
543                            source_id = %source_id,
544                            error = %error,
545                            "duplicate consolidation provenance edge write did not leave a repaired edge"
546                        );
547                    }
548                }
549                Err(error) => {
550                    tracing::warn!(
551                        semantic_id = %resolved.semantic_id,
552                        source_id = %source_id,
553                        error = %error,
554                        "failed to create consolidation provenance edge"
555                    );
556                }
557            }
558        }
559    }
560
561    // 7. Archive source episodes if configured.
562    let mut episodes_archived = 0;
563    if config.archive_after_consolidation && !consolidated_ids.is_empty() {
564        for id in consolidated_ids {
565            if db.archive_episode(id).await.is_ok() {
566                episodes_archived += 1;
567            }
568        }
569    }
570
571    // Advance the incremental consolidation cursor to the timestamp of the
572    // newest episode processed in this batch so the next run skips them.
573    if let Some(max_ts) = episodes.iter().map(|e| e.timestamp.millis()).max() {
574        db.write_runtime().advance_consolidation_cursor(max_ts);
575    }
576
577    let execution_time_ms = start.elapsed().as_secs_f64() * 1000.0;
578    metrics::histogram!(crate::metrics::CONSOLIDATION_DURATION_SECONDS)
579        .record(start.elapsed().as_secs_f64());
580    metrics::counter!(crate::metrics::CONSOLIDATION_TOTAL).increment(1);
581
582    db.emit(crate::event::MemoryEvent::Consolidated { records_processed })
583        .await;
584
585    Ok(ConsolidationResult {
586        records_processed,
587        segments_created,
588        patterns_detected,
589        causal_edges_discovered,
590        threads_formed,
591        communities_detected,
592        community_summaries_stored,
593        community_edges_created,
594        raptor_summaries_stored,
595        raptor_levels_created,
596        raptor_edges_created,
597        concepts_extracted,
598        provenance_edges_created,
599        episodes_archived,
600        execution_time_ms,
601    })
602}
603
604pub(super) fn episode_matches_filter(ep: &EpisodicRecord, filter: &WhereFilter) -> bool {
605    let val = match filter.field.as_str() {
606        "importance" => f64::from(ep.importance),
607        "surprise" => f64::from(ep.surprise),
608        "access_count" | "episodic.access_count" => ep.access_count as f64,
609        _ => return true,
610    };
611
612    match filter.op {
613        FilterOp::Gt => val > filter.value,
614        FilterOp::Lt => val < filter.value,
615        FilterOp::Gte => val >= filter.value,
616        FilterOp::Lte => val <= filter.value,
617        FilterOp::Eq => (val - filter.value).abs() < f64::EPSILON,
618    }
619}
620
621/// Archive + provenance-repair pass for consolidation reruns.
622///
623/// When the incremental cursor has advanced past all episodes, the main pipeline
624/// returns early without running the archive/provenance steps.  This helper runs
625/// those two operations over ALL existing semantic records so that:
626///
627/// 1. Source episodes that were consolidated in a previous pass can be archived
628///    when `archive_after_consolidation` is true.
629/// 2. Any `DerivedFrom` graph edges that were removed after the original
630///    consolidation run are recreated.
631///
632/// Returns `(episodes_archived, provenance_edges_created)`.
633async fn run_rerun_repair_pass(db: &HirnDB, config: &ConsolidationConfig) -> (usize, usize) {
634    let semantics = match db
635        .list_semantics(&crate::db::SemanticFilter::default())
636        .await
637    {
638        Ok(s) => s,
639        Err(error) => {
640            tracing::warn!(
641                error = %error,
642                "rerun repair pass: failed to load semantic records"
643            );
644            return (0, 0);
645        }
646    };
647
648    let mut episodes_archived = 0usize;
649    let mut provenance_edges_created = 0usize;
650
651    for sem in &semantics {
652        // Collect existing DerivedFrom targets so we only create missing ones.
653        let existing_targets = match db
654            .cached_graph()
655            .get_edges_of_type(sem.id, EdgeRelation::DerivedFrom)
656            .await
657        {
658            Ok(edges) => edges
659                .into_iter()
660                .filter_map(|edge| {
661                    if edge.source == sem.id {
662                        Some(edge.target)
663                    } else if edge.target == sem.id {
664                        Some(edge.source)
665                    } else {
666                        None
667                    }
668                })
669                .collect::<HashSet<_>>(),
670            Err(_) => HashSet::new(),
671        };
672
673        for &source_id in &sem.source_episodes {
674            // Archive only when configured.
675            if config.archive_after_consolidation && db.archive_episode(source_id).await.is_ok() {
676                episodes_archived += 1;
677            }
678
679            // Repair the provenance edge unconditionally.
680            if !existing_targets.contains(&source_id) {
681                match db
682                    .connect_with(
683                        sem.id,
684                        source_id,
685                        EdgeRelation::DerivedFrom,
686                        1.0,
687                        Metadata::default(),
688                    )
689                    .await
690                {
691                    Ok(_) | Err(HirnError::AlreadyExists(_)) => {
692                        provenance_edges_created += 1;
693                    }
694                    Err(error) => {
695                        tracing::warn!(
696                            semantic_id = %sem.id,
697                            source_id = %source_id,
698                            error = %error,
699                            "rerun repair pass: failed to recreate provenance edge"
700                        );
701                    }
702                }
703            }
704        }
705    }
706
707    (episodes_archived, provenance_edges_created)
708}
709
710/// Discover new causal edges from temporal co-occurrence (Granger-like).
711///
712/// Scans time-sorted episodes for pairs where A consistently precedes B
713/// within a 1-hour window. When evidence count ≥ 3, creates a `Causes`
714/// edge in the graph with strength and confidence proportional to evidence.
715/// The `consolidation_causal_window` config limits the number of episodes
716/// considered (0 = no limit). Returns the number of new edges created.
717async fn discover_causal_edges(episodes: &[EpisodicRecord], db: &HirnDB) -> usize {
718    if episodes.len() < 2 {
719        return 0;
720    }
721
722    let window = db.config().consolidation_causal_window;
723    let episodes = if window > 0 && episodes.len() > window {
724        &episodes[episodes.len() - window..]
725    } else {
726        episodes
727    };
728
729    let max_gap_ms: i64 = 3_600_000;
730    let min_evidence: usize = 3;
731
732    // Collect temporal co-occurrence: (content_key_a, content_key_b) → list of (id_a, id_b).
733    let mut pair_counts: HashMap<(String, String), Vec<(MemoryId, MemoryId)>> = HashMap::new();
734
735    for (i, ep_b) in episodes.iter().enumerate() {
736        let ts_b = ep_b.timestamp.timestamp_ms();
737        let key_b = truncate_content_key(&ep_b.content);
738
739        // Look backward at previous episodes within the time window.
740        for ep_a in episodes[..i].iter().rev() {
741            let ts_a = ep_a.timestamp.timestamp_ms();
742            let gap = ts_b - ts_a;
743            if gap > max_gap_ms {
744                break; // Episodes are sorted by time, so no more within window.
745            }
746            if gap <= 0 {
747                continue;
748            }
749            let key_a = truncate_content_key(&ep_a.content);
750            if key_a != key_b {
751                pair_counts
752                    .entry((key_a, key_b.clone()))
753                    .or_default()
754                    .push((ep_a.id, ep_b.id));
755            }
756        }
757    }
758
759    let store = db.graph_store();
760    let mut edges_created = 0;
761
762    for pairs in pair_counts.values() {
763        let count = pairs.len();
764        if count < min_evidence {
765            continue;
766        }
767
768        let strength = (count as f32 / 10.0).min(1.0);
769
770        // Use the last observed pair as representative.
771        if let Some(&(cause_id, effect_id)) = pairs.last() {
772            // Check if edge already exists to avoid duplicates.
773            let existing = store
774                .get_edges_of_type(cause_id, EdgeRelation::Causes)
775                .await
776                .unwrap_or_default();
777            if existing.iter().any(|e| e.target == effect_id) {
778                continue;
779            }
780
781            if store
782                .add_causal_edge(
783                    cause_id,
784                    effect_id,
785                    EdgeRelation::Causes,
786                    strength,
787                    Metadata::default(),
788                    hirn_graph::CausalEdgeData::new(strength, 0.5, count as u32)
789                        .with_mechanism("temporal_granger"),
790                )
791                .await
792                .is_ok()
793            {
794                edges_created += 1;
795                db.emit(crate::event::MemoryEvent::CausalEdgeDiscovered {
796                    cause: cause_id,
797                    effect: effect_id,
798                    strength,
799                })
800                .await;
801            } else {
802                tracing::debug!(
803                    %cause_id, %effect_id,
804                    "causal edge creation failed during discovery"
805                );
806            }
807        }
808    }
809
810    edges_created
811}
812
813fn truncate_content_key(content: &str) -> String {
814    content.chars().take(50).collect::<String>().to_lowercase()
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820
821    use crate::graph_store::GraphStore;
822    use std::sync::atomic::{AtomicUsize, Ordering};
823
824    use hirn_core::embed::{ChatMessage, LlmOptions, LlmProvider};
825    use hirn_core::episodic::EpisodicRecord;
826    use hirn_core::types::{EventType, KnowledgeType, Layer};
827
828    struct MockPipelineLlm {
829        calls: AtomicUsize,
830    }
831
832    impl MockPipelineLlm {
833        fn new() -> Self {
834            Self {
835                calls: AtomicUsize::new(0),
836            }
837        }
838    }
839
840    #[async_trait::async_trait]
841    impl LlmProvider for MockPipelineLlm {
842        async fn generate_text(
843            &self,
844            _messages: &[ChatMessage],
845            _options: &LlmOptions,
846        ) -> hirn_core::HirnResult<String> {
847            self.calls.fetch_add(1, Ordering::Relaxed);
848            Ok("THEME: test theme\nKEY_ENTITIES: entity-a, entity-b\nSUMMARY: A community about testing.".into())
849        }
850
851        fn model_id(&self) -> &str {
852            "mock-pipeline"
853        }
854    }
855
856    async fn test_db() -> crate::db::HirnDB {
857        let dir = tempfile::tempdir().unwrap();
858        let db_path = dir.path().join("test");
859        let lance_path = dir.path().join("lance");
860        let mut config = hirn_core::HirnConfig::default();
861        config.db_path = db_path;
862        config.embedding_dimensions = hirn_core::EmbeddingDimension::new_const(3);
863        let storage: Arc<dyn hirn_storage::PhysicalStore> = hirn_storage::HirnDb::open(
864            hirn_storage::HirnDbConfig::local(lance_path.to_str().unwrap()),
865        )
866        .await
867        .unwrap()
868        .store_arc();
869        let db = crate::db::HirnDB::open_with_config(config, storage)
870            .await
871            .unwrap();
872        std::mem::forget(dir);
873        db
874    }
875
876    fn agent() -> AgentId {
877        AgentId::new("test").unwrap()
878    }
879
880    /// Store episodes and wire them in the graph so community detection has edges.
881    async fn populate_db_for_pipeline(db: &crate::db::HirnDB) -> Vec<MemoryId> {
882        let mut ids = Vec::new();
883        // Create 6 episodes about "auth", all sharing the entity so they cluster.
884        for i in 0..6 {
885            let emb = match i % 3 {
886                0 => vec![1.0, 0.0, 0.0],
887                1 => vec![0.95, 0.05, 0.0],
888                _ => vec![0.9, 0.1, 0.0],
889            };
890            let record = EpisodicRecord::builder()
891                .event_type(EventType::Observation)
892                .content(&format!("Auth episode {i}: JWT tokens used for API auth"))
893                .summary(&format!("Auth episode {i}"))
894                .importance(0.7)
895                .surprise(0.5)
896                .agent_id(agent())
897                .embedding(emb)
898                .entity("auth", "topic")
899                .build()
900                .unwrap();
901            let id = db.remember_bypass_admission(record).await.unwrap();
902            ids.push(id);
903        }
904
905        // Create graph edges between episodes so community detection finds structure.
906        for i in 0..ids.len() {
907            for j in (i + 1)..ids.len() {
908                let _ = db
909                    .connect_with(
910                        ids[i],
911                        ids[j],
912                        EdgeRelation::SimilarTo,
913                        0.9,
914                        Metadata::default(),
915                    )
916                    .await;
917            }
918        }
919
920        ids
921    }
922
923    #[tokio::test(flavor = "multi_thread")]
924    async fn full_consolidation_pipeline_with_communities() {
925        let db = test_db().await;
926        let _ids = populate_db_for_pipeline(&db).await;
927        let llm: Arc<dyn LlmProvider> = Arc::new(MockPipelineLlm::new());
928
929        let config = ConsolidationConfig::default();
930        let result = execute_consolidation_pipeline(&db, &config, &[], Some(&llm))
931            .await
932            .unwrap();
933
934        // Verify episodes were processed.
935        assert!(
936            result.records_processed >= 6,
937            "expected >= 6 records processed, got {}",
938            result.records_processed
939        );
940        // Segmentation should produce at least 1 segment.
941        assert!(result.segments_created >= 1);
942    }
943
944    #[tokio::test(flavor = "multi_thread")]
945    async fn community_summaries_in_semantic_store_after_pipeline() {
946        let db = test_db().await;
947        let _ids = populate_db_for_pipeline(&db).await;
948        let llm: Arc<dyn LlmProvider> = Arc::new(MockPipelineLlm::new());
949
950        let config = ConsolidationConfig::default();
951        let result = execute_consolidation_pipeline(&db, &config, &[], Some(&llm))
952            .await
953            .unwrap();
954
955        if result.communities_detected > 0 {
956            // Community summaries should have been stored.
957            assert!(
958                result.community_summaries_stored > 0,
959                "expected community summaries when communities detected"
960            );
961
962            // Verify at least one community record exists in semantic store.
963            let stored = db.get_semantic_by_concept("community-0-0").await;
964            assert!(
965                stored.is_ok(),
966                "community-0-0 should exist in semantic store"
967            );
968            let record = stored.unwrap();
969            assert_eq!(record.knowledge_type, KnowledgeType::Community);
970        }
971    }
972
973    #[tokio::test(flavor = "multi_thread")]
974    async fn community_edges_in_graph_after_pipeline() {
975        let db = test_db().await;
976        let _ids = populate_db_for_pipeline(&db).await;
977        let llm: Arc<dyn LlmProvider> = Arc::new(MockPipelineLlm::new());
978
979        let config = ConsolidationConfig::default();
980        let result = execute_consolidation_pipeline(&db, &config, &[], Some(&llm))
981            .await
982            .unwrap();
983
984        if result.communities_detected > 0 && result.community_summaries_stored > 0 {
985            // DerivedFrom + PartOf edges should have been created.
986            assert!(
987                result.community_edges_created > 0,
988                "expected community edges when summaries were stored"
989            );
990
991            // Verify community nodes exist in graph.
992            let stored = db.get_semantic_by_concept("community-0-0").await;
993            if let Ok(community_record) = stored {
994                assert!(
995                    db.cached_graph()
996                        .has_node(community_record.id)
997                        .await
998                        .unwrap(),
999                    "community node should appear in the authoritative graph view"
1000                );
1001
1002                // Check edges from community to members.
1003                let edges = db
1004                    .cached_graph()
1005                    .get_edges(community_record.id)
1006                    .await
1007                    .unwrap();
1008                assert!(
1009                    !edges.is_empty(),
1010                    "community node should have edges to members"
1011                );
1012            }
1013        }
1014    }
1015
1016    #[tokio::test(flavor = "multi_thread")]
1017    async fn community_nodes_in_graph_after_consolidation() {
1018        let db = test_db().await;
1019        let _ids = populate_db_for_pipeline(&db).await;
1020        let llm: Arc<dyn LlmProvider> = Arc::new(MockPipelineLlm::new());
1021
1022        let config = ConsolidationConfig::default();
1023        let result = execute_consolidation_pipeline(&db, &config, &[], Some(&llm))
1024            .await
1025            .unwrap();
1026
1027        if result.community_summaries_stored > 0 {
1028            // Find community nodes by checking for Semantic layer nodes
1029            // that were added during this pipeline run.
1030            let all_nodes = db.cached_graph().node_ids().await.unwrap();
1031            let mut community_nodes = Vec::new();
1032            for id in &all_nodes {
1033                if db.cached_graph().node_layer(*id).await.unwrap() == Some(Layer::Semantic) {
1034                    community_nodes.push(*id);
1035                }
1036            }
1037
1038            assert!(
1039                !community_nodes.is_empty(),
1040                "graph should contain semantic (community) nodes after consolidation"
1041            );
1042        }
1043    }
1044
1045    #[tokio::test(flavor = "multi_thread")]
1046    async fn consolidation_feedback_reduces_interference_backlog_on_progress() {
1047        let db = test_db().await;
1048        let llm: Arc<dyn LlmProvider> = Arc::new(MockPipelineLlm::new());
1049
1050        let action = db.write_runtime().accumulate_interference(
1051            0.4,
1052            hirn_core::types::Namespace::default(),
1053            0.3,
1054            300,
1055        );
1056        assert!(matches!(
1057            action,
1058            crate::db::write_path::InterferenceAction::TriggerConsolidation { .. }
1059        ));
1060        assert!(db.write_runtime().interference_snapshot().awaiting_feedback);
1061
1062        let _ids = populate_db_for_pipeline(&db).await;
1063        let result =
1064            execute_consolidation_pipeline(&db, &ConsolidationConfig::default(), &[], Some(&llm))
1065                .await
1066                .unwrap();
1067        assert!(result.made_progress());
1068
1069        let snapshot = db.write_runtime().interference_snapshot();
1070        assert_eq!(snapshot.backlog_score, 0.0);
1071        assert_eq!(snapshot.namespace_count, 0);
1072        assert!(!snapshot.awaiting_feedback);
1073    }
1074}