use std::collections::{BTreeMap, BinaryHeap};
use crate::chunk_header::{ChunkHeader, ChunkType};
use crate::compression::{
CompressOptions, CompressionType, compress_length_prefixed, compress_with_prefix,
};
use crate::error::RiegeliError;
use crate::proto_wire::{WireType, is_proto_message, tag_field_number, tag_wire_type};
use crate::simple_chunk::Chunk;
use crate::transpose::internal::{
MAX_RECURSION_DEPTH, MAX_VARINT_INLINE, SUBMESSAGE_WIRE_TYPE, has_data_buffer, has_subtype,
message_id, subtype,
};
use crate::varint::{decode_u32, decode_u64, encode_u32, encode_u64};
const MAX_TRANSITION: u32 = 63;
const MIN_COUNT_FOR_STATE: usize = 10;
const INVALID_POS: u32 = u32::MAX;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
struct NodeId {
parent_message_id: u32,
tag: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u32)]
enum BufferType {
Varint = 0,
Fixed32 = 1,
Fixed64 = 2,
String = 3,
NonProto = 4,
}
const ALL_BUFFER_TYPES: [BufferType; 5] = [
BufferType::Varint,
BufferType::Fixed32,
BufferType::Fixed64,
BufferType::String,
BufferType::NonProto,
];
#[derive(Debug, Clone)]
struct DestInfo {
pos: u32,
num_transitions: usize,
}
impl DestInfo {
fn new() -> Self {
Self {
pos: INVALID_POS,
num_transitions: 0,
}
}
}
#[derive(Debug)]
struct EncodedTagInfo {
node_id: NodeId,
subtype: u8,
dest_info: BTreeMap<u32, DestInfo>,
num_incoming_transitions: usize,
state_machine_pos: u32,
public_list_noop_pos: u32,
base: u32,
}
impl EncodedTagInfo {
fn new(node_id: NodeId, subtype: u8) -> Self {
Self {
node_id,
subtype,
dest_info: BTreeMap::new(),
num_incoming_transitions: 0,
state_machine_pos: INVALID_POS,
public_list_noop_pos: INVALID_POS,
base: INVALID_POS,
}
}
}
#[derive(Debug, Clone)]
struct StateInfo {
etag_index: u32,
base: u32,
canonical_source: u32,
}
impl StateInfo {
fn new(etag_index: u32, base: u32) -> Self {
Self {
etag_index,
base,
canonical_source: INVALID_POS,
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
struct PriorityQueueEntry {
dest_index: u32,
num_transitions: usize,
}
impl Ord for PriorityQueueEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.num_transitions
.cmp(&self.num_transitions)
.then_with(|| self.dest_index.cmp(&other.dest_index))
.reverse()
}
}
impl PartialOrd for PriorityQueueEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Default)]
struct BackwardBuffer {
data: Vec<u8>,
sizes: Vec<u32>,
}
impl BackwardBuffer {
fn push_chunk(&mut self, chunk: &[u8]) {
self.data.extend_from_slice(chunk);
self.sizes.push(chunk.len() as u32);
}
fn len(&self) -> usize {
self.data.len()
}
fn is_empty(&self) -> bool {
self.data.is_empty()
}
fn write_to(&self, out: &mut Vec<u8>) {
out.reserve(self.data.len());
let mut end = self.data.len();
for &size in self.sizes.iter().rev() {
let start = end - size as usize;
out.extend_from_slice(&self.data[start..end]);
end = start;
}
}
}
struct BufferWithMetadata {
node_id: NodeId,
data: BackwardBuffer,
}
struct MessageNode {
message_id: u32,
encoded_tag_pos: Vec<u32>,
}
struct MessageFrame {
end_sub_tag_idx: u32,
parent_message_id: u32,
parent_end_pos: usize,
}
struct DataBuffers {
inner: [Vec<BufferWithMetadata>; 5],
}
impl DataBuffers {
fn new() -> Self {
Self {
inner: [Vec::new(), Vec::new(), Vec::new(), Vec::new(), Vec::new()],
}
}
fn get_mut(&mut self, bt: BufferType) -> &mut Vec<BufferWithMetadata> {
&mut self.inner[bt as usize]
}
}
pub struct TransposeChunkEncoder {
compression: CompressionType,
compress_opts: CompressOptions,
message_nodes: BTreeMap<NodeId, MessageNode>,
next_message_id: u32,
tags_list: Vec<EncodedTagInfo>,
encoded_tags: Vec<u32>,
data: DataBuffers,
nonproto_lengths: BackwardBuffer,
num_records: u64,
decoded_data_size: u64,
bucket_size: u64,
}
impl TransposeChunkEncoder {
pub fn new(compression: CompressionType) -> Self {
Self {
compression,
compress_opts: CompressOptions::default(),
message_nodes: BTreeMap::new(),
next_message_id: message_id::ROOT + 1,
tags_list: Vec::new(),
encoded_tags: Vec::new(),
data: DataBuffers::new(),
nonproto_lengths: BackwardBuffer::default(),
num_records: 0,
decoded_data_size: 0,
bucket_size: u64::MAX,
}
}
pub fn compress_opts(mut self, opts: CompressOptions) -> Self {
self.compress_opts = opts;
self
}
pub fn bucket_size(mut self, size: u64) -> Self {
self.bucket_size = size;
self
}
pub fn add_record(&mut self, data: &[u8]) -> Result<(), RiegeliError> {
self.num_records += 1;
self.decoded_data_size += data.len() as u64;
if is_proto_message(data) {
let start_node_id = NodeId {
parent_message_id: message_id::START_OF_MESSAGE,
tag: 0,
};
let start_idx = self.get_pos_in_tags_list(start_node_id, subtype::TRIVIAL);
self.encoded_tags.push(start_idx);
self.add_message(data, message_id::ROOT, 0)?;
} else {
let np_node_id = NodeId {
parent_message_id: message_id::NON_PROTO,
tag: 0,
};
let np_idx = self.get_pos_in_tags_list(np_node_id, subtype::TRIVIAL);
self.encoded_tags.push(np_idx);
let buffer = self.get_buffer(np_node_id, BufferType::NonProto);
buffer.push_chunk(data);
let len_varint = encode_u64(data.len() as u64);
self.nonproto_lengths.push_chunk(&len_varint);
}
Ok(())
}
pub fn encode(mut self) -> Result<Chunk, RiegeliError> {
if self.num_records == 0 {
return self.encode_empty();
}
let (all_buffers, buffer_sizes, buffer_pos) = self.collect_and_sort_buffers();
let state_machine = self.create_state_machine();
self.ensure_last_state_explicit();
let compressed_buckets = self.create_compressed_buckets(&all_buffers)?;
let header = self.build_header(
&compressed_buckets,
&buffer_sizes,
&buffer_pos,
&state_machine,
)?;
self.assemble_chunk(header, &compressed_buckets, &state_machine)
}
fn collect_and_sort_buffers(&mut self) -> (Vec<Vec<u8>>, Vec<u64>, BTreeMap<NodeId, u32>) {
let mut buffer_pos: BTreeMap<NodeId, u32> = BTreeMap::new();
let mut all_buffers: Vec<Vec<u8>> = Vec::new();
let mut buffer_sizes: Vec<u64> = Vec::new();
for &buf_type in &ALL_BUFFER_TYPES {
let buffers = self.data.get_mut(buf_type);
buffers.sort_by(|a, b| {
let size_cmp = a.data.len().cmp(&b.data.len());
if size_cmp != std::cmp::Ordering::Equal {
return size_cmp;
}
let parent_cmp = a
.node_id
.parent_message_id
.cmp(&b.node_id.parent_message_id);
if parent_cmp != std::cmp::Ordering::Equal {
return parent_cmp;
}
a.node_id.tag.cmp(&b.node_id.tag)
});
for buf in buffers.iter() {
buffer_pos.insert(buf.node_id, all_buffers.len() as u32);
buffer_sizes.push(buf.data.len() as u64);
let mut reversed = Vec::new();
buf.data.write_to(&mut reversed);
all_buffers.push(reversed);
}
}
if !self.nonproto_lengths.is_empty() {
buffer_sizes.push(self.nonproto_lengths.len() as u64);
let mut reversed = Vec::new();
self.nonproto_lengths.write_to(&mut reversed);
all_buffers.push(reversed);
}
(all_buffers, buffer_sizes, buffer_pos)
}
fn create_compressed_buckets(
&self,
all_buffers: &[Vec<u8>],
) -> Result<Vec<Vec<u8>>, RiegeliError> {
if all_buffers.is_empty() {
return Ok(Vec::new());
}
if self.bucket_size == u64::MAX {
let mut bucket_data: Vec<u8> = Vec::new();
for buf in all_buffers {
bucket_data.extend_from_slice(buf);
}
let compressed =
compress_with_prefix(&bucket_data, self.compression, self.compress_opts)?;
return Ok(vec![compressed]);
}
let mut compressed_buckets: Vec<Vec<u8>> = Vec::new();
let mut current_bucket: Vec<u8> = Vec::new();
let mut current_size: u64 = 0;
for buf in all_buffers {
let buf_size = buf.len() as u64;
if !current_bucket.is_empty() && current_size + buf_size / 2 >= self.bucket_size {
compressed_buckets.push(compress_with_prefix(
¤t_bucket,
self.compression,
self.compress_opts,
)?);
current_bucket = Vec::new();
current_size = 0;
}
current_bucket.extend_from_slice(buf);
current_size += buf_size;
}
if !current_bucket.is_empty() {
compressed_buckets.push(compress_with_prefix(
¤t_bucket,
self.compression,
self.compress_opts,
)?);
}
Ok(compressed_buckets)
}
fn build_header(
&self,
compressed_buckets: &[Vec<u8>],
buffer_sizes: &[u64],
buffer_pos: &BTreeMap<NodeId, u32>,
state_machine: &[StateInfo],
) -> Result<Vec<u8>, RiegeliError> {
let num_buffers = buffer_sizes.len();
let num_buckets = compressed_buckets.len() as u32;
let mut header: Vec<u8> = Vec::new();
header.extend_from_slice(&encode_u32(num_buckets));
header.extend_from_slice(&encode_u32(num_buffers as u32));
for bucket in compressed_buckets {
header.extend_from_slice(&encode_u64(bucket.len() as u64));
}
for &size in buffer_sizes {
header.extend_from_slice(&encode_u64(size));
}
self.write_states_and_data(&mut header, buffer_pos, state_machine);
Ok(header)
}
fn write_states_and_data(
&self,
header: &mut Vec<u8>,
buffer_pos: &BTreeMap<NodeId, u32>,
state_machine: &[StateInfo],
) {
let num_sm = state_machine.len() as u32;
let mut subtype_to_write: Vec<u8> = Vec::new();
let mut buffer_index_to_write: Vec<u32> = Vec::new();
let mut base_to_write: Vec<u32> = Vec::new();
header.extend_from_slice(&encode_u32(num_sm));
for state_info in state_machine {
if state_info.etag_index == INVALID_POS {
header.extend_from_slice(&encode_u32(message_id::NO_OP));
base_to_write.push(state_info.base);
continue;
}
let etag = &self.tags_list[state_info.etag_index as usize];
let node_id = &etag.node_id;
if node_id.tag != 0 {
let wt = tag_wire_type(node_id.tag);
let is_string = wt == Some(WireType::LengthDelimited);
if is_string && etag.subtype == subtype::LENGTH_DELIMITED_START_OF_SUBMESSAGE {
header.extend_from_slice(&encode_u32(message_id::START_OF_SUBMESSAGE));
} else if is_string && etag.subtype == subtype::LENGTH_DELIMITED_END_OF_SUBMESSAGE {
let submsg_tag =
node_id.tag + SUBMESSAGE_WIRE_TYPE - WireType::LengthDelimited as u32;
header.extend_from_slice(&encode_u32(submsg_tag));
} else {
header.extend_from_slice(&encode_u32(node_id.tag));
if has_subtype(node_id.tag) {
subtype_to_write.push(etag.subtype);
}
if has_data_buffer(node_id.tag, etag.subtype) {
let lookup_node = NodeId {
parent_message_id: node_id.parent_message_id,
tag: node_id.tag,
};
let idx = buffer_pos.get(&lookup_node).copied().unwrap_or(0);
buffer_index_to_write.push(idx);
}
}
} else {
header.extend_from_slice(&encode_u32(node_id.parent_message_id));
if node_id.parent_message_id == message_id::NON_PROTO {
let np_node_id = NodeId {
parent_message_id: message_id::NON_PROTO,
tag: 0,
};
let idx = buffer_pos.get(&np_node_id).copied().unwrap_or(0);
buffer_index_to_write.push(idx);
}
}
let etag_ref = &self.tags_list[state_info.etag_index as usize];
if etag_ref.base != INVALID_POS {
let implicit_offset = if etag_ref.dest_info.len() == 1 {
num_sm
} else {
0
};
base_to_write.push(etag_ref.base + implicit_offset);
} else {
base_to_write.push(0);
}
}
for &value in &base_to_write {
header.extend_from_slice(&encode_u32(value));
}
header.extend_from_slice(&subtype_to_write);
for &value in &buffer_index_to_write {
header.extend_from_slice(&encode_u32(value));
}
let first_tag_pos = if self.encoded_tags.is_empty() {
0u32
} else {
let last_etag = *self.encoded_tags.last().unwrap();
state_machine
.iter()
.position(|s| s.etag_index == last_etag)
.unwrap_or(0) as u32
};
header.extend_from_slice(&encode_u32(first_tag_pos));
}
fn assemble_chunk(
&self,
raw_header: Vec<u8>,
compressed_buckets: &[Vec<u8>],
state_machine: &[StateInfo],
) -> Result<Chunk, RiegeliError> {
let transitions = self.build_transitions(state_machine);
let compressed_transitions =
compress_with_prefix(&transitions, self.compression, self.compress_opts)?;
let length_prefixed_header =
compress_length_prefixed(&raw_header, self.compression, self.compress_opts)?;
let mut chunk_data: Vec<u8> = Vec::new();
chunk_data.push(self.compression as u8);
chunk_data.extend_from_slice(&length_prefixed_header);
for bucket in compressed_buckets {
chunk_data.extend_from_slice(bucket);
}
chunk_data.extend_from_slice(&compressed_transitions);
let chunk_header = ChunkHeader::from_parts(
&chunk_data,
ChunkType::Transposed,
self.num_records,
self.decoded_data_size,
);
Ok(Chunk {
header: chunk_header,
data: chunk_data,
})
}
fn collect_transition_statistics(&mut self) {
if self.encoded_tags.is_empty() {
return;
}
let last = *self.encoded_tags.last().unwrap();
let mut prev_pos = last;
for i in (0..self.encoded_tags.len() - 1).rev() {
let pos = self.encoded_tags[i];
self.tags_list[prev_pos as usize]
.dest_info
.entry(pos)
.or_insert_with(DestInfo::new)
.num_transitions += 1;
self.tags_list[pos as usize].num_incoming_transitions += 1;
prev_pos = pos;
}
if self.tags_list[last as usize].num_incoming_transitions == 0 {
self.tags_list[last as usize].num_incoming_transitions = 1;
}
}
fn ensure_last_state_explicit(&mut self) {
if self.encoded_tags.is_empty() {
return;
}
let last_etag_idx = self.encoded_tags[0] as usize;
if self.tags_list[last_etag_idx].dest_info.len() == 1 {
let first_key = *self.tags_list[last_etag_idx]
.dest_info
.keys()
.next()
.unwrap();
self.tags_list[last_etag_idx]
.dest_info
.entry(first_key + 1)
.or_insert_with(DestInfo::new);
}
}
fn create_state_machine(&mut self) -> Vec<StateInfo> {
if self.encoded_tags.is_empty() {
return vec![StateInfo::new(INVALID_POS, 0)];
}
self.collect_transition_statistics();
self.mark_frequent_transitions();
let mut state_machine: Vec<StateInfo> = Vec::new();
let mut public_list_noops: Vec<(u32, u32)> = Vec::new();
self.build_private_lists(&mut state_machine, &mut public_list_noops);
let public_list_base = state_machine.len() as u32;
self.build_public_list(&mut state_machine);
self.compute_base_indices(public_list_base, &public_list_noops, &mut state_machine);
state_machine
}
fn mark_frequent_transitions(&mut self) {
let k_in_list_pos: u32 = 0;
let num_tags = self.tags_list.len();
for tag_id in 0..num_tags {
let dest_keys: Vec<u32> = self.tags_list[tag_id].dest_info.keys().copied().collect();
for dest_key in dest_keys {
let num_trans = self.tags_list[tag_id].dest_info[&dest_key].num_transitions;
if num_trans >= MIN_COUNT_FOR_STATE {
self.tags_list[dest_key as usize].num_incoming_transitions -= num_trans;
self.tags_list[tag_id]
.dest_info
.get_mut(&dest_key)
.unwrap()
.pos = k_in_list_pos;
}
}
}
}
fn build_private_lists(
&mut self,
state_machine: &mut Vec<StateInfo>,
public_list_noops: &mut Vec<(u32, u32)>,
) {
let k_in_list_pos: u32 = 0;
let num_tags = self.tags_list.len();
for tag_id in 0..num_tags {
let mut tag_priority: BinaryHeap<PriorityQueueEntry> = BinaryHeap::new();
let mut excluded_state: Option<PriorityQueueEntry> = None;
let mut num_excluded_transitions: usize = 0;
let dest_keys: Vec<u32> = self.tags_list[tag_id].dest_info.keys().copied().collect();
let sz = dest_keys.len() as u32;
for &dest_key in &dest_keys {
let di = &self.tags_list[tag_id].dest_info[&dest_key];
let di_pos = di.pos;
let di_num = di.num_transitions;
if di_pos == k_in_list_pos
|| di_num == self.tags_list[dest_key as usize].num_incoming_transitions
{
if di_pos != k_in_list_pos {
self.tags_list[dest_key as usize].num_incoming_transitions -= di_num;
}
tag_priority.push(PriorityQueueEntry {
dest_index: dest_key,
num_transitions: di_num,
});
} else {
num_excluded_transitions += di_num;
excluded_state = Some(PriorityQueueEntry {
dest_index: dest_key,
num_transitions: di_num,
});
}
}
let mut num_states = tag_priority.len() as u32;
if num_states == 0 {
continue;
}
if num_states + 1 == sz {
num_states += 1;
if let Some(es) = excluded_state.take() {
self.tags_list[es.dest_index as usize].num_incoming_transitions -=
es.num_transitions;
tag_priority.push(es);
}
}
if num_states != sz {
tag_priority.push(PriorityQueueEntry {
dest_index: INVALID_POS,
num_transitions: num_excluded_transitions,
});
num_states += 1;
}
self.tags_list[tag_id].base = state_machine.len() as u32;
let noop_nodes = if num_states <= MAX_TRANSITION + 1 {
0u32
} else {
(num_states - 2) / MAX_TRANSITION
};
num_states += noop_nodes;
let mut prev_state = state_machine.len() as u32 + num_states;
state_machine.resize(
prev_state as usize,
StateInfo::new(INVALID_POS, INVALID_POS),
);
let mut block_size = (num_states - 1) % (MAX_TRANSITION + 1) + 1;
let mut noop_base: Vec<u32> = Vec::new();
loop {
let mut total_block_weight: usize = 0;
for _ in 0..block_size {
let entry = tag_priority.pop().unwrap();
total_block_weight += entry.num_transitions;
let node_index = entry.dest_index;
if node_index == INVALID_POS {
prev_state -= 1;
state_machine[prev_state as usize] =
StateInfo::new(INVALID_POS, INVALID_POS);
self.tags_list[tag_id].public_list_noop_pos = prev_state;
public_list_noops.push((tag_id as u32, prev_state));
} else if node_index >= num_tags as u32 {
let base = noop_base[(node_index - num_tags as u32) as usize];
prev_state -= 1;
state_machine[prev_state as usize] = StateInfo::new(INVALID_POS, base);
for j in 0..=MAX_TRANSITION {
let idx = (j + base) as usize;
if idx >= state_machine.len() {
break;
}
state_machine[idx].canonical_source = prev_state;
}
} else {
prev_state -= 1;
state_machine[prev_state as usize] =
StateInfo::new(node_index, INVALID_POS);
self.tags_list[tag_id]
.dest_info
.get_mut(&node_index)
.unwrap()
.pos = prev_state;
}
}
if tag_priority.is_empty() {
break;
}
tag_priority.push(PriorityQueueEntry {
dest_index: num_tags as u32 + noop_base.len() as u32,
num_transitions: total_block_weight,
});
noop_base.push(prev_state);
block_size = MAX_TRANSITION + 1;
}
}
}
fn build_public_list(&mut self, state_machine: &mut Vec<StateInfo>) {
let num_tags = self.tags_list.len();
let mut tag_priority: BinaryHeap<PriorityQueueEntry> = BinaryHeap::new();
for i in 0..num_tags {
if self.tags_list[i].num_incoming_transitions != 0 {
tag_priority.push(PriorityQueueEntry {
dest_index: i as u32,
num_transitions: self.tags_list[i].num_incoming_transitions,
});
}
}
let mut num_states = tag_priority.len() as u32;
if num_states == 0 {
return;
}
let noop_nodes = if num_states <= MAX_TRANSITION + 1 {
0u32
} else {
(num_states - 2) / MAX_TRANSITION
};
num_states += noop_nodes;
let mut prev_node = state_machine.len() as u32 + num_states;
state_machine.resize(prev_node as usize, StateInfo::new(INVALID_POS, INVALID_POS));
let mut block_size = (num_states - 1) % (MAX_TRANSITION + 1) + 1;
let mut noop_base: Vec<u32> = Vec::new();
loop {
let mut total_block_weight: usize = 0;
for _ in 0..block_size {
let entry = tag_priority.pop().unwrap();
total_block_weight += entry.num_transitions;
let node_index = entry.dest_index;
if node_index >= num_tags as u32 {
let base = noop_base[(node_index - num_tags as u32) as usize];
prev_node -= 1;
state_machine[prev_node as usize] = StateInfo::new(INVALID_POS, base);
for j in 0..=MAX_TRANSITION {
let idx = (j + base) as usize;
if idx >= state_machine.len() {
break;
}
state_machine[idx].canonical_source = prev_node;
}
} else {
prev_node -= 1;
state_machine[prev_node as usize] = StateInfo::new(node_index, INVALID_POS);
self.tags_list[node_index as usize].state_machine_pos = prev_node;
}
}
if tag_priority.is_empty() {
break;
}
tag_priority.push(PriorityQueueEntry {
dest_index: num_tags as u32 + noop_base.len() as u32,
num_transitions: total_block_weight,
});
noop_base.push(prev_node);
block_size = MAX_TRANSITION + 1;
}
}
fn compute_base_indices(
&mut self,
public_list_base: u32,
public_list_noops: &[(u32, u32)],
state_machine: &mut [StateInfo],
) {
let sm_len = state_machine.len();
for &(tag_index, state_index) in public_list_noops {
let (base, min_pos) =
self.find_optimal_base(tag_index, public_list_base, state_machine, sm_len);
if min_pos != INVALID_POS {
state_machine[state_index as usize].base = min_pos;
}
let _ = base;
}
let num_tags = self.tags_list.len();
for tag_idx in 0..num_tags {
if self.tags_list[tag_idx].base != INVALID_POS {
continue;
}
let (_, min_pos) =
self.find_optimal_base(tag_idx as u32, public_list_base, state_machine, sm_len);
if min_pos != INVALID_POS {
self.tags_list[tag_idx].base = min_pos;
}
}
}
fn find_optimal_base(
&self,
tag_index: u32,
public_list_base: u32,
state_machine: &[StateInfo],
_sm_len: usize,
) -> (u32, u32) {
let mut base = INVALID_POS;
let mut min_pos = INVALID_POS;
let dest_keys: Vec<(u32, u32)> = self.tags_list[tag_index as usize]
.dest_info
.iter()
.map(|(&k, v)| (k, v.pos))
.collect();
for (dest_key, di_pos) in dest_keys {
if di_pos != INVALID_POS {
continue;
}
let mut pos = self.tags_list[dest_key as usize].state_machine_pos;
if pos == INVALID_POS {
continue;
}
while base > pos || (base != INVALID_POS && pos - base > MAX_TRANSITION) {
if base > pos {
let cs = if base == INVALID_POS {
state_machine[pos as usize].canonical_source
} else {
let cs_of_base = state_machine[base as usize].canonical_source;
if cs_of_base == INVALID_POS {
base = public_list_base;
continue;
}
min_pos = min_pos.min(cs_of_base);
state_machine[cs_of_base as usize].canonical_source
};
if cs == INVALID_POS {
base = public_list_base;
} else {
base = state_machine[cs as usize].base;
}
} else {
let cs = state_machine[pos as usize].canonical_source;
if cs == INVALID_POS {
break;
}
pos = cs;
}
}
min_pos = min_pos.min(pos);
}
(base, min_pos)
}
fn build_transitions(&self, state_machine: &[StateInfo]) -> Vec<u8> {
if self.encoded_tags.is_empty() {
return Vec::new();
}
let mut transitions: Vec<u8> = Vec::new();
let mut last_transition: Option<u8> = None;
let prev_etag_start = *self.encoded_tags.last().unwrap();
let mut prev_etag = prev_etag_start;
let mut current_base = self.tags_list[prev_etag as usize].base;
let n = self.encoded_tags.len();
for i in (0..n - 1).rev() {
let tag = self.encoded_tags[i];
if self.tags_list[prev_etag as usize].dest_info.len() == 1 {
prev_etag = tag;
current_base = self.tags_list[prev_etag as usize].base;
continue;
}
let private_pos = self.tags_list[prev_etag as usize]
.dest_info
.get(&tag)
.map(|di| di.pos)
.unwrap_or(INVALID_POS);
let mut pos = private_pos;
if pos == INVALID_POS {
let public_noop_pos = self.tags_list[prev_etag as usize].public_list_noop_pos;
if public_noop_pos != INVALID_POS {
let orig_pos = public_noop_pos;
let mut noop_pos = public_noop_pos;
let mut write_buf: Vec<u8> = Vec::new();
while current_base > noop_pos
|| (current_base != INVALID_POS && noop_pos - current_base > MAX_TRANSITION)
{
let cs = state_machine[noop_pos as usize].canonical_source;
write_buf.push((noop_pos - state_machine[cs as usize].base) as u8);
noop_pos = cs;
}
write_buf.push((noop_pos - current_base) as u8);
write_buf.reverse();
emit_transition_bytes(&write_buf, &mut transitions, &mut last_transition);
current_base = state_machine[orig_pos as usize].base;
}
pos = self.tags_list[tag as usize].state_machine_pos;
}
if current_base == INVALID_POS || pos == INVALID_POS {
prev_etag = tag;
current_base = self.tags_list[prev_etag as usize].base;
continue;
}
let mut write_buf: Vec<u8> = Vec::new();
let mut target = pos;
while current_base > target
|| (current_base != INVALID_POS && target - current_base > MAX_TRANSITION)
{
let cs = state_machine[target as usize].canonical_source;
if cs == INVALID_POS || cs as usize >= state_machine.len() {
break;
}
write_buf.push((target - state_machine[cs as usize].base) as u8);
target = cs;
}
write_buf.push((target - current_base) as u8);
write_buf.reverse();
emit_transition_bytes(&write_buf, &mut transitions, &mut last_transition);
prev_etag = tag;
current_base = self.tags_list[prev_etag as usize].base;
}
if let Some(lt) = last_transition {
transitions.push(lt);
}
transitions
}
fn encode_empty(&self) -> Result<Chunk, RiegeliError> {
let mut header: Vec<u8> = Vec::new();
header.extend_from_slice(&encode_u32(0)); header.extend_from_slice(&encode_u32(0)); header.extend_from_slice(&encode_u32(0)); header.extend_from_slice(&encode_u32(0));
let length_prefixed_header =
compress_length_prefixed(&header, self.compression, self.compress_opts)?;
let mut chunk_data: Vec<u8> = Vec::new();
chunk_data.push(self.compression as u8);
chunk_data.extend_from_slice(&length_prefixed_header);
let chunk_header = ChunkHeader::from_parts(&chunk_data, ChunkType::Transposed, 0, 0);
Ok(Chunk {
header: chunk_header,
data: chunk_data,
})
}
fn get_or_create_node(&mut self, node_id: NodeId) -> u32 {
if let Some(node) = self.message_nodes.get(&node_id) {
return node.message_id;
}
let mid = self.next_message_id;
self.next_message_id += 1;
self.message_nodes.insert(
node_id,
MessageNode {
message_id: mid,
encoded_tag_pos: Vec::new(),
},
);
mid
}
fn get_pos_in_tags_list(&mut self, node_id: NodeId, st: u8) -> u32 {
self.get_or_create_node(node_id);
let node = match self.message_nodes.get_mut(&node_id) {
Some(n) => n,
None => return u32::MAX, };
let pos_idx = st as usize;
if node.encoded_tag_pos.len() <= pos_idx {
node.encoded_tag_pos.resize(pos_idx + 1, u32::MAX);
}
if node.encoded_tag_pos[pos_idx] == u32::MAX {
node.encoded_tag_pos[pos_idx] = self.tags_list.len() as u32;
self.tags_list.push(EncodedTagInfo::new(node_id, st));
}
node.encoded_tag_pos[pos_idx]
}
fn get_buffer(&mut self, node_id: NodeId, buf_type: BufferType) -> &mut BackwardBuffer {
let buffers = self.data.get_mut(buf_type);
let existing = buffers.iter().position(|b| b.node_id == node_id);
if let Some(idx) = existing {
return &mut buffers[idx].data;
}
buffers.push(BufferWithMetadata {
node_id,
data: BackwardBuffer::default(),
});
let len = buffers.len();
&mut buffers[len - 1].data
}
fn add_message(
&mut self,
data: &[u8],
parent_message_id: u32,
depth: usize,
) -> Result<(), RiegeliError> {
let mut pos = 0usize;
let mut message_stack: Vec<MessageFrame> = Vec::new();
let mut current_parent = parent_message_id;
let mut current_end = data.len();
while pos < current_end {
let (tag, consumed) = decode_u32(&data[pos..])
.map_err(|e| RiegeliError::MalformedData(format!("tag decode: {e}")))?;
pos += consumed;
let field_num = tag_field_number(tag);
if field_num == 0 {
return Err(RiegeliError::MalformedData(
"field number 0 in proto".to_string(),
));
}
let node_id = NodeId {
parent_message_id: current_parent,
tag,
};
let node_mid = self.get_or_create_node(node_id);
match tag_wire_type(tag) {
Some(WireType::Varint) => {
self.encode_varint_field(data, &mut pos, node_id)?;
}
Some(WireType::Fixed32) => {
self.encode_fixed32_field(data, &mut pos, node_id)?;
}
Some(WireType::Fixed64) => {
self.encode_fixed64_field(data, &mut pos, node_id)?;
}
Some(WireType::LengthDelimited) => {
let entered_submessage = self.encode_length_delimited_field(
data,
&mut pos,
node_id,
node_mid,
current_end,
depth,
&message_stack,
&mut current_parent,
&mut current_end,
)?;
if let Some(frame) = entered_submessage {
message_stack.push(frame);
continue;
}
}
Some(WireType::StartGroup) => {
let idx = self.get_pos_in_tags_list(node_id, subtype::TRIVIAL);
self.encoded_tags.push(idx);
}
Some(WireType::EndGroup) => {
let idx = self.get_pos_in_tags_list(node_id, subtype::TRIVIAL);
self.encoded_tags.push(idx);
}
None => {
return Err(RiegeliError::MalformedData(format!(
"invalid wire type in tag {tag}"
)));
}
}
while pos >= current_end && !message_stack.is_empty() {
let frame = message_stack.pop().ok_or_else(|| {
RiegeliError::MalformedData("message stack empty".to_string())
})?;
self.encoded_tags.push(frame.end_sub_tag_idx);
current_parent = frame.parent_message_id;
current_end = frame.parent_end_pos;
}
}
Ok(())
}
fn encode_varint_field(
&mut self,
data: &[u8],
pos: &mut usize,
node_id: NodeId,
) -> Result<(), RiegeliError> {
let varint_start = *pos;
let (_, vlen) = decode_u64(&data[*pos..])
.map_err(|e| RiegeliError::MalformedData(format!("varint value: {e}")))?;
let varint_bytes = &data[varint_start..varint_start + vlen];
*pos += vlen;
if vlen == 1 && varint_bytes[0] <= MAX_VARINT_INLINE {
let st = subtype::VARINT_INLINE_0 + varint_bytes[0];
let idx = self.get_pos_in_tags_list(node_id, st);
self.encoded_tags.push(idx);
} else {
let st = subtype::VARINT_1 + (vlen as u8 - 1);
let idx = self.get_pos_in_tags_list(node_id, st);
self.encoded_tags.push(idx);
let mut stripped = Vec::with_capacity(vlen);
for &b in varint_bytes {
stripped.push(b & 0x7F);
}
let buffer = self.get_buffer(node_id, BufferType::Varint);
buffer.push_chunk(&stripped);
}
Ok(())
}
fn encode_fixed32_field(
&mut self,
data: &[u8],
pos: &mut usize,
node_id: NodeId,
) -> Result<(), RiegeliError> {
if *pos + 4 > data.len() {
return Err(RiegeliError::MalformedData("truncated fixed32".to_string()));
}
let idx = self.get_pos_in_tags_list(node_id, subtype::TRIVIAL);
self.encoded_tags.push(idx);
let bytes = &data[*pos..*pos + 4];
*pos += 4;
let buffer = self.get_buffer(node_id, BufferType::Fixed32);
buffer.push_chunk(bytes);
Ok(())
}
fn encode_fixed64_field(
&mut self,
data: &[u8],
pos: &mut usize,
node_id: NodeId,
) -> Result<(), RiegeliError> {
if *pos + 8 > data.len() {
return Err(RiegeliError::MalformedData("truncated fixed64".to_string()));
}
let idx = self.get_pos_in_tags_list(node_id, subtype::TRIVIAL);
self.encoded_tags.push(idx);
let bytes = &data[*pos..*pos + 8];
*pos += 8;
let buffer = self.get_buffer(node_id, BufferType::Fixed64);
buffer.push_chunk(bytes);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn encode_length_delimited_field(
&mut self,
data: &[u8],
pos: &mut usize,
node_id: NodeId,
node_mid: u32,
current_end: usize,
depth: usize,
message_stack: &[MessageFrame],
current_parent: &mut u32,
current_end_mut: &mut usize,
) -> Result<Option<MessageFrame>, RiegeliError> {
let length_pos = *pos;
let (length, llen) = decode_u32(&data[*pos..])
.map_err(|e| RiegeliError::MalformedData(format!("length: {e}")))?;
*pos += llen;
let value_pos = *pos;
let value_end = *pos + length as usize;
if value_end > current_end {
return Err(RiegeliError::MalformedData(
"length-delimited field overflow".to_string(),
));
}
let total_depth = depth + message_stack.len();
if length > 0
&& total_depth < MAX_RECURSION_DEPTH
&& is_proto_message(&data[value_pos..value_end])
{
let start_sub_idx =
self.get_pos_in_tags_list(node_id, subtype::LENGTH_DELIMITED_START_OF_SUBMESSAGE);
self.encoded_tags.push(start_sub_idx);
let end_sub_idx =
self.get_pos_in_tags_list(node_id, subtype::LENGTH_DELIMITED_END_OF_SUBMESSAGE);
let frame = MessageFrame {
end_sub_tag_idx: end_sub_idx,
parent_message_id: *current_parent,
parent_end_pos: *current_end_mut,
};
*current_parent = node_mid;
*current_end_mut = value_end;
*pos = value_pos;
return Ok(Some(frame));
}
let idx = self.get_pos_in_tags_list(node_id, subtype::LENGTH_DELIMITED_STRING);
self.encoded_tags.push(idx);
let string_bytes = &data[length_pos..value_end];
let buffer = self.get_buffer(node_id, BufferType::String);
buffer.push_chunk(string_bytes);
*pos = value_end;
Ok(None)
}
}
fn emit_transition_bytes(
write_buf: &[u8],
transitions: &mut Vec<u8>,
last_transition: &mut Option<u8>,
) {
for &b in write_buf {
if let Some(ref mut lt) = *last_transition {
if b == 0 && (*lt & 3) < 3 {
*lt += 1;
continue;
}
transitions.push(*lt);
}
*last_transition = Some(b << 2);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transpose::decoder::TransposeChunkDecoder;
fn roundtrip(records: &[&[u8]], compression: CompressionType) -> Vec<Vec<u8>> {
let mut enc = TransposeChunkEncoder::new(compression);
for rec in records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
assert_eq!(chunk.header.chunk_type().unwrap(), ChunkType::Transposed);
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut out = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
out.push(rec);
}
out
}
#[test]
fn test_roundtrip_single_proto() {
let record = vec![0x08, 0x2A];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_roundtrip_multiple_proto_records() {
let r0 = vec![0x08, 0x05]; let r1 = vec![0x08, 0x0A]; let r2 = vec![0x08, 0x2A]; let result = roundtrip(&[&r0, &r1, &r2], CompressionType::None);
assert_eq!(result.len(), 3);
assert_eq!(result[0], r0);
assert_eq!(result[1], r1);
assert_eq!(result[2], r2);
}
#[test]
fn test_zero_records() {
let enc = TransposeChunkEncoder::new(CompressionType::None);
let chunk = enc.encode().expect("encode");
assert_eq!(chunk.header.num_records(), 0);
assert_eq!(chunk.header.chunk_type().unwrap(), ChunkType::Transposed);
assert_eq!(chunk.header.data_size(), chunk.data.len() as u64);
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
assert!(dec.read_record().unwrap().is_none());
}
#[test]
fn test_nonproto_roundtrip() {
let record = vec![0x0F, 0x01, 0x02, 0x03];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_nonproto_hello() {
let record = b"hello world!";
let result = roundtrip(&[record.as_slice()], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_nested_submessage() {
let record = vec![0x0A, 0x02, 0x10, 0x2A];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_complex_proto() {
let record: Vec<u8> = vec![
0x08, 0x2A, 0x15, 0x01, 0x02, 0x03, 0x04, 0x19, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
0x07, 0x08, 0x22, 0x03, 0x61, 0x62, 0x63, 0x2A, 0x02, 0x08, 0x07,
];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_1000_proto_records() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..1000 {
let mut rec = Vec::new();
rec.push(0x08); rec.extend_from_slice(&encode_u64(i as u64));
rec.push(0x15); rec.extend_from_slice(&i.to_le_bytes());
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 1000);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_mixed_proto_nonproto() {
let proto_rec = vec![0x08, 0x2A]; let nonproto_rec = vec![0x0F, 0xAA, 0xBB]; let proto_rec2 = vec![0x08, 0x01]; let result = roundtrip(
&[&proto_rec, &nonproto_rec, &proto_rec2],
CompressionType::None,
);
assert_eq!(result.len(), 3);
assert_eq!(result[0], proto_rec);
assert_eq!(result[1], nonproto_rec);
assert_eq!(result[2], proto_rec2);
}
#[test]
fn test_no_implicit_loops() {
let record = vec![0x08, 0x01];
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for _ in 0..10 {
enc.add_record(&record).unwrap();
}
let chunk = enc.encode().unwrap();
let dec = TransposeChunkDecoder::new(chunk);
assert!(dec.is_ok(), "decoder should not report implicit loops");
}
#[test]
fn test_inline_varint_values() {
for val in 0u8..=3 {
let record = vec![0x08, val];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1, "failed for value {val}");
assert_eq!(result[0], record, "failed for value {val}");
}
}
#[test]
fn test_multibyte_varint() {
let record = vec![0x08, 0xAC, 0x02];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_empty_proto_record() {
let record: Vec<u8> = vec![];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_string_field_roundtrip() {
let record = vec![0x12, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_multiple_nonproto() {
let r0 = vec![0xFF, 0x01];
let r1 = vec![0xFF, 0x02, 0x03];
let r2 = vec![0xFF];
let result = roundtrip(&[&r0, &r1, &r2], CompressionType::None);
assert_eq!(result.len(), 3);
assert_eq!(result[0], r0);
assert_eq!(result[1], r1);
assert_eq!(result[2], r2);
}
#[test]
#[cfg(feature = "brotli")]
fn test_roundtrip_brotli() {
let record = vec![0x08, 0x2A];
let result = roundtrip(&[&record], CompressionType::Brotli);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_optimized_fewer_transitions() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..1000 {
let mut rec = Vec::new();
rec.push(0x08); rec.extend_from_slice(&encode_u64(i as u64));
rec.push(0x15); rec.extend_from_slice(&i.to_le_bytes());
rec.push(0x18); rec.extend_from_slice(&encode_u64((i * 2) as u64));
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &refs {
enc.add_record(rec).unwrap();
}
let chunk = enc.encode().unwrap();
let mut dec = TransposeChunkDecoder::new(chunk.clone()).unwrap();
let mut out = Vec::new();
while let Some(rec) = dec.read_record().unwrap() {
out.push(rec);
}
assert_eq!(out.len(), 1000);
for (i, (got, expected)) in out.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
assert!(chunk.data.len() > 0);
}
#[test]
#[cfg(feature = "brotli")]
fn test_transpose_smaller_than_simple_brotli() {
use crate::simple_chunk::SimpleChunkEncoder;
let mut records: Vec<Vec<u8>> = Vec::new();
let payload = b"AAAAAAAAAA"; for i in 0u32..1000 {
let mut rec = Vec::new();
rec.push(0x08);
rec.extend_from_slice(&encode_u64(i as u64));
rec.push(0x12);
rec.push(payload.len() as u8);
rec.extend_from_slice(payload);
rec.push(0x1A);
rec.push(payload.len() as u8);
rec.extend_from_slice(payload);
rec.push(0x22);
rec.push(payload.len() as u8);
rec.extend_from_slice(payload);
records.push(rec);
}
let mut t_enc = TransposeChunkEncoder::new(CompressionType::Brotli);
for rec in &records {
t_enc.add_record(rec).unwrap();
}
let t_chunk = t_enc.encode().unwrap();
let mut s_enc = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
for rec in &records {
s_enc.add_record(rec);
}
let s_chunk = s_enc.encode().unwrap();
assert!(
t_chunk.data.len() < s_chunk.data.len(),
"transpose+brotli ({}) should be smaller than simple+brotli ({})",
t_chunk.data.len(),
s_chunk.data.len()
);
}
#[test]
fn test_implicit_transitions_used() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..100 {
let mut rec = Vec::new();
rec.push(0x08);
rec.extend_from_slice(&encode_u64(i as u64));
rec.push(0x15);
rec.extend_from_slice(&i.to_le_bytes());
records.push(rec);
}
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &records {
enc.add_record(rec).unwrap();
}
enc.collect_transition_statistics();
enc.ensure_last_state_explicit();
let has_single_dest = enc.tags_list.iter().any(|t| t.dest_info.len() == 1);
assert!(
has_single_dest,
"Expected at least one tag with a single destination (implicit transition)"
);
}
#[test]
fn test_noop_bridging_many_tags() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..80 {
let mut rec = Vec::new();
rec.push(0x08); rec.extend_from_slice(&encode_u64(i as u64));
let tag = ((i + 2) << 3) | 0;
rec.extend_from_slice(&encode_u32(tag));
rec.push(0x01);
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 80);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &records {
enc.add_record(rec).unwrap();
}
let sm = enc.create_state_machine();
let noop_count = sm.iter().filter(|s| s.etag_index == INVALID_POS).count();
assert!(
noop_count > 0,
"Expected NoOp bridging states for >64 unique tags, got {} total states",
sm.len()
);
}
#[test]
fn test_wide_schema_roundtrip() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..100 {
let mut rec = Vec::new();
for field_num in 1u32..=100 {
let tag = (field_num << 3) | 0; rec.extend_from_slice(&encode_u32(tag));
rec.extend_from_slice(&encode_u64((i + field_num) as u64));
}
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 100);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_no_implicit_loops_optimized() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..100 {
let mut rec = Vec::new();
rec.push(0x08);
rec.extend_from_slice(&encode_u64(i as u64));
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &refs {
enc.add_record(rec).unwrap();
}
let chunk = enc.encode().unwrap();
let dec = TransposeChunkDecoder::new(chunk);
assert!(dec.is_ok(), "decoder should not report implicit loops");
}
fn push_varint(buf: &mut Vec<u8>, mut v: u64) {
loop {
if v < 0x80 {
buf.push(v as u8);
return;
}
buf.push((v as u8 & 0x7F) | 0x80);
v >>= 7;
}
}
#[test]
fn test_1000_mixed_field_records() {
let records: Vec<Vec<u8>> = (0u32..1000)
.map(|i| {
let mut rec = vec![0x08];
rec.extend_from_slice(&encode_u64(i as u64));
rec.push(0x15);
rec.extend_from_slice(&i.to_le_bytes());
rec
})
.collect();
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 1000);
for (i, (got, want)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, want, "record {i} mismatch");
}
}
#[test]
fn test_large_nonproto_roundtrip() {
let record = vec![0xFF_u8; 65536];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_deep_nested_submessage_roundtrip() {
let mut inner = vec![0x10, 0x63_u8]; for _ in 0..5 {
let len = inner.len() as u8;
let mut outer = vec![0x0A, len];
outer.extend_from_slice(&inner);
inner = outer;
}
let result = roundtrip(&[&inner], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], inner);
}
#[test]
fn test_1000_same_schema_roundtrip() {
let records: Vec<Vec<u8>> = (0u32..1000)
.map(|i| {
let mut rec = vec![0x08];
rec.extend_from_slice(&encode_u64(i as u64));
rec.push(0x15);
rec.extend_from_slice(&i.to_le_bytes());
rec
})
.collect();
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 1000);
for (i, (got, want)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, want, "record {i} mismatch");
}
}
#[test]
fn test_deep_nesting_roundtrip() {
fn build_nested(depth: usize, value: u8) -> Vec<u8> {
if depth == 0 {
return vec![0x08, value];
}
let inner = build_nested(depth - 1, value);
let mut rec = vec![0x0A, inner.len() as u8];
rec.extend_from_slice(&inner);
rec
}
let records: Vec<Vec<u8>> = (0u8..20).map(|i| build_nested(10, i)).collect();
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 20);
for (i, (got, want)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, want, "record {i} mismatch");
}
}
#[test]
fn test_alternating_schemas() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..200 {
let mut rec = Vec::new();
if i % 2 == 0 {
rec.push(0x08);
rec.extend_from_slice(&encode_u64(i as u64));
rec.push(0x15);
rec.extend_from_slice(&i.to_le_bytes());
} else {
rec.push(0x18);
rec.extend_from_slice(&encode_u64(i as u64));
rec.push(0x22);
rec.push(2);
rec.push(b'x');
rec.push(b'y');
}
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 200);
for (i, (got, want)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, want, "record {i} mismatch");
}
}
#[test]
fn test_many_nonproto_records() {
let records: Vec<Vec<u8>> = (0u32..500)
.map(|i| {
let len = (i % 50) as usize + 1;
let mut rec = vec![0xFF_u8; len];
if len > 1 {
rec[1] = (i & 0xFF) as u8;
}
rec
})
.collect();
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 500);
for (i, (got, want)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, want, "record {i} mismatch");
}
}
#[test]
fn test_mixed_proto_nonproto_heavy() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..300 {
if i % 3 == 0 {
records.push(vec![0xFF, (i & 0xFF) as u8]);
} else if i % 3 == 1 {
let mut rec = vec![0x08];
rec.extend_from_slice(&encode_u64(i as u64));
records.push(rec);
} else {
let s = format!("val{i}");
let mut rec = vec![0x12, s.len() as u8];
rec.extend_from_slice(s.as_bytes());
records.push(rec);
}
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 300);
for (i, (got, want)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, want, "record {i} mismatch");
}
}
#[test]
fn test_deeply_nested_proto_roundtrip() {
fn make_nested(depth: usize) -> Vec<u8> {
if depth == 0 {
return vec![0x08, 42, 0x10, 43, 0x18, 44];
}
let inner = make_nested(depth - 1);
let mut rec = Vec::new();
rec.push(0x0A);
let mut buf = Vec::new();
push_varint(&mut buf, inner.len() as u64);
rec.extend_from_slice(&buf);
rec.extend_from_slice(&inner);
rec
}
let record = make_nested(10);
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn test_large_record_roundtrip() {
let mut rec = Vec::new();
rec.push(0x08);
push_varint(&mut rec, 999);
rec.push(0x12);
let payload = vec![0xAB_u8; 4096];
push_varint(&mut rec, payload.len() as u64);
rec.extend_from_slice(&payload);
let result = roundtrip(&[&rec], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], rec);
}
#[test]
fn test_single_byte_all_values() {
let records: Vec<Vec<u8>> = (0u8..=255).map(|b| vec![b]).collect();
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 256);
for (i, (got, want)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, want, "record {i} mismatch");
}
}
#[test]
fn test_alternating_proto_nonproto() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u64..100 {
if i % 2 == 0 {
let mut rec = vec![0x08];
push_varint(&mut rec, i);
rec.push(0x10);
push_varint(&mut rec, i + 1);
rec.push(0x18);
push_varint(&mut rec, i + 2);
records.push(rec);
} else {
records.push(vec![0xFF, 0xFE, i as u8]);
}
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 100);
for (i, (got, want)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, want, "record {i} mismatch");
}
}
fn make_proto_3_varints_int(a: u64, b: u64, c: u64) -> Vec<u8> {
let mut rec = Vec::new();
rec.push(0x08);
push_varint(&mut rec, a);
rec.push(0x10);
push_varint(&mut rec, b);
rec.push(0x18);
push_varint(&mut rec, c);
rec
}
fn push_varint_u32_int(buf: &mut Vec<u8>, mut v: u32) {
loop {
if v < 0x80 {
buf.push(v as u8);
return;
}
buf.push((v as u8 & 0x7F) | 0x80);
v >>= 7;
}
}
fn transpose_roundtrip_with_bucket_size(
records: &[Vec<u8>],
compression: CompressionType,
bucket_size: u64,
) -> Vec<Vec<u8>> {
let mut enc = TransposeChunkEncoder::new(compression).bucket_size(bucket_size);
for rec in records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut out = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
out.push(rec);
}
out
}
#[test]
fn criterion_13_4_multi_bucket_produces_multiple_buckets() {
use crate::varint::{decode_u32, decode_u64};
let records: Vec<Vec<u8>> = (0..200)
.map(|i| make_proto_3_varints_int(i, i * 2, i * 3))
.collect();
let mut enc = TransposeChunkEncoder::new(CompressionType::None).bucket_size(64);
for rec in &records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let data = &chunk.data;
assert_eq!(data[0], 0x00, "compression type should be None");
let (header_len, consumed) = decode_u64(&data[1..]).expect("header length");
let header_start = 1 + consumed;
let header_bytes = &data[header_start..header_start + header_len as usize];
let (num_buckets, _) = decode_u32(header_bytes).expect("num_buckets");
assert!(
num_buckets > 1,
"expected multiple buckets with bucket_size=64, got {num_buckets}"
);
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut out = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
out.push(rec);
}
assert_eq!(out.len(), records.len());
for (i, (expected, actual)) in records.iter().zip(out.iter()).enumerate() {
assert_eq!(expected, actual, "mismatch at record {i}");
}
}
#[test]
fn criterion_13_4_multi_bucket_with_brotli() {
let records: Vec<Vec<u8>> = (0..100)
.map(|i| make_proto_3_varints_int(i, i + 100, i + 200))
.collect();
let mut enc = TransposeChunkEncoder::new(CompressionType::Brotli).bucket_size(128);
for rec in &records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut out = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
out.push(rec);
}
assert_eq!(out.len(), records.len());
for (i, (expected, actual)) in records.iter().zip(out.iter()).enumerate() {
assert_eq!(expected, actual, "mismatch at record {i}");
}
}
#[test]
fn criterion_13_4_multi_bucket_nonproto_roundtrip() {
let records: Vec<Vec<u8>> = (0..50)
.map(|i| format!("record number {i}: some arbitrary data here").into_bytes())
.collect();
let mut enc = TransposeChunkEncoder::new(CompressionType::None).bucket_size(100);
for rec in &records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut out = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
out.push(rec);
}
assert_eq!(out.len(), records.len());
for (i, (expected, actual)) in records.iter().zip(out.iter()).enumerate() {
assert_eq!(expected, actual, "mismatch at record {i}");
}
}
#[test]
fn criterion_13_4_single_bucket_default() {
let records: Vec<Vec<u8>> = (0..10).map(|i| make_proto_3_varints_int(i, i, i)).collect();
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), records.len());
for (i, (expected, actual)) in records.iter().zip(result.iter()).enumerate() {
assert_eq!(expected, actual, "mismatch at record {i}");
}
}
#[test]
fn criterion_13_5_transpose_brotli_smaller_than_simple_brotli() {
use crate::simple_chunk::SimpleChunkEncoder;
let records: Vec<Vec<u8>> = (0..10_000)
.map(|i| make_proto_3_varints_int(i, i * 2, i * 3))
.collect();
let mut transpose_enc = TransposeChunkEncoder::new(CompressionType::Brotli);
for rec in &records {
transpose_enc.add_record(rec).expect("add_record");
}
let transpose_chunk = transpose_enc.encode().expect("encode");
let transpose_size = transpose_chunk.data.len();
let mut simple_enc = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
for rec in &records {
simple_enc.add_record(rec);
}
let simple_chunk = simple_enc.encode().expect("encode");
let simple_size = simple_chunk.data.len();
assert!(
transpose_size < simple_size,
"transpose+Brotli ({transpose_size}) should be smaller than simple+Brotli ({simple_size})"
);
}
#[test]
fn criterion_13_7_many_unique_fields_no_overflow() {
let mut records = Vec::new();
for batch in 0..10 {
let mut rec = Vec::new();
for field_num in 1..=200u32 {
let adjusted_field = field_num + (batch * 200);
let tag = adjusted_field << 3;
push_varint_u32_int(&mut rec, tag);
push_varint(&mut rec, (field_num + batch * 100) as u64);
}
records.push(rec);
}
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut out = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
out.push(rec);
}
assert_eq!(out.len(), records.len());
for (i, (expected, actual)) in records.iter().zip(out.iter()).enumerate() {
assert_eq!(expected, actual, "mismatch at record {i}");
}
}
#[test]
fn criterion_13_7_corrupted_num_states_no_panic() {
use crate::simple_chunk::Chunk;
let records = vec![make_proto_3_varints_int(1, 2, 3)];
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let mut corrupted_data = chunk.data.clone();
if corrupted_data.len() > 10 {
corrupted_data.truncate(10);
}
let corrupted_chunk = Chunk {
header: chunk.header.clone(),
data: corrupted_data,
};
let result = TransposeChunkDecoder::new(corrupted_chunk);
assert!(result.is_err(), "corrupted chunk should produce an error");
}
fn simple_hash_int(seed: u64, len: usize) -> Vec<u8> {
let mut result = Vec::with_capacity(len);
let mut state = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
for _ in 0..len {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
result.push((state >> 33) as u8);
}
result
}
#[test]
fn eval_13_1_varint32_overflow_regression() {
let record = vec![0xF8, 0x80, 0x80, 0x80, 0x10, 0x00];
for compression in [
CompressionType::None,
CompressionType::Brotli,
CompressionType::Zstd,
] {
let refs = vec![record.as_slice()];
let result = roundtrip(&refs, compression);
assert_eq!(result.len(), 1);
assert_eq!(
result[0], record,
"regression case failed for {compression:?}"
);
}
}
#[test]
fn eval_13_3_proptest_regression_vector() {
let record = vec![248, 128, 128, 128, 16, 0];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record);
}
#[test]
fn eval_13_4_bucket_size_1() {
let records: Vec<Vec<u8>> = (0..50)
.map(|i| make_proto_3_varints_int(i, i * 2, i * 3))
.collect();
let result = transpose_roundtrip_with_bucket_size(&records, CompressionType::None, 1);
assert_eq!(result.len(), records.len());
for (i, (expected, actual)) in records.iter().zip(result.iter()).enumerate() {
assert_eq!(expected, actual, "mismatch at record {i}");
}
}
#[test]
fn eval_13_4_bucket_size_0() {
let records = vec![make_proto_3_varints_int(1, 2, 3)];
let result = transpose_roundtrip_with_bucket_size(&records, CompressionType::None, 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], records[0]);
}
#[test]
fn eval_13_4_multi_bucket_zstd() {
let records: Vec<Vec<u8>> = (0..100)
.map(|i| make_proto_3_varints_int(i, i + 50, i + 100))
.collect();
let result = transpose_roundtrip_with_bucket_size(&records, CompressionType::Zstd, 64);
assert_eq!(result.len(), records.len());
for (i, (expected, actual)) in records.iter().zip(result.iter()).enumerate() {
assert_eq!(expected, actual, "mismatch at record {i}");
}
}
#[test]
fn eval_13_4_multi_bucket_wide_record() {
let mut rec = Vec::new();
for field_num in 1..=100u32 {
let tag = (field_num << 3) | 0;
push_varint(&mut rec, tag as u64);
push_varint(&mut rec, field_num as u64);
}
let records = vec![rec];
let result = transpose_roundtrip_with_bucket_size(&records, CompressionType::Brotli, 32);
assert_eq!(result.len(), 1);
assert_eq!(result[0], records[0]);
}
#[test]
fn eval_13_4_multi_bucket_empty_records() {
let records: Vec<Vec<u8>> = vec![vec![]; 50];
let result = transpose_roundtrip_with_bucket_size(&records, CompressionType::None, 16);
assert_eq!(result.len(), records.len());
for (i, (expected, actual)) in records.iter().zip(result.iter()).enumerate() {
assert_eq!(expected, actual, "mismatch at record {i}");
}
}
#[test]
fn eval_13_4_multi_bucket_through_writer_reader() {
let records: Vec<Vec<u8>> = (0..300)
.map(|i| make_proto_3_varints_int(i, i * 3, i * 7))
.collect();
for bucket_size in [1u64, 32, 128, 256, u64::MAX] {
let mut enc =
TransposeChunkEncoder::new(CompressionType::Brotli).bucket_size(bucket_size);
for rec in &records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut out = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
out.push(rec);
}
assert_eq!(
out.len(),
records.len(),
"count mismatch for bucket_size={bucket_size}"
);
for (i, (expected, actual)) in records.iter().zip(out.iter()).enumerate() {
assert_eq!(
expected, actual,
"mismatch at record {i} for bucket_size={bucket_size}"
);
}
}
}
#[test]
fn eval_13_5_transpose_zstd_vs_simple_zstd() {
use crate::simple_chunk::SimpleChunkEncoder;
let records: Vec<Vec<u8>> = (0..10_000)
.map(|i| make_proto_3_varints_int(i, i * 2, i * 3))
.collect();
let mut transpose_enc = TransposeChunkEncoder::new(CompressionType::Zstd);
for rec in &records {
transpose_enc.add_record(rec).expect("add_record");
}
let transpose_chunk = transpose_enc.encode().expect("encode");
let transpose_size = transpose_chunk.data.len();
let mut simple_enc = SimpleChunkEncoder::with_compression(CompressionType::Zstd);
for rec in &records {
simple_enc.add_record(rec);
}
let simple_chunk = simple_enc.encode().expect("encode");
let simple_size = simple_chunk.data.len();
assert!(
transpose_size < simple_size,
"transpose+Zstd ({transpose_size}) should be smaller than simple+Zstd ({simple_size})"
);
}
#[test]
fn eval_13_7_max_varint_values() {
let mut rec = Vec::new();
rec.push(0x08);
push_varint(&mut rec, u64::MAX);
rec.push(0x10);
push_varint(&mut rec, u64::MAX);
let records = vec![rec];
let result = roundtrip(&[records[0].as_slice()], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], records[0]);
}
#[test]
fn eval_13_7_truncated_chunk_no_panic() {
use crate::simple_chunk::Chunk;
let records = vec![make_proto_3_varints_int(1, 2, 3); 10];
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
for truncate_at in [0, 1, 5, 10, chunk.data.len() / 2] {
let mut corrupted = chunk.data.clone();
corrupted.truncate(truncate_at);
let corrupted_chunk = Chunk {
header: chunk.header.clone(),
data: corrupted,
};
let result = TransposeChunkDecoder::new(corrupted_chunk);
assert!(
result.is_err(),
"truncation at {truncate_at} should produce an error"
);
}
}
#[test]
fn eval_13_6_multi_block_brotli_block_headers_valid() {
use crate::block_header::BlockHeader;
use crate::compression::CompressionType;
use crate::record_writer::{RecordWriter, WriterOptions};
use std::io::Cursor;
let records: Vec<Vec<u8>> = (0..1000).map(|i| simple_hash_int(i as u64, 128)).collect();
let opts = WriterOptions::new()
.compression(CompressionType::Brotli)
.transpose(true)
.chunk_size(4096);
let mut cursor = Cursor::new(Vec::<u8>::new());
{
let mut writer = RecordWriter::new(&mut cursor, opts).expect("writer::new");
for rec in &records {
writer.write_record(rec).expect("write_record");
}
writer.flush().expect("flush");
}
let file_bytes = cursor.into_inner();
let block_size = 65536usize;
let mut boundaries_checked = 0;
let mut offset = 0;
while offset + 24 <= file_bytes.len() {
if offset % block_size == 0 {
let bytes: [u8; 24] = file_bytes[offset..offset + 24].try_into().unwrap();
let header = BlockHeader::from_bytes(bytes);
assert!(header.is_valid(), "invalid block header at offset {offset}");
boundaries_checked += 1;
}
offset += block_size;
}
assert!(
boundaries_checked >= 2,
"expected >= 2 block boundaries, got {boundaries_checked} (file size: {})",
file_bytes.len()
);
}
#[test]
fn test_roundtrip_varied_proto_wire_types() {
let r0 = vec![0x08, 0x2A]; let r1 = vec![0x15, 0x01, 0x02, 0x03, 0x04]; let r2 = vec![0x19, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]; let r3 = vec![0x22, 0x03, 0x61, 0x62, 0x63]; let result = roundtrip(&[&r0, &r1, &r2, &r3], CompressionType::None);
assert_eq!(result.len(), 4);
assert_eq!(result[0], r0);
assert_eq!(result[1], r1);
assert_eq!(result[2], r2);
assert_eq!(result[3], r3);
}
#[test]
fn test_nonproto_then_proto() {
let records: Vec<Vec<u8>> = vec![
vec![0xFF], vec![0x08, 0x2A], ];
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 2);
assert_eq!(result[0], records[0]);
assert_eq!(result[1], records[1]);
}
#[test]
fn test_many_unique_field_tags() {
let mut records: Vec<Vec<u8>> = Vec::new();
for field_num in 1u32..=20 {
let tag = (field_num << 3) | 0; let mut rec = vec![];
rec.extend_from_slice(&encode_u32(tag));
rec.push(0x01);
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 20);
}
#[test]
fn test_writer_reader_roundtrip() {
use crate::record_reader::{ReaderOptions, RecordReader};
use crate::record_writer::{RecordWriter, WriterOptions};
let records: Vec<Vec<u8>> = vec![
vec![0x08, 0x2A], b"hello world".to_vec(), vec![0x0A, 0x02, 0x10, 0x2A], ];
let mut buf: Vec<u8> = Vec::new();
{
let opts = WriterOptions::new()
.compression(CompressionType::None)
.transpose(true);
let cursor = std::io::Cursor::new(&mut buf);
let mut writer = RecordWriter::new(cursor, opts).unwrap();
for rec in &records {
writer.write_record(rec).unwrap();
}
writer.close().unwrap();
}
let cursor = std::io::Cursor::new(&buf);
let mut reader = RecordReader::new(cursor, ReaderOptions::new()).unwrap();
let mut result = Vec::new();
while let Some(rec) = reader.read_record().unwrap() {
result.push(rec);
}
assert_eq!(result.len(), records.len());
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(
got, expected,
"record {i} mismatch in writer/reader roundtrip"
);
}
}
#[test]
#[cfg(feature = "brotli")]
fn test_writer_reader_brotli_100_records() {
use crate::record_reader::{ReaderOptions, RecordReader};
use crate::record_writer::{RecordWriter, WriterOptions};
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..100 {
let mut rec = vec![0x08];
rec.extend_from_slice(&encode_u64(i as u64));
records.push(rec);
}
let mut buf: Vec<u8> = Vec::new();
{
let opts = WriterOptions::new()
.compression(CompressionType::Brotli)
.transpose(true);
let cursor = std::io::Cursor::new(&mut buf);
let mut writer = RecordWriter::new(cursor, opts).unwrap();
for rec in &records {
writer.write_record(rec).unwrap();
}
writer.close().unwrap();
}
let cursor = std::io::Cursor::new(&buf);
let mut reader = RecordReader::new(cursor, ReaderOptions::new()).unwrap();
let mut result = Vec::new();
while let Some(rec) = reader.read_record().unwrap() {
result.push(rec);
}
assert_eq!(result.len(), records.len());
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "brotli transpose record {i} mismatch");
}
}
#[test]
fn test_varint_boundary_value_4() {
let record = vec![0x08, 0x04];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(
result[0], record,
"value 4 (first non-inline) must round-trip"
);
}
#[test]
fn test_varint_boundary_value_127() {
let record = vec![0x08, 0x7F];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record, "value 127 must round-trip");
}
#[test]
fn test_varint_boundary_value_128() {
let record = vec![0x08, 0x80, 0x01];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record, "value 128 must round-trip");
}
#[test]
fn test_high_field_number_1000() {
let tag_bytes = encode_u32((1000 << 3) | 0);
let mut record = tag_bytes;
record.push(0x2A); let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record, "high field number must round-trip");
}
#[test]
fn test_empty_string_field() {
let record = vec![0x12, 0x00];
let result = roundtrip(&[&record], CompressionType::None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], record, "empty string field must round-trip");
}
#[test]
fn test_multiple_records_different_schemas() {
let records: Vec<Vec<u8>> = vec![
vec![0x08, 0x01], vec![0x15, 0x01, 0x02, 0x03, 0x04], vec![0x22, 0x02, 0x68, 0x69], vec![0xFF, 0xAB], ];
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 4);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_decoded_data_size_matches() {
let records: Vec<Vec<u8>> =
vec![vec![0x08, 0x01], vec![0xFF, 0xAA, 0xBB], vec![0x08, 0x02]];
let total_size: u64 = records.iter().map(|r| r.len() as u64).sum();
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &records {
enc.add_record(rec).unwrap();
}
let chunk = enc.encode().unwrap();
assert_eq!(
chunk.header.decoded_data_size(),
total_size,
"decoded_data_size must equal sum of record sizes"
);
assert_eq!(chunk.header.num_records(), 3);
}
fn build_wide_proto(start: u32, count: u32, value_seed: u32) -> Vec<u8> {
let mut rec = Vec::new();
for i in 0..count {
let field_num = start + i;
let tag = (field_num << 3) | 0; rec.extend_from_slice(&encode_u32(tag));
rec.extend_from_slice(&encode_u64((value_seed + i) as u64));
}
rec
}
#[test]
fn test_implicit_transitions_500_records() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..500 {
let mut rec = Vec::new();
rec.push(0x08);
rec.extend_from_slice(&encode_u64(i as u64));
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 500);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_noop_bridging_70_unique_fields() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..70 {
let field_num = i + 1;
let tag = (field_num << 3) | 0;
let mut rec = Vec::new();
rec.extend_from_slice(&encode_u32(tag));
rec.push(0x01);
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 70);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_noop_bridging_100_unique_fields_per_record() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..10 {
let rec = build_wide_proto(1, 100, i * 100);
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 10);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_noop_bridging_200_unique_fields() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..5 {
let rec = build_wide_proto(1, 200, i * 200);
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 5);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
#[cfg(feature = "brotli")]
fn test_optimized_brotli_100_records() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..100 {
let mut rec = Vec::new();
rec.push(0x08);
rec.extend_from_slice(&encode_u64(i as u64));
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::Brotli);
assert_eq!(result.len(), 100);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_wide_schema_mixed_types() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..50 {
let mut rec = Vec::new();
for f in 1u32..=50 {
let tag = (f << 3) | 0;
rec.extend_from_slice(&encode_u32(tag));
rec.extend_from_slice(&encode_u64((i + f) as u64));
}
for f in 51u32..=75 {
let tag = (f << 3) | 5;
rec.extend_from_slice(&encode_u32(tag));
rec.extend_from_slice(&(i + f).to_le_bytes());
}
for f in 76u32..=100 {
let tag = (f << 3) | 2;
rec.extend_from_slice(&encode_u32(tag));
rec.push(2);
rec.push(b'h');
rec.push(b'i');
}
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 50);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
#[test]
fn test_no_implicit_loops_many_unique_tags() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..80 {
let field_num = i + 1;
let tag = (field_num << 3) | 0;
let mut rec = Vec::new();
rec.extend_from_slice(&encode_u32(tag));
rec.push(0x01);
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let mut enc = TransposeChunkEncoder::new(CompressionType::None);
for rec in &refs {
enc.add_record(rec).unwrap();
}
let chunk = enc.encode().unwrap();
let dec = TransposeChunkDecoder::new(chunk);
assert!(dec.is_ok(), "decoder should not report implicit loops");
}
#[test]
fn test_high_field_numbers_10000_and_20000() {
let mut records: Vec<Vec<u8>> = Vec::new();
for i in 0u32..10 {
let mut rec = Vec::new();
let tag = (10000u32 << 3) | 0;
rec.extend_from_slice(&encode_u32(tag));
rec.extend_from_slice(&encode_u64(i as u64));
let tag2 = (20000u32 << 3) | 0;
rec.extend_from_slice(&encode_u32(tag2));
rec.extend_from_slice(&encode_u64((i * 2) as u64));
records.push(rec);
}
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let result = roundtrip(&refs, CompressionType::None);
assert_eq!(result.len(), 10);
for (i, (got, expected)) in result.iter().zip(records.iter()).enumerate() {
assert_eq!(got, expected, "record {i} mismatch");
}
}
fn transpose_roundtrip_proptest_check(records: &[Vec<u8>], compression: CompressionType) {
let mut enc = TransposeChunkEncoder::new(compression);
for rec in records {
enc.add_record(rec).expect("add_record");
}
let chunk = enc.encode().expect("encode");
let mut dec = TransposeChunkDecoder::new(chunk).expect("decoder");
let mut got: Vec<Vec<u8>> = Vec::new();
while let Some(rec) = dec.read_record().expect("read_record") {
got.push(rec);
}
assert_eq!(
got.len(),
records.len(),
"record count mismatch: wrote {} read {}",
records.len(),
got.len()
);
for (i, (expected, actual)) in records.iter().zip(got.iter()).enumerate() {
assert_eq!(
expected,
actual,
"record {} mismatch: expected {} bytes, got {} bytes",
i,
expected.len(),
actual.len()
);
}
}
proptest::proptest! {
#![proptest_config(proptest::prelude::ProptestConfig::with_cases(500))]
#[test]
#[ignore]
fn proptest_transpose_roundtrip_none(
records in proptest::collection::vec(
proptest::collection::vec(proptest::prelude::any::<u8>(), 0..=4096),
0..=500,
)
) {
transpose_roundtrip_proptest_check(&records, CompressionType::None);
}
}
}