use std::fs;
use std::path::Path;
use std::sync::Arc;
use crc32fast::Hasher;
use crate::core::GLOBAL_INTERNER;
use crate::core::property::{PropertyMap, PropertyMapBuilder, PropertyValue};
use crate::storage::compression::decompress_with_limit;
use super::error::{IndexPersistenceError, Result};
use super::formats::{
GraphIndexData, GraphIndexDelta, PersistedPropertyMap, PersistedPropertyValue,
};
use super::{DELTA_MAGIC, GRAPH_MAGIC, MANIFEST_VERSION};
fn map_decompress_error(e: crate::core::error::Error) -> IndexPersistenceError {
match e {
crate::core::error::Error::Storage(
crate::core::error::StorageError::CapacityExceeded { current, limit, .. },
) => IndexPersistenceError::SizeLimitExceeded {
message: format!(
"Decompressed size {} exceeds limit {} (possible zip bomb)",
current, limit
),
},
_ => IndexPersistenceError::Serialization(format!("zstd decompression failed: {}", e)),
}
}
pub fn persist_property_value(value: &PropertyValue) -> Result<PersistedPropertyValue> {
Ok(match value {
PropertyValue::Null => PersistedPropertyValue::Null,
PropertyValue::Bool(b) => PersistedPropertyValue::Bool(*b),
PropertyValue::Int(i) => PersistedPropertyValue::Int(*i),
PropertyValue::Float(f) => PersistedPropertyValue::Float(*f),
PropertyValue::String(s) => {
let interned = GLOBAL_INTERNER.intern(s.as_ref()).map_err(|e| {
IndexPersistenceError::Serialization(format!("Failed to intern string: {}", e))
})?;
PersistedPropertyValue::String(interned.as_u32())
}
PropertyValue::Bytes(b) => PersistedPropertyValue::Bytes(b.to_vec()),
PropertyValue::Vector(v) => PersistedPropertyValue::Vector(v.to_vec()),
PropertyValue::Array(_) => {
return Err(IndexPersistenceError::Serialization(
"Array properties are not yet supported for persistence. \
This prevents silent data loss. Support will be added in a future update."
.to_string(),
));
}
PropertyValue::SparseVector(_) => {
return Err(IndexPersistenceError::Serialization(
"SparseVector properties are not yet supported for index persistence. \
This prevents silent data loss. Support will be added in a future update."
.to_string(),
));
}
})
}
pub fn restore_property_value(persisted: &PersistedPropertyValue) -> Result<PropertyValue> {
Ok(match persisted {
PersistedPropertyValue::Null => PropertyValue::Null,
PersistedPropertyValue::Bool(b) => PropertyValue::Bool(*b),
PersistedPropertyValue::Int(i) => PropertyValue::Int(*i),
PersistedPropertyValue::Float(f) => PropertyValue::Float(*f),
PersistedPropertyValue::String(idx) => {
#[allow(deprecated)]
let s = GLOBAL_INTERNER
.resolve(crate::core::InternedString::from_raw(*idx))
.ok_or_else(|| {
IndexPersistenceError::Serialization(format!(
"Failed to resolve interned string with ID: {}. \
This likely indicates data corruption.",
idx
))
})?;
PropertyValue::String(s)
}
PersistedPropertyValue::Bytes(b) => PropertyValue::Bytes(Arc::from(b.as_slice())),
PersistedPropertyValue::Vector(v) => {
if v.len() > super::MAX_VECTOR_DIMENSIONS {
return Err(IndexPersistenceError::SizeLimitExceeded {
message: format!(
"Vector dimension {} exceeds maximum allowed dimension {}",
v.len(),
super::MAX_VECTOR_DIMENSIONS
),
});
}
PropertyValue::Vector(Arc::from(v.as_slice()))
}
})
}
pub fn persist_property_map(props: &PropertyMap) -> Result<PersistedPropertyMap> {
let mut entries = Vec::with_capacity(props.len());
for (k, v) in props.iter() {
entries.push((k.as_u32(), persist_property_value(v)?));
}
Ok(PersistedPropertyMap { entries })
}
pub fn restore_property_map(persisted: &PersistedPropertyMap) -> Result<PropertyMap> {
let mut builder = PropertyMapBuilder::new();
for (key_idx, value) in &persisted.entries {
let key_id = crate::core::InternedString::from_raw(*key_idx);
let val = restore_property_value(value)?;
builder = GLOBAL_INTERNER
.resolve_with(key_id, |key_str| builder.insert(key_str, val))
.ok_or_else(|| {
IndexPersistenceError::Serialization(format!(
"Failed to resolve interned property key with ID: {}. \
This likely indicates data corruption.",
key_idx
))
})?;
}
Ok(builder.build())
}
pub fn save_graph_index(data: &GraphIndexData, path: &Path) -> Result<()> {
let encoded = bitcode::encode(data);
let mut hasher = Hasher::new();
hasher.update(&encoded);
let checksum = hasher.finalize();
let mut data_with_checksum = encoded;
data_with_checksum.extend_from_slice(&checksum.to_le_bytes());
super::atomic_write(path, &data_with_checksum)?;
Ok(())
}
pub fn load_graph_index(path: &Path) -> Result<GraphIndexData> {
let metadata = fs::metadata(path)?;
if metadata.len() > super::MAX_GRAPH_INDEX_FILE_SIZE {
return Err(IndexPersistenceError::SizeLimitExceeded {
message: format!(
"Graph index file size {} exceeds limit {}",
metadata.len(),
super::MAX_GRAPH_INDEX_FILE_SIZE
),
});
}
let bytes = fs::read(path)?;
if bytes.len() < 4 {
return Err(IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: "File too small to contain CRC32 checksum".into(),
});
}
let (data_slice, checksum_bytes) = bytes.split_at(bytes.len() - 4);
let stored_checksum = u32::from_le_bytes(checksum_bytes.try_into().map_err(|_| {
IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: "Invalid CRC32 checksum format".into(),
}
})?);
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
let decompressed_data;
let data_to_verify = if data_slice.len() >= 4 && data_slice[..4] == ZSTD_MAGIC {
decompressed_data = decompress_with_limit(data_slice, super::MAX_GRAPH_DECOMPRESSED_SIZE)
.map_err(map_decompress_error)?;
&decompressed_data[..]
} else {
data_slice
};
let mut hasher = Hasher::new();
hasher.update(data_to_verify);
let computed_checksum = hasher.finalize();
if computed_checksum != stored_checksum {
return Err(IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: format!(
"CRC32 checksum mismatch: expected {}, got {}",
stored_checksum, computed_checksum
)
.into(),
});
}
let data: GraphIndexData = bitcode::decode(data_to_verify)?;
if data.magic != GRAPH_MAGIC {
return Err(IndexPersistenceError::InvalidMagic {
path: path.to_path_buf(),
expected: GRAPH_MAGIC,
got: data.magic,
});
}
if data.version > MANIFEST_VERSION {
return Err(IndexPersistenceError::UnsupportedVersion {
found: data.version,
supported: MANIFEST_VERSION,
});
}
Ok(data)
}
pub fn new_graph_index_data() -> GraphIndexData {
GraphIndexData {
magic: GRAPH_MAGIC,
version: MANIFEST_VERSION,
node_count: 0,
edge_count: 0,
nodes: Vec::new(),
edges: Vec::new(),
outgoing_node_ids: Vec::new(),
outgoing_offsets: Vec::new(),
outgoing_neighbors: Vec::new(),
incoming_node_ids: Vec::new(),
incoming_offsets: Vec::new(),
incoming_neighbors: Vec::new(),
}
}
pub fn save_graph_index_compressed(
data: &GraphIndexData,
path: &Path,
compression_level: i32,
) -> Result<()> {
let encoded = bitcode::encode(data);
let mut hasher = Hasher::new();
hasher.update(&encoded);
let checksum = hasher.finalize();
let compressed = zstd::encode_all(&encoded[..], compression_level).map_err(|e| {
IndexPersistenceError::Serialization(format!("zstd compression failed: {}", e))
})?;
let mut data_with_checksum = compressed;
data_with_checksum.extend_from_slice(&checksum.to_le_bytes());
super::atomic_write(path, &data_with_checksum)?;
Ok(())
}
pub fn load_graph_index_mmap(path: &Path) -> Result<GraphIndexData> {
use memmap2::Mmap;
use std::fs::File;
let file = File::open(path)?;
let metadata = file.metadata()?;
if metadata.len() > super::MAX_MMAP_FILE_SIZE {
return Err(IndexPersistenceError::SizeLimitExceeded {
message: format!(
"Graph index file size {} exceeds sanity limit {}",
metadata.len(),
super::MAX_MMAP_FILE_SIZE
),
});
}
let mmap = unsafe { Mmap::map(&file)? };
if mmap.len() < 4 {
return Err(IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: "File too small to contain CRC32 checksum".into(),
});
}
let (data_slice, checksum_bytes) = mmap.split_at(mmap.len() - 4);
let stored_checksum = u32::from_le_bytes(checksum_bytes.try_into().map_err(|_| {
IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: "Invalid CRC32 checksum format".into(),
}
})?);
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
let decompressed_data;
let data_to_verify = if data_slice.len() >= 4 && data_slice[..4] == ZSTD_MAGIC {
decompressed_data = decompress_with_limit(data_slice, super::MAX_GRAPH_DECOMPRESSED_SIZE)
.map_err(map_decompress_error)?;
&decompressed_data[..]
} else {
data_slice
};
let mut hasher = Hasher::new();
hasher.update(data_to_verify);
let computed_checksum = hasher.finalize();
if computed_checksum != stored_checksum {
return Err(IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: format!(
"CRC32 checksum mismatch: expected {}, got {}",
stored_checksum, computed_checksum
)
.into(),
});
}
let data: GraphIndexData = bitcode::decode(data_to_verify)?;
if data.magic != GRAPH_MAGIC {
return Err(IndexPersistenceError::InvalidMagic {
path: path.to_path_buf(),
expected: GRAPH_MAGIC,
got: data.magic,
});
}
if data.version > MANIFEST_VERSION {
return Err(IndexPersistenceError::UnsupportedVersion {
found: data.version,
supported: MANIFEST_VERSION,
});
}
Ok(data)
}
pub fn save_graph_index_delta(
base: &GraphIndexData,
modified: &GraphIndexData,
path: &Path,
compression_level: i32,
) -> Result<()> {
let base_nodes: std::collections::HashMap<u64, &super::formats::PersistedNode> =
base.nodes.iter().map(|n| (n.id, n)).collect();
let modified_nodes: std::collections::HashMap<u64, &super::formats::PersistedNode> =
modified.nodes.iter().map(|n| (n.id, n)).collect();
let base_edges: std::collections::HashMap<u64, &super::formats::PersistedEdge> =
base.edges.iter().map(|e| (e.id, e)).collect();
let modified_edges: std::collections::HashMap<u64, &super::formats::PersistedEdge> =
modified.edges.iter().map(|e| (e.id, e)).collect();
let added_nodes: Vec<_> = modified
.nodes
.iter()
.filter(|node| !base_nodes.contains_key(&node.id))
.cloned()
.collect();
let modified_nodes_vec: Vec<_> = modified
.nodes
.iter()
.filter(|node| {
base_nodes
.get(&node.id)
.is_some_and(|base_node| *base_node != *node)
})
.cloned()
.collect();
let deleted_node_ids: Vec<_> = base
.nodes
.iter()
.filter(|node| !modified_nodes.contains_key(&node.id))
.map(|node| node.id)
.collect();
let added_edges: Vec<_> = modified
.edges
.iter()
.filter(|edge| !base_edges.contains_key(&edge.id))
.cloned()
.collect();
let modified_edges_vec: Vec<_> = modified
.edges
.iter()
.filter(|edge| {
base_edges
.get(&edge.id)
.is_some_and(|base_edge| *base_edge != *edge)
})
.cloned()
.collect();
let deleted_edge_ids: Vec<_> = base
.edges
.iter()
.filter(|edge| !modified_edges.contains_key(&edge.id))
.map(|edge| edge.id)
.collect();
let delta = GraphIndexDelta {
magic: DELTA_MAGIC,
version: MANIFEST_VERSION,
added_nodes,
modified_nodes: modified_nodes_vec,
deleted_node_ids,
added_edges,
modified_edges: modified_edges_vec,
deleted_edge_ids,
new_node_count: modified.node_count,
new_edge_count: modified.edge_count,
};
let encoded = bitcode::encode(&delta);
let mut hasher = Hasher::new();
hasher.update(&encoded);
let checksum = hasher.finalize();
let compressed = zstd::encode_all(&encoded[..], compression_level).map_err(|e| {
IndexPersistenceError::Serialization(format!("zstd compression failed: {}", e))
})?;
let mut data_with_checksum = compressed;
data_with_checksum.extend_from_slice(&checksum.to_le_bytes());
super::atomic_write(path, &data_with_checksum)?;
Ok(())
}
pub fn load_graph_index_with_delta(
base_path: &Path,
delta_path: &Path,
limit: Option<usize>,
) -> Result<GraphIndexData> {
let mut base = load_graph_index(base_path)?;
let metadata = fs::metadata(delta_path)?;
if metadata.len() > super::MAX_GRAPH_INDEX_FILE_SIZE {
return Err(IndexPersistenceError::SizeLimitExceeded {
message: format!(
"Graph index delta file size {} exceeds limit {}",
metadata.len(),
super::MAX_GRAPH_INDEX_FILE_SIZE
),
});
}
let bytes = fs::read(delta_path)?;
if bytes.len() < 4 {
return Err(IndexPersistenceError::Corrupted {
path: delta_path.to_path_buf(),
source: "File too small to contain CRC32 checksum".into(),
});
}
let (data_slice, checksum_bytes) = bytes.split_at(bytes.len() - 4);
let stored_checksum = u32::from_le_bytes(checksum_bytes.try_into().map_err(|_| {
IndexPersistenceError::Corrupted {
path: delta_path.to_path_buf(),
source: "Invalid CRC32 checksum format".into(),
}
})?);
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
let decompressed_data;
let data_to_verify = if data_slice.len() >= 4 && data_slice[..4] == ZSTD_MAGIC {
decompressed_data = decompress_with_limit(data_slice, super::MAX_GRAPH_DECOMPRESSED_SIZE)
.map_err(map_decompress_error)?;
&decompressed_data[..]
} else {
data_slice
};
let mut hasher = Hasher::new();
hasher.update(data_to_verify);
let computed_checksum = hasher.finalize();
if computed_checksum != stored_checksum {
return Err(IndexPersistenceError::Corrupted {
path: delta_path.to_path_buf(),
source: format!(
"CRC32 checksum mismatch: expected {}, got {}",
stored_checksum, computed_checksum
)
.into(),
});
}
let delta: GraphIndexDelta = bitcode::decode(data_to_verify)?;
if delta.magic != DELTA_MAGIC {
return Err(IndexPersistenceError::InvalidMagic {
path: delta_path.to_path_buf(),
expected: DELTA_MAGIC,
got: delta.magic,
});
}
if delta.version > MANIFEST_VERSION {
return Err(IndexPersistenceError::UnsupportedVersion {
found: delta.version,
supported: MANIFEST_VERSION,
});
}
let deleted_node_set: std::collections::HashSet<_> =
delta.deleted_node_ids.into_iter().collect();
let deleted_edge_set: std::collections::HashSet<_> =
delta.deleted_edge_ids.into_iter().collect();
base.nodes
.retain(|node| !deleted_node_set.contains(&node.id));
base.edges
.retain(|edge| !deleted_edge_set.contains(&edge.id));
let mut node_mods = std::collections::HashMap::new();
for n in delta.modified_nodes {
node_mods.insert(n.id, n);
}
let mut edge_mods = std::collections::HashMap::new();
for e in delta.modified_edges {
edge_mods.insert(e.id, e);
}
for existing in base.nodes.iter_mut() {
if let Some(modified_node) = node_mods.remove(&existing.id) {
*existing = modified_node;
}
}
for existing in base.edges.iter_mut() {
if let Some(modified_edge) = edge_mods.remove(&existing.id) {
*existing = modified_edge;
}
}
base.nodes.extend(delta.added_nodes);
base.edges.extend(delta.added_edges);
if let Some(l) = limit {
base.nodes.truncate(l);
base.edges.truncate(l);
}
base.node_count = base.nodes.len() as u64;
base.edge_count = base.edges.len() as u64;
Ok(base)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::index_persistence::formats::PersistedNode;
use tempfile::tempdir;
#[test]
fn test_property_value_round_trip() {
let values = vec![
PropertyValue::Null,
PropertyValue::Bool(true),
PropertyValue::Int(42),
PropertyValue::Float(2.71), PropertyValue::String(Arc::from("test")),
PropertyValue::Bytes(Arc::from(vec![1u8, 2, 3].as_slice())),
PropertyValue::Vector(Arc::from(vec![1.0f32, 2.0, 3.0].as_slice())),
];
for value in values {
let persisted = persist_property_value(&value).unwrap();
let restored = restore_property_value(&persisted).unwrap();
assert_eq!(format!("{:?}", value), format!("{:?}", restored));
}
}
#[test]
fn test_graph_index_round_trip() {
let dir = tempdir().unwrap();
let path = dir.path().join("graph.idx");
let mut data = new_graph_index_data();
data.node_count = 2;
data.nodes.push(PersistedNode {
id: 1,
label_idx: GLOBAL_INTERNER.intern("Person").unwrap().as_u32(),
version_id: 1,
properties: PersistedPropertyMap { entries: vec![] },
});
data.nodes.push(PersistedNode {
id: 2,
label_idx: GLOBAL_INTERNER.intern("Document").unwrap().as_u32(),
version_id: 2,
properties: PersistedPropertyMap { entries: vec![] },
});
save_graph_index(&data, &path).unwrap();
let loaded = load_graph_index(&path).unwrap();
assert_eq!(loaded.node_count, 2);
assert_eq!(loaded.nodes.len(), 2);
}
#[test]
fn test_array_property_errors() {
let array_value = PropertyValue::Array(Arc::from(vec![
PropertyValue::Int(1),
PropertyValue::Int(2),
PropertyValue::Int(3),
]));
let result = persist_property_value(&array_value);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Array properties are not yet supported")
);
}
#[test]
fn test_missing_string_interned_id_errors() {
let persisted = PersistedPropertyValue::String(999999);
let result = restore_property_value(&persisted);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Failed to resolve interned string")
);
}
#[test]
fn test_vector_size_limit_dos_protection() {
let oversized_vector = vec![0.0f32; super::super::MAX_VECTOR_DIMENSIONS + 1];
let persisted = PersistedPropertyValue::Vector(oversized_vector);
let result = restore_property_value(&persisted);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Size limit exceeded"));
assert!(err.to_string().contains("Vector dimension"));
}
#[test]
fn test_vector_at_size_limit_allowed() {
let max_vector = vec![1.0f32; super::super::MAX_VECTOR_DIMENSIONS];
let persisted = PersistedPropertyValue::Vector(max_vector);
let result = restore_property_value(&persisted);
assert!(result.is_ok());
if let PropertyValue::Vector(ref v) = result.unwrap() {
assert_eq!(v.len(), super::super::MAX_VECTOR_DIMENSIONS);
} else {
panic!("Expected vector property");
}
}
}
#[cfg(test)]
mod zstd_bomb_tests {
use super::*;
use std::io::Write;
use tempfile::tempdir;
use zstd::stream::write::Encoder;
fn create_zstd_bomb(uncompressed_mb: usize) -> Vec<u8> {
let mut compressed = Vec::new();
{
let mut encoder = Encoder::new(&mut compressed, 1).unwrap();
let chunk = vec![0u8; 1024 * 1024]; for _ in 0..uncompressed_mb {
encoder.write_all(&chunk).unwrap();
}
encoder.finish().unwrap();
}
compressed.extend_from_slice(&[0, 0, 0, 0]);
compressed
}
#[test]
fn test_zstd_bomb_blocked_by_load_graph_index() {
let dir = tempdir().unwrap();
let path = dir.path().join("bomb.idx");
let bomb = create_zstd_bomb(200);
assert!(bomb.len() < 500_000, "Bomb compressed size: {}", bomb.len());
std::fs::write(&path, &bomb).unwrap();
let result = load_graph_index(&path);
assert!(
matches!(result, Err(IndexPersistenceError::SizeLimitExceeded { .. })),
"Expected SizeLimitExceeded, got: {:?}",
result,
);
}
#[test]
fn test_zstd_bomb_blocked_by_load_graph_index_mmap() {
let dir = tempdir().unwrap();
let path = dir.path().join("bomb_mmap.idx");
let bomb = create_zstd_bomb(200);
std::fs::write(&path, &bomb).unwrap();
let result = load_graph_index_mmap(&path);
assert!(
matches!(result, Err(IndexPersistenceError::SizeLimitExceeded { .. })),
"Expected SizeLimitExceeded, got: {:?}",
result,
);
}
#[test]
fn test_zstd_bomb_blocked_by_load_graph_index_with_delta() {
let dir = tempdir().unwrap();
let base_path = dir.path().join("base.idx");
let base_data = new_graph_index_data();
save_graph_index(&base_data, &base_path).unwrap();
let delta_path = dir.path().join("bomb_delta.idx");
let bomb = create_zstd_bomb(200);
std::fs::write(&delta_path, &bomb).unwrap();
let result = load_graph_index_with_delta(&base_path, &delta_path, None);
assert!(
matches!(result, Err(IndexPersistenceError::SizeLimitExceeded { .. })),
"Expected SizeLimitExceeded, got: {:?}",
result,
);
}
#[test]
fn test_legitimate_compressed_index_loads_fine() {
let dir = tempdir().unwrap();
let path = dir.path().join("compressed.idx");
let data = new_graph_index_data();
save_graph_index_compressed(&data, &path, 3).unwrap();
let loaded = load_graph_index(&path).unwrap();
assert_eq!(loaded.node_count, 0);
}
}
#[cfg(test)]
#[path = "graph_delta_tests.rs"]
mod delta_tests;