converge_knowledge/agentic/
causal.rs1use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CausalNode {
17 pub id: Uuid,
19
20 pub label: String,
22
23 pub node_type: String,
25
26 pub description: String,
28
29 #[serde(skip)]
31 pub embedding: Option<Vec<f32>>,
32
33 pub created_at: DateTime<Utc>,
35}
36
37impl CausalNode {
38 pub fn new(label: impl Into<String>, node_type: impl Into<String>) -> Self {
40 Self {
41 id: Uuid::new_v4(),
42 label: label.into(),
43 node_type: node_type.into(),
44 description: String::new(),
45 embedding: None,
46 created_at: Utc::now(),
47 }
48 }
49
50 pub fn with_description(mut self, description: impl Into<String>) -> Self {
52 self.description = description.into();
53 self
54 }
55
56 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
58 self.embedding = Some(embedding);
59 self
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CausalEdge {
66 pub id: Uuid,
68
69 pub cause: Uuid,
71
72 pub effect: Uuid,
74
75 pub relationship: String,
77
78 pub strength: f32,
80
81 pub evidence_count: u32,
83}
84
85impl CausalEdge {
86 pub fn new(cause: Uuid, effect: Uuid, relationship: impl Into<String>, strength: f32) -> Self {
88 Self {
89 id: Uuid::new_v4(),
90 cause,
91 effect,
92 relationship: relationship.into(),
93 strength: strength.clamp(0.0, 1.0),
94 evidence_count: 1,
95 }
96 }
97
98 pub fn add_evidence(&mut self, observed_strength: f32) {
100 self.evidence_count += 1;
101 let n = self.evidence_count as f32;
103 self.strength = ((n - 1.0) * self.strength + observed_strength) / n;
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct Hyperedge {
113 pub id: Uuid,
115
116 pub causes: Vec<Uuid>,
118
119 pub effects: Vec<Uuid>,
121
122 pub relationship: String,
124
125 pub strength: f32,
127
128 pub description: String,
130}
131
132impl Hyperedge {
133 pub fn new(
135 causes: Vec<Uuid>,
136 effects: Vec<Uuid>,
137 relationship: impl Into<String>,
138 strength: f32,
139 ) -> Self {
140 Self {
141 id: Uuid::new_v4(),
142 causes,
143 effects,
144 relationship: relationship.into(),
145 strength: strength.clamp(0.0, 1.0),
146 description: String::new(),
147 }
148 }
149}
150
151pub struct CausalMemory {
153 nodes: HashMap<Uuid, CausalNode>,
154 edges: Vec<CausalEdge>,
155 hyperedges: Vec<Hyperedge>,
156}
157
158impl CausalMemory {
159 pub fn new() -> Self {
161 Self {
162 nodes: HashMap::new(),
163 edges: Vec::new(),
164 hyperedges: Vec::new(),
165 }
166 }
167
168 pub fn add_node(&mut self, node: CausalNode) -> Uuid {
170 let id = node.id;
171 self.nodes.insert(id, node);
172 id
173 }
174
175 pub fn add_edge(&mut self, edge: CausalEdge) {
177 if let Some(existing) = self.edges.iter_mut().find(|e| {
179 e.cause == edge.cause && e.effect == edge.effect && e.relationship == edge.relationship
180 }) {
181 existing.add_evidence(edge.strength);
182 } else {
183 self.edges.push(edge);
184 }
185 }
186
187 pub fn add_hyperedge(&mut self, hyperedge: Hyperedge) {
189 self.hyperedges.push(hyperedge);
190 }
191
192 pub fn get_node(&self, id: Uuid) -> Option<&CausalNode> {
194 self.nodes.get(&id)
195 }
196
197 pub fn find_causes(&self, effect: Uuid) -> Vec<(&CausalEdge, Option<&CausalNode>)> {
199 self.edges
200 .iter()
201 .filter(|e| e.effect == effect)
202 .map(|e| (e, self.nodes.get(&e.cause)))
203 .collect()
204 }
205
206 pub fn find_effects(&self, cause: Uuid) -> Vec<(&CausalEdge, Option<&CausalNode>)> {
208 self.edges
209 .iter()
210 .filter(|e| e.cause == cause)
211 .map(|e| (e, self.nodes.get(&e.effect)))
212 .collect()
213 }
214
215 pub fn trace_chain(&self, start: Uuid, max_depth: usize) -> Vec<(Uuid, usize, f32)> {
217 let mut visited: HashSet<Uuid> = HashSet::new();
218 let mut result: Vec<(Uuid, usize, f32)> = Vec::new();
219 let mut queue: Vec<(Uuid, usize, f32)> = vec![(start, 0, 1.0)];
220
221 while let Some((current, depth, cumulative_strength)) = queue.pop() {
222 if depth > max_depth || visited.contains(¤t) {
223 continue;
224 }
225 visited.insert(current);
226
227 if current != start {
228 result.push((current, depth, cumulative_strength));
229 }
230
231 for edge in self.edges.iter().filter(|e| e.cause == current) {
233 let new_strength = cumulative_strength * edge.strength;
234 if new_strength > 0.1 {
235 queue.push((edge.effect, depth + 1, new_strength));
237 }
238 }
239 }
240
241 result
242 }
243
244 pub fn find_by_relationship(&self, relationship: &str) -> Vec<&CausalEdge> {
246 self.edges
247 .iter()
248 .filter(|e| e.relationship == relationship)
249 .collect()
250 }
251
252 pub fn strongest_relationships(&self, limit: usize) -> Vec<&CausalEdge> {
254 let mut edges: Vec<_> = self.edges.iter().collect();
255 edges.sort_by(|a, b| {
256 b.strength
257 .partial_cmp(&a.strength)
258 .unwrap_or(std::cmp::Ordering::Equal)
259 });
260 edges.into_iter().take(limit).collect()
261 }
262
263 pub fn node_count(&self) -> usize {
265 self.nodes.len()
266 }
267
268 pub fn edge_count(&self) -> usize {
270 self.edges.len()
271 }
272
273 pub fn hyperedge_count(&self) -> usize {
275 self.hyperedges.len()
276 }
277}
278
279impl Default for CausalMemory {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
296 fn test_causal_graph() {
297 let mut memory = CausalMemory::new();
298
299 let unwrap_id = memory.add_node(
301 CausalNode::new("Using unwrap()", "code_pattern")
302 .with_description("Calling .unwrap() on Option/Result"),
303 );
304
305 let panic_id = memory.add_node(
306 CausalNode::new("Runtime panic", "error")
307 .with_description("Program crashes at runtime"),
308 );
309
310 let option_handling_id = memory.add_node(
311 CausalNode::new("Proper Option handling", "code_pattern")
312 .with_description("Using match or if-let"),
313 );
314
315 let reliability_id = memory.add_node(
316 CausalNode::new("Code reliability", "quality")
317 .with_description("Code works correctly in edge cases"),
318 );
319
320 memory.add_edge(CausalEdge::new(unwrap_id, panic_id, "causes", 0.8));
322 memory.add_edge(CausalEdge::new(
323 option_handling_id,
324 reliability_id,
325 "improves",
326 0.9,
327 ));
328 memory.add_edge(CausalEdge::new(
329 option_handling_id,
330 panic_id,
331 "prevents",
332 0.95,
333 ));
334
335 let panic_causes = memory.find_causes(panic_id);
338 assert!(!panic_causes.is_empty());
339 let direct_cause = panic_causes
341 .iter()
342 .find(|(e, _)| e.relationship == "causes");
343 assert!(direct_cause.is_some());
344 assert!(direct_cause.unwrap().1.unwrap().label.contains("unwrap"));
345
346 let handling_effects = memory.find_effects(option_handling_id);
348 assert_eq!(handling_effects.len(), 2);
349 }
350
351 #[test]
358 fn test_causal_chain() {
359 let mut memory = CausalMemory::new();
360
361 let a = memory.add_node(CausalNode::new("A", "concept"));
363 let b = memory.add_node(CausalNode::new("B", "concept"));
364 let c = memory.add_node(CausalNode::new("C", "concept"));
365 let d = memory.add_node(CausalNode::new("D", "concept"));
366
367 memory.add_edge(CausalEdge::new(a, b, "causes", 0.9));
368 memory.add_edge(CausalEdge::new(b, c, "causes", 0.8));
369 memory.add_edge(CausalEdge::new(c, d, "causes", 0.7));
370
371 let chain = memory.trace_chain(a, 10);
373
374 assert_eq!(chain.len(), 3); let b_entry = chain.iter().find(|(id, _, _)| *id == b).unwrap();
378 assert_eq!(b_entry.1, 1); let d_entry = chain.iter().find(|(id, _, _)| *id == d).unwrap();
381 assert_eq!(d_entry.1, 3); assert!(d_entry.2 < b_entry.2); }
386
387 #[test]
394 fn test_evidence_accumulation() {
395 let mut memory = CausalMemory::new();
396
397 let cause = Uuid::new_v4();
398 let effect = Uuid::new_v4();
399
400 memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.8));
402
403 memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.9));
405
406 memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.85));
408
409 assert_eq!(memory.edge_count(), 1);
411
412 let edge = &memory.edges[0];
413 assert_eq!(edge.evidence_count, 3);
414
415 assert!(edge.strength > 0.8 && edge.strength < 0.9);
417 }
418
419 #[test]
425 fn test_hyperedge() {
426 let mut memory = CausalMemory::new();
427
428 let fuel = memory.add_node(CausalNode::new("Fuel", "resource"));
429 let spark = memory.add_node(CausalNode::new("Spark", "event"));
430 let oxygen = memory.add_node(CausalNode::new("Oxygen", "resource"));
431 let fire = memory.add_node(CausalNode::new("Fire", "outcome"));
432
433 memory.add_hyperedge(Hyperedge::new(
435 vec![fuel, spark, oxygen],
436 vec![fire],
437 "causes",
438 0.99,
439 ));
440
441 assert_eq!(memory.hyperedge_count(), 1);
442 assert_eq!(memory.hyperedges[0].causes.len(), 3);
443 }
444}