Skip to main content

converge_knowledge/agentic/
causal.rs

1//! Causal Memory - Hypergraph Relationships
2//!
3//! Implements a causal knowledge graph where agents can:
4//! 1. Record cause-effect relationships
5//! 2. Build hypergraph structures (edges connecting multiple nodes)
6//! 3. Query causal chains and relationships
7//! 4. Reason about consequences of actions
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use uuid::Uuid;
13
14/// A node in the causal graph.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CausalNode {
17    /// Unique identifier.
18    pub id: Uuid,
19
20    /// Node label/name.
21    pub label: String,
22
23    /// Node type/category.
24    pub node_type: String,
25
26    /// Description.
27    pub description: String,
28
29    /// Embedding for similarity search.
30    #[serde(skip)]
31    pub embedding: Option<Vec<f32>>,
32
33    /// When this node was created.
34    pub created_at: DateTime<Utc>,
35}
36
37impl CausalNode {
38    /// Create a new causal node.
39    pub fn new(label: impl Into<String>, node_type: impl Into<String>) -> Self {
40        Self {
41            id: Uuid::new_v4(),
42            label: label.into(),
43            node_type: node_type.into(),
44            description: String::new(),
45            embedding: None,
46            created_at: Utc::now(),
47        }
48    }
49
50    /// Add description.
51    pub fn with_description(mut self, description: impl Into<String>) -> Self {
52        self.description = description.into();
53        self
54    }
55
56    /// Add embedding.
57    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
58        self.embedding = Some(embedding);
59        self
60    }
61}
62
63/// A directed edge representing causation.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CausalEdge {
66    /// Unique identifier.
67    pub id: Uuid,
68
69    /// Source node (cause).
70    pub cause: Uuid,
71
72    /// Target node (effect).
73    pub effect: Uuid,
74
75    /// Relationship type (e.g., "causes", "prevents", "enables").
76    pub relationship: String,
77
78    /// Strength of the causal relationship (0.0 to 1.0).
79    pub strength: f32,
80
81    /// Number of observations supporting this edge.
82    pub evidence_count: u32,
83}
84
85impl CausalEdge {
86    /// Create a new causal edge.
87    pub fn new(cause: Uuid, effect: Uuid, relationship: impl Into<String>, strength: f32) -> Self {
88        Self {
89            id: Uuid::new_v4(),
90            cause,
91            effect,
92            relationship: relationship.into(),
93            strength: strength.clamp(0.0, 1.0),
94            evidence_count: 1,
95        }
96    }
97
98    /// Add evidence (increases count and adjusts strength).
99    pub fn add_evidence(&mut self, observed_strength: f32) {
100        self.evidence_count += 1;
101        // Bayesian update of strength
102        let n = self.evidence_count as f32;
103        self.strength = ((n - 1.0) * self.strength + observed_strength) / n;
104    }
105}
106
107/// A hyperedge connecting multiple nodes.
108///
109/// Unlike regular edges which connect two nodes, hyperedges can
110/// represent complex relationships like "A AND B cause C".
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct Hyperedge {
113    /// Unique identifier.
114    pub id: Uuid,
115
116    /// Source nodes (all must be present for the effect).
117    pub causes: Vec<Uuid>,
118
119    /// Target nodes (effects).
120    pub effects: Vec<Uuid>,
121
122    /// Relationship type.
123    pub relationship: String,
124
125    /// Strength.
126    pub strength: f32,
127
128    /// Description.
129    pub description: String,
130}
131
132impl Hyperedge {
133    /// Create a new hyperedge.
134    pub fn new(
135        causes: Vec<Uuid>,
136        effects: Vec<Uuid>,
137        relationship: impl Into<String>,
138        strength: f32,
139    ) -> Self {
140        Self {
141            id: Uuid::new_v4(),
142            causes,
143            effects,
144            relationship: relationship.into(),
145            strength: strength.clamp(0.0, 1.0),
146            description: String::new(),
147        }
148    }
149}
150
151/// Causal memory store.
152pub struct CausalMemory {
153    nodes: HashMap<Uuid, CausalNode>,
154    edges: Vec<CausalEdge>,
155    hyperedges: Vec<Hyperedge>,
156}
157
158impl CausalMemory {
159    /// Create a new causal memory.
160    pub fn new() -> Self {
161        Self {
162            nodes: HashMap::new(),
163            edges: Vec::new(),
164            hyperedges: Vec::new(),
165        }
166    }
167
168    /// Add a node.
169    pub fn add_node(&mut self, node: CausalNode) -> Uuid {
170        let id = node.id;
171        self.nodes.insert(id, node);
172        id
173    }
174
175    /// Add an edge.
176    pub fn add_edge(&mut self, edge: CausalEdge) {
177        // Check if similar edge exists
178        if let Some(existing) = self.edges.iter_mut().find(|e| {
179            e.cause == edge.cause && e.effect == edge.effect && e.relationship == edge.relationship
180        }) {
181            existing.add_evidence(edge.strength);
182        } else {
183            self.edges.push(edge);
184        }
185    }
186
187    /// Add a hyperedge.
188    pub fn add_hyperedge(&mut self, hyperedge: Hyperedge) {
189        self.hyperedges.push(hyperedge);
190    }
191
192    /// Get a node by ID.
193    pub fn get_node(&self, id: Uuid) -> Option<&CausalNode> {
194        self.nodes.get(&id)
195    }
196
197    /// Find causes of a given effect.
198    pub fn find_causes(&self, effect: Uuid) -> Vec<(&CausalEdge, Option<&CausalNode>)> {
199        self.edges
200            .iter()
201            .filter(|e| e.effect == effect)
202            .map(|e| (e, self.nodes.get(&e.cause)))
203            .collect()
204    }
205
206    /// Find effects of a given cause.
207    pub fn find_effects(&self, cause: Uuid) -> Vec<(&CausalEdge, Option<&CausalNode>)> {
208        self.edges
209            .iter()
210            .filter(|e| e.cause == cause)
211            .map(|e| (e, self.nodes.get(&e.effect)))
212            .collect()
213    }
214
215    /// Trace causal chain from cause to all reachable effects.
216    pub fn trace_chain(&self, start: Uuid, max_depth: usize) -> Vec<(Uuid, usize, f32)> {
217        let mut visited: HashSet<Uuid> = HashSet::new();
218        let mut result: Vec<(Uuid, usize, f32)> = Vec::new();
219        let mut queue: Vec<(Uuid, usize, f32)> = vec![(start, 0, 1.0)];
220
221        while let Some((current, depth, cumulative_strength)) = queue.pop() {
222            if depth > max_depth || visited.contains(&current) {
223                continue;
224            }
225            visited.insert(current);
226
227            if current != start {
228                result.push((current, depth, cumulative_strength));
229            }
230
231            // Find all effects
232            for edge in self.edges.iter().filter(|e| e.cause == current) {
233                let new_strength = cumulative_strength * edge.strength;
234                if new_strength > 0.1 {
235                    // Prune weak chains
236                    queue.push((edge.effect, depth + 1, new_strength));
237                }
238            }
239        }
240
241        result
242    }
243
244    /// Find all edges of a specific relationship type.
245    pub fn find_by_relationship(&self, relationship: &str) -> Vec<&CausalEdge> {
246        self.edges
247            .iter()
248            .filter(|e| e.relationship == relationship)
249            .collect()
250    }
251
252    /// Get strongest causal relationships.
253    pub fn strongest_relationships(&self, limit: usize) -> Vec<&CausalEdge> {
254        let mut edges: Vec<_> = self.edges.iter().collect();
255        edges.sort_by(|a, b| {
256            b.strength
257                .partial_cmp(&a.strength)
258                .unwrap_or(std::cmp::Ordering::Equal)
259        });
260        edges.into_iter().take(limit).collect()
261    }
262
263    /// Node count.
264    pub fn node_count(&self) -> usize {
265        self.nodes.len()
266    }
267
268    /// Edge count.
269    pub fn edge_count(&self) -> usize {
270        self.edges.len()
271    }
272
273    /// Hyperedge count.
274    pub fn hyperedge_count(&self) -> usize {
275        self.hyperedges.len()
276    }
277}
278
279impl Default for CausalMemory {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    /// Test: Building a causal graph.
290    ///
291    /// What happens:
292    /// 1. Create nodes representing concepts
293    /// 2. Add edges representing causal relationships
294    /// 3. Query causes and effects
295    #[test]
296    fn test_causal_graph() {
297        let mut memory = CausalMemory::new();
298
299        // Create nodes
300        let unwrap_id = memory.add_node(
301            CausalNode::new("Using unwrap()", "code_pattern")
302                .with_description("Calling .unwrap() on Option/Result"),
303        );
304
305        let panic_id = memory.add_node(
306            CausalNode::new("Runtime panic", "error")
307                .with_description("Program crashes at runtime"),
308        );
309
310        let option_handling_id = memory.add_node(
311            CausalNode::new("Proper Option handling", "code_pattern")
312                .with_description("Using match or if-let"),
313        );
314
315        let reliability_id = memory.add_node(
316            CausalNode::new("Code reliability", "quality")
317                .with_description("Code works correctly in edge cases"),
318        );
319
320        // Add causal relationships
321        memory.add_edge(CausalEdge::new(unwrap_id, panic_id, "causes", 0.8));
322        memory.add_edge(CausalEdge::new(
323            option_handling_id,
324            reliability_id,
325            "improves",
326            0.9,
327        ));
328        memory.add_edge(CausalEdge::new(
329            option_handling_id,
330            panic_id,
331            "prevents",
332            0.95,
333        ));
334
335        // Query: What causes panics?
336        // Both unwrap (causes) and proper handling (prevents) are related to panic
337        let panic_causes = memory.find_causes(panic_id);
338        assert!(!panic_causes.is_empty());
339        // Find the "causes" relationship
340        let direct_cause = panic_causes
341            .iter()
342            .find(|(e, _)| e.relationship == "causes");
343        assert!(direct_cause.is_some());
344        assert!(direct_cause.unwrap().1.unwrap().label.contains("unwrap"));
345
346        // Query: What does proper handling improve?
347        let handling_effects = memory.find_effects(option_handling_id);
348        assert_eq!(handling_effects.len(), 2);
349    }
350
351    /// Test: Causal chain tracing.
352    ///
353    /// What happens:
354    /// 1. Build a chain: A → B → C → D
355    /// 2. Trace from A
356    /// 3. Get all reachable nodes with depths and cumulative strength
357    #[test]
358    fn test_causal_chain() {
359        let mut memory = CausalMemory::new();
360
361        // A → B → C → D chain
362        let a = memory.add_node(CausalNode::new("A", "concept"));
363        let b = memory.add_node(CausalNode::new("B", "concept"));
364        let c = memory.add_node(CausalNode::new("C", "concept"));
365        let d = memory.add_node(CausalNode::new("D", "concept"));
366
367        memory.add_edge(CausalEdge::new(a, b, "causes", 0.9));
368        memory.add_edge(CausalEdge::new(b, c, "causes", 0.8));
369        memory.add_edge(CausalEdge::new(c, d, "causes", 0.7));
370
371        // Trace from A
372        let chain = memory.trace_chain(a, 10);
373
374        assert_eq!(chain.len(), 3); // B, C, D
375
376        // Check depths
377        let b_entry = chain.iter().find(|(id, _, _)| *id == b).unwrap();
378        assert_eq!(b_entry.1, 1); // Depth 1
379
380        let d_entry = chain.iter().find(|(id, _, _)| *id == d).unwrap();
381        assert_eq!(d_entry.1, 3); // Depth 3
382
383        // Check cumulative strength decays
384        assert!(d_entry.2 < b_entry.2); // D has lower cumulative strength
385    }
386
387    /// Test: Evidence accumulation.
388    ///
389    /// What happens:
390    /// 1. Observe a causal relationship multiple times
391    /// 2. Strength gets updated with Bayesian averaging
392    /// 3. More evidence = more reliable estimate
393    #[test]
394    fn test_evidence_accumulation() {
395        let mut memory = CausalMemory::new();
396
397        let cause = Uuid::new_v4();
398        let effect = Uuid::new_v4();
399
400        // First observation: strength 0.8
401        memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.8));
402
403        // Second observation: strength 0.9
404        memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.9));
405
406        // Third observation: strength 0.85
407        memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.85));
408
409        // Should be one edge with accumulated evidence
410        assert_eq!(memory.edge_count(), 1);
411
412        let edge = &memory.edges[0];
413        assert_eq!(edge.evidence_count, 3);
414
415        // Strength should be average-ish of observations
416        assert!(edge.strength > 0.8 && edge.strength < 0.9);
417    }
418
419    /// Test: Hyperedge for complex causation.
420    ///
421    /// What happens:
422    /// 1. Create a hyperedge: (A AND B) → C
423    /// 2. This represents that both A and B are needed to cause C
424    #[test]
425    fn test_hyperedge() {
426        let mut memory = CausalMemory::new();
427
428        let fuel = memory.add_node(CausalNode::new("Fuel", "resource"));
429        let spark = memory.add_node(CausalNode::new("Spark", "event"));
430        let oxygen = memory.add_node(CausalNode::new("Oxygen", "resource"));
431        let fire = memory.add_node(CausalNode::new("Fire", "outcome"));
432
433        // Fire requires fuel AND spark AND oxygen
434        memory.add_hyperedge(Hyperedge::new(
435            vec![fuel, spark, oxygen],
436            vec![fire],
437            "causes",
438            0.99,
439        ));
440
441        assert_eq!(memory.hyperedge_count(), 1);
442        assert_eq!(memory.hyperedges[0].causes.len(), 3);
443    }
444}