1use crate::{
12 core::{EntityId, KnowledgeGraph, Relationship, Result},
13 graph::temporal::{TemporalRange, TemporalRelationType},
14};
15use std::collections::VecDeque;
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
23pub struct CausalChain {
24 pub cause: EntityId,
26
27 pub effect: EntityId,
29
30 pub steps: Vec<CausalStep>,
32
33 pub total_confidence: f32,
35
36 pub temporal_consistency: bool,
38
39 pub time_span: Option<i64>,
41}
42
43#[derive(Debug, Clone)]
47pub struct CausalStep {
48 pub source: EntityId,
50
51 pub target: EntityId,
53
54 pub relation_type: String,
56
57 pub temporal_type: Option<TemporalRelationType>,
59
60 pub temporal_range: Option<TemporalRange>,
62
63 pub confidence: f32,
65
66 pub causal_strength: Option<f32>,
68}
69
70impl CausalStep {
71 pub fn from_relationship(rel: &Relationship) -> Self {
73 Self {
74 source: rel.source.clone(),
75 target: rel.target.clone(),
76 relation_type: rel.relation_type.clone(),
77 temporal_type: rel.temporal_type,
78 temporal_range: rel.temporal_range,
79 confidence: rel.confidence,
80 causal_strength: rel.causal_strength,
81 }
82 }
83
84 pub fn has_temporal_info(&self) -> bool {
86 self.temporal_range.is_some()
87 }
88
89 pub fn get_timestamp(&self) -> Option<i64> {
91 self.temporal_range.map(|tr| (tr.start + tr.end) / 2)
92 }
93}
94
95impl CausalChain {
96 pub fn calculate_confidence(&self) -> f32 {
100 if self.steps.is_empty() {
101 return 0.0;
102 }
103
104 let mut product = 1.0;
105 for step in &self.steps {
106 let weighted_confidence = if let Some(strength) = step.causal_strength {
108 step.confidence * (0.5 + 0.5 * strength) } else {
110 step.confidence * 0.7 };
112 product *= weighted_confidence;
113 }
114
115 product
116 }
117
118 pub fn check_temporal_consistency(&self) -> bool {
122 let mut prev_timestamp: Option<i64> = None;
123
124 for step in &self.steps {
125 if let Some(current_ts) = step.get_timestamp() {
126 if let Some(prev_ts) = prev_timestamp {
127 if current_ts < prev_ts {
129 return false; }
131 }
132 prev_timestamp = Some(current_ts);
133 }
134 }
135
136 true
137 }
138
139 pub fn calculate_time_span(&self) -> Option<i64> {
141 let first_timestamp = self.steps.first()?.get_timestamp()?;
142 let last_timestamp = self.steps.last()?.get_timestamp()?;
143
144 Some(last_timestamp - first_timestamp)
145 }
146
147 pub fn describe(&self) -> String {
149 let step_descriptions: Vec<String> = self
150 .steps
151 .iter()
152 .map(|s| format!("{} --[{}]--> {}", s.source.0, s.relation_type, s.target.0))
153 .collect();
154
155 format!(
156 "Causal chain (conf={:.2}, consistent={}): {}",
157 self.total_confidence,
158 self.temporal_consistency,
159 step_descriptions.join(" → ")
160 )
161 }
162}
163
164pub struct CausalAnalyzer {
168 graph: Arc<KnowledgeGraph>,
170
171 min_confidence: f32,
173
174 min_causal_strength: f32,
176
177 require_temporal_consistency: bool,
179}
180
181impl CausalAnalyzer {
182 pub fn new(graph: Arc<KnowledgeGraph>) -> Self {
188 Self {
189 graph,
190 min_confidence: 0.3,
191 min_causal_strength: 0.0, require_temporal_consistency: false, }
194 }
195
196 pub fn with_min_confidence(mut self, min_confidence: f32) -> Self {
198 self.min_confidence = min_confidence.clamp(0.0, 1.0);
199 self
200 }
201
202 pub fn with_min_causal_strength(mut self, min_causal_strength: f32) -> Self {
204 self.min_causal_strength = min_causal_strength.clamp(0.0, 1.0);
205 self
206 }
207
208 pub fn with_temporal_consistency(mut self, required: bool) -> Self {
210 self.require_temporal_consistency = required;
211 self
212 }
213
214 pub fn find_causal_chains(
226 &self,
227 cause: &EntityId,
228 effect: &EntityId,
229 max_depth: usize,
230 ) -> Result<Vec<CausalChain>> {
231 let mut chains = Vec::new();
232
233 let all_paths = self.find_all_paths(cause, effect, max_depth)?;
235
236 #[cfg(feature = "tracing")]
237 tracing::debug!(
238 cause = %cause.0,
239 effect = %effect.0,
240 paths_found = all_paths.len(),
241 "Found potential causal paths"
242 );
243
244 for path in all_paths {
246 let mut steps = Vec::new();
247
248 for i in 0..path.len() - 1 {
249 let source_id = &path[i];
250 let target_id = &path[i + 1];
251
252 if let Some(rel) = self.find_relationship(source_id, target_id) {
254 if self.is_causal_relationship(&rel) {
256 steps.push(CausalStep::from_relationship(&rel));
257 }
258 }
259 }
260
261 if !steps.is_empty() {
263 let mut chain = CausalChain {
264 cause: cause.clone(),
265 effect: effect.clone(),
266 steps,
267 total_confidence: 0.0,
268 temporal_consistency: false,
269 time_span: None,
270 };
271
272 chain.total_confidence = chain.calculate_confidence();
274 chain.temporal_consistency = chain.check_temporal_consistency();
275 chain.time_span = chain.calculate_time_span();
276
277 if self.require_temporal_consistency && !chain.temporal_consistency {
279 continue;
280 }
281
282 chains.push(chain);
283 }
284 }
285
286 chains.sort_by(|a, b| {
288 b.total_confidence
289 .partial_cmp(&a.total_confidence)
290 .unwrap_or(std::cmp::Ordering::Equal)
291 });
292
293 #[cfg(feature = "tracing")]
294 tracing::info!(causal_chains = chains.len(), "Found valid causal chains");
295
296 Ok(chains)
297 }
298
299 fn find_all_paths(
301 &self,
302 start: &EntityId,
303 end: &EntityId,
304 max_depth: usize,
305 ) -> Result<Vec<Vec<EntityId>>> {
306 let mut paths = Vec::new();
307 let mut queue: VecDeque<(EntityId, Vec<EntityId>)> = VecDeque::new();
308
309 queue.push_back((start.clone(), vec![start.clone()]));
310
311 while let Some((current, path)) = queue.pop_front() {
312 if path.len() > max_depth {
314 continue;
315 }
316
317 if current == *end {
319 paths.push(path);
320 continue;
321 }
322
323 for rel in self.graph.get_entity_relationships(¤t.0) {
325 let next = &rel.target;
326
327 if path.contains(next) {
329 continue;
330 }
331
332 if rel.confidence < self.min_confidence {
334 continue;
335 }
336
337 let mut new_path = path.clone();
338 new_path.push(next.clone());
339 queue.push_back((next.clone(), new_path));
340 }
341 }
342
343 Ok(paths)
344 }
345
346 fn find_relationship(&self, source: &EntityId, target: &EntityId) -> Option<Relationship> {
348 self.graph
349 .get_entity_relationships(&source.0)
350 .into_iter()
351 .find(|rel| rel.target == *target)
352 .cloned()
353 }
354
355 fn is_causal_relationship(&self, rel: &Relationship) -> bool {
357 if let Some(temporal_type) = rel.temporal_type {
359 if temporal_type.is_causal() {
360 if let Some(strength) = rel.causal_strength {
362 return strength >= self.min_causal_strength;
363 }
364 return true; }
366 }
367
368 let relation_lower = rel.relation_type.to_lowercase();
370 let causal_keywords = ["caused", "led_to", "resulted_in", "enabled", "triggered"];
371
372 causal_keywords.iter().any(|kw| relation_lower.contains(kw))
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::core::Entity;
380
381 fn create_test_graph_with_causal_chain() -> KnowledgeGraph {
382 let mut graph = KnowledgeGraph::new();
383
384 let entity_a = Entity::new(
386 EntityId::new("a".to_string()),
387 "Event A".to_string(),
388 "EVENT".to_string(),
389 0.9,
390 );
391
392 let entity_b = Entity::new(
393 EntityId::new("b".to_string()),
394 "Event B".to_string(),
395 "EVENT".to_string(),
396 0.9,
397 );
398
399 let entity_c = Entity::new(
400 EntityId::new("c".to_string()),
401 "Event C".to_string(),
402 "EVENT".to_string(),
403 0.9,
404 );
405
406 graph.add_entity(entity_a).unwrap();
407 graph.add_entity(entity_b).unwrap();
408 graph.add_entity(entity_c).unwrap();
409
410 let rel_ab = Relationship::new(
412 EntityId::new("a".to_string()),
413 EntityId::new("b".to_string()),
414 "CAUSED".to_string(),
415 0.8,
416 )
417 .with_temporal_type(TemporalRelationType::Caused)
418 .with_temporal_range(100, 100)
419 .with_causal_strength(0.9);
420
421 let rel_bc = Relationship::new(
423 EntityId::new("b".to_string()),
424 EntityId::new("c".to_string()),
425 "CAUSED".to_string(),
426 0.85,
427 )
428 .with_temporal_type(TemporalRelationType::Caused)
429 .with_temporal_range(200, 200)
430 .with_causal_strength(0.95);
431
432 graph.add_relationship(rel_ab).unwrap();
433 graph.add_relationship(rel_bc).unwrap();
434
435 graph
436 }
437
438 #[test]
439 fn test_causal_chain_creation() {
440 let graph = Arc::new(create_test_graph_with_causal_chain());
441 let analyzer = CausalAnalyzer::new(graph);
442
443 let chains = analyzer
444 .find_causal_chains(
445 &EntityId::new("a".to_string()),
446 &EntityId::new("c".to_string()),
447 5,
448 )
449 .unwrap();
450
451 assert_eq!(chains.len(), 1, "Should find exactly one causal chain");
452
453 let chain = &chains[0];
454 assert_eq!(chain.steps.len(), 2, "Chain should have 2 steps (A→B, B→C)");
455 assert!(
456 chain.temporal_consistency,
457 "Chain should be temporally consistent"
458 );
459 assert!(
460 chain.total_confidence > 0.6,
461 "Chain should have reasonable confidence"
462 );
463 }
464
465 #[test]
466 fn test_temporal_consistency_validation() {
467 let mut graph = KnowledgeGraph::new();
468
469 let a = Entity::new(
471 EntityId::new("a".to_string()),
472 "A".to_string(),
473 "EVENT".to_string(),
474 0.9,
475 );
476 let b = Entity::new(
477 EntityId::new("b".to_string()),
478 "B".to_string(),
479 "EVENT".to_string(),
480 0.9,
481 );
482 let c = Entity::new(
483 EntityId::new("c".to_string()),
484 "C".to_string(),
485 "EVENT".to_string(),
486 0.9,
487 );
488
489 graph.add_entity(a).unwrap();
490 graph.add_entity(b).unwrap();
491 graph.add_entity(c).unwrap();
492
493 let rel_ab = Relationship::new(
495 EntityId::new("a".to_string()),
496 EntityId::new("b".to_string()),
497 "CAUSED".to_string(),
498 0.8,
499 )
500 .with_temporal_range(100, 100)
501 .with_causal_strength(0.9);
502
503 let rel_bc = Relationship::new(
504 EntityId::new("b".to_string()),
505 EntityId::new("c".to_string()),
506 "CAUSED".to_string(),
507 0.8,
508 )
509 .with_temporal_range(50, 50) .with_causal_strength(0.9);
511
512 graph.add_relationship(rel_ab).unwrap();
513 graph.add_relationship(rel_bc).unwrap();
514
515 let analyzer = CausalAnalyzer::new(Arc::new(graph)).with_temporal_consistency(true); let chains = analyzer
518 .find_causal_chains(
519 &EntityId::new("a".to_string()),
520 &EntityId::new("c".to_string()),
521 5,
522 )
523 .unwrap();
524
525 assert_eq!(
526 chains.len(),
527 0,
528 "Should reject temporally inconsistent chain"
529 );
530 }
531
532 #[test]
533 fn test_confidence_calculation() {
534 let step1 = CausalStep {
535 source: EntityId::new("a".to_string()),
536 target: EntityId::new("b".to_string()),
537 relation_type: "CAUSED".to_string(),
538 temporal_type: Some(TemporalRelationType::Caused),
539 temporal_range: None,
540 confidence: 0.8,
541 causal_strength: Some(0.9),
542 };
543
544 let step2 = CausalStep {
545 source: EntityId::new("b".to_string()),
546 target: EntityId::new("c".to_string()),
547 relation_type: "CAUSED".to_string(),
548 temporal_type: Some(TemporalRelationType::Caused),
549 temporal_range: None,
550 confidence: 0.9,
551 causal_strength: Some(0.95),
552 };
553
554 let chain = CausalChain {
555 cause: EntityId::new("a".to_string()),
556 effect: EntityId::new("c".to_string()),
557 steps: vec![step1, step2],
558 total_confidence: 0.0,
559 temporal_consistency: true,
560 time_span: None,
561 };
562
563 let confidence = chain.calculate_confidence();
564
565 assert!(
570 confidence > 0.65 && confidence < 0.7,
571 "Confidence calculation incorrect: {}",
572 confidence
573 );
574 }
575}