use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use grafeo_common::storage::section::{Section, SectionType};
use grafeo_common::types::{EdgeId, NodeId};
use grafeo_common::utils::error::{Error, Result};
use parking_lot::RwLock;
#[cfg(feature = "lpg")]
use super::layered::LayeredStore;
const MAGIC: [u8; 4] = *b"GODL";
const FORMAT_VERSION: u8 = 1;
pub struct OverlayDeletionsSection {
payload: RwLock<DeletionsPayload>,
#[cfg(feature = "lpg")]
layered: Option<Arc<LayeredStore>>,
local_dirty: AtomicBool,
}
#[derive(Default, Clone, Debug)]
struct DeletionsPayload {
nodes: Vec<NodeId>,
edges: Vec<EdgeId>,
}
impl OverlayDeletionsSection {
#[cfg(feature = "lpg")]
#[must_use]
pub fn from_layered(layered: Arc<LayeredStore>) -> Self {
let mut nodes = layered.snapshot_deleted_node_ids();
let mut edges = layered.snapshot_deleted_edge_ids();
nodes.sort_unstable();
nodes.dedup();
edges.sort_unstable();
edges.dedup();
Self {
payload: RwLock::new(DeletionsPayload { nodes, edges }),
layered: Some(layered),
local_dirty: AtomicBool::new(false),
}
}
#[must_use]
pub fn empty() -> Self {
Self {
payload: RwLock::new(DeletionsPayload::default()),
#[cfg(feature = "lpg")]
layered: None,
local_dirty: AtomicBool::new(false),
}
}
#[must_use]
pub fn deleted_node_ids(&self) -> Vec<NodeId> {
self.payload.read().nodes.clone()
}
#[must_use]
pub fn deleted_edge_ids(&self) -> Vec<EdgeId> {
self.payload.read().edges.clone()
}
#[must_use]
pub fn is_empty(&self) -> bool {
let p = self.payload.read();
p.nodes.is_empty() && p.edges.is_empty()
}
fn encode_payload(&self) -> Vec<u8> {
let p = self.payload.read();
let mut buf = Vec::with_capacity(8 + 8 + p.nodes.len() * 8 + 8 + p.edges.len() * 8 + 4);
buf.extend_from_slice(&MAGIC);
buf.push(FORMAT_VERSION);
buf.extend_from_slice(&[0u8; 3]);
buf.extend_from_slice(&(p.nodes.len() as u64).to_le_bytes());
for nid in &p.nodes {
buf.extend_from_slice(&nid.0.to_le_bytes());
}
buf.extend_from_slice(&(p.edges.len() as u64).to_le_bytes());
for eid in &p.edges {
buf.extend_from_slice(&eid.0.to_le_bytes());
}
let crc = crc32fast::hash(&buf);
buf.extend_from_slice(&crc.to_le_bytes());
buf
}
fn decode_payload(data: &[u8]) -> Result<DeletionsPayload> {
if data.len() < 8 + 8 + 8 + 4 {
return Err(Error::Serialization(
"OverlayDeletions section too short".into(),
));
}
if data[..4] != MAGIC {
return Err(Error::Serialization(format!(
"OverlayDeletions magic mismatch: expected {MAGIC:?}, got {:?}",
&data[..4],
)));
}
let version = data[4];
if version != FORMAT_VERSION {
return Err(Error::Serialization(format!(
"unsupported OverlayDeletions section version {version}, expected {FORMAT_VERSION}",
)));
}
let payload = &data[..data.len() - 4];
let stored_crc = u32::from_le_bytes(data[data.len() - 4..].try_into().unwrap());
let actual_crc = crc32fast::hash(payload);
if stored_crc != actual_crc {
return Err(Error::Serialization(format!(
"OverlayDeletions CRC mismatch: stored {stored_crc:#010X}, computed {actual_crc:#010X}",
)));
}
let mut pos = 8usize;
let read_u64 = |buf: &[u8], pos: &mut usize| -> Result<u64> {
if *pos + 8 > buf.len() {
return Err(Error::Serialization(
"OverlayDeletions truncated mid-entry".into(),
));
}
let v = u64::from_le_bytes(buf[*pos..*pos + 8].try_into().unwrap());
*pos += 8;
Ok(v)
};
let node_count_u64 = read_u64(data, &mut pos)?;
let node_count = usize::try_from(node_count_u64).map_err(|_| {
Error::Serialization(format!(
"OverlayDeletions node_count {node_count_u64} exceeds usize on this target",
))
})?;
if node_count
.checked_mul(8)
.map_or(true, |n| pos + n + 8 + 4 > data.len())
{
return Err(Error::Serialization(format!(
"OverlayDeletions node_count {node_count} exceeds section size",
)));
}
let mut nodes = Vec::with_capacity(node_count);
for _ in 0..node_count {
nodes.push(NodeId(read_u64(data, &mut pos)?));
}
let edge_count_u64 = read_u64(data, &mut pos)?;
let edge_count = usize::try_from(edge_count_u64).map_err(|_| {
Error::Serialization(format!(
"OverlayDeletions edge_count {edge_count_u64} exceeds usize on this target",
))
})?;
if edge_count
.checked_mul(8)
.map_or(true, |n| pos + n + 4 > data.len())
{
return Err(Error::Serialization(format!(
"OverlayDeletions edge_count {edge_count} exceeds section size",
)));
}
let mut edges = Vec::with_capacity(edge_count);
for _ in 0..edge_count {
edges.push(EdgeId(read_u64(data, &mut pos)?));
}
Ok(DeletionsPayload { nodes, edges })
}
pub fn take(&self) -> (Vec<NodeId>, Vec<EdgeId>) {
let mut p = self.payload.write();
let nodes = std::mem::take(&mut p.nodes);
let edges = std::mem::take(&mut p.edges);
(nodes, edges)
}
}
impl Section for OverlayDeletionsSection {
fn section_type(&self) -> SectionType {
SectionType::OverlayDeletions
}
fn version(&self) -> u8 {
FORMAT_VERSION
}
fn serialize(&self) -> Result<Vec<u8>> {
Ok(self.encode_payload())
}
fn deserialize(&mut self, data: &[u8]) -> Result<()> {
let payload = Self::decode_payload(data)?;
*self.payload.write() = payload;
self.local_dirty.store(false, Ordering::Release);
Ok(())
}
fn is_dirty(&self) -> bool {
#[cfg(feature = "lpg")]
if let Some(ref layered) = self.layered {
return layered.deletions_dirty();
}
self.local_dirty.load(Ordering::Acquire)
}
fn mark_clean(&self) {
#[cfg(feature = "lpg")]
if let Some(ref layered) = self.layered {
layered.mark_deletions_clean();
return;
}
self.local_dirty.store(false, Ordering::Release);
}
fn memory_usage(&self) -> usize {
let p = self.payload.read();
p.nodes.len() * std::mem::size_of::<NodeId>()
+ p.edges.len() * std::mem::size_of::<EdgeId>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_empty_payload() {
let section = OverlayDeletionsSection::empty();
let bytes = section.serialize().unwrap();
let mut roundtrip = OverlayDeletionsSection::empty();
roundtrip.deserialize(&bytes).unwrap();
assert!(roundtrip.is_empty());
assert!(roundtrip.deleted_node_ids().is_empty());
assert!(roundtrip.deleted_edge_ids().is_empty());
}
#[test]
fn roundtrip_mixed_payload() {
let section = OverlayDeletionsSection {
payload: RwLock::new(DeletionsPayload {
nodes: vec![NodeId(1), NodeId(7), NodeId(42)],
edges: vec![EdgeId(3), EdgeId(99)],
}),
#[cfg(feature = "lpg")]
layered: None,
local_dirty: AtomicBool::new(true),
};
let bytes = section.serialize().unwrap();
let mut roundtrip = OverlayDeletionsSection::empty();
roundtrip.deserialize(&bytes).unwrap();
assert_eq!(
roundtrip.deleted_node_ids(),
vec![NodeId(1), NodeId(7), NodeId(42)]
);
assert_eq!(roundtrip.deleted_edge_ids(), vec![EdgeId(3), EdgeId(99)]);
}
#[test]
fn rejects_bad_magic() {
let mut bytes = OverlayDeletionsSection::empty().serialize().unwrap();
bytes[0] = b'X';
let new_crc = crc32fast::hash(&bytes[..bytes.len() - 4]);
let crc_offset = bytes.len() - 4;
bytes[crc_offset..].copy_from_slice(&new_crc.to_le_bytes());
let mut section = OverlayDeletionsSection::empty();
let err = section
.deserialize(&bytes)
.expect_err("bad magic must fail");
assert!(err.to_string().contains("magic"));
}
#[test]
fn rejects_crc_mismatch() {
let original = OverlayDeletionsSection {
payload: RwLock::new(DeletionsPayload {
nodes: vec![NodeId(11)],
edges: vec![],
}),
#[cfg(feature = "lpg")]
layered: None,
local_dirty: AtomicBool::new(true),
};
let mut bytes = original.serialize().unwrap();
bytes[16] ^= 0xFF;
let mut section = OverlayDeletionsSection::empty();
let err = section
.deserialize(&bytes)
.expect_err("CRC mismatch must fail");
assert!(err.to_string().contains("CRC mismatch"));
}
#[test]
fn rejects_unreasonable_node_count() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&MAGIC);
bytes.push(FORMAT_VERSION);
bytes.extend_from_slice(&[0u8; 3]);
bytes.extend_from_slice(&u64::MAX.to_le_bytes()); bytes.extend_from_slice(&0u64.to_le_bytes()); let crc = crc32fast::hash(&bytes);
bytes.extend_from_slice(&crc.to_le_bytes());
let mut section = OverlayDeletionsSection::empty();
let err = section
.deserialize(&bytes)
.expect_err("absurd node_count must fail");
assert!(err.to_string().contains("node_count"));
}
}