Skip to main content

nodedb_graph/csr/
memory.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Memory estimation, access tracking, and prefetching for the CSR index.
4//!
5//! Provides RAM usage estimation for SpillController integration,
6//! hot/cold node identification via access counting, and CPU cache
7//! prefetch hints for BFS traversal optimization.
8
9use super::index::CsrIndex;
10
11impl CsrIndex {
12    /// Record a node access (called during traversals for hot/cold tracking).
13    pub fn record_access(&self, node_id: u32) {
14        let idx = node_id as usize;
15        if idx < self.access_counts.len() {
16            let c = &self.access_counts[idx];
17            c.set(c.get().saturating_add(1));
18        }
19    }
20
21    /// Get nodes with access count below `threshold` (cold nodes).
22    pub fn cold_nodes(&self, threshold: u32) -> Vec<u32> {
23        self.access_counts
24            .iter()
25            .enumerate()
26            .filter(|(_, c)| c.get() <= threshold)
27            .map(|(i, _)| i as u32)
28            .collect()
29    }
30
31    /// Number of hot nodes (access count > 0).
32    pub fn hot_node_count(&self) -> usize {
33        self.access_counts.iter().filter(|c| c.get() > 0).count()
34    }
35
36    /// Current query epoch (incremented per traversal call).
37    pub fn query_epoch(&self) -> u64 {
38        self.query_epoch
39    }
40
41    /// Reset access counters (called during compaction or periodically).
42    pub fn reset_access_counts(&mut self) {
43        self.access_counts.iter().for_each(|c| c.set(0));
44        self.query_epoch = 0;
45    }
46
47    /// Predictive prefetch: hint the OS to load a node's adjacency data
48    /// into the page cache before the traversal touches it.
49    ///
50    /// For in-memory dense CSR, this prefetches the cache line containing
51    /// the node's offset/target entries. For future mmap'd cold segments,
52    /// this would call `madvise(MADV_WILLNEED)` on the relevant pages.
53    ///
54    /// Called during BFS planning: when adding nodes to the next frontier,
55    /// prefetch their neighbors' data so it's resident when the BFS loop
56    /// reaches them on the next iteration.
57    #[inline]
58    pub fn prefetch_node(&self, node_id: u32) {
59        let idx = node_id as usize;
60        if idx + 1 < self.out_offsets.len() {
61            // SAFETY: We're just hinting the CPU to load this address.
62            // The offset is within bounds (checked above). This is a
63            // performance hint, not a correctness requirement.
64            #[cfg(target_arch = "x86_64")]
65            unsafe {
66                let ptr = self.out_offsets.as_ptr().add(idx) as *const u8;
67                std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
68            }
69        }
70    }
71
72    /// Prefetch a batch of nodes (called during BFS frontier expansion).
73    pub fn prefetch_batch(&self, node_ids: &[u32]) {
74        for &id in node_ids {
75            self.prefetch_node(id);
76        }
77    }
78
79    /// Evaluate graph memory pressure and return promotion/demotion hints.
80    ///
81    /// Uses SpillController-compatible thresholds (90%/75% hysteresis):
82    /// - Above spill threshold: demote cold nodes (spill to potential mmap)
83    /// - Below restore threshold: promote warm nodes back to hot RAM
84    ///
85    /// `utilization` = estimated memory usage as percentage (0-100).
86    /// Returns `(nodes_to_demote, nodes_to_promote)` counts.
87    pub fn evaluate_memory_pressure(
88        &self,
89        utilization: u8,
90        spill_threshold: u8,
91        restore_threshold: u8,
92    ) -> (usize, usize) {
93        if utilization >= spill_threshold {
94            // Above spill threshold: identify cold nodes to demote.
95            let cold = self.cold_nodes(0);
96            (cold.len(), 0)
97        } else if utilization <= restore_threshold {
98            // Below restore threshold: all nodes can stay hot.
99            (0, self.node_count())
100        } else {
101            // In hysteresis band: no action.
102            (0, 0)
103        }
104    }
105
106    /// Estimated memory usage of the dense CSR in bytes.
107    ///
108    /// Used for SpillController utilization calculation.
109    pub fn estimated_memory_bytes(&self) -> usize {
110        let offsets = (self.out_offsets.len() + self.in_offsets.len()) * 4;
111        let targets = (self.out_targets.len() + self.in_targets.len()) * 4;
112        let labels = (self.out_labels.len() + self.in_labels.len()) * 2;
113        let weights = self.out_weights.as_ref().map_or(0, |w| w.len() * 8)
114            + self.in_weights.as_ref().map_or(0, |w| w.len() * 8);
115        let buffer: usize = self
116            .buffer_out
117            .iter()
118            .chain(self.buffer_in.iter())
119            .map(|b| b.len() * 6) // (u16 + u32) per entry
120            .sum();
121        let buffer_weights: usize = self
122            .buffer_out_weights
123            .iter()
124            .chain(self.buffer_in_weights.iter())
125            .map(|b| b.len() * 8)
126            .sum();
127        let interning = self.id_to_node.iter().map(|s| s.len() + 24).sum::<usize>()
128            + self.id_to_label.iter().map(|s| s.len() + 24).sum::<usize>();
129        offsets + targets + labels + weights + buffer + buffer_weights + interning
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn access_tracking() {
139        let mut csr = CsrIndex::new();
140        csr.add_edge("a", "L", "b").unwrap();
141        csr.add_edge("b", "L", "c").unwrap();
142
143        let a_id = csr.node_id_raw("a").unwrap();
144        assert_eq!(csr.hot_node_count(), 0);
145
146        csr.record_access(a_id);
147        csr.record_access(a_id);
148        assert_eq!(csr.hot_node_count(), 1);
149
150        let cold = csr.cold_nodes(0);
151        assert!(!cold.contains(&a_id));
152        assert_eq!(cold.len(), 2); // b and c
153
154        csr.reset_access_counts();
155        assert_eq!(csr.hot_node_count(), 0);
156    }
157
158    #[test]
159    fn memory_estimation_includes_weights() {
160        let mut unweighted = CsrIndex::new();
161        unweighted.add_edge("a", "L", "b").unwrap();
162
163        let mut weighted = CsrIndex::new();
164        weighted.add_edge_weighted("a", "L", "b", 5.0).unwrap();
165
166        // Weighted graph uses more memory.
167        assert!(weighted.estimated_memory_bytes() >= unweighted.estimated_memory_bytes());
168    }
169
170    #[test]
171    fn evaluate_memory_pressure_hysteresis() {
172        let mut csr = CsrIndex::new();
173        csr.add_edge("a", "L", "b").unwrap();
174
175        // Above spill threshold.
176        let (demote, promote) = csr.evaluate_memory_pressure(95, 90, 75);
177        assert!(demote > 0);
178        assert_eq!(promote, 0);
179
180        // Below restore threshold.
181        let (demote, promote) = csr.evaluate_memory_pressure(60, 90, 75);
182        assert_eq!(demote, 0);
183        assert!(promote > 0);
184
185        // In hysteresis band.
186        let (demote, promote) = csr.evaluate_memory_pressure(80, 90, 75);
187        assert_eq!(demote, 0);
188        assert_eq!(promote, 0);
189    }
190}