use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CausalNode {
pub id: Uuid,
pub label: String,
pub node_type: String,
pub description: String,
#[serde(skip)]
pub embedding: Option<Vec<f32>>,
pub created_at: DateTime<Utc>,
}
impl CausalNode {
pub fn new(label: impl Into<String>, node_type: impl Into<String>) -> Self {
Self {
id: Uuid::new_v4(),
label: label.into(),
node_type: node_type.into(),
description: String::new(),
embedding: None,
created_at: Utc::now(),
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = Some(embedding);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CausalEdge {
pub id: Uuid,
pub cause: Uuid,
pub effect: Uuid,
pub relationship: String,
pub strength: f32,
pub evidence_count: u32,
}
impl CausalEdge {
pub fn new(cause: Uuid, effect: Uuid, relationship: impl Into<String>, strength: f32) -> Self {
Self {
id: Uuid::new_v4(),
cause,
effect,
relationship: relationship.into(),
strength: strength.clamp(0.0, 1.0),
evidence_count: 1,
}
}
pub fn add_evidence(&mut self, observed_strength: f32) {
self.evidence_count += 1;
let n = self.evidence_count as f32;
self.strength = ((n - 1.0) * self.strength + observed_strength) / n;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hyperedge {
pub id: Uuid,
pub causes: Vec<Uuid>,
pub effects: Vec<Uuid>,
pub relationship: String,
pub strength: f32,
pub description: String,
}
impl Hyperedge {
pub fn new(
causes: Vec<Uuid>,
effects: Vec<Uuid>,
relationship: impl Into<String>,
strength: f32,
) -> Self {
Self {
id: Uuid::new_v4(),
causes,
effects,
relationship: relationship.into(),
strength: strength.clamp(0.0, 1.0),
description: String::new(),
}
}
}
pub struct CausalMemory {
nodes: HashMap<Uuid, CausalNode>,
edges: Vec<CausalEdge>,
hyperedges: Vec<Hyperedge>,
}
impl CausalMemory {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: Vec::new(),
hyperedges: Vec::new(),
}
}
pub fn add_node(&mut self, node: CausalNode) -> Uuid {
let id = node.id;
self.nodes.insert(id, node);
id
}
pub fn add_edge(&mut self, edge: CausalEdge) {
if let Some(existing) = self.edges.iter_mut().find(|e| {
e.cause == edge.cause && e.effect == edge.effect && e.relationship == edge.relationship
}) {
existing.add_evidence(edge.strength);
} else {
self.edges.push(edge);
}
}
pub fn add_hyperedge(&mut self, hyperedge: Hyperedge) {
self.hyperedges.push(hyperedge);
}
pub fn get_node(&self, id: Uuid) -> Option<&CausalNode> {
self.nodes.get(&id)
}
pub fn find_causes(&self, effect: Uuid) -> Vec<(&CausalEdge, Option<&CausalNode>)> {
self.edges
.iter()
.filter(|e| e.effect == effect)
.map(|e| (e, self.nodes.get(&e.cause)))
.collect()
}
pub fn find_effects(&self, cause: Uuid) -> Vec<(&CausalEdge, Option<&CausalNode>)> {
self.edges
.iter()
.filter(|e| e.cause == cause)
.map(|e| (e, self.nodes.get(&e.effect)))
.collect()
}
pub fn trace_chain(&self, start: Uuid, max_depth: usize) -> Vec<(Uuid, usize, f32)> {
let mut visited: HashSet<Uuid> = HashSet::new();
let mut result: Vec<(Uuid, usize, f32)> = Vec::new();
let mut queue: Vec<(Uuid, usize, f32)> = vec![(start, 0, 1.0)];
while let Some((current, depth, cumulative_strength)) = queue.pop() {
if depth > max_depth || visited.contains(¤t) {
continue;
}
visited.insert(current);
if current != start {
result.push((current, depth, cumulative_strength));
}
for edge in self.edges.iter().filter(|e| e.cause == current) {
let new_strength = cumulative_strength * edge.strength;
if new_strength > 0.1 {
queue.push((edge.effect, depth + 1, new_strength));
}
}
}
result
}
pub fn find_by_relationship(&self, relationship: &str) -> Vec<&CausalEdge> {
self.edges
.iter()
.filter(|e| e.relationship == relationship)
.collect()
}
pub fn strongest_relationships(&self, limit: usize) -> Vec<&CausalEdge> {
let mut edges: Vec<_> = self.edges.iter().collect();
edges.sort_by(|a, b| {
b.strength
.partial_cmp(&a.strength)
.unwrap_or(std::cmp::Ordering::Equal)
});
edges.into_iter().take(limit).collect()
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn hyperedge_count(&self) -> usize {
self.hyperedges.len()
}
}
impl Default for CausalMemory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_causal_graph() {
let mut memory = CausalMemory::new();
let unwrap_id = memory.add_node(
CausalNode::new("Using unwrap()", "code_pattern")
.with_description("Calling .unwrap() on Option/Result"),
);
let panic_id = memory.add_node(
CausalNode::new("Runtime panic", "error")
.with_description("Program crashes at runtime"),
);
let option_handling_id = memory.add_node(
CausalNode::new("Proper Option handling", "code_pattern")
.with_description("Using match or if-let"),
);
let reliability_id = memory.add_node(
CausalNode::new("Code reliability", "quality")
.with_description("Code works correctly in edge cases"),
);
memory.add_edge(CausalEdge::new(unwrap_id, panic_id, "causes", 0.8));
memory.add_edge(CausalEdge::new(
option_handling_id,
reliability_id,
"improves",
0.9,
));
memory.add_edge(CausalEdge::new(
option_handling_id,
panic_id,
"prevents",
0.95,
));
let panic_causes = memory.find_causes(panic_id);
assert!(!panic_causes.is_empty());
let direct_cause = panic_causes
.iter()
.find(|(e, _)| e.relationship == "causes");
assert!(direct_cause.is_some());
assert!(direct_cause.unwrap().1.unwrap().label.contains("unwrap"));
let handling_effects = memory.find_effects(option_handling_id);
assert_eq!(handling_effects.len(), 2);
}
#[test]
fn test_causal_chain() {
let mut memory = CausalMemory::new();
let a = memory.add_node(CausalNode::new("A", "concept"));
let b = memory.add_node(CausalNode::new("B", "concept"));
let c = memory.add_node(CausalNode::new("C", "concept"));
let d = memory.add_node(CausalNode::new("D", "concept"));
memory.add_edge(CausalEdge::new(a, b, "causes", 0.9));
memory.add_edge(CausalEdge::new(b, c, "causes", 0.8));
memory.add_edge(CausalEdge::new(c, d, "causes", 0.7));
let chain = memory.trace_chain(a, 10);
assert_eq!(chain.len(), 3);
let b_entry = chain.iter().find(|(id, _, _)| *id == b).unwrap();
assert_eq!(b_entry.1, 1);
let d_entry = chain.iter().find(|(id, _, _)| *id == d).unwrap();
assert_eq!(d_entry.1, 3);
assert!(d_entry.2 < b_entry.2); }
#[test]
fn test_evidence_accumulation() {
let mut memory = CausalMemory::new();
let cause = Uuid::new_v4();
let effect = Uuid::new_v4();
memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.8));
memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.9));
memory.add_edge(CausalEdge::new(cause, effect, "causes", 0.85));
assert_eq!(memory.edge_count(), 1);
let edge = &memory.edges[0];
assert_eq!(edge.evidence_count, 3);
assert!(edge.strength > 0.8 && edge.strength < 0.9);
}
#[test]
fn test_hyperedge() {
let mut memory = CausalMemory::new();
let fuel = memory.add_node(CausalNode::new("Fuel", "resource"));
let spark = memory.add_node(CausalNode::new("Spark", "event"));
let oxygen = memory.add_node(CausalNode::new("Oxygen", "resource"));
let fire = memory.add_node(CausalNode::new("Fire", "outcome"));
memory.add_hyperedge(Hyperedge::new(
vec![fuel, spark, oxygen],
vec![fire],
"causes",
0.99,
));
assert_eq!(memory.hyperedge_count(), 1);
assert_eq!(memory.hyperedges[0].causes.len(), 3);
}
}