use super::error::{PersistentARTrieError, Result};
use super::nodes::node48::NO_CHILD;
use super::nodes::{
CompressedPrefix, Node, Node16, Node256, Node4, Node48, NodeHeader, MAX_PREFIX_LEN,
};
use super::swizzled_ptr::{NodeType, SwizzledPtr};
use std::io::{Read, Write};
use super::arena_manager::ArenaSlot;
use super::relative_encoding::{
encode_children, encode_sequential_siblings, try_decode_children,
try_decode_sequential_siblings, RelativeEncodingError,
};
fn io_err(e: std::io::Error) -> PersistentARTrieError {
PersistentARTrieError::io_error("serialization", "<buffer>", e)
}
pub const NODE_MAGIC: [u8; 4] = *b"ART\0";
pub const FORMAT_VERSION: u8 = 1;
pub const FORMAT_VERSION_V2: u8 = 2;
pub const SERIALIZED_HEADER_SIZE: usize = 16;
pub mod encoding_flags {
pub const RELATIVE_OFFSETS: u8 = 0x80;
pub const SEQUENTIAL_SIBLINGS: u8 = 0x40;
pub const HAS_VALUE: u8 = 0x20;
}
pub mod node_types {
pub const NODE4: u8 = 4;
pub const NODE16: u8 = 16;
pub const NODE48: u8 = 48;
pub const NODE256: u8 = 0; }
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SerializedNodeHeader {
pub magic: [u8; 4],
pub version: u8,
pub node_type: u8,
pub flags: u8,
pub encoding_flags: u8,
pub num_children: u16,
pub prefix_len: u8,
pub _padding: u8,
pub data_size: u32,
}
impl SerializedNodeHeader {
pub fn from_node_header(header: &NodeHeader, data_size: u32) -> Self {
Self {
magic: NODE_MAGIC,
version: FORMAT_VERSION,
node_type: header.node_type,
flags: header.flags,
encoding_flags: 0,
num_children: header.num_children,
prefix_len: header.prefix_len,
_padding: 0,
data_size,
}
}
pub fn from_node_header_v2(header: &NodeHeader, data_size: u32, encoding_flags: u8) -> Self {
Self {
magic: NODE_MAGIC,
version: FORMAT_VERSION_V2,
node_type: header.node_type,
flags: header.flags,
encoding_flags,
num_children: header.num_children,
prefix_len: header.prefix_len,
_padding: 0,
data_size,
}
}
pub fn uses_relative_offsets(&self) -> bool {
self.version >= FORMAT_VERSION_V2
&& (self.encoding_flags & encoding_flags::RELATIVE_OFFSETS) != 0
}
pub fn uses_sequential_siblings(&self) -> bool {
self.version >= FORMAT_VERSION_V2
&& (self.encoding_flags & encoding_flags::SEQUENTIAL_SIBLINGS) != 0
}
pub fn to_node_header(&self) -> NodeHeader {
NodeHeader {
node_type: self.node_type,
prefix_len: self.prefix_len,
flags: self.flags,
_padding: 0,
num_children: self.num_children,
_padding2: [0; 2],
version: 0, }
}
pub fn validate(&self) -> Result<()> {
if self.magic != NODE_MAGIC {
return Err(PersistentARTrieError::InvalidMagic {
expected: u64::from_le_bytes([
NODE_MAGIC[0],
NODE_MAGIC[1],
NODE_MAGIC[2],
NODE_MAGIC[3],
0,
0,
0,
0,
]),
found: u64::from_le_bytes([
self.magic[0],
self.magic[1],
self.magic[2],
self.magic[3],
0,
0,
0,
0,
]),
});
}
if self.version > FORMAT_VERSION_V2 {
return Err(PersistentARTrieError::UnsupportedVersion {
max_supported: FORMAT_VERSION_V2 as u32,
found: self.version as u32,
});
}
match self.node_type {
node_types::NODE4 | node_types::NODE16 | node_types::NODE48 | node_types::NODE256 => {}
_ => {
return Err(PersistentARTrieError::corrupted(format!(
"invalid node type: {}",
self.node_type
)));
}
}
if self.prefix_len as usize > MAX_PREFIX_LEN {
return Err(PersistentARTrieError::corrupted(format!(
"prefix length {} exceeds maximum {}",
self.prefix_len, MAX_PREFIX_LEN
)));
}
Ok(())
}
pub fn to_bytes(&self) -> [u8; SERIALIZED_HEADER_SIZE] {
let mut bytes = [0u8; SERIALIZED_HEADER_SIZE];
bytes[0..4].copy_from_slice(&self.magic);
bytes[4] = self.version;
bytes[5] = self.node_type;
bytes[6] = self.flags;
bytes[7] = self.encoding_flags;
bytes[8..10].copy_from_slice(&self.num_children.to_le_bytes());
bytes[10] = self.prefix_len;
bytes[11] = self._padding;
bytes[12..16].copy_from_slice(&self.data_size.to_le_bytes());
bytes
}
pub fn from_bytes(bytes: &[u8; SERIALIZED_HEADER_SIZE]) -> Self {
Self {
magic: [bytes[0], bytes[1], bytes[2], bytes[3]],
version: bytes[4],
node_type: bytes[5],
flags: bytes[6],
encoding_flags: bytes[7],
num_children: u16::from_le_bytes([bytes[8], bytes[9]]),
prefix_len: bytes[10],
_padding: bytes[11],
data_size: u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]),
}
}
}
pub fn serialized_size(node: &Node) -> usize {
SERIALIZED_HEADER_SIZE + prefix_size(node) + node_data_size(node)
}
fn prefix_size(node: &Node) -> usize {
if node.header().prefix_len > 0 {
MAX_PREFIX_LEN
} else {
0
}
}
fn node_data_size(node: &Node) -> usize {
match node {
Node::N4(_) => 4 + 4 * 8, Node::N16(_) => 16 + 16 * 8, Node::N48(_) => 256 + 48 * 8, Node::N256(n) => {
32 + n.header.num_children as usize * 8
}
}
}
pub fn serialize_node<W: Write>(node: &Node, writer: &mut W) -> Result<usize> {
let data_size = prefix_size(node) + node_data_size(node);
let header = SerializedNodeHeader::from_node_header(node.header(), data_size as u32);
writer.write_all(&header.to_bytes()).map_err(io_err)?;
if node.header().prefix_len > 0 {
writer.write_all(&node.prefix().bytes).map_err(io_err)?;
}
match node {
Node::N4(n) => serialize_node4(n, writer)?,
Node::N16(n) => serialize_node16(n, writer)?,
Node::N48(n) => serialize_node48(n, writer)?,
Node::N256(n) => serialize_node256(n, writer)?,
}
Ok(SERIALIZED_HEADER_SIZE + data_size)
}
fn serialize_node4<W: Write>(node: &Node4, writer: &mut W) -> Result<()> {
writer.write_all(&node.keys).map_err(io_err)?;
for child in &node.children {
let raw = child.to_raw();
writer.write_all(&raw.to_le_bytes()).map_err(io_err)?;
}
Ok(())
}
fn serialize_node16<W: Write>(node: &Node16, writer: &mut W) -> Result<()> {
writer.write_all(&node.keys).map_err(io_err)?;
for child in &node.children {
let raw = child.to_raw();
writer.write_all(&raw.to_le_bytes()).map_err(io_err)?;
}
Ok(())
}
fn serialize_node48<W: Write>(node: &Node48, writer: &mut W) -> Result<()> {
writer.write_all(&node.index).map_err(io_err)?;
for child in &node.children {
let raw = child.to_raw();
writer.write_all(&raw.to_le_bytes()).map_err(io_err)?;
}
Ok(())
}
fn serialize_node256<W: Write>(node: &Node256, writer: &mut W) -> Result<()> {
let mut bitmap = [0u64; 4];
for (i, child) in node.children.iter().enumerate() {
if !child.is_null() {
bitmap[i / 64] |= 1u64 << (i % 64);
}
}
for word in &bitmap {
writer.write_all(&word.to_le_bytes()).map_err(io_err)?;
}
for child in &node.children {
if !child.is_null() {
let raw = child.to_raw();
writer.write_all(&raw.to_le_bytes()).map_err(io_err)?;
}
}
Ok(())
}
pub fn deserialize_node<R: Read>(reader: &mut R) -> Result<Node> {
let mut header_bytes = [0u8; SERIALIZED_HEADER_SIZE];
reader.read_exact(&mut header_bytes).map_err(io_err)?;
let header = SerializedNodeHeader::from_bytes(&header_bytes);
header.validate()?;
let prefix = if header.prefix_len > 0 {
let mut prefix_bytes = [0u8; MAX_PREFIX_LEN];
reader.read_exact(&mut prefix_bytes).map_err(io_err)?;
CompressedPrefix {
bytes: prefix_bytes,
}
} else {
CompressedPrefix::empty()
};
match header.node_type {
node_types::NODE4 => deserialize_node4(reader, &header, prefix),
node_types::NODE16 => deserialize_node16(reader, &header, prefix),
node_types::NODE48 => deserialize_node48(reader, &header, prefix),
node_types::NODE256 => deserialize_node256(reader, &header, prefix),
_ => Err(PersistentARTrieError::corrupted(format!(
"invalid node type: {}",
header.node_type
))),
}
}
fn deserialize_node4<R: Read>(
reader: &mut R,
header: &SerializedNodeHeader,
prefix: CompressedPrefix,
) -> Result<Node> {
let mut node = Node4::new();
node.header = header.to_node_header();
node.prefix = prefix;
reader.read_exact(&mut node.keys).map_err(io_err)?;
for child in &mut node.children {
let mut raw_bytes = [0u8; 8];
reader.read_exact(&mut raw_bytes).map_err(io_err)?;
*child = SwizzledPtr::from_raw(u64::from_le_bytes(raw_bytes));
}
Ok(Node::N4(Box::new(node)))
}
fn deserialize_node16<R: Read>(
reader: &mut R,
header: &SerializedNodeHeader,
prefix: CompressedPrefix,
) -> Result<Node> {
let mut node = Node16::new();
node.header = header.to_node_header();
node.prefix = prefix;
reader.read_exact(&mut node.keys).map_err(io_err)?;
for child in &mut node.children {
let mut raw_bytes = [0u8; 8];
reader.read_exact(&mut raw_bytes).map_err(io_err)?;
*child = SwizzledPtr::from_raw(u64::from_le_bytes(raw_bytes));
}
Ok(Node::N16(Box::new(node)))
}
fn deserialize_node48<R: Read>(
reader: &mut R,
header: &SerializedNodeHeader,
prefix: CompressedPrefix,
) -> Result<Node> {
let mut node = Node48::new();
node.header = header.to_node_header();
node.prefix = prefix;
reader.read_exact(&mut node.index).map_err(io_err)?;
for child in &mut node.children {
let mut raw_bytes = [0u8; 8];
reader.read_exact(&mut raw_bytes).map_err(io_err)?;
*child = SwizzledPtr::from_raw(u64::from_le_bytes(raw_bytes));
}
Ok(Node::N48(Box::new(node)))
}
fn deserialize_node256<R: Read>(
reader: &mut R,
header: &SerializedNodeHeader,
prefix: CompressedPrefix,
) -> Result<Node> {
let mut node = Node256::new();
node.header = header.to_node_header();
node.prefix = prefix;
let mut bitmap = [0u64; 4];
for word in &mut bitmap {
let mut word_bytes = [0u8; 8];
reader.read_exact(&mut word_bytes).map_err(io_err)?;
*word = u64::from_le_bytes(word_bytes);
}
for i in 0..256 {
if bitmap[i / 64] & (1u64 << (i % 64)) != 0 {
let mut raw_bytes = [0u8; 8];
reader.read_exact(&mut raw_bytes).map_err(io_err)?;
node.children[i] = SwizzledPtr::from_raw(u64::from_le_bytes(raw_bytes));
}
}
Ok(Node::N256(Box::new(node)))
}
pub fn to_bytes(node: &Node) -> Result<Vec<u8>> {
let mut buffer = Vec::with_capacity(serialized_size(node));
serialize_node(node, &mut buffer)?;
Ok(buffer)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Node> {
let mut reader = std::io::Cursor::new(bytes);
deserialize_node(&mut reader)
}
pub mod v2 {
use super::*;
#[derive(Debug, Clone)]
pub struct SerializationContext {
pub parent_slot: ArenaSlot,
pub use_relative: bool,
pub use_sequential: bool,
pub first_child_slot: Option<ArenaSlot>,
}
impl SerializationContext {
pub fn new(parent_slot: ArenaSlot) -> Self {
Self {
parent_slot,
use_relative: true,
use_sequential: false,
first_child_slot: None,
}
}
pub fn sequential(parent_slot: ArenaSlot, first_child_slot: ArenaSlot) -> Self {
Self {
parent_slot,
use_relative: true,
use_sequential: true,
first_child_slot: Some(first_child_slot),
}
}
pub fn encoding_flags(&self) -> u8 {
let mut flags = 0u8;
if self.use_relative {
flags |= encoding_flags::RELATIVE_OFFSETS;
}
if self.use_sequential {
flags |= encoding_flags::SEQUENTIAL_SIBLINGS;
}
flags
}
}
#[derive(Debug, Clone)]
pub struct DeserializationContext {
pub parent_slot: ArenaSlot,
}
impl DeserializationContext {
pub fn new(parent_slot: ArenaSlot) -> Self {
Self { parent_slot }
}
}
fn relative_decode_err(err: RelativeEncodingError) -> PersistentARTrieError {
PersistentARTrieError::corrupted(format!("invalid relative child encoding: {}", err))
}
fn decode_v2_child_slots(
data: &[u8],
parent: ArenaSlot,
count: usize,
uses_sequential: bool,
) -> Result<(Vec<ArenaSlot>, usize)> {
if uses_sequential {
try_decode_sequential_siblings(data, parent, count).map_err(relative_decode_err)
} else {
try_decode_children(data, parent, count).map_err(relative_decode_err)
}
}
fn read_v2_node_type(data: &[u8], offset: usize) -> Result<NodeType> {
let byte = *data.get(offset).ok_or_else(|| {
PersistentARTrieError::corrupted(format!(
"missing relative child node type at offset {} in {} byte node payload",
offset,
data.len()
))
})?;
Ok(NodeType::try_from(byte).unwrap_or(NodeType::Node4))
}
pub fn collect_child_slots(node: &Node) -> Vec<ArenaSlot> {
let mut slots = Vec::new();
match node {
Node::N4(n) => {
for i in 0..n.header.num_children as usize {
if let Some(slot) = n.children[i].as_arena_slot() {
slots.push(slot);
}
}
}
Node::N16(n) => {
for i in 0..n.header.num_children as usize {
if let Some(slot) = n.children[i].as_arena_slot() {
slots.push(slot);
}
}
}
Node::N48(n) => {
for i in 0..48 {
if let Some(slot) = n.children[i].as_arena_slot() {
slots.push(slot);
}
}
}
Node::N256(n) => {
for child in &n.children {
if let Some(slot) = child.as_arena_slot() {
slots.push(slot);
}
}
}
}
slots
}
pub fn collect_child_slots_and_types(node: &Node) -> Vec<(ArenaSlot, NodeType)> {
let mut result = Vec::new();
match node {
Node::N4(n) => {
for i in 0..n.header.num_children as usize {
if let (Some(slot), Some(node_type)) = (
n.children[i].as_arena_slot(),
n.children[i].disk_location().map(|loc| loc.node_type),
) {
result.push((slot, node_type));
}
}
}
Node::N16(n) => {
for i in 0..n.header.num_children as usize {
if let (Some(slot), Some(node_type)) = (
n.children[i].as_arena_slot(),
n.children[i].disk_location().map(|loc| loc.node_type),
) {
result.push((slot, node_type));
}
}
}
Node::N48(n) => {
for i in 0..48 {
if let (Some(slot), Some(node_type)) = (
n.children[i].as_arena_slot(),
n.children[i].disk_location().map(|loc| loc.node_type),
) {
result.push((slot, node_type));
}
}
}
Node::N256(n) => {
for child in &n.children {
if let (Some(slot), Some(node_type)) = (
child.as_arena_slot(),
child.disk_location().map(|loc| loc.node_type),
) {
result.push((slot, node_type));
}
}
}
}
result
}
pub fn estimate_serialized_size_v2(node: &Node, ctx: &SerializationContext) -> usize {
let header_size = SERIALIZED_HEADER_SIZE;
let prefix_size = if node.header().prefix_len > 0 {
MAX_PREFIX_LEN
} else {
0
};
let num_children = node.header().num_children as usize;
let (children_size, node_types_size) = if ctx.use_sequential {
let encoded_size = if let Some(first_child) = ctx.first_child_slot {
super::super::relative_encoding::encoded_size(ctx.parent_slot, first_child)
} else {
0
};
(encoded_size, num_children)
} else if ctx.use_relative {
let child_slots = collect_child_slots(node);
let encoded_size: usize = child_slots
.iter()
.map(|&child| super::super::relative_encoding::encoded_size(ctx.parent_slot, child))
.sum();
(encoded_size, num_children)
} else {
(num_children * 8, 0)
};
let keys_size = match node {
Node::N4(_) => 4,
Node::N16(_) => 16,
Node::N48(_) => 256, Node::N256(_) => 32, };
header_size + prefix_size + keys_size + children_size + node_types_size
}
pub fn serialize_node_v2(node: &Node, ctx: &SerializationContext) -> Result<Vec<u8>> {
let estimated_size = estimate_serialized_size_v2(node, ctx);
let mut buffer = Vec::with_capacity(estimated_size);
let child_slots_and_types = collect_child_slots_and_types(node);
let child_slots: Vec<ArenaSlot> = child_slots_and_types.iter().map(|(s, _)| *s).collect();
let mut children_buf = Vec::new();
if ctx.use_sequential {
if let Some(first_child) = ctx.first_child_slot {
encode_sequential_siblings(ctx.parent_slot, first_child, &mut children_buf);
}
} else {
encode_children(ctx.parent_slot, &child_slots, &mut children_buf);
}
let prefix_size = if node.header().prefix_len > 0 {
MAX_PREFIX_LEN
} else {
0
};
let keys_size = match node {
Node::N4(_) => 4,
Node::N16(_) => 16,
Node::N48(_) => 256,
Node::N256(_) => 32,
};
let node_types_size = if ctx.use_sequential || !child_slots.is_empty() {
child_slots_and_types.len()
} else {
0
};
let data_size = prefix_size + keys_size + children_buf.len() + node_types_size;
let header = SerializedNodeHeader::from_node_header_v2(
node.header(),
data_size as u32,
ctx.encoding_flags(),
);
buffer.extend_from_slice(&header.to_bytes());
if node.header().prefix_len > 0 {
buffer.extend_from_slice(&node.prefix().bytes);
}
match node {
Node::N4(n) => {
buffer.extend_from_slice(&n.keys);
}
Node::N16(n) => {
buffer.extend_from_slice(&n.keys);
}
Node::N48(n) => {
buffer.extend_from_slice(&n.index);
}
Node::N256(n) => {
let mut bitmap = [0u64; 4];
for (i, child) in n.children.iter().enumerate() {
if !child.is_null() {
bitmap[i / 64] |= 1u64 << (i % 64);
}
}
for word in &bitmap {
buffer.extend_from_slice(&word.to_le_bytes());
}
}
}
buffer.extend_from_slice(&children_buf);
for (_, node_type) in &child_slots_and_types {
buffer.push(*node_type as u8);
}
Ok(buffer)
}
pub fn append_node_value(mut node_bytes: Vec<u8>, value_bytes: Option<&[u8]>) -> Vec<u8> {
if let Some(vb) = value_bytes {
node_bytes[7] |= encoding_flags::HAS_VALUE;
node_bytes.extend_from_slice(&(vb.len() as u32).to_le_bytes());
node_bytes.extend_from_slice(vb);
}
node_bytes
}
pub fn read_node_value(data: &[u8]) -> Option<Vec<u8>> {
if data.len() < SERIALIZED_HEADER_SIZE {
return None;
}
if data[7] & encoding_flags::HAS_VALUE == 0 {
return None;
}
let data_size = u32::from_le_bytes([data[12], data[13], data[14], data[15]]) as usize;
let off = SERIALIZED_HEADER_SIZE + data_size;
if data.len() < off + 4 {
return None;
}
let len =
u32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]) as usize;
if data.len() < off + 4 + len {
return None;
}
Some(data[off + 4..off + 4 + len].to_vec())
}
pub fn deserialize_node_v2(data: &[u8], ctx: &DeserializationContext) -> Result<Node> {
let mut reader = std::io::Cursor::new(data);
let mut header_bytes = [0u8; SERIALIZED_HEADER_SIZE];
reader.read_exact(&mut header_bytes).map_err(io_err)?;
let header = SerializedNodeHeader::from_bytes(&header_bytes);
header.validate()?;
let prefix = if header.prefix_len > 0 {
let mut prefix_bytes = [0u8; MAX_PREFIX_LEN];
reader.read_exact(&mut prefix_bytes).map_err(io_err)?;
CompressedPrefix {
bytes: prefix_bytes,
}
} else {
CompressedPrefix::empty()
};
let remaining = &data[reader.position() as usize..];
match header.node_type {
node_types::NODE4 => deserialize_node4_v2(&header, prefix, remaining, ctx),
node_types::NODE16 => deserialize_node16_v2(&header, prefix, remaining, ctx),
node_types::NODE48 => deserialize_node48_v2(&header, prefix, remaining, ctx),
node_types::NODE256 => deserialize_node256_v2(&header, prefix, remaining, ctx),
_ => Err(PersistentARTrieError::corrupted(format!(
"invalid node type: {}",
header.node_type
))),
}
}
fn deserialize_node4_v2(
header: &SerializedNodeHeader,
prefix: CompressedPrefix,
data: &[u8],
ctx: &DeserializationContext,
) -> Result<Node> {
let mut node = Node4::new();
node.header = header.to_node_header();
node.prefix = prefix;
node.keys.copy_from_slice(&data[..4]);
let num_children = header.num_children as usize;
if header.uses_sequential_siblings() {
let (children, bytes_consumed) =
decode_v2_child_slots(&data[4..], ctx.parent_slot, num_children, true)?;
let types_start = 4 + bytes_consumed;
for (i, slot) in children.into_iter().enumerate() {
let node_type = read_v2_node_type(data, types_start + i)?;
node.children[i] = SwizzledPtr::from_arena_slot(slot, node_type);
}
} else if header.uses_relative_offsets() {
let (children, bytes_consumed) =
decode_v2_child_slots(&data[4..], ctx.parent_slot, num_children, false)?;
let types_start = 4 + bytes_consumed;
for (i, slot) in children.into_iter().enumerate() {
let node_type = read_v2_node_type(data, types_start + i)?;
node.children[i] = SwizzledPtr::from_arena_slot(slot, node_type);
}
} else {
for i in 0..num_children {
let offset = 4 + i * 8;
let raw = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
node.children[i] = SwizzledPtr::from_raw(raw);
}
}
Ok(Node::N4(Box::new(node)))
}
fn deserialize_node16_v2(
header: &SerializedNodeHeader,
prefix: CompressedPrefix,
data: &[u8],
ctx: &DeserializationContext,
) -> Result<Node> {
let mut node = Node16::new();
node.header = header.to_node_header();
node.prefix = prefix;
node.keys.copy_from_slice(&data[..16]);
let num_children = header.num_children as usize;
if header.uses_sequential_siblings() {
let (children, bytes_consumed) =
decode_v2_child_slots(&data[16..], ctx.parent_slot, num_children, true)?;
let types_start = 16 + bytes_consumed;
for (i, slot) in children.into_iter().enumerate() {
let node_type = read_v2_node_type(data, types_start + i)?;
node.children[i] = SwizzledPtr::from_arena_slot(slot, node_type);
}
} else if header.uses_relative_offsets() {
let (children, bytes_consumed) =
decode_v2_child_slots(&data[16..], ctx.parent_slot, num_children, false)?;
let types_start = 16 + bytes_consumed;
for (i, slot) in children.into_iter().enumerate() {
let node_type = read_v2_node_type(data, types_start + i)?;
node.children[i] = SwizzledPtr::from_arena_slot(slot, node_type);
}
} else {
for i in 0..num_children {
let offset = 16 + i * 8;
let raw = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
node.children[i] = SwizzledPtr::from_raw(raw);
}
}
Ok(Node::N16(Box::new(node)))
}
fn deserialize_node48_v2(
header: &SerializedNodeHeader,
prefix: CompressedPrefix,
data: &[u8],
ctx: &DeserializationContext,
) -> Result<Node> {
let mut node = Node48::new();
node.header = header.to_node_header();
node.prefix = prefix;
node.index.copy_from_slice(&data[..256]);
let num_children = header.num_children as usize;
let mut used_slots: Vec<u8> = Vec::with_capacity(num_children);
for key in 0..256usize {
let slot = node.index[key];
if slot != NO_CHILD && !used_slots.contains(&slot) {
used_slots.push(slot);
}
}
used_slots.sort_unstable();
if header.uses_sequential_siblings() {
let (children, bytes_consumed) =
decode_v2_child_slots(&data[256..], ctx.parent_slot, num_children, true)?;
let types_start = 256 + bytes_consumed;
for (i, child_slot) in children.into_iter().enumerate() {
if i >= used_slots.len() {
return Err(PersistentARTrieError::corrupted(format!(
"node48 relative child count {} exceeds index entries {}",
num_children,
used_slots.len()
)));
}
let actual_slot = used_slots[i] as usize;
let node_type = read_v2_node_type(data, types_start + i)?;
node.children[actual_slot] = SwizzledPtr::from_arena_slot(child_slot, node_type);
}
} else if header.uses_relative_offsets() {
let (children, bytes_consumed) =
decode_v2_child_slots(&data[256..], ctx.parent_slot, num_children, false)?;
let types_start = 256 + bytes_consumed;
for (i, child_slot) in children.into_iter().enumerate() {
if i >= used_slots.len() {
return Err(PersistentARTrieError::corrupted(format!(
"node48 relative child count {} exceeds index entries {}",
num_children,
used_slots.len()
)));
}
let actual_slot = used_slots[i] as usize;
let node_type = read_v2_node_type(data, types_start + i)?;
node.children[actual_slot] = SwizzledPtr::from_arena_slot(child_slot, node_type);
}
} else {
for i in 0..num_children {
let actual_slot = used_slots[i] as usize;
let offset = 256 + i * 8;
let raw = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
node.children[actual_slot] = SwizzledPtr::from_raw(raw);
}
}
Ok(Node::N48(Box::new(node)))
}
fn deserialize_node256_v2(
header: &SerializedNodeHeader,
prefix: CompressedPrefix,
data: &[u8],
ctx: &DeserializationContext,
) -> Result<Node> {
let mut node = Node256::new();
node.header = header.to_node_header();
node.prefix = prefix;
let mut bitmap = [0u64; 4];
for (i, word) in bitmap.iter_mut().enumerate() {
let offset = i * 8;
*word = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
}
let num_children = header.num_children as usize;
let children_start = 32;
if header.uses_sequential_siblings() {
let (children, bytes_consumed) = decode_v2_child_slots(
&data[children_start..],
ctx.parent_slot,
num_children,
true,
)?;
let types_start = children_start + bytes_consumed;
let mut child_idx = 0;
for i in 0..256 {
if bitmap[i / 64] & (1u64 << (i % 64)) != 0 {
if child_idx >= children.len() {
return Err(PersistentARTrieError::corrupted(format!(
"node256 bitmap references more children than header count {}",
num_children
)));
}
let node_type = read_v2_node_type(data, types_start + child_idx)?;
node.children[i] = SwizzledPtr::from_arena_slot(children[child_idx], node_type);
child_idx += 1;
}
}
} else if header.uses_relative_offsets() {
let (children, bytes_consumed) = decode_v2_child_slots(
&data[children_start..],
ctx.parent_slot,
num_children,
false,
)?;
let types_start = children_start + bytes_consumed;
let mut child_idx = 0;
for i in 0..256 {
if bitmap[i / 64] & (1u64 << (i % 64)) != 0 {
if child_idx >= children.len() {
return Err(PersistentARTrieError::corrupted(format!(
"node256 bitmap references more children than header count {}",
num_children
)));
}
let node_type = read_v2_node_type(data, types_start + child_idx)?;
node.children[i] = SwizzledPtr::from_arena_slot(children[child_idx], node_type);
child_idx += 1;
}
}
} else {
let mut child_idx = 0;
for i in 0..256 {
if bitmap[i / 64] & (1u64 << (i % 64)) != 0 {
let offset = children_start + child_idx * 8;
let raw = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
node.children[i] = SwizzledPtr::from_raw(raw);
child_idx += 1;
}
}
}
Ok(Node::N256(Box::new(node)))
}
}
pub use v2::{
collect_child_slots, deserialize_node_v2, estimate_serialized_size_v2, serialize_node_v2,
DeserializationContext, SerializationContext,
};
#[cfg(test)]
mod tests {
use super::*;
use crate::persistent_artrie::nodes::{flags, ArtNode};
use crate::persistent_artrie::NodeType;
#[test]
fn test_header_roundtrip() {
let header = SerializedNodeHeader {
magic: NODE_MAGIC,
version: FORMAT_VERSION,
node_type: node_types::NODE4,
flags: flags::IS_FINAL,
encoding_flags: 0,
num_children: 3,
prefix_len: 5,
_padding: 0,
data_size: 100,
};
let bytes = header.to_bytes();
let restored = SerializedNodeHeader::from_bytes(&bytes);
assert_eq!(restored.magic, NODE_MAGIC);
assert_eq!(restored.version, FORMAT_VERSION);
assert_eq!(restored.node_type, node_types::NODE4);
assert_eq!(restored.flags, flags::IS_FINAL);
assert_eq!(restored.num_children, 3);
assert_eq!(restored.prefix_len, 5);
assert_eq!(restored.data_size, 100);
}
#[test]
fn test_header_validation() {
let mut header = SerializedNodeHeader {
magic: NODE_MAGIC,
version: FORMAT_VERSION,
node_type: node_types::NODE4,
flags: 0,
encoding_flags: 0,
num_children: 0,
prefix_len: 0,
_padding: 0,
data_size: 0,
};
assert!(header.validate().is_ok());
header.magic = *b"BAD\0";
assert!(matches!(
header.validate(),
Err(PersistentARTrieError::InvalidMagic { .. })
));
header.magic = NODE_MAGIC;
header.version = 255;
assert!(matches!(
header.validate(),
Err(PersistentARTrieError::UnsupportedVersion { .. })
));
header.version = FORMAT_VERSION;
header.node_type = 99;
assert!(matches!(
header.validate(),
Err(PersistentARTrieError::CorruptedFile { .. })
));
header.node_type = node_types::NODE4;
header.prefix_len = 20;
assert!(matches!(
header.validate(),
Err(PersistentARTrieError::CorruptedFile { .. })
));
}
#[test]
fn test_node4_roundtrip() {
let mut node4 = Node4::new();
node4.prefix = CompressedPrefix::from_bytes(b"test");
node4.header.prefix_len = 4;
node4.header.set_final(true);
node4
.add_child(b'a', SwizzledPtr::on_disk(100, 0, NodeType::Node4))
.expect("add child a");
node4
.add_child(b'b', SwizzledPtr::on_disk(200, 0, NodeType::Node16))
.expect("add child b");
let node = Node::N4(Box::new(node4));
let bytes = to_bytes(&node).expect("serialize");
let restored = from_bytes(&bytes).expect("deserialize");
assert!(matches!(restored, Node::N4(_)));
assert_eq!(restored.header().prefix_len, 4);
assert!(restored.header().is_final());
assert_eq!(restored.header().num_children, 2);
assert!(restored.find_child(b'a').is_some());
assert!(restored.find_child(b'b').is_some());
assert!(restored.find_child(b'c').is_none());
}
#[test]
fn test_node16_roundtrip() {
let mut node16 = Node16::new();
node16.prefix = CompressedPrefix::from_bytes(b"prefix");
node16.header.prefix_len = 6;
for i in 0..8 {
node16
.add_child(b'a' + i, SwizzledPtr::on_disk(i as u32, 0, NodeType::Node4))
.expect("add child");
}
let node = Node::N16(Box::new(node16));
let bytes = to_bytes(&node).expect("serialize");
let restored = from_bytes(&bytes).expect("deserialize");
assert!(matches!(restored, Node::N16(_)));
assert_eq!(restored.header().prefix_len, 6);
assert_eq!(restored.header().num_children, 8);
for i in 0..8 {
assert!(restored.find_child(b'a' + i).is_some());
}
}
#[test]
fn test_node48_roundtrip() {
let mut node48 = Node48::new();
for key in [0, 50, 100, 150, 200, 255u8] {
node48
.add_child(key, SwizzledPtr::on_disk(key as u32, 0, NodeType::Node4))
.expect("add child");
}
let node = Node::N48(Box::new(node48));
let bytes = to_bytes(&node).expect("serialize");
let restored = from_bytes(&bytes).expect("deserialize");
assert!(matches!(restored, Node::N48(_)));
assert_eq!(restored.header().num_children, 6);
for key in [0, 50, 100, 150, 200, 255u8] {
assert!(
restored.find_child(key).is_some(),
"should find key {}",
key
);
}
}
#[test]
fn test_node256_roundtrip() {
let mut node256 = Node256::new();
for key in [0, 64, 128, 192, 255u8] {
node256
.add_child(key, SwizzledPtr::on_disk(key as u32, 0, NodeType::Node4))
.expect("add child");
}
let node = Node::N256(Box::new(node256));
let bytes = to_bytes(&node).expect("serialize");
let restored = from_bytes(&bytes).expect("deserialize");
assert!(matches!(restored, Node::N256(_)));
assert_eq!(restored.header().num_children, 5);
for key in [0, 64, 128, 192, 255u8] {
assert!(
restored.find_child(key).is_some(),
"should find key {}",
key
);
}
assert!(restored.find_child(1).is_none());
}
#[test]
fn test_node256_sparse_bitmap() {
let mut node256 = Node256::new();
node256
.add_child(0, SwizzledPtr::on_disk(1, 0, NodeType::Node4))
.expect("add child 0");
node256
.add_child(255, SwizzledPtr::on_disk(2, 0, NodeType::Node4))
.expect("add child 255");
let node = Node::N256(Box::new(node256));
let bytes = to_bytes(&node).expect("serialize");
assert_eq!(bytes.len(), 16 + 32 + 16);
let restored = from_bytes(&bytes).expect("deserialize");
assert_eq!(restored.header().num_children, 2);
assert!(restored.find_child(0).is_some());
assert!(restored.find_child(255).is_some());
assert!(restored.find_child(128).is_none());
}
#[test]
fn test_serialized_size_calculation() {
let node4 = Node::N4(Box::new(Node4::new()));
assert_eq!(serialized_size(&node4), 16 + 0 + (4 + 32));
let mut node4_with_prefix = Node4::new();
node4_with_prefix.prefix = CompressedPrefix::from_bytes(b"test");
node4_with_prefix.header.prefix_len = 4;
let node4_p = Node::N4(Box::new(node4_with_prefix));
assert_eq!(serialized_size(&node4_p), 16 + 12 + (4 + 32));
let node16 = Node::N16(Box::new(Node16::new()));
assert_eq!(serialized_size(&node16), 16 + 0 + (16 + 128));
let node48 = Node::N48(Box::new(Node48::new()));
assert_eq!(serialized_size(&node48), 16 + 0 + (256 + 384));
let mut node256 = Node256::new();
for i in 0..5 {
node256
.add_child(i, SwizzledPtr::on_disk(i as u32, 0, NodeType::Node4))
.expect("add");
}
let node256_node = Node::N256(Box::new(node256));
assert_eq!(serialized_size(&node256_node), 16 + 0 + (32 + 5 * 8)); }
#[test]
fn test_empty_node_roundtrip() {
for create_node in [
|| Node::N4(Box::new(Node4::new())),
|| Node::N16(Box::new(Node16::new())),
|| Node::N48(Box::new(Node48::new())),
|| Node::N256(Box::new(Node256::new())),
] {
let node = create_node();
let bytes = to_bytes(&node).expect("serialize");
let restored = from_bytes(&bytes).expect("deserialize");
assert_eq!(restored.header().num_children, 0);
}
}
#[test]
fn test_deserialize_truncated_header() {
let truncated = vec![0u8; 10];
let result = from_bytes(&truncated);
assert!(result.is_err());
}
#[test]
fn test_deserialize_invalid_magic() {
let mut data = vec![0u8; 32];
data[0..4].copy_from_slice(b"BAD!");
let result = from_bytes(&data);
assert!(matches!(
result,
Err(PersistentARTrieError::InvalidMagic { .. })
));
}
#[test]
fn test_deserialize_unsupported_version() {
let header = SerializedNodeHeader {
magic: NODE_MAGIC,
version: 255, node_type: node_types::NODE4,
flags: 0,
encoding_flags: 0,
num_children: 0,
prefix_len: 0,
_padding: 0,
data_size: 0,
};
let bytes = header.to_bytes();
let result = from_bytes(&bytes);
assert!(matches!(
result,
Err(PersistentARTrieError::UnsupportedVersion { .. })
));
}
#[test]
fn test_deserialize_invalid_node_type() {
let header = SerializedNodeHeader {
magic: NODE_MAGIC,
version: FORMAT_VERSION,
node_type: 99, flags: 0,
encoding_flags: 0,
num_children: 0,
prefix_len: 0,
_padding: 0,
data_size: 0,
};
let bytes = header.to_bytes();
let result = from_bytes(&bytes);
assert!(matches!(
result,
Err(PersistentARTrieError::CorruptedFile { .. })
));
}
#[test]
fn test_deserialize_truncated_prefix() {
let header = SerializedNodeHeader {
magic: NODE_MAGIC,
version: FORMAT_VERSION,
node_type: node_types::NODE4,
flags: 0,
encoding_flags: 0,
num_children: 0,
prefix_len: 8,
_padding: 0,
data_size: 50,
};
let header_bytes = header.to_bytes();
let mut data = Vec::new();
data.extend_from_slice(&header_bytes);
data.extend_from_slice(&[0u8; 4]);
let result = from_bytes(&data);
assert!(result.is_err());
}
#[test]
fn test_deserialize_truncated_children_node4() {
let node4 = Node::N4(Box::new(Node4::new()));
let mut bytes = to_bytes(&node4).expect("serialize");
let header_arr: [u8; SERIALIZED_HEADER_SIZE] = bytes[0..SERIALIZED_HEADER_SIZE]
.try_into()
.expect("header slice should be 16 bytes");
let mut header = SerializedNodeHeader::from_bytes(&header_arr);
header.num_children = 4;
bytes[0..SERIALIZED_HEADER_SIZE].copy_from_slice(&header.to_bytes());
bytes.truncate(20);
let result = from_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn test_deserialize_empty_data() {
let result = from_bytes(&[]);
assert!(result.is_err());
}
#[test]
fn test_serialize_roundtrip_with_max_prefix() {
let mut node4 = Node4::new();
node4.prefix = CompressedPrefix::from_bytes(b"12345678");
node4.header.prefix_len = 8;
node4.header.set_final(true);
let node = Node::N4(Box::new(node4));
let bytes = to_bytes(&node).expect("serialize");
let restored = from_bytes(&bytes).expect("deserialize");
assert_eq!(restored.header().prefix_len, 8);
assert!(restored.header().is_final());
}
#[test]
fn test_deserialize_invalid_prefix_len() {
let header = SerializedNodeHeader {
magic: NODE_MAGIC,
version: FORMAT_VERSION,
node_type: node_types::NODE4,
flags: 0,
encoding_flags: 0,
num_children: 0,
prefix_len: 20, _padding: 0,
data_size: 50,
};
let bytes = header.to_bytes();
let result = from_bytes(&bytes);
assert!(matches!(
result,
Err(PersistentARTrieError::CorruptedFile { .. })
));
}
#[test]
fn test_serialize_all_node_types() {
let nodes: Vec<Node> = vec![
Node::N4(Box::new(Node4::new())),
Node::N16(Box::new(Node16::new())),
Node::N48(Box::new(Node48::new())),
Node::N256(Box::new(Node256::new())),
];
for node in nodes {
let bytes = to_bytes(&node).expect("serialize");
assert!(!bytes.is_empty());
let restored = from_bytes(&bytes).expect("deserialize");
assert_eq!(restored.header().num_children, node.header().num_children);
}
}
#[test]
fn test_node_type_constants() {
assert_eq!(node_types::NODE4, 4);
assert_eq!(node_types::NODE16, 16);
assert_eq!(node_types::NODE48, 48);
assert_eq!(node_types::NODE256, 0); }
#[test]
fn test_header_size() {
assert_eq!(SERIALIZED_HEADER_SIZE, 16);
let header = SerializedNodeHeader {
magic: NODE_MAGIC,
version: 1,
node_type: node_types::NODE4,
flags: 0,
encoding_flags: 0,
num_children: 0,
prefix_len: 0,
_padding: 0,
data_size: 0,
};
assert_eq!(header.to_bytes().len(), SERIALIZED_HEADER_SIZE);
}
#[test]
fn test_all_flag_combinations() {
let flag_combinations = [
0u8,
flags::IS_FINAL,
flags::IS_DIRTY,
flags::IS_FINAL | flags::IS_DIRTY,
];
for flags_val in flag_combinations {
let mut node4 = Node4::new();
node4.header.flags = flags_val;
let node = Node::N4(Box::new(node4));
let bytes = to_bytes(&node).expect("serialize");
let restored = from_bytes(&bytes).expect("deserialize");
if flags_val & flags::IS_FINAL != 0 {
assert!(restored.header().is_final());
}
}
}
fn record_has_value_flag(record: &[u8]) -> bool {
record.len() > 7 && (record[7] & encoding_flags::HAS_VALUE) != 0
}
fn sample_nodes_with_children(parent: ArenaSlot, child_count: usize) -> Vec<Node> {
let make = |mut add: Box<dyn FnMut(u8, SwizzledPtr)>| {
for i in 0..child_count {
let slot = ArenaSlot::new(parent.arena_id, parent.slot_id + 1 + i as u32);
add(i as u8, SwizzledPtr::from_arena_slot(slot, NodeType::Node4));
}
};
let mut n4 = Node4::new();
make(Box::new(|k, p| {
let _ = n4.add_child(k, p);
}));
let mut n16 = Node16::new();
make(Box::new(|k, p| {
let _ = n16.add_child(k, p);
}));
let mut n48 = Node48::new();
make(Box::new(|k, p| {
let _ = n48.add_child(k, p);
}));
let mut n256 = Node256::new();
make(Box::new(|k, p| {
let _ = n256.add_child(k, p);
}));
vec![
Node::N4(Box::new(n4)),
Node::N16(Box::new(n16)),
Node::N48(Box::new(n48)),
Node::N256(Box::new(n256)),
]
}
#[test]
fn test_value_blob_roundtrip_all_node_types() {
let parent = ArenaSlot::new(2, 10);
let ser_ctx = SerializationContext::new(parent);
let de_ctx = DeserializationContext::new(parent);
let value: &[u8] = &[0x2A, 0x00, 0xFF, 0x01, 0x10, 0x20, 0x30, 0x40];
for child_count in [0usize, 1, 3] {
for node in sample_nodes_with_children(parent, child_count) {
let node_ty = node.header().node_type;
let bytes = v2::append_node_value(
serialize_node_v2(&node, &ser_ctx).expect("serialize"),
Some(value),
);
assert!(
record_has_value_flag(&bytes),
"HAS_VALUE must be set for a valued record (type {node_ty}, {child_count} children)"
);
assert_eq!(
v2::read_node_value(&bytes).as_deref(),
Some(value),
"value bytes must round-trip exactly (type {node_ty}, {child_count} children)"
);
let restored = deserialize_node_v2(&bytes, &de_ctx).expect("deserialize");
assert_eq!(
restored.header().num_children,
node.header().num_children,
"structure must survive (type {node_ty}, {child_count} children)"
);
}
}
}
#[test]
fn test_value_less_record_byte_identical() {
let parent = ArenaSlot::new(5, 100);
let ser_ctx = SerializationContext::new(parent);
for child_count in [0usize, 1, 3, 5] {
for node in sample_nodes_with_children(parent, child_count) {
let node_ty = node.header().node_type;
let legacy = serialize_node_v2(&node, &ser_ctx).expect("legacy serialize");
let via_none = v2::append_node_value(legacy.clone(), None);
assert_eq!(
legacy, via_none,
"value-less record must be byte-identical to the legacy layout \
(type {node_ty}, {child_count} children)"
);
assert!(
!record_has_value_flag(&via_none),
"value-less record must NOT set HAS_VALUE (type {node_ty}, {child_count} children)"
);
assert!(
v2::read_node_value(&via_none).is_none(),
"value-less record must read back no value (type {node_ty}, {child_count} children)"
);
}
}
}
#[test]
fn test_legacy_value_less_record_reads_none() {
let parent = ArenaSlot::new(0, 7);
let ser_ctx = SerializationContext::new(parent);
let de_ctx = DeserializationContext::new(parent);
for child_count in [0usize, 2, 4] {
for node in sample_nodes_with_children(parent, child_count) {
let node_ty = node.header().node_type;
let legacy_bytes = serialize_node_v2(&node, &ser_ctx).expect("legacy serialize");
assert!(
v2::read_node_value(&legacy_bytes).is_none(),
"legacy value-less record must read back as no-value (type {node_ty})"
);
let restored = deserialize_node_v2(&legacy_bytes, &de_ctx).expect("legacy reader");
assert_eq!(restored.header().num_children, node.header().num_children);
}
}
}
#[test]
fn test_value_blob_empty_and_large() {
let parent = ArenaSlot::new(1, 1);
let ser_ctx = SerializationContext::new(parent);
let node = Node::N4(Box::new(Node4::new()));
let base = serialize_node_v2(&node, &ser_ctx).expect("serialize");
let empty = v2::append_node_value(base.clone(), Some(&[]));
assert!(
record_has_value_flag(&empty),
"empty value still sets HAS_VALUE"
);
assert_eq!(
v2::read_node_value(&empty),
Some(Vec::new()),
"empty value must round-trip as Some(empty), distinct from None"
);
let large: Vec<u8> = (0..4096u32).map(|i| (i % 251) as u8).collect();
let big = v2::append_node_value(base, Some(&large));
assert_eq!(
v2::read_node_value(&big).as_deref(),
Some(large.as_slice()),
"large value must round-trip exactly"
);
}
#[test]
fn test_valued_record_only_grows_by_value_blob() {
let parent = ArenaSlot::new(3, 30);
let ser_ctx = SerializationContext::new(parent);
let value: &[u8] = &[1, 2, 3, 4, 5, 6, 7];
for node in sample_nodes_with_children(parent, 2) {
let node_ty = node.header().node_type;
let less = serialize_node_v2(&node, &ser_ctx).expect("value-less");
let valued = v2::append_node_value(less.clone(), Some(value));
assert_eq!(
valued.len(),
less.len() + 4 + value.len(),
"valued record must grow by exactly the value blob (type {node_ty})"
);
assert_eq!(
&valued[..7],
&less[..7],
"header bytes before encoding_flags must be unchanged (type {node_ty})"
);
assert_eq!(
valued[7],
less[7] | encoding_flags::HAS_VALUE,
"encoding_flags must gain exactly the HAS_VALUE bit (type {node_ty})"
);
assert_eq!(
&valued[8..less.len()],
&less[8..],
"structural bytes after the flags byte must be unchanged (type {node_ty})"
);
}
}
}