use std::collections::HashMap;
use zerompk::{FromMessagePack, ToMessagePack};
use super::index::CsrIndex;
const RKYV_MAGIC: &[u8; 6] = b"RKCSR\0";
#[derive(ToMessagePack, FromMessagePack)]
struct CsrSnapshotMsgpack {
nodes: Vec<String>,
labels: Vec<String>,
out_offsets: Vec<u32>,
out_targets: Vec<u32>,
out_labels: Vec<u16>,
in_offsets: Vec<u32>,
in_targets: Vec<u32>,
in_labels: Vec<u16>,
buffer_out: Vec<Vec<(u16, u32)>>,
buffer_in: Vec<Vec<(u16, u32)>>,
deleted: Vec<(u32, u16, u32)>,
has_weights: bool,
out_weights: Option<Vec<f64>>,
in_weights: Option<Vec<f64>>,
buffer_out_weights: Vec<Vec<f64>>,
buffer_in_weights: Vec<Vec<f64>>,
}
#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
struct CsrSnapshotRkyv {
nodes: Vec<String>,
labels: Vec<String>,
out_offsets: Vec<u32>,
out_targets: Vec<u32>,
out_labels: Vec<u16>,
in_offsets: Vec<u32>,
in_targets: Vec<u32>,
in_labels: Vec<u16>,
buffer_out: Vec<Vec<(u16, u32)>>,
buffer_in: Vec<Vec<(u16, u32)>>,
deleted: Vec<(u32, u16, u32)>,
has_weights: bool,
out_weights: Option<Vec<f64>>,
in_weights: Option<Vec<f64>>,
buffer_out_weights: Vec<Vec<f64>>,
buffer_in_weights: Vec<Vec<f64>>,
}
impl CsrIndex {
pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
let snapshot = CsrSnapshotRkyv {
nodes: self.id_to_node.clone(),
labels: self.id_to_label.clone(),
out_offsets: self.out_offsets.clone(),
out_targets: self.out_targets.to_vec(),
out_labels: self.out_labels.to_vec(),
in_offsets: self.in_offsets.clone(),
in_targets: self.in_targets.to_vec(),
in_labels: self.in_labels.to_vec(),
buffer_out: self.buffer_out.clone(),
buffer_in: self.buffer_in.clone(),
deleted: self.deleted_edges.iter().copied().collect(),
has_weights: self.has_weights,
out_weights: self.out_weights.as_ref().map(|w| w.to_vec()),
in_weights: self.in_weights.as_ref().map(|w| w.to_vec()),
buffer_out_weights: self.buffer_out_weights.clone(),
buffer_in_weights: self.buffer_in_weights.clone(),
};
let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
.expect("CSR rkyv serialization should not fail");
let mut buf = Vec::with_capacity(RKYV_MAGIC.len() + rkyv_bytes.len());
buf.extend_from_slice(RKYV_MAGIC);
buf.extend_from_slice(&rkyv_bytes);
buf
}
pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
if bytes.len() > RKYV_MAGIC.len() && &bytes[..RKYV_MAGIC.len()] == RKYV_MAGIC {
return Self::from_rkyv_checkpoint(&bytes[RKYV_MAGIC.len()..]);
}
Self::from_msgpack_checkpoint(bytes)
}
fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
aligned.extend_from_slice(bytes);
#[cfg(target_endian = "little")]
{
Self::from_rkyv_zero_copy(aligned)
}
#[cfg(not(target_endian = "little"))]
{
let snap: CsrSnapshotRkyv =
rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
Some(Self::from_snapshot_fields(snap))
}
}
#[cfg(target_endian = "little")]
fn from_rkyv_zero_copy(aligned: rkyv::util::AlignedVec) -> Option<Self> {
use super::dense_array::DenseArray;
let backing = std::sync::Arc::new(aligned);
let archived =
rkyv::access::<rkyv::Archived<CsrSnapshotRkyv>, rkyv::rancor::Error>(&backing).ok()?;
let out_targets = unsafe {
let s = archived.out_targets.as_slice();
DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
};
let out_labels = unsafe {
let s = archived.out_labels.as_slice();
DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u16>(), s.len())
};
let in_targets = unsafe {
let s = archived.in_targets.as_slice();
DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
};
let in_labels = unsafe {
let s = archived.in_labels.as_slice();
DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u16>(), s.len())
};
let out_weights = archived.out_weights.as_ref().map(|w| unsafe {
let s = w.as_slice();
DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
});
let in_weights = archived.in_weights.as_ref().map(|w| unsafe {
let s = w.as_slice();
DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
});
let snap: CsrSnapshotRkyv =
rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&backing).ok()?;
let node_to_id: HashMap<String, u32> = snap
.nodes
.iter()
.enumerate()
.map(|(i, n)| (n.clone(), i as u32))
.collect();
let label_to_id: HashMap<String, u16> = snap
.labels
.iter()
.enumerate()
.map(|(i, l)| (l.clone(), i as u16))
.collect();
let node_count = snap.nodes.len();
let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
snap.buffer_out_weights
} else {
vec![Vec::new(); node_count]
};
let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
snap.buffer_in_weights
} else {
vec![Vec::new(); node_count]
};
Some(Self {
node_to_id,
id_to_node: snap.nodes,
label_to_id,
id_to_label: snap.labels,
out_offsets: snap.out_offsets,
out_targets,
out_labels,
out_weights,
in_offsets: snap.in_offsets,
in_targets,
in_labels,
in_weights,
buffer_out: snap.buffer_out,
buffer_in: snap.buffer_in,
buffer_out_weights,
buffer_in_weights,
deleted_edges: snap.deleted.into_iter().collect(),
has_weights: snap.has_weights,
node_label_bits: vec![0; node_count],
node_label_to_id: HashMap::new(),
node_label_names: Vec::new(),
access_counts,
query_epoch: 0,
})
}
fn from_msgpack_checkpoint(bytes: &[u8]) -> Option<Self> {
let snap: CsrSnapshotMsgpack = zerompk::from_msgpack(bytes).ok()?;
Some(Self::from_snapshot_fields(CsrSnapshotRkyv {
nodes: snap.nodes,
labels: snap.labels,
out_offsets: snap.out_offsets,
out_targets: snap.out_targets,
out_labels: snap.out_labels,
in_offsets: snap.in_offsets,
in_targets: snap.in_targets,
in_labels: snap.in_labels,
buffer_out: snap.buffer_out,
buffer_in: snap.buffer_in,
deleted: snap.deleted,
has_weights: snap.has_weights,
out_weights: snap.out_weights,
in_weights: snap.in_weights,
buffer_out_weights: snap.buffer_out_weights,
buffer_in_weights: snap.buffer_in_weights,
}))
}
fn from_snapshot_fields(snap: CsrSnapshotRkyv) -> Self {
let node_to_id: HashMap<String, u32> = snap
.nodes
.iter()
.enumerate()
.map(|(i, n)| (n.clone(), i as u32))
.collect();
let label_to_id: HashMap<String, u16> = snap
.labels
.iter()
.enumerate()
.map(|(i, l)| (l.clone(), i as u16))
.collect();
let node_count = snap.nodes.len();
let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
snap.buffer_out_weights
} else {
vec![Vec::new(); node_count]
};
let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
snap.buffer_in_weights
} else {
vec![Vec::new(); node_count]
};
Self {
node_to_id,
id_to_node: snap.nodes,
label_to_id,
id_to_label: snap.labels,
out_offsets: snap.out_offsets,
out_targets: snap.out_targets.into(),
out_labels: snap.out_labels.into(),
out_weights: snap.out_weights.map(Into::into),
in_offsets: snap.in_offsets,
in_targets: snap.in_targets.into(),
in_labels: snap.in_labels.into(),
in_weights: snap.in_weights.map(Into::into),
buffer_out: snap.buffer_out,
buffer_in: snap.buffer_in,
buffer_out_weights,
buffer_in_weights,
deleted_edges: snap.deleted.into_iter().collect(),
has_weights: snap.has_weights,
node_label_bits: vec![0; node_count],
node_label_to_id: HashMap::new(),
node_label_names: Vec::new(),
access_counts,
query_epoch: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr::index::Direction;
#[test]
fn checkpoint_roundtrip_unweighted() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "KNOWS", "b");
csr.add_edge("b", "KNOWS", "c");
csr.compact();
let bytes = csr.checkpoint_to_bytes();
let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
assert_eq!(restored.node_count(), 3);
assert_eq!(restored.edge_count(), 2);
assert!(!restored.has_weights());
let n = restored.neighbors("a", Some("KNOWS"), Direction::Out);
assert_eq!(n.len(), 1);
assert_eq!(n[0].1, "b");
}
#[test]
fn checkpoint_roundtrip_weighted() {
let mut csr = CsrIndex::new();
csr.add_edge_weighted("a", "R", "b", 2.5);
csr.add_edge_weighted("b", "R", "c", 7.0);
csr.add_edge("c", "R", "d");
csr.compact();
let bytes = csr.checkpoint_to_bytes();
let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
assert!(restored.has_weights());
assert_eq!(restored.edge_weight("a", "R", "b"), Some(2.5));
assert_eq!(restored.edge_weight("b", "R", "c"), Some(7.0));
assert_eq!(restored.edge_weight("c", "R", "d"), Some(1.0));
}
#[test]
fn checkpoint_roundtrip_with_buffer() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
let bytes = csr.checkpoint_to_bytes();
let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
assert_eq!(restored.edge_count(), 1);
}
}