Skip to main content

nodedb_graph/
csr_persist.rs

1//! CSR checkpoint serialization/deserialization and compaction.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7use super::csr::CsrIndex;
8
9#[derive(Serialize, Deserialize)]
10struct CsrSnapshot {
11    nodes: Vec<String>,
12    labels: Vec<String>,
13    out_offsets: Vec<u32>,
14    out_targets: Vec<u32>,
15    out_labels: Vec<u16>,
16    in_offsets: Vec<u32>,
17    in_targets: Vec<u32>,
18    in_labels: Vec<u16>,
19    buffer_out: Vec<Vec<(u16, u32)>>,
20    buffer_in: Vec<Vec<(u16, u32)>>,
21    deleted: Vec<(u32, u16, u32)>,
22}
23
24impl CsrIndex {
25    /// Merge the mutable buffer into dense CSR arrays.
26    pub fn compact(&mut self) {
27        let n = self.id_to_node.len();
28        let mut new_out_edges: Vec<Vec<(u16, u32)>> = vec![Vec::new(); n];
29        let mut new_in_edges: Vec<Vec<(u16, u32)>> = vec![Vec::new(); n];
30
31        // Collect surviving dense edges.
32        for node in 0..n {
33            let node_id = node as u32;
34            let idx = node_id as usize;
35
36            if idx + 1 < self.out_offsets.len() {
37                let start = self.out_offsets[idx] as usize;
38                let end = self.out_offsets[idx + 1] as usize;
39                for i in start..end {
40                    let lid = self.out_labels[i];
41                    let dst = self.out_targets[i];
42                    if !self.deleted_edges.contains(&(node_id, lid, dst)) {
43                        new_out_edges[node].push((lid, dst));
44                    }
45                }
46            }
47
48            if idx + 1 < self.in_offsets.len() {
49                let start = self.in_offsets[idx] as usize;
50                let end = self.in_offsets[idx + 1] as usize;
51                for i in start..end {
52                    let lid = self.in_labels[i];
53                    let src = self.in_targets[i];
54                    if !self.deleted_edges.contains(&(src, lid, node_id)) {
55                        new_in_edges[node].push((lid, src));
56                    }
57                }
58            }
59        }
60
61        // Merge buffer edges.
62        for node in 0..n {
63            for &(lid, dst) in &self.buffer_out[node] {
64                if !new_out_edges[node]
65                    .iter()
66                    .any(|&(l, d)| l == lid && d == dst)
67                {
68                    new_out_edges[node].push((lid, dst));
69                }
70            }
71            for &(lid, src) in &self.buffer_in[node] {
72                if !new_in_edges[node]
73                    .iter()
74                    .any(|&(l, s)| l == lid && s == src)
75                {
76                    new_in_edges[node].push((lid, src));
77                }
78            }
79        }
80
81        // Build new dense arrays.
82        let (out_offsets, out_targets, out_labels) = build_dense(&new_out_edges);
83        let (in_offsets, in_targets, in_labels) = build_dense(&new_in_edges);
84
85        self.out_offsets = out_offsets;
86        self.out_targets = out_targets;
87        self.out_labels = out_labels;
88        self.in_offsets = in_offsets;
89        self.in_targets = in_targets;
90        self.in_labels = in_labels;
91
92        for buf in &mut self.buffer_out {
93            buf.clear();
94        }
95        for buf in &mut self.buffer_in {
96            buf.clear();
97        }
98        self.deleted_edges.clear();
99    }
100
101    /// Serialize the index to MessagePack bytes for storage.
102    pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
103        let snapshot = CsrSnapshot {
104            nodes: self.id_to_node.clone(),
105            labels: self.id_to_label.clone(),
106            out_offsets: self.out_offsets.clone(),
107            out_targets: self.out_targets.clone(),
108            out_labels: self.out_labels.clone(),
109            in_offsets: self.in_offsets.clone(),
110            in_targets: self.in_targets.clone(),
111            in_labels: self.in_labels.clone(),
112            buffer_out: self.buffer_out.clone(),
113            buffer_in: self.buffer_in.clone(),
114            deleted: self.deleted_edges.iter().copied().collect(),
115        };
116        match rmp_serde::to_vec_named(&snapshot) {
117            Ok(bytes) => bytes,
118            Err(e) => {
119                tracing::error!(error = %e, "CSR checkpoint serialization failed");
120                Vec::new()
121            }
122        }
123    }
124
125    /// Restore an index from a checkpoint snapshot.
126    pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
127        let snap: CsrSnapshot = rmp_serde::from_slice(bytes).ok()?;
128
129        let node_to_id: HashMap<String, u32> = snap
130            .nodes
131            .iter()
132            .enumerate()
133            .map(|(i, n)| (n.clone(), i as u32))
134            .collect();
135        let label_to_id: HashMap<String, u16> = snap
136            .labels
137            .iter()
138            .enumerate()
139            .map(|(i, l)| (l.clone(), i as u16))
140            .collect();
141
142        Some(Self {
143            node_to_id,
144            id_to_node: snap.nodes,
145            label_to_id,
146            id_to_label: snap.labels,
147            out_offsets: snap.out_offsets,
148            out_targets: snap.out_targets,
149            out_labels: snap.out_labels,
150            in_offsets: snap.in_offsets,
151            in_targets: snap.in_targets,
152            in_labels: snap.in_labels,
153            buffer_out: snap.buffer_out,
154            buffer_in: snap.buffer_in,
155            deleted_edges: snap.deleted.into_iter().collect(),
156        })
157    }
158}
159
160/// Build contiguous offset/target/label arrays from per-node edge lists.
161pub(crate) fn build_dense(edges: &[Vec<(u16, u32)>]) -> (Vec<u32>, Vec<u32>, Vec<u16>) {
162    let n = edges.len();
163    let total: usize = edges.iter().map(|e| e.len()).sum();
164    let mut offsets = Vec::with_capacity(n + 1);
165    let mut targets = Vec::with_capacity(total);
166    let mut labels = Vec::with_capacity(total);
167
168    let mut offset = 0u32;
169    for node_edges in edges {
170        offsets.push(offset);
171        for &(lid, target) in node_edges {
172            targets.push(target);
173            labels.push(lid);
174        }
175        offset += node_edges.len() as u32;
176    }
177    offsets.push(offset);
178
179    (offsets, targets, labels)
180}