Skip to main content

agentic_memory/v3/indexes/
causal.rs

1//! DAG-based causal index. Tracks decision chains: what led to what.
2
3use super::{Index, IndexResult};
4use crate::v3::block::{Block, BlockContent, BlockHash, BlockType};
5use std::collections::{HashMap, HashSet, VecDeque};
6
7/// DAG-based causal index.
8pub struct CausalIndex {
9    /// Forward edges: block -> blocks it caused
10    forward: HashMap<u64, Vec<u64>>,
11
12    /// Backward edges: block -> blocks that caused it
13    backward: HashMap<u64, Vec<u64>>,
14
15    /// Decision blocks (entry points for causal queries)
16    decisions: HashSet<u64>,
17
18    /// Block hashes
19    hashes: HashMap<u64, BlockHash>,
20}
21
22impl CausalIndex {
23    pub fn new() -> Self {
24        Self {
25            forward: HashMap::new(),
26            backward: HashMap::new(),
27            decisions: HashSet::new(),
28            hashes: HashMap::new(),
29        }
30    }
31
32    /// Add causal link: cause -> effect
33    pub fn add_link(&mut self, cause: u64, effect: u64) {
34        self.forward.entry(cause).or_default().push(effect);
35        self.backward.entry(effect).or_default().push(cause);
36    }
37
38    /// Get all blocks that led to this block (ancestors)
39    pub fn get_ancestors(&self, sequence: u64, max_depth: usize) -> Vec<IndexResult> {
40        let mut result = Vec::new();
41        let mut visited = HashSet::new();
42        let mut queue = VecDeque::new();
43
44        queue.push_back((sequence, 0));
45
46        while let Some((current, depth)) = queue.pop_front() {
47            if depth > max_depth || visited.contains(&current) {
48                continue;
49            }
50            visited.insert(current);
51
52            if current != sequence {
53                if let Some(&hash) = self.hashes.get(&current) {
54                    result.push(IndexResult {
55                        block_sequence: current,
56                        block_hash: hash,
57                        score: 1.0 - (depth as f32 / max_depth as f32),
58                    });
59                }
60            }
61
62            if let Some(causes) = self.backward.get(&current) {
63                for &cause in causes {
64                    queue.push_back((cause, depth + 1));
65                }
66            }
67        }
68
69        result
70    }
71
72    /// Get all blocks that resulted from this block (descendants)
73    pub fn get_descendants(&self, sequence: u64, max_depth: usize) -> Vec<IndexResult> {
74        let mut result = Vec::new();
75        let mut visited = HashSet::new();
76        let mut queue = VecDeque::new();
77
78        queue.push_back((sequence, 0));
79
80        while let Some((current, depth)) = queue.pop_front() {
81            if depth > max_depth || visited.contains(&current) {
82                continue;
83            }
84            visited.insert(current);
85
86            if current != sequence {
87                if let Some(&hash) = self.hashes.get(&current) {
88                    result.push(IndexResult {
89                        block_sequence: current,
90                        block_hash: hash,
91                        score: 1.0 - (depth as f32 / max_depth as f32),
92                    });
93                }
94            }
95
96            if let Some(effects) = self.forward.get(&current) {
97                for &effect in effects {
98                    queue.push_back((effect, depth + 1));
99                }
100            }
101        }
102
103        result
104    }
105
106    /// Get all decision blocks
107    pub fn get_decisions(&self) -> Vec<IndexResult> {
108        self.decisions
109            .iter()
110            .filter_map(|&seq| {
111                self.hashes.get(&seq).map(|&hash| IndexResult {
112                    block_sequence: seq,
113                    block_hash: hash,
114                    score: 1.0,
115                })
116            })
117            .collect()
118    }
119
120    /// Get decision chain leading to a block
121    pub fn get_decision_chain(&self, sequence: u64) -> Vec<IndexResult> {
122        self.get_ancestors(sequence, 100)
123            .into_iter()
124            .filter(|r| self.decisions.contains(&r.block_sequence))
125            .collect()
126    }
127}
128
129impl Default for CausalIndex {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl Index for CausalIndex {
136    fn index(&mut self, block: &Block) {
137        self.hashes.insert(block.sequence, block.hash);
138
139        if matches!(block.block_type, BlockType::Decision) {
140            self.decisions.insert(block.sequence);
141        }
142
143        // Extract causal links from content
144        match &block.content {
145            BlockContent::Decision {
146                evidence_blocks, ..
147            } => {
148                for evidence_hash in evidence_blocks {
149                    for (&seq, &hash) in &self.hashes {
150                        if &hash == evidence_hash {
151                            self.add_link(seq, block.sequence);
152                            break;
153                        }
154                    }
155                }
156            }
157            // Tool results are caused by tool calls
158            BlockContent::Tool { .. } if block.sequence > 0 => {
159                self.add_link(block.sequence - 1, block.sequence);
160            }
161            _ => {}
162        }
163
164        // Default: previous block causes current block
165        if block.sequence > 0 {
166            self.add_link(block.sequence - 1, block.sequence);
167        }
168    }
169
170    fn remove(&mut self, sequence: u64) {
171        self.forward.remove(&sequence);
172        self.backward.remove(&sequence);
173        self.decisions.remove(&sequence);
174        self.hashes.remove(&sequence);
175
176        for edges in self.forward.values_mut() {
177            edges.retain(|&s| s != sequence);
178        }
179        for edges in self.backward.values_mut() {
180            edges.retain(|&s| s != sequence);
181        }
182    }
183
184    fn rebuild(&mut self, blocks: impl Iterator<Item = Block>) {
185        self.forward.clear();
186        self.backward.clear();
187        self.decisions.clear();
188        self.hashes.clear();
189        for block in blocks {
190            self.index(&block);
191        }
192    }
193}