Skip to main content

nodedb_graph/csr/
memory.rs

1//! Memory estimation, access tracking, and prefetching for the CSR index.
2//!
3//! Provides RAM usage estimation for SpillController integration,
4//! hot/cold node identification via access counting, and CPU cache
5//! prefetch hints for BFS traversal optimization.
6
7use super::index::CsrIndex;
8
9impl CsrIndex {
10    /// Record a node access (called during traversals for hot/cold tracking).
11    pub fn record_access(&self, node_id: u32) {
12        let idx = node_id as usize;
13        if idx < self.access_counts.len() {
14            let c = &self.access_counts[idx];
15            c.set(c.get().saturating_add(1));
16        }
17    }
18
19    /// Get nodes with access count below `threshold` (cold nodes).
20    pub fn cold_nodes(&self, threshold: u32) -> Vec<u32> {
21        self.access_counts
22            .iter()
23            .enumerate()
24            .filter(|(_, c)| c.get() <= threshold)
25            .map(|(i, _)| i as u32)
26            .collect()
27    }
28
29    /// Number of hot nodes (access count > 0).
30    pub fn hot_node_count(&self) -> usize {
31        self.access_counts.iter().filter(|c| c.get() > 0).count()
32    }
33
34    /// Current query epoch (incremented per traversal call).
35    pub fn query_epoch(&self) -> u64 {
36        self.query_epoch
37    }
38
39    /// Reset access counters (called during compaction or periodically).
40    pub fn reset_access_counts(&mut self) {
41        self.access_counts.iter().for_each(|c| c.set(0));
42        self.query_epoch = 0;
43    }
44
45    /// Predictive prefetch: hint the OS to load a node's adjacency data
46    /// into the page cache before the traversal touches it.
47    ///
48    /// For in-memory dense CSR, this prefetches the cache line containing
49    /// the node's offset/target entries. For future mmap'd cold segments,
50    /// this would call `madvise(MADV_WILLNEED)` on the relevant pages.
51    ///
52    /// Called during BFS planning: when adding nodes to the next frontier,
53    /// prefetch their neighbors' data so it's resident when the BFS loop
54    /// reaches them on the next iteration.
55    #[inline]
56    pub fn prefetch_node(&self, node_id: u32) {
57        let idx = node_id as usize;
58        if idx + 1 < self.out_offsets.len() {
59            // SAFETY: We're just hinting the CPU to load this address.
60            // The offset is within bounds (checked above). This is a
61            // performance hint, not a correctness requirement.
62            #[cfg(target_arch = "x86_64")]
63            unsafe {
64                let ptr = self.out_offsets.as_ptr().add(idx) as *const u8;
65                std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
66            }
67        }
68    }
69
70    /// Prefetch a batch of nodes (called during BFS frontier expansion).
71    pub fn prefetch_batch(&self, node_ids: &[u32]) {
72        for &id in node_ids {
73            self.prefetch_node(id);
74        }
75    }
76
77    /// Evaluate graph memory pressure and return promotion/demotion hints.
78    ///
79    /// Uses SpillController-compatible thresholds (90%/75% hysteresis):
80    /// - Above spill threshold: demote cold nodes (spill to potential mmap)
81    /// - Below restore threshold: promote warm nodes back to hot RAM
82    ///
83    /// `utilization` = estimated memory usage as percentage (0-100).
84    /// Returns `(nodes_to_demote, nodes_to_promote)` counts.
85    pub fn evaluate_memory_pressure(
86        &self,
87        utilization: u8,
88        spill_threshold: u8,
89        restore_threshold: u8,
90    ) -> (usize, usize) {
91        if utilization >= spill_threshold {
92            // Above spill threshold: identify cold nodes to demote.
93            let cold = self.cold_nodes(0);
94            (cold.len(), 0)
95        } else if utilization <= restore_threshold {
96            // Below restore threshold: all nodes can stay hot.
97            (0, self.node_count())
98        } else {
99            // In hysteresis band: no action.
100            (0, 0)
101        }
102    }
103
104    /// Estimated memory usage of the dense CSR in bytes.
105    ///
106    /// Used for SpillController utilization calculation.
107    pub fn estimated_memory_bytes(&self) -> usize {
108        let offsets = (self.out_offsets.len() + self.in_offsets.len()) * 4;
109        let targets = (self.out_targets.len() + self.in_targets.len()) * 4;
110        let labels = (self.out_labels.len() + self.in_labels.len()) * 2;
111        let weights = self.out_weights.as_ref().map_or(0, |w| w.len() * 8)
112            + self.in_weights.as_ref().map_or(0, |w| w.len() * 8);
113        let buffer: usize = self
114            .buffer_out
115            .iter()
116            .chain(self.buffer_in.iter())
117            .map(|b| b.len() * 6) // (u16 + u32) per entry
118            .sum();
119        let buffer_weights: usize = self
120            .buffer_out_weights
121            .iter()
122            .chain(self.buffer_in_weights.iter())
123            .map(|b| b.len() * 8)
124            .sum();
125        let interning = self.id_to_node.iter().map(|s| s.len() + 24).sum::<usize>()
126            + self.id_to_label.iter().map(|s| s.len() + 24).sum::<usize>();
127        offsets + targets + labels + weights + buffer + buffer_weights + interning
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn access_tracking() {
137        let mut csr = CsrIndex::new();
138        csr.add_edge("a", "L", "b");
139        csr.add_edge("b", "L", "c");
140
141        let a_id = csr.node_id("a").unwrap();
142        assert_eq!(csr.hot_node_count(), 0);
143
144        csr.record_access(a_id);
145        csr.record_access(a_id);
146        assert_eq!(csr.hot_node_count(), 1);
147
148        let cold = csr.cold_nodes(0);
149        assert!(!cold.contains(&a_id));
150        assert_eq!(cold.len(), 2); // b and c
151
152        csr.reset_access_counts();
153        assert_eq!(csr.hot_node_count(), 0);
154    }
155
156    #[test]
157    fn memory_estimation_includes_weights() {
158        let mut unweighted = CsrIndex::new();
159        unweighted.add_edge("a", "L", "b");
160
161        let mut weighted = CsrIndex::new();
162        weighted.add_edge_weighted("a", "L", "b", 5.0);
163
164        // Weighted graph uses more memory.
165        assert!(weighted.estimated_memory_bytes() >= unweighted.estimated_memory_bytes());
166    }
167
168    #[test]
169    fn evaluate_memory_pressure_hysteresis() {
170        let mut csr = CsrIndex::new();
171        csr.add_edge("a", "L", "b");
172
173        // Above spill threshold.
174        let (demote, promote) = csr.evaluate_memory_pressure(95, 90, 75);
175        assert!(demote > 0);
176        assert_eq!(promote, 0);
177
178        // Below restore threshold.
179        let (demote, promote) = csr.evaluate_memory_pressure(60, 90, 75);
180        assert_eq!(demote, 0);
181        assert!(promote > 0);
182
183        // In hysteresis band.
184        let (demote, promote) = csr.evaluate_memory_pressure(80, 90, 75);
185        assert_eq!(demote, 0);
186        assert_eq!(promote, 0);
187    }
188}