Skip to main content

graphrag_core/retrieval/
causal_analysis.rs

1//! Causal Chain Analysis (Phase 2.3)
2//!
3//! This module implements causal chain discovery and temporal validation.
4//! It finds paths between cause and effect entities, validating temporal ordering.
5//!
6//! Example: Query "What caused the fall of Athens?" should find:
7//! Plague → Weakened Athens → Sparta attacked → Athens fell
8//!
9//! Each step is validated for temporal consistency (t1 < t2 < t3).
10
11use crate::{
12    core::{EntityId, KnowledgeGraph, Relationship, Result},
13    graph::temporal::{TemporalRange, TemporalRelationType},
14};
15use std::collections::VecDeque;
16use std::sync::Arc;
17
18/// A complete causal chain from cause to effect
19///
20/// Represents a series of causal steps connecting two entities.
21/// All steps must be temporally ordered (earlier causes → later effects).
22#[derive(Debug, Clone)]
23pub struct CausalChain {
24    /// Starting entity (the cause)
25    pub cause: EntityId,
26
27    /// Ending entity (the effect)
28    pub effect: EntityId,
29
30    /// Intermediate causal steps
31    pub steps: Vec<CausalStep>,
32
33    /// Overall confidence of the chain (product of step confidences)
34    pub total_confidence: f32,
35
36    /// Whether the chain is temporally consistent (all steps ordered correctly)
37    pub temporal_consistency: bool,
38
39    /// Total time span of the chain (if temporal data available)
40    pub time_span: Option<i64>,
41}
42
43/// A single step in a causal chain
44///
45/// Represents one causal relationship in a chain.
46#[derive(Debug, Clone)]
47pub struct CausalStep {
48    /// Source entity of this step
49    pub source: EntityId,
50
51    /// Target entity of this step
52    pub target: EntityId,
53
54    /// Type of causal relationship
55    pub relation_type: String,
56
57    /// Temporal type (Caused, Enabled, etc.)
58    pub temporal_type: Option<TemporalRelationType>,
59
60    /// When this step occurred
61    pub temporal_range: Option<TemporalRange>,
62
63    /// Confidence of this causal link
64    pub confidence: f32,
65
66    /// Strength of causality (0.0-1.0)
67    pub causal_strength: Option<f32>,
68}
69
70impl CausalStep {
71    /// Create a causal step from a relationship
72    pub fn from_relationship(rel: &Relationship) -> Self {
73        Self {
74            source: rel.source.clone(),
75            target: rel.target.clone(),
76            relation_type: rel.relation_type.clone(),
77            temporal_type: rel.temporal_type,
78            temporal_range: rel.temporal_range,
79            confidence: rel.confidence,
80            causal_strength: rel.causal_strength,
81        }
82    }
83
84    /// Check if this step has temporal information
85    pub fn has_temporal_info(&self) -> bool {
86        self.temporal_range.is_some()
87    }
88
89    /// Get the midpoint timestamp of this step (for ordering)
90    pub fn get_timestamp(&self) -> Option<i64> {
91        self.temporal_range.map(|tr| (tr.start + tr.end) / 2)
92    }
93}
94
95impl CausalChain {
96    /// Calculate the total confidence of the chain
97    ///
98    /// Uses product of step confidences, weighted by causal strengths
99    pub fn calculate_confidence(&self) -> f32 {
100        if self.steps.is_empty() {
101            return 0.0;
102        }
103
104        let mut product = 1.0;
105        for step in &self.steps {
106            // Weight confidence by causal strength if available
107            let weighted_confidence = if let Some(strength) = step.causal_strength {
108                step.confidence * (0.5 + 0.5 * strength) // Range: 0.5*conf to 1.0*conf
109            } else {
110                step.confidence * 0.7 // Default weight for non-causal
111            };
112            product *= weighted_confidence;
113        }
114
115        product
116    }
117
118    /// Check temporal consistency of the chain
119    ///
120    /// Returns true if all steps are temporally ordered (t1 < t2 < t3...)
121    pub fn check_temporal_consistency(&self) -> bool {
122        let mut prev_timestamp: Option<i64> = None;
123
124        for step in &self.steps {
125            if let Some(current_ts) = step.get_timestamp() {
126                if let Some(prev_ts) = prev_timestamp {
127                    // Check if current step happened after previous
128                    if current_ts < prev_ts {
129                        return false; // Temporal violation
130                    }
131                }
132                prev_timestamp = Some(current_ts);
133            }
134        }
135
136        true
137    }
138
139    /// Calculate the time span of the chain
140    pub fn calculate_time_span(&self) -> Option<i64> {
141        let first_timestamp = self.steps.first()?.get_timestamp()?;
142        let last_timestamp = self.steps.last()?.get_timestamp()?;
143
144        Some(last_timestamp - first_timestamp)
145    }
146
147    /// Get a human-readable description of the chain
148    pub fn describe(&self) -> String {
149        let step_descriptions: Vec<String> = self
150            .steps
151            .iter()
152            .map(|s| format!("{} --[{}]--> {}", s.source.0, s.relation_type, s.target.0))
153            .collect();
154
155        format!(
156            "Causal chain (conf={:.2}, consistent={}): {}",
157            self.total_confidence,
158            self.temporal_consistency,
159            step_descriptions.join(" → ")
160        )
161    }
162}
163
164/// Analyzer for finding causal chains in the knowledge graph
165///
166/// Uses depth-first search with temporal validation to find causal paths.
167pub struct CausalAnalyzer {
168    /// Reference to the knowledge graph
169    graph: Arc<KnowledgeGraph>,
170
171    /// Minimum confidence threshold for causal steps
172    min_confidence: f32,
173
174    /// Minimum causal strength to consider a relationship causal
175    min_causal_strength: f32,
176
177    /// Whether to require temporal consistency
178    require_temporal_consistency: bool,
179}
180
181impl CausalAnalyzer {
182    /// Create a new causal analyzer
183    ///
184    /// # Arguments
185    ///
186    /// * `graph` - Reference to the knowledge graph
187    pub fn new(graph: Arc<KnowledgeGraph>) -> Self {
188        Self {
189            graph,
190            min_confidence: 0.3,
191            min_causal_strength: 0.0, // Accept all relationships by default
192            require_temporal_consistency: false, // Lenient by default
193        }
194    }
195
196    /// Set minimum confidence threshold
197    pub fn with_min_confidence(mut self, min_confidence: f32) -> Self {
198        self.min_confidence = min_confidence.clamp(0.0, 1.0);
199        self
200    }
201
202    /// Set minimum causal strength threshold
203    pub fn with_min_causal_strength(mut self, min_causal_strength: f32) -> Self {
204        self.min_causal_strength = min_causal_strength.clamp(0.0, 1.0);
205        self
206    }
207
208    /// Enable/disable temporal consistency requirement
209    pub fn with_temporal_consistency(mut self, required: bool) -> Self {
210        self.require_temporal_consistency = required;
211        self
212    }
213
214    /// Find all causal chains between cause and effect
215    ///
216    /// # Arguments
217    ///
218    /// * `cause` - Starting entity ID
219    /// * `effect` - Target entity ID
220    /// * `max_depth` - Maximum chain length (number of steps)
221    ///
222    /// # Returns
223    ///
224    /// Vector of causal chains, sorted by confidence (highest first)
225    pub fn find_causal_chains(
226        &self,
227        cause: &EntityId,
228        effect: &EntityId,
229        max_depth: usize,
230    ) -> Result<Vec<CausalChain>> {
231        let mut chains = Vec::new();
232
233        // Use BFS to find all paths
234        let all_paths = self.find_all_paths(cause, effect, max_depth)?;
235
236        #[cfg(feature = "tracing")]
237        tracing::debug!(
238            cause = %cause.0,
239            effect = %effect.0,
240            paths_found = all_paths.len(),
241            "Found potential causal paths"
242        );
243
244        // Convert paths to causal chains
245        for path in all_paths {
246            let mut steps = Vec::new();
247
248            for i in 0..path.len() - 1 {
249                let source_id = &path[i];
250                let target_id = &path[i + 1];
251
252                // Find the relationship between these entities
253                if let Some(rel) = self.find_relationship(source_id, target_id) {
254                    // Check if this is a causal relationship
255                    if self.is_causal_relationship(&rel) {
256                        steps.push(CausalStep::from_relationship(&rel));
257                    }
258                }
259            }
260
261            // Only create chain if we have causal steps
262            if !steps.is_empty() {
263                let mut chain = CausalChain {
264                    cause: cause.clone(),
265                    effect: effect.clone(),
266                    steps,
267                    total_confidence: 0.0,
268                    temporal_consistency: false,
269                    time_span: None,
270                };
271
272                // Calculate properties
273                chain.total_confidence = chain.calculate_confidence();
274                chain.temporal_consistency = chain.check_temporal_consistency();
275                chain.time_span = chain.calculate_time_span();
276
277                // Filter by temporal consistency if required
278                if self.require_temporal_consistency && !chain.temporal_consistency {
279                    continue;
280                }
281
282                chains.push(chain);
283            }
284        }
285
286        // Sort by confidence (highest first)
287        chains.sort_by(|a, b| {
288            b.total_confidence
289                .partial_cmp(&a.total_confidence)
290                .unwrap_or(std::cmp::Ordering::Equal)
291        });
292
293        #[cfg(feature = "tracing")]
294        tracing::info!(causal_chains = chains.len(), "Found valid causal chains");
295
296        Ok(chains)
297    }
298
299    /// Find all paths between two entities using BFS
300    fn find_all_paths(
301        &self,
302        start: &EntityId,
303        end: &EntityId,
304        max_depth: usize,
305    ) -> Result<Vec<Vec<EntityId>>> {
306        let mut paths = Vec::new();
307        let mut queue: VecDeque<(EntityId, Vec<EntityId>)> = VecDeque::new();
308
309        queue.push_back((start.clone(), vec![start.clone()]));
310
311        while let Some((current, path)) = queue.pop_front() {
312            // Check depth limit
313            if path.len() > max_depth {
314                continue;
315            }
316
317            // Found the target
318            if current == *end {
319                paths.push(path);
320                continue;
321            }
322
323            // Explore neighbors
324            for rel in self.graph.get_entity_relationships(&current.0) {
325                let next = &rel.target;
326
327                // Avoid cycles
328                if path.contains(next) {
329                    continue;
330                }
331
332                // Check if relationship meets minimum confidence
333                if rel.confidence < self.min_confidence {
334                    continue;
335                }
336
337                let mut new_path = path.clone();
338                new_path.push(next.clone());
339                queue.push_back((next.clone(), new_path));
340            }
341        }
342
343        Ok(paths)
344    }
345
346    /// Find a relationship between two entities
347    fn find_relationship(&self, source: &EntityId, target: &EntityId) -> Option<Relationship> {
348        self.graph
349            .get_entity_relationships(&source.0)
350            .into_iter()
351            .find(|rel| rel.target == *target)
352            .cloned()
353    }
354
355    /// Check if a relationship is causal
356    fn is_causal_relationship(&self, rel: &Relationship) -> bool {
357        // Check temporal type
358        if let Some(temporal_type) = rel.temporal_type {
359            if temporal_type.is_causal() {
360                // Check causal strength threshold
361                if let Some(strength) = rel.causal_strength {
362                    return strength >= self.min_causal_strength;
363                }
364                return true; // Has causal type but no strength specified
365            }
366        }
367
368        // Check relation type contains causal keywords
369        let relation_lower = rel.relation_type.to_lowercase();
370        let causal_keywords = ["caused", "led_to", "resulted_in", "enabled", "triggered"];
371
372        causal_keywords.iter().any(|kw| relation_lower.contains(kw))
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use crate::core::Entity;
380
381    fn create_test_graph_with_causal_chain() -> KnowledgeGraph {
382        let mut graph = KnowledgeGraph::new();
383
384        // Create entities: A → B → C (causal chain)
385        let entity_a = Entity::new(
386            EntityId::new("a".to_string()),
387            "Event A".to_string(),
388            "EVENT".to_string(),
389            0.9,
390        );
391
392        let entity_b = Entity::new(
393            EntityId::new("b".to_string()),
394            "Event B".to_string(),
395            "EVENT".to_string(),
396            0.9,
397        );
398
399        let entity_c = Entity::new(
400            EntityId::new("c".to_string()),
401            "Event C".to_string(),
402            "EVENT".to_string(),
403            0.9,
404        );
405
406        graph.add_entity(entity_a).unwrap();
407        graph.add_entity(entity_b).unwrap();
408        graph.add_entity(entity_c).unwrap();
409
410        // A caused B (at time 100)
411        let rel_ab = Relationship::new(
412            EntityId::new("a".to_string()),
413            EntityId::new("b".to_string()),
414            "CAUSED".to_string(),
415            0.8,
416        )
417        .with_temporal_type(TemporalRelationType::Caused)
418        .with_temporal_range(100, 100)
419        .with_causal_strength(0.9);
420
421        // B caused C (at time 200)
422        let rel_bc = Relationship::new(
423            EntityId::new("b".to_string()),
424            EntityId::new("c".to_string()),
425            "CAUSED".to_string(),
426            0.85,
427        )
428        .with_temporal_type(TemporalRelationType::Caused)
429        .with_temporal_range(200, 200)
430        .with_causal_strength(0.95);
431
432        graph.add_relationship(rel_ab).unwrap();
433        graph.add_relationship(rel_bc).unwrap();
434
435        graph
436    }
437
438    #[test]
439    fn test_causal_chain_creation() {
440        let graph = Arc::new(create_test_graph_with_causal_chain());
441        let analyzer = CausalAnalyzer::new(graph);
442
443        let chains = analyzer
444            .find_causal_chains(
445                &EntityId::new("a".to_string()),
446                &EntityId::new("c".to_string()),
447                5,
448            )
449            .unwrap();
450
451        assert_eq!(chains.len(), 1, "Should find exactly one causal chain");
452
453        let chain = &chains[0];
454        assert_eq!(chain.steps.len(), 2, "Chain should have 2 steps (A→B, B→C)");
455        assert!(
456            chain.temporal_consistency,
457            "Chain should be temporally consistent"
458        );
459        assert!(
460            chain.total_confidence > 0.6,
461            "Chain should have reasonable confidence"
462        );
463    }
464
465    #[test]
466    fn test_temporal_consistency_validation() {
467        let mut graph = KnowledgeGraph::new();
468
469        // Create entities
470        let a = Entity::new(
471            EntityId::new("a".to_string()),
472            "A".to_string(),
473            "EVENT".to_string(),
474            0.9,
475        );
476        let b = Entity::new(
477            EntityId::new("b".to_string()),
478            "B".to_string(),
479            "EVENT".to_string(),
480            0.9,
481        );
482        let c = Entity::new(
483            EntityId::new("c".to_string()),
484            "C".to_string(),
485            "EVENT".to_string(),
486            0.9,
487        );
488
489        graph.add_entity(a).unwrap();
490        graph.add_entity(b).unwrap();
491        graph.add_entity(c).unwrap();
492
493        // A→B at time 100, B→C at time 50 (temporal violation!)
494        let rel_ab = Relationship::new(
495            EntityId::new("a".to_string()),
496            EntityId::new("b".to_string()),
497            "CAUSED".to_string(),
498            0.8,
499        )
500        .with_temporal_range(100, 100)
501        .with_causal_strength(0.9);
502
503        let rel_bc = Relationship::new(
504            EntityId::new("b".to_string()),
505            EntityId::new("c".to_string()),
506            "CAUSED".to_string(),
507            0.8,
508        )
509        .with_temporal_range(50, 50) // Earlier than A→B!
510        .with_causal_strength(0.9);
511
512        graph.add_relationship(rel_ab).unwrap();
513        graph.add_relationship(rel_bc).unwrap();
514
515        let analyzer = CausalAnalyzer::new(Arc::new(graph)).with_temporal_consistency(true); // Require consistency
516
517        let chains = analyzer
518            .find_causal_chains(
519                &EntityId::new("a".to_string()),
520                &EntityId::new("c".to_string()),
521                5,
522            )
523            .unwrap();
524
525        assert_eq!(
526            chains.len(),
527            0,
528            "Should reject temporally inconsistent chain"
529        );
530    }
531
532    #[test]
533    fn test_confidence_calculation() {
534        let step1 = CausalStep {
535            source: EntityId::new("a".to_string()),
536            target: EntityId::new("b".to_string()),
537            relation_type: "CAUSED".to_string(),
538            temporal_type: Some(TemporalRelationType::Caused),
539            temporal_range: None,
540            confidence: 0.8,
541            causal_strength: Some(0.9),
542        };
543
544        let step2 = CausalStep {
545            source: EntityId::new("b".to_string()),
546            target: EntityId::new("c".to_string()),
547            relation_type: "CAUSED".to_string(),
548            temporal_type: Some(TemporalRelationType::Caused),
549            temporal_range: None,
550            confidence: 0.9,
551            causal_strength: Some(0.95),
552        };
553
554        let chain = CausalChain {
555            cause: EntityId::new("a".to_string()),
556            effect: EntityId::new("c".to_string()),
557            steps: vec![step1, step2],
558            total_confidence: 0.0,
559            temporal_consistency: true,
560            time_span: None,
561        };
562
563        let confidence = chain.calculate_confidence();
564
565        // Confidence should be product of weighted confidences
566        // step1: 0.8 * (0.5 + 0.5*0.9) = 0.8 * 0.95 = 0.76
567        // step2: 0.9 * (0.5 + 0.5*0.95) = 0.9 * 0.975 = 0.8775
568        // product: 0.76 * 0.8775 ≈ 0.667
569        assert!(
570            confidence > 0.65 && confidence < 0.7,
571            "Confidence calculation incorrect: {}",
572            confidence
573        );
574    }
575}