use crate::compression::{CompressionType, decompress_with_prefix};
use crate::error::RiegeliError;
use crate::field_projection::FieldProjection;
use crate::proto_wire::{WireType, is_valid_proto_tag, tag_field_number, tag_wire_type};
use crate::simple_chunk::Chunk;
use crate::transpose::internal::{
SUBMESSAGE_WIRE_TYPE, has_data_buffer, has_subtype, message_id, subtype,
};
use crate::varint::{decode_u32, decode_u64, encode_u32};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CallbackType {
NoOp,
MessageStart,
SubmessageStart,
SubmessageEnd,
NonProto,
CopyTag,
Varint {
data_length: u8,
},
Fixed32,
Fixed64,
StringField,
Failure,
}
#[derive(Debug, Clone)]
struct StateMachineNode {
tag_data: Vec<u8>,
tag_data_size: usize,
callback_type: CallbackType,
buffer_index: Option<usize>,
next_node_index: usize,
is_implicit: bool,
}
#[derive(Debug)]
struct BufferCursor {
data: Vec<u8>,
pos: usize,
pruned: bool,
scratch: Vec<u8>,
}
impl BufferCursor {
fn new(data: Vec<u8>) -> Self {
Self {
data,
pos: 0,
pruned: false,
scratch: Vec::new(),
}
}
fn pruned() -> Self {
Self {
data: Vec::new(),
pos: 0,
pruned: true,
scratch: Vec::new(),
}
}
fn read_exact(&mut self, n: usize) -> Result<&[u8], RiegeliError> {
if self.pruned {
self.scratch.clear();
self.scratch.resize(n, 0u8);
return Ok(&self.scratch);
}
if self.pos + n > self.data.len() {
return Err(RiegeliError::MalformedData(
"buffer underflow in transpose chunk".to_string(),
));
}
let slice = &self.data[self.pos..self.pos + n];
self.pos += n;
Ok(slice)
}
fn read_varint32(&mut self) -> Result<u32, RiegeliError> {
if self.pruned {
return Ok(0);
}
let remaining = &self.data[self.pos..];
let (val, consumed) = decode_u32(remaining).map_err(|e| {
RiegeliError::MalformedData(format!("varint32 decode error in buffer: {e}"))
})?;
self.pos += consumed;
Ok(val)
}
fn read_varint64(&mut self) -> Result<u64, RiegeliError> {
if self.pruned {
return Ok(0);
}
let remaining = &self.data[self.pos..];
let (val, consumed) = decode_u64(remaining).map_err(|e| {
RiegeliError::MalformedData(format!("varint64 decode error in buffer: {e}"))
})?;
self.pos += consumed;
Ok(val)
}
fn read_byte(&mut self) -> Result<u8, RiegeliError> {
if self.pruned {
return Ok(0);
}
if self.pos >= self.data.len() {
return Err(RiegeliError::MalformedData(
"buffer underflow reading byte".to_string(),
));
}
let b = self.data[self.pos];
self.pos += 1;
Ok(b)
}
fn is_empty(&self) -> bool {
if self.pruned {
return true;
}
self.pos >= self.data.len()
}
fn save_pos(&self) -> usize {
self.pos
}
fn restore_pos(&mut self, saved: usize) {
self.pos = saved;
}
fn clone_for_needed_buffers(&self, _num_buffers: usize) -> BufferCursorSnapshot {
BufferCursorSnapshot {
data: self.data.clone(),
pos: self.pos,
}
}
}
struct BufferCursorSnapshot {
data: Vec<u8>,
pos: usize,
}
impl BufferCursorSnapshot {
fn read_varint32(&mut self) -> Option<u32> {
let remaining = &self.data[self.pos..];
let (val, consumed) = decode_u32(remaining).ok()?;
self.pos += consumed;
Some(val)
}
}
struct BackwardBuffer {
data: Vec<u8>,
}
impl BackwardBuffer {
fn new(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
}
}
fn pos(&self) -> usize {
self.data.len()
}
fn write(&mut self, bytes: &[u8]) {
for &b in bytes.iter().rev() {
self.data.push(b);
}
}
fn into_forward(mut self) -> Vec<u8> {
self.data.reverse();
self.data
}
}
struct SubmessageFrame {
end_pos: usize,
node_index: usize,
}
struct StateMetadata {
tags: Vec<u32>,
next_node_indices: Vec<u32>,
subtypes_bytes: Vec<u8>,
num_states: usize,
}
struct ParsedHeader {
buffers: Vec<BufferCursor>,
num_states: usize,
tags: Vec<u32>,
next_node_indices: Vec<u32>,
subtypes_bytes: Vec<u8>,
hdr: BufferCursor,
num_buffers: usize,
compression_type: CompressionType,
transitions_compressed: Vec<u8>,
}
struct BuiltStateMachine {
nodes: Vec<StateMachineNode>,
first_node: usize,
has_nonproto_op: bool,
}
pub struct TransposeChunkDecoder {
data: Vec<u8>,
limits: Vec<usize>,
next_yield: usize,
}
impl TransposeChunkDecoder {
pub fn new(chunk: Chunk) -> Result<Self, RiegeliError> {
Self::new_with_projection(chunk, None)
}
pub fn new_with_projection(
chunk: Chunk,
projection: Option<&FieldProjection>,
) -> Result<Self, RiegeliError> {
let num_records = chunk.header.num_records();
let decoded_data_size = chunk.header.decoded_data_size();
if num_records == 0 {
return Ok(Self {
data: Vec::new(),
limits: Vec::new(),
next_yield: 0,
});
}
let active_projection: Option<&FieldProjection> = match projection {
Some(p) if !p.is_all() => Some(p),
_ => None,
};
let mut parsed = Self::parse_header_and_buckets_with_projection(&chunk, active_projection)?;
let built = Self::build_state_machine_nodes(&mut parsed)?;
let nonproto_lengths_index = if built.has_nonproto_op {
if parsed.num_buffers == 0 {
return Err(RiegeliError::MalformedData(
"nonproto op but no buffers".to_string(),
));
}
Some(parsed.num_buffers - 1)
} else {
None
};
let transitions_data =
decompress_with_prefix(&parsed.transitions_compressed, parsed.compression_type)?;
let (decoded_data, limits) = decode_all_records(
num_records,
decoded_data_size,
&built.nodes,
&mut parsed.buffers,
&mut BufferCursor::new(transitions_data),
built.first_node,
nonproto_lengths_index,
)?;
let (data, limits) = if let Some(proj) = active_projection {
let mut proj_data: Vec<u8> = Vec::with_capacity(decoded_data.len());
let mut proj_limits: Vec<usize> = Vec::with_capacity(limits.len());
let mut prev = 0usize;
for end in limits {
let projected = proj.apply(&decoded_data[prev..end]);
proj_data.extend_from_slice(&projected);
proj_limits.push(proj_data.len());
prev = end;
}
(proj_data, proj_limits)
} else {
(decoded_data, limits)
};
Ok(Self {
data,
limits,
next_yield: 0,
})
}
fn parse_header_and_buckets_with_projection(
chunk: &Chunk,
projection: Option<&FieldProjection>,
) -> Result<ParsedHeader, RiegeliError> {
let data = &chunk.data;
let mut pos: usize = 0;
if data.is_empty() {
return Err(RiegeliError::MalformedData(
"transpose chunk data is empty".to_string(),
));
}
let compression_type = CompressionType::try_from(data[0])?;
pos += 1;
let (header_length, consumed) = decode_u64(&data[pos..])
.map_err(|e| RiegeliError::MalformedData(format!("reading header_length: {e}")))?;
pos += consumed;
let header_end = pos + header_length as usize;
if header_end > data.len() {
return Err(RiegeliError::MalformedData(
"transpose header extends past chunk data".to_string(),
));
}
let header_compressed = &data[pos..header_end];
pos = header_end;
let header_data = decompress_with_prefix(header_compressed, compression_type)?;
let mut hdr = BufferCursor::new(header_data);
let num_buckets = hdr.read_varint32()? as usize;
let num_buffers = hdr.read_varint32()? as usize;
let mut bucket_compressed_sizes = Vec::with_capacity(num_buckets);
for _ in 0..num_buckets {
bucket_compressed_sizes.push(hdr.read_varint64()? as usize);
}
let mut buffer_uncompressed_sizes = Vec::with_capacity(num_buffers);
for _ in 0..num_buffers {
buffer_uncompressed_sizes.push(hdr.read_varint64()? as usize);
}
let sm = Self::parse_state_metadata(&mut hdr)?;
let (bucket_compressed_data, new_pos) =
Self::read_bucket_data(data, pos, &bucket_compressed_sizes)?;
let transitions_compressed = data[new_pos..].to_vec();
let needed_buffers = if let Some(proj) = projection {
let buf_idx_scan_pos = hdr.save_pos();
let mut snap = hdr.clone_for_needed_buffers(num_buffers);
let result = Self::compute_needed_buffers_from_scan(
&sm.tags,
&sm.subtypes_bytes,
&mut snap,
num_buffers,
proj,
);
hdr.restore_pos(buf_idx_scan_pos);
result
} else {
vec![true; num_buffers]
};
let buffers = Self::decompress_into_buffers_with_pruning(
&bucket_compressed_data,
&buffer_uncompressed_sizes,
compression_type,
&needed_buffers,
)?;
Ok(ParsedHeader {
buffers,
num_states: sm.num_states,
tags: sm.tags,
next_node_indices: sm.next_node_indices,
subtypes_bytes: sm.subtypes_bytes,
hdr,
num_buffers,
compression_type,
transitions_compressed,
})
}
fn compute_needed_buffers_from_scan(
tags: &[u32],
subtypes_bytes: &[u8],
snap: &mut BufferCursorSnapshot,
num_buffers: usize,
projection: &FieldProjection,
) -> Vec<bool> {
if num_buffers == 0 {
return Vec::new();
}
let mut needed = vec![false; num_buffers];
let mut subtype_idx: usize = 0;
let mut has_nonproto = false;
for &raw_tag in tags {
let (reads_buffer, field_number_opt) = match raw_tag {
t if t == message_id::NO_OP
|| t == message_id::START_OF_MESSAGE
|| t == message_id::START_OF_SUBMESSAGE =>
{
(false, None)
}
t if t == message_id::NON_PROTO => {
has_nonproto = true;
(true, None) }
_ => {
let mut tag = raw_tag;
let wire_raw = tag & 7;
let st: u8 = if wire_raw == crate::transpose::internal::SUBMESSAGE_WIRE_TYPE {
tag = tag - crate::transpose::internal::SUBMESSAGE_WIRE_TYPE
+ WireType::LengthDelimited as u32;
crate::transpose::internal::subtype::LENGTH_DELIMITED_END_OF_SUBMESSAGE
} else if is_valid_proto_tag(tag) && has_subtype(tag) {
let s = subtypes_bytes.get(subtype_idx).copied().unwrap_or(0);
subtype_idx += 1;
s
} else {
crate::transpose::internal::subtype::TRIVIAL
};
if is_valid_proto_tag(tag) && has_data_buffer(tag, st) {
let fn_num = tag_field_number(tag);
(true, Some(fn_num))
} else {
(false, None)
}
}
};
if reads_buffer && let Some(buf_idx) = snap.read_varint32() {
let buf_idx = buf_idx as usize;
if buf_idx < num_buffers {
let needed_for_field = match field_number_opt {
None => true, Some(fn_num) => projection.includes_top_level_field(fn_num),
};
if needed_for_field {
needed[buf_idx] = true;
}
}
}
}
if has_nonproto && num_buffers > 0 {
needed[num_buffers - 1] = true;
}
needed
}
fn read_bucket_data(
data: &[u8],
mut pos: usize,
bucket_compressed_sizes: &[usize],
) -> Result<(Vec<Vec<u8>>, usize), RiegeliError> {
let mut bucket_compressed_data = Vec::with_capacity(bucket_compressed_sizes.len());
for &size in bucket_compressed_sizes {
let end = pos + size;
if end > data.len() {
return Err(RiegeliError::MalformedData(
"bucket data extends past chunk data".to_string(),
));
}
bucket_compressed_data.push(data[pos..end].to_vec());
pos = end;
}
Ok((bucket_compressed_data, pos))
}
fn decompress_into_buffers_with_pruning(
bucket_compressed_data: &[Vec<u8>],
buffer_uncompressed_sizes: &[usize],
compression_type: CompressionType,
needed_buffers: &[bool],
) -> Result<Vec<BufferCursor>, RiegeliError> {
let num_buckets = bucket_compressed_data.len();
let num_buffers = buffer_uncompressed_sizes.len();
let mut buffers: Vec<BufferCursor> = Vec::with_capacity(num_buffers);
let mut bucket_index: usize = 0;
let mut bucket_decompressed: Vec<u8> = Vec::new();
let mut bucket_pos: usize = 0;
if num_buckets > 0 && num_buffers > 0 {
bucket_decompressed =
decompress_with_prefix(&bucket_compressed_data[0], compression_type)?;
}
for (i, &buf_size) in buffer_uncompressed_sizes.iter().enumerate() {
while bucket_pos >= bucket_decompressed.len() && bucket_index + 1 < num_buckets {
bucket_index += 1;
bucket_decompressed = decompress_with_prefix(
&bucket_compressed_data[bucket_index],
compression_type,
)?;
bucket_pos = 0;
}
let end = bucket_pos + buf_size;
if end > bucket_decompressed.len() {
return Err(RiegeliError::MalformedData(format!(
"buffer {} (size {}) exceeds bucket {} data (len {})",
i,
buf_size,
bucket_index,
bucket_decompressed.len()
)));
}
let is_needed = needed_buffers.get(i).copied().unwrap_or(true);
if is_needed {
buffers.push(BufferCursor::new(
bucket_decompressed[bucket_pos..end].to_vec(),
));
} else {
buffers.push(BufferCursor::pruned());
}
bucket_pos = end;
}
Ok(buffers)
}
fn parse_state_metadata(hdr: &mut BufferCursor) -> Result<StateMetadata, RiegeliError> {
let num_states = hdr.read_varint32()? as usize;
let mut tags = Vec::with_capacity(num_states);
for _ in 0..num_states {
tags.push(hdr.read_varint32()?);
}
let mut next_node_indices = Vec::with_capacity(num_states);
for _ in 0..num_states {
next_node_indices.push(hdr.read_varint32()?);
}
let mut num_subtypes = 0usize;
for &tag in &tags {
if is_valid_proto_tag(tag) && has_subtype(tag) {
num_subtypes += 1;
}
}
let mut subtypes_bytes = Vec::with_capacity(num_subtypes);
for _ in 0..num_subtypes {
subtypes_bytes.push(hdr.read_byte()?);
}
Ok(StateMetadata {
tags,
next_node_indices,
subtypes_bytes,
num_states,
})
}
fn build_state_machine_nodes(
parsed: &mut ParsedHeader,
) -> Result<BuiltStateMachine, RiegeliError> {
let num_states = parsed.num_states;
let num_buffers = parsed.num_buffers;
let hdr = &mut parsed.hdr;
let mut nodes: Vec<StateMachineNode> = Vec::with_capacity(num_states + 0xFF);
let mut subtype_idx: usize = 0;
let mut has_nonproto_op = false;
for i in 0..num_states {
let raw_tag = parsed.tags[i];
let next_raw = parsed.next_node_indices[i] as usize;
let (is_implicit, next_node_idx) = if next_raw >= num_states {
let adjusted = next_raw - num_states;
if adjusted >= num_states {
return Err(RiegeliError::MalformedData(format!(
"node index {} too large (num_states={})",
adjusted, num_states
)));
}
(true, adjusted)
} else {
(false, next_raw)
};
let node = Self::build_single_node(
raw_tag,
next_node_idx,
is_implicit,
i,
hdr,
num_buffers,
&parsed.subtypes_bytes,
&mut subtype_idx,
&mut has_nonproto_op,
)?;
nodes.push(node);
}
let first_node = hdr.read_varint32()? as usize;
if num_states > 0 && first_node >= num_states {
return Err(RiegeliError::MalformedData(format!(
"first_node {} >= num_states {}",
first_node, num_states
)));
}
for _ in 0..0xFF_usize {
nodes.push(StateMachineNode {
tag_data: Vec::new(),
tag_data_size: 0,
callback_type: CallbackType::Failure,
buffer_index: None,
next_node_index: 0,
is_implicit: false,
});
}
if contains_implicit_loop(&nodes, num_states) {
return Err(RiegeliError::MalformedData(
"state machine contains an implicit loop".to_string(),
));
}
Ok(BuiltStateMachine {
nodes,
first_node,
has_nonproto_op,
})
}
#[allow(clippy::too_many_arguments)]
fn build_single_node(
raw_tag: u32,
next_node_idx: usize,
is_implicit: bool,
state_index: usize,
hdr: &mut BufferCursor,
num_buffers: usize,
subtypes_bytes: &[u8],
subtype_idx: &mut usize,
has_nonproto_op: &mut bool,
) -> Result<StateMachineNode, RiegeliError> {
match raw_tag {
t if t == message_id::NO_OP => {
return Ok(StateMachineNode {
tag_data: Vec::new(),
tag_data_size: 0,
callback_type: CallbackType::NoOp,
buffer_index: None,
next_node_index: next_node_idx,
is_implicit,
});
}
t if t == message_id::NON_PROTO => {
let buf_idx = hdr.read_varint32()? as usize;
if buf_idx >= num_buffers {
return Err(RiegeliError::MalformedData(
"nonproto buffer index too large".to_string(),
));
}
*has_nonproto_op = true;
return Ok(StateMachineNode {
tag_data: Vec::new(),
tag_data_size: 0,
callback_type: CallbackType::NonProto,
buffer_index: Some(buf_idx),
next_node_index: next_node_idx,
is_implicit,
});
}
t if t == message_id::START_OF_MESSAGE => {
return Ok(StateMachineNode {
tag_data: Vec::new(),
tag_data_size: 0,
callback_type: CallbackType::MessageStart,
buffer_index: None,
next_node_index: next_node_idx,
is_implicit,
});
}
t if t == message_id::START_OF_SUBMESSAGE => {
return Ok(StateMachineNode {
tag_data: Vec::new(),
tag_data_size: 0,
callback_type: CallbackType::SubmessageStart,
buffer_index: None,
next_node_index: next_node_idx,
is_implicit,
});
}
_ => {}
}
Self::build_proto_tag_node(
raw_tag,
next_node_idx,
is_implicit,
state_index,
hdr,
num_buffers,
subtypes_bytes,
subtype_idx,
)
}
#[allow(clippy::too_many_arguments)]
fn build_proto_tag_node(
raw_tag: u32,
next_node_idx: usize,
is_implicit: bool,
state_index: usize,
hdr: &mut BufferCursor,
num_buffers: usize,
subtypes_bytes: &[u8],
subtype_idx: &mut usize,
) -> Result<StateMachineNode, RiegeliError> {
let mut tag = raw_tag;
let mut st: u8 = subtype::TRIVIAL;
let wire_raw = tag & 7;
if wire_raw == SUBMESSAGE_WIRE_TYPE {
tag = tag - SUBMESSAGE_WIRE_TYPE + WireType::LengthDelimited as u32;
st = subtype::LENGTH_DELIMITED_END_OF_SUBMESSAGE;
}
if !is_valid_proto_tag(tag) {
return Err(RiegeliError::MalformedData(format!(
"invalid tag {} in state {}",
tag, state_index
)));
}
let tag_bytes = encode_u32(tag);
let tag_length = tag_bytes.len();
if has_subtype(tag) {
st = subtypes_bytes[*subtype_idx];
*subtype_idx += 1;
}
let buf_idx = if has_data_buffer(tag, st) {
let idx = hdr.read_varint32()? as usize;
if idx >= num_buffers {
return Err(RiegeliError::MalformedData(
"buffer index too large".to_string(),
));
}
Some(idx)
} else {
None
};
let wt = tag_wire_type(tag);
let callback_type = Self::callback_for_wire_type(wt, st)?;
let mut tag_data_vec = tag_bytes;
let tag_data_size = if wt == Some(WireType::Varint) && st >= subtype::VARINT_INLINE_0 {
tag_data_vec.push(st - subtype::VARINT_INLINE_0);
tag_length + 1
} else {
tag_length
};
Ok(StateMachineNode {
tag_data: tag_data_vec,
tag_data_size,
callback_type,
buffer_index: buf_idx,
next_node_index: next_node_idx,
is_implicit,
})
}
fn callback_for_wire_type(wt: Option<WireType>, st: u8) -> Result<CallbackType, RiegeliError> {
match wt {
Some(WireType::Varint) => {
if st >= subtype::VARINT_INLINE_0 {
Ok(CallbackType::CopyTag)
} else {
Ok(CallbackType::Varint {
data_length: st + 1,
})
}
}
Some(WireType::Fixed32) => Ok(CallbackType::Fixed32),
Some(WireType::Fixed64) => Ok(CallbackType::Fixed64),
Some(WireType::LengthDelimited) => match st {
s if s == subtype::LENGTH_DELIMITED_STRING => Ok(CallbackType::StringField),
s if s == subtype::LENGTH_DELIMITED_END_OF_SUBMESSAGE => {
Ok(CallbackType::SubmessageEnd)
}
_ => Err(RiegeliError::MalformedData(format!(
"unknown LengthDelimited subtype {st}"
))),
},
Some(WireType::StartGroup) | Some(WireType::EndGroup) => Ok(CallbackType::CopyTag),
None => Err(RiegeliError::MalformedData(
"invalid wire type in tag".to_string(),
)),
}
}
pub fn read_record(&mut self) -> Result<Option<Vec<u8>>, RiegeliError> {
if self.next_yield >= self.limits.len() {
return Ok(None);
}
let start = if self.next_yield == 0 {
0
} else {
self.limits[self.next_yield - 1]
};
let end = self.limits[self.next_yield];
self.next_yield += 1;
Ok(Some(self.data[start..end].to_vec()))
}
}
struct DecodeState<'a> {
dest: &'a mut BackwardBuffer,
buffers: &'a mut [BufferCursor],
submessage_stack: &'a mut Vec<SubmessageFrame>,
limits: &'a mut Vec<usize>,
num_records: u64,
nonproto_lengths_index: Option<usize>,
}
fn execute_node_action(
node: &StateMachineNode,
current_node_idx: usize,
nodes: &[StateMachineNode],
state: &mut DecodeState<'_>,
) -> Result<(), RiegeliError> {
match node.callback_type {
CallbackType::NoOp => {}
CallbackType::MessageStart => {
if !state.submessage_stack.is_empty() {
return Err(RiegeliError::MalformedData(
"submessages still open at record boundary".into(),
));
}
if state.limits.len() as u64 == state.num_records {
return Err(RiegeliError::MalformedData("too many records".into()));
}
state.limits.push(state.dest.pos());
}
CallbackType::SubmessageEnd => {
state.submessage_stack.push(SubmessageFrame {
end_pos: state.dest.pos(),
node_index: current_node_idx,
});
}
CallbackType::SubmessageStart => {
write_submessage_header(state.dest, nodes, state.submessage_stack)?;
}
CallbackType::NonProto => {
write_nonproto_record(node, state)?;
}
CallbackType::CopyTag => {
state.dest.write(&node.tag_data[..node.tag_data_size]);
}
CallbackType::Varint { data_length } => {
write_varint_field(node, state.dest, state.buffers, data_length)?;
}
CallbackType::Fixed32 => {
write_fixed_field(node, state.dest, state.buffers, 4)?;
}
CallbackType::Fixed64 => {
write_fixed_field(node, state.dest, state.buffers, 8)?;
}
CallbackType::StringField => {
write_string_field(node, state.dest, state.buffers)?;
}
CallbackType::Failure => {
return Err(RiegeliError::MalformedData(
"hit failure node in state machine".into(),
));
}
}
Ok(())
}
fn write_submessage_header(
dest: &mut BackwardBuffer,
nodes: &[StateMachineNode],
submessage_stack: &mut Vec<SubmessageFrame>,
) -> Result<(), RiegeliError> {
let frame = submessage_stack
.pop()
.ok_or_else(|| RiegeliError::MalformedData("submessage stack underflow".into()))?;
if dest.pos() < frame.end_pos {
return Err(RiegeliError::MalformedData(
"destination position decreased".into(),
));
}
let length = dest.pos() - frame.end_pos;
if length > u32::MAX as usize {
return Err(RiegeliError::MalformedData("submessage too large".into()));
}
let submsg_node = &nodes[frame.node_index];
let tag_bytes = &submsg_node.tag_data[..submsg_node.tag_data_size];
let len_varint = encode_u32(length as u32);
let mut hdr = Vec::with_capacity(tag_bytes.len() + len_varint.len());
hdr.extend_from_slice(tag_bytes);
hdr.extend_from_slice(&len_varint);
dest.write(&hdr);
Ok(())
}
fn write_nonproto_record(
node: &StateMachineNode,
state: &mut DecodeState<'_>,
) -> Result<(), RiegeliError> {
let lengths_idx = state
.nonproto_lengths_index
.ok_or_else(|| RiegeliError::MalformedData("nonproto op but no lengths buffer".into()))?;
let length = state.buffers[lengths_idx].read_varint32()? as usize;
let data_idx = node
.buffer_index
.ok_or_else(|| RiegeliError::MalformedData("nonproto node missing buffer index".into()))?;
let data_bytes = state.buffers[data_idx].read_exact(length)?.to_vec();
state.dest.write(&data_bytes);
if !state.submessage_stack.is_empty() {
return Err(RiegeliError::MalformedData(
"submessages still open at nonproto record".into(),
));
}
if state.limits.len() as u64 == state.num_records {
return Err(RiegeliError::MalformedData("too many records".into()));
}
state.limits.push(state.dest.pos());
Ok(())
}
fn write_varint_field(
node: &StateMachineNode,
dest: &mut BackwardBuffer,
buffers: &mut [BufferCursor],
data_length: u8,
) -> Result<(), RiegeliError> {
let buf_idx = node
.buffer_index
.ok_or_else(|| RiegeliError::MalformedData("varint node missing buffer index".into()))?;
let raw = buffers[buf_idx].read_exact(data_length as usize)?.to_vec();
let tag_size = node.tag_data_size;
let mut combined = Vec::with_capacity(tag_size + raw.len());
combined.extend_from_slice(&node.tag_data[..tag_size]);
for (j, &b) in raw.iter().enumerate() {
combined.push(if j < raw.len() - 1 { b | 0x80 } else { b });
}
dest.write(&combined);
Ok(())
}
fn write_fixed_field(
node: &StateMachineNode,
dest: &mut BackwardBuffer,
buffers: &mut [BufferCursor],
n: usize,
) -> Result<(), RiegeliError> {
let buf_idx = node.buffer_index.ok_or_else(|| {
RiegeliError::MalformedData("fixed field node missing buffer index".into())
})?;
let data = buffers[buf_idx].read_exact(n)?.to_vec();
let tag_size = node.tag_data_size;
let mut combined = Vec::with_capacity(tag_size + n);
combined.extend_from_slice(&node.tag_data[..tag_size]);
combined.extend_from_slice(&data);
dest.write(&combined);
Ok(())
}
fn write_string_field(
node: &StateMachineNode,
dest: &mut BackwardBuffer,
buffers: &mut [BufferCursor],
) -> Result<(), RiegeliError> {
let buf_idx = node
.buffer_index
.ok_or_else(|| RiegeliError::MalformedData("string node missing buffer index".into()))?;
let str_len = buffers[buf_idx].read_varint32()? as usize;
let str_data = buffers[buf_idx].read_exact(str_len)?.to_vec();
let len_varint = encode_u32(str_len as u32);
let tag_size = node.tag_data_size;
let mut combined = Vec::with_capacity(tag_size + len_varint.len() + str_data.len());
combined.extend_from_slice(&node.tag_data[..tag_size]);
combined.extend_from_slice(&len_varint);
combined.extend_from_slice(&str_data);
dest.write(&combined);
Ok(())
}
fn advance_state_machine(
current_node_idx: usize,
num_iters: i32,
nodes: &[StateMachineNode],
transitions: &mut BufferCursor,
) -> Result<(usize, i32, bool), RiegeliError> {
let mut idx = nodes[current_node_idx].next_node_index;
if num_iters == 0 {
if transitions.is_empty() {
return Ok((idx, 0, true));
}
let tb = transitions.read_byte()?;
let offset = (tb >> 2) as usize;
let repeat = (tb & 3) as i32;
idx += offset;
if idx >= nodes.len() {
return Err(RiegeliError::MalformedData(
"transition offset overflow".into(),
));
}
let iters = repeat + if nodes[idx].is_implicit { 1 } else { 0 };
Ok((idx, iters, false))
} else {
let iters = num_iters - if nodes[idx].is_implicit { 0 } else { 1 };
Ok((idx, iters, false))
}
}
fn run_state_machine(
num_records: u64,
decoded_data_size: u64,
nodes: &[StateMachineNode],
buffers: &mut [BufferCursor],
transitions: &mut BufferCursor,
first_node: usize,
nonproto_lengths_index: Option<usize>,
) -> Result<(BackwardBuffer, Vec<usize>), RiegeliError> {
let mut dest = BackwardBuffer::new(decoded_data_size as usize);
let mut limits: Vec<usize> = Vec::with_capacity(num_records as usize);
let mut submessage_stack: Vec<SubmessageFrame> = Vec::new();
let mut current_node_idx = first_node;
let mut num_iters: i32 = if nodes[current_node_idx].is_implicit {
1
} else {
0
};
{
let mut state = DecodeState {
dest: &mut dest,
buffers,
submessage_stack: &mut submessage_stack,
limits: &mut limits,
num_records,
nonproto_lengths_index,
};
loop {
execute_node_action(
&nodes[current_node_idx],
current_node_idx,
nodes,
&mut state,
)?;
let (next, iters, done) =
advance_state_machine(current_node_idx, num_iters, nodes, transitions)?;
if done {
break;
}
current_node_idx = next;
num_iters = iters;
}
}
if !submessage_stack.is_empty() {
return Err(RiegeliError::MalformedData(
"submessages still open after decode".into(),
));
}
if limits.len() as u64 != num_records {
return Err(RiegeliError::MalformedData(format!(
"expected {} records, got {}",
num_records,
limits.len()
)));
}
Ok((dest, limits))
}
fn finalize_records(
dest: BackwardBuffer,
mut limits: Vec<usize>,
num_records: u64,
) -> Result<(Vec<u8>, Vec<usize>), RiegeliError> {
let total_size = dest.pos();
let n = limits.len();
if let Some(&last_limit) = limits.last()
&& last_limit != total_size
{
return Err(RiegeliError::MalformedData(format!(
"last limit {} != total size {}",
last_limit, total_size
)));
}
{
let size = total_size;
let (mut first, mut last) = (0usize, n);
if first != last {
last -= 1;
while first < last {
last -= 1;
let tmp = size - limits[first];
limits[first] = size - limits[last];
limits[last] = tmp;
first += 1;
}
}
}
let mut prev = 0usize;
for &end in &limits {
if end > total_size || prev > end {
return Err(RiegeliError::MalformedData(
"record boundary out of range".into(),
));
}
prev = end;
}
let _ = num_records;
Ok((dest.into_forward(), limits))
}
fn decode_all_records(
num_records: u64,
decoded_data_size: u64,
nodes: &[StateMachineNode],
buffers: &mut [BufferCursor],
transitions: &mut BufferCursor,
first_node: usize,
nonproto_lengths_index: Option<usize>,
) -> Result<(Vec<u8>, Vec<usize>), RiegeliError> {
let (dest, limits) = run_state_machine(
num_records,
decoded_data_size,
nodes,
buffers,
transitions,
first_node,
nonproto_lengths_index,
)?;
finalize_records(dest, limits, num_records)
}
fn contains_implicit_loop(nodes: &[StateMachineNode], _num_states: usize) -> bool {
let total = nodes.len();
let mut loop_ids = vec![0usize; total];
let mut next_id: usize = 1;
for i in 0..total {
if loop_ids[i] != 0 {
continue;
}
let mut idx = i;
loop_ids[idx] = next_id;
while nodes[idx].is_implicit {
idx = nodes[idx].next_node_index;
if idx >= total {
break;
}
if loop_ids[idx] == next_id {
return true;
}
if loop_ids[idx] != 0 {
break;
}
loop_ids[idx] = next_id;
}
next_id += 1;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chunk_header::{ChunkHeader, ChunkType};
use crate::varint::encode_u64;
fn build_transpose_chunk(
compression: CompressionType,
num_records: u64,
decoded_data_size: u64,
states: &[TestState],
buffers_data: &[Vec<u8>],
transitions: &[u8],
first_node: u32,
) -> Chunk {
let mut header_bytes: Vec<u8> = Vec::new();
let num_buckets: u32 = if buffers_data.is_empty() { 0 } else { 1 };
let num_buffers = buffers_data.len() as u32;
header_bytes.extend_from_slice(&encode_u32(num_buckets));
header_bytes.extend_from_slice(&encode_u32(num_buffers));
let total_buf_size: usize = buffers_data.iter().map(|b| b.len()).sum();
if num_buckets > 0 {
header_bytes.extend_from_slice(&encode_u64(total_buf_size as u64));
}
for buf in buffers_data {
header_bytes.extend_from_slice(&encode_u64(buf.len() as u64));
}
let num_states = states.len() as u32;
header_bytes.extend_from_slice(&encode_u32(num_states));
for state in states {
header_bytes.extend_from_slice(&encode_u32(state.tag));
}
for state in states {
header_bytes.extend_from_slice(&encode_u32(state.next_node));
}
for state in states {
let tag = state.tag;
if is_valid_proto_tag(tag) && has_subtype(tag) {
header_bytes.push(state.subtype);
}
}
for state in states {
let mut tag = state.tag;
let mut st = state.subtype;
if tag < 8 {
if tag == message_id::NON_PROTO {
header_bytes.extend_from_slice(&encode_u32(state.buffer_index));
}
continue;
}
let wire_raw = tag & 7;
if wire_raw == SUBMESSAGE_WIRE_TYPE {
tag = tag - SUBMESSAGE_WIRE_TYPE + WireType::LengthDelimited as u32;
st = subtype::LENGTH_DELIMITED_END_OF_SUBMESSAGE;
}
if has_data_buffer(tag, st) {
header_bytes.extend_from_slice(&encode_u32(state.buffer_index));
}
}
header_bytes.extend_from_slice(&encode_u32(first_node));
let mut chunk_data: Vec<u8> = Vec::new();
chunk_data.push(compression as u8);
chunk_data.extend_from_slice(&encode_u64(header_bytes.len() as u64));
chunk_data.extend_from_slice(&header_bytes);
for buf in buffers_data {
chunk_data.extend_from_slice(buf);
}
chunk_data.extend_from_slice(transitions);
let chunk_header = ChunkHeader::from_parts(
&chunk_data,
ChunkType::Transposed,
num_records,
decoded_data_size,
);
Chunk {
header: chunk_header,
data: chunk_data,
}
}
struct TestState {
tag: u32,
next_node: u32,
subtype: u8,
buffer_index: u32,
}
impl TestState {
fn new(tag: u32, next_node: u32, subtype: u8, buffer_index: u32) -> Self {
Self {
tag,
next_node,
subtype,
buffer_index,
}
}
}
#[test]
fn test_zero_records() {
let mut header_bytes: Vec<u8> = Vec::new();
header_bytes.extend_from_slice(&encode_u32(0)); header_bytes.extend_from_slice(&encode_u32(0)); header_bytes.extend_from_slice(&encode_u32(0)); header_bytes.extend_from_slice(&encode_u32(0));
let mut chunk_data: Vec<u8> = Vec::new();
chunk_data.push(0x00); chunk_data.extend_from_slice(&encode_u64(header_bytes.len() as u64));
chunk_data.extend_from_slice(&header_bytes);
let chunk_header = ChunkHeader::from_parts(&chunk_data, ChunkType::Transposed, 0, 0);
let chunk = Chunk {
header: chunk_header,
data: chunk_data,
};
let mut dec = TransposeChunkDecoder::new(chunk).expect("new ok");
assert!(dec.read_record().unwrap().is_none());
}
#[test]
fn test_single_nonproto_record() {
let nonproto_data = b"hello".to_vec();
let mut nonproto_lengths = Vec::new();
nonproto_lengths.extend_from_slice(&encode_u32(5));
let states = vec![TestState::new(message_id::NON_PROTO, 0, 0, 0)];
let chunk = build_transpose_chunk(
CompressionType::None,
1,
5,
&states,
&[nonproto_data, nonproto_lengths],
&[],
0,
);
let mut dec = TransposeChunkDecoder::new(chunk).expect("new ok");
let rec = dec.read_record().unwrap().expect("should have record");
assert_eq!(rec, b"hello");
assert!(dec.read_record().unwrap().is_none());
}
#[test]
fn test_single_proto_record() {
let expected: Vec<u8> = vec![
0x08, 0x2A, 0x15, 0x04, 0x03, 0x02, 0x01, 0x19, 0x08, 0x07, 0x06, 0x05, 0x04, 0x03,
0x02, 0x01, 0x22, 0x03, 0x61, 0x62, 0x63, 0x2A, 0x02, 0x08, 0x07,
];
let buf0 = vec![0x2A]; let buf1 = vec![0x04, 0x03, 0x02, 0x01]; let buf2 = vec![0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]; let buf3 = vec![0x03, 0x61, 0x62, 0x63]; let buf4 = vec![0x07];
let num_states = 8u32;
let states = vec![
TestState::new(0x2E, num_states + 1, 0, 0),
TestState::new(0x08, num_states + 2, subtype::VARINT_1, 4),
TestState::new(message_id::START_OF_SUBMESSAGE, num_states + 3, 0, 0),
TestState::new(0x22, num_states + 4, 0, 3),
TestState::new(0x19, num_states + 5, 0, 2),
TestState::new(0x15, num_states + 6, 0, 1),
TestState::new(0x08, num_states + 7, subtype::VARINT_1, 0),
TestState::new(message_id::START_OF_MESSAGE, 0, 0, 0),
];
let chunk = build_transpose_chunk(
CompressionType::None,
1,
expected.len() as u64,
&states,
&[buf0, buf1, buf2, buf3, buf4],
&[], 0, );
let mut dec = TransposeChunkDecoder::new(chunk).expect("new ok");
let rec = dec.read_record().unwrap().expect("should have record");
assert_eq!(rec, expected, "proto record mismatch");
assert!(dec.read_record().unwrap().is_none());
}
#[test]
fn test_mixed_proto_nonproto() {
let buf0 = vec![0x01];
let buf1 = b"xyz".to_vec();
let mut buf2 = Vec::new();
buf2.extend_from_slice(&encode_u32(3));
let num_states = 3u32;
let states = vec![
TestState::new(message_id::NON_PROTO, num_states + 1, 0, 1),
TestState::new(0x08, num_states + 2, subtype::VARINT_1, 0),
TestState::new(message_id::START_OF_MESSAGE, 0, 0, 0),
];
let chunk = build_transpose_chunk(
CompressionType::None,
2,
5, &states,
&[buf0, buf1, buf2],
&[], 0,
);
let mut dec = TransposeChunkDecoder::new(chunk).expect("new ok");
let rec0 = dec.read_record().unwrap().expect("should have record 0");
let rec1 = dec.read_record().unwrap().expect("should have record 1");
assert!(dec.read_record().unwrap().is_none());
assert_eq!(rec0, vec![0x08, 0x01], "record 0 should be proto");
assert_eq!(rec1, b"xyz", "record 1 should be nonproto");
}
#[test]
fn test_varint_high_bit_restoration() {
let expected: Vec<u8> = vec![0x08, 0xAC, 0x02];
let buf0 = vec![0x2C, 0x02];
let num_states = 2u32;
let states = vec![
TestState::new(0x08, num_states + 1, subtype::VARINT_1 + 1, 0),
TestState::new(message_id::START_OF_MESSAGE, 0, 0, 0),
];
let chunk = build_transpose_chunk(
CompressionType::None,
1,
expected.len() as u64,
&states,
&[buf0],
&[],
0,
);
let mut dec = TransposeChunkDecoder::new(chunk).expect("new ok");
let rec = dec.read_record().unwrap().expect("record");
assert_eq!(
rec, expected,
"multi-byte varint high-bit restoration failed"
);
}
#[test]
fn test_inline_varint() {
let expected: Vec<u8> = vec![0x08, 0x00];
let num_states = 2u32;
let states = vec![
TestState::new(0x08, num_states + 1, subtype::VARINT_INLINE_0, 0),
TestState::new(message_id::START_OF_MESSAGE, 0, 0, 0),
];
let chunk = build_transpose_chunk(
CompressionType::None,
1,
expected.len() as u64,
&states,
&[], &[],
0,
);
let mut dec = TransposeChunkDecoder::new(chunk).expect("new ok");
let rec = dec.read_record().unwrap().expect("record");
assert_eq!(rec, expected, "inline varint mismatch");
}
#[test]
fn test_corrupted_bucket() {
let nonproto_data = b"hello".to_vec();
let mut nonproto_lengths = Vec::new();
nonproto_lengths.extend_from_slice(&encode_u32(5));
let states = vec![TestState::new(message_id::NON_PROTO, 0, 0, 0)];
let mut chunk = build_transpose_chunk(
CompressionType::None,
1,
5,
&states,
&[nonproto_data, nonproto_lengths],
&[],
0,
);
let data_len = chunk.data.len();
if data_len > 20 {
chunk.data[data_len - 3] ^= 0xFF;
chunk.data[data_len - 2] ^= 0xFF;
}
chunk.header = ChunkHeader::from_parts(&chunk.data, ChunkType::Transposed, 1, 5);
let result = TransposeChunkDecoder::new(chunk);
match result {
Ok(mut dec) => {
let _ = dec.read_record();
}
Err(e) => {
match e {
RiegeliError::MalformedData(_) => {}
other => panic!("expected MalformedData, got {other:?}"),
}
}
}
}
#[test]
fn test_multiple_records_with_transitions() {
let expected_records: Vec<Vec<u8>> =
vec![vec![0x08, 0x05], vec![0x08, 0x0A], vec![0x08, 0x2A]];
let buf0 = vec![0x2A, 0x0A, 0x05];
let num_states = 2u32;
let states = vec![
TestState::new(0x08, num_states + 1, subtype::VARINT_1, 0),
TestState::new(message_id::START_OF_MESSAGE, 0, 0, 0),
];
let transitions = vec![0x01];
let total_decoded: u64 = expected_records.iter().map(|r| r.len() as u64).sum();
let chunk = build_transpose_chunk(
CompressionType::None,
3,
total_decoded,
&states,
&[buf0],
&transitions,
0,
);
let mut dec = TransposeChunkDecoder::new(chunk).expect("new ok");
for (i, expected) in expected_records.iter().enumerate() {
let rec = dec
.read_record()
.unwrap()
.unwrap_or_else(|| panic!("expected record {i}"));
assert_eq!(&rec, expected, "record {i} mismatch");
}
assert!(dec.read_record().unwrap().is_none());
}
fn roundtrip_via_encoder(records: &[&[u8]], compression: CompressionType) -> Vec<Vec<u8>> {
use crate::transpose::encoder::TransposeChunkEncoder;
let mut enc = TransposeChunkEncoder::new(compression);
for rec in records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut out = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
out.push(rec);
}
out
}
#[test]
fn test_zero_records_via_encoder() {
use crate::chunk_header::ChunkType;
use crate::transpose::encoder::TransposeChunkEncoder;
let enc = TransposeChunkEncoder::new(CompressionType::None);
let chunk = enc.encode().expect("encode");
assert_eq!(chunk.header.num_records(), 0);
assert_eq!(chunk.header.chunk_type().unwrap(), ChunkType::Transposed);
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
assert!(
dec.read_record().unwrap().is_none(),
"first call should be None"
);
assert!(
dec.read_record().unwrap().is_none(),
"second call should also be None"
);
}
#[test]
fn test_nonproto_binary() {
let record: Vec<u8> = vec![0x0F, 0xDE, 0xAD, 0xBE, 0xEF];
let result = roundtrip_via_encoder(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_nonproto_single_byte() {
let record = vec![0xFF];
let result = roundtrip_via_encoder(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_nonproto_empty_record_roundtrip() {
let record: Vec<u8> = vec![];
let result = roundtrip_via_encoder(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_all_wire_types_roundtrip() {
let record: Vec<u8> = vec![
0x08, 0x2A, 0x15, 0x01, 0x02, 0x03, 0x04, 0x19, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
0x07, 0x08, 0x22, 0x03, 0x61, 0x62, 0x63, 0x2A, 0x02, 0x08, 0x07,
];
let result = roundtrip_via_encoder(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(
result[0], record,
"proto with all wire types must round-trip exactly"
);
}
#[test]
fn test_inline_varint_0_through_3() {
for val in 0u8..=3 {
let record = vec![0x08, val];
let result = roundtrip_via_encoder(&[&record], CompressionType::None);
assert_eq!(result.len(), 1, "inline varint {val}");
assert_eq!(result[0], record, "inline varint {val} mismatch");
}
}
#[test]
fn test_max_varint64_roundtrip() {
let record: Vec<u8> = vec![
0x08, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01,
];
let result = roundtrip_via_encoder(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record, "max varint64 must round-trip");
}
#[test]
fn test_multiple_fields_same_type() {
let record: Vec<u8> = vec![0x08, 0x0A, 0x10, 0x14];
let result = roundtrip_via_encoder(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_mixed_proto_nonproto_interleaved() {
let proto1 = vec![0x08, 0x2A];
let nonproto = vec![0xFF, 0xAA, 0xBB];
let proto2 = vec![0x10, 0x01];
let result = roundtrip_via_encoder(&[&proto1, &nonproto, &proto2], CompressionType::None);
assert_eq!(result.len(), 3);
assert_eq!(result[0], proto1, "record 0 proto mismatch");
assert_eq!(result[1], nonproto, "record 1 nonproto mismatch");
assert_eq!(result[2], proto2, "record 2 proto mismatch");
}
#[test]
fn test_many_mixed_records() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..20 {
if i % 3 == 0 {
records.push(vec![0xFF, i as u8]);
} else {
let mut rec = vec![0x08];
rec.extend_from_slice(&encode_u64(i as u64));
records.push(rec);
}
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip_via_encoder(&refs, CompressionType::None);
assert_eq!(result.len(), records.len());
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_corrupted_bucket_no_panic() {
use crate::chunk_header::ChunkHeader;
use crate::transpose::encoder::TransposeChunkEncoder;
let record = vec![0x08, 0x2A];
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
enc.add_record(&record).unwrap();
let mut chunk = enc.encode().unwrap();
let len = chunk.data.len();
if len > 10 {
for i in (len - 5)..len {
chunk.data[i] ^= 0xFF;
}
}
chunk.header = ChunkHeader::from_parts(
&chunk.data,
crate::chunk_header::ChunkType::Transposed,
1,
2,
);
let result = std::panic::catch_unwind(|| match TransposeChunkDecoder::new(chunk) {
Ok(mut dec) => {
let _ = dec.read_record();
}
Err(_) => {}
});
assert!(result.is_ok(), "corrupted bucket must not panic");
}
#[test]
fn test_truncated_chunk_data_no_panic() {
use crate::chunk_header::ChunkHeader;
use crate::chunk_header::ChunkType;
use crate::simple_chunk::Chunk;
use crate::transpose::encoder::TransposeChunkEncoder;
let record = vec![0x08, 0x2A];
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
enc.add_record(&record).unwrap();
let chunk = enc.encode().unwrap();
let truncated_data = chunk.data[..chunk.data.len() / 2].to_vec();
let truncated_header =
ChunkHeader::from_parts(&truncated_data, ChunkType::Transposed, 1, 2);
let truncated_chunk = Chunk {
header: truncated_header,
data: truncated_data,
};
let result = std::panic::catch_unwind(|| {
let _ = TransposeChunkDecoder::new(truncated_chunk);
});
assert!(result.is_ok(), "truncated chunk must not panic");
}
#[test]
fn test_empty_chunk_data_returns_err() {
use crate::chunk_header::ChunkHeader;
use crate::chunk_header::ChunkType;
use crate::simple_chunk::Chunk;
let chunk_data: Vec<u8> = Vec::new();
let chunk_header = ChunkHeader::from_parts(&chunk_data, ChunkType::Transposed, 1, 5);
let chunk = Chunk {
header: chunk_header,
data: chunk_data,
};
let result = TransposeChunkDecoder::new(chunk);
assert!(result.is_err(), "empty chunk data should be Err");
}
#[test]
fn test_interleaved_simple_and_transpose_files() {
use crate::compression::CompressionType;
use crate::record_reader::{ReaderOptions, RecordReader};
use crate::record_writer::{RecordWriter, WriterOptions};
let mut buf_simple: Vec<u8> = Vec::new();
{
let opts = WriterOptions::new().compression(CompressionType::None);
let cursor = std::io::Cursor::new(&mut buf_simple);
let mut writer = RecordWriter::new(cursor, opts).unwrap();
writer.write_record(b"simple_record_1").unwrap();
writer.write_record(b"simple_record_2").unwrap();
writer.close().unwrap();
}
let mut buf_transpose: Vec<u8> = Vec::new();
{
let opts = WriterOptions::new()
.compression(CompressionType::None)
.transpose(true);
let cursor = std::io::Cursor::new(&mut buf_transpose);
let mut writer = RecordWriter::new(cursor, opts).unwrap();
writer.write_record(b"transpose_record_1").unwrap();
writer.write_record(b"transpose_record_2").unwrap();
writer.close().unwrap();
}
let simple_records = {
let cursor = std::io::Cursor::new(&buf_simple);
let mut reader = RecordReader::new(cursor, ReaderOptions::new()).unwrap();
let mut recs = Vec::new();
while let Some(rec) = reader.read_record().unwrap() {
recs.push(rec);
}
recs
};
assert_eq!(simple_records.len(), 2);
assert_eq!(simple_records[0], b"simple_record_1");
assert_eq!(simple_records[1], b"simple_record_2");
let transpose_records = {
let cursor = std::io::Cursor::new(&buf_transpose);
let mut reader = RecordReader::new(cursor, ReaderOptions::new()).unwrap();
let mut recs = Vec::new();
while let Some(rec) = reader.read_record().unwrap() {
recs.push(rec);
}
recs
};
assert_eq!(transpose_records.len(), 2);
assert_eq!(transpose_records[0], b"transpose_record_1");
assert_eq!(transpose_records[1], b"transpose_record_2");
}
#[test]
fn test_large_proto_string_field_roundtrip() {
let mut record = vec![0x12]; record.extend_from_slice(&encode_u32(5000));
record.extend(std::iter::repeat(0x41).take(5000));
let result = roundtrip_via_encoder(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record, "large string field must round-trip");
}
#[test]
fn test_repeated_read_after_none() {
use crate::transpose::encoder::TransposeChunkEncoder;
let record = vec![0x08, 0x01];
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
enc.add_record(&record).unwrap();
let chunk = enc.encode().unwrap();
let mut dec = TransposeChunkDecoder::new(chunk).unwrap();
assert!(dec.read_record().unwrap().is_some());
assert!(dec.read_record().unwrap().is_none());
assert!(dec.read_record().unwrap().is_none());
assert!(dec.read_record().unwrap().is_none());
}
}