use arrow_array::RecordBatch;
use crate::Result;
use super::{
DirectEncoder, binding::BoundDirectBatch, invalid_payload, layout,
layout::build_fixed_width_row_range_layout,
};
pub(crate) fn measure_layout(
encoder: &DirectEncoder,
batch: &RecordBatch,
) -> Result<layout::RowLayout> {
BoundDirectBatch::new(encoder, batch)?.measure_layout()
}
pub(crate) fn measure_cell_lengths(
encoder: &DirectEncoder,
batch: &RecordBatch,
) -> Result<Vec<usize>> {
let row_count = batch.num_rows();
if row_count == 0 {
return Ok(Vec::new());
}
BoundDirectBatch::new(encoder, batch)?.measure_cell_lengths()
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct MeasuredDirectBatch {
row_count: usize,
column_count: usize,
cell_lengths: Vec<usize>,
row_lengths: Vec<usize>,
payload_len: usize,
}
impl MeasuredDirectBatch {
pub(super) fn empty(column_count: usize) -> Self {
Self {
row_count: 0,
column_count,
cell_lengths: Vec::new(),
row_lengths: Vec::new(),
payload_len: 0,
}
}
pub(super) fn new(
row_count: usize,
column_count: usize,
cell_lengths: Vec<usize>,
) -> Result<Self> {
let expected_cell_count = row_count
.checked_mul(column_count)
.ok_or_else(|| invalid_payload("measured cell count overflowed usize"))?;
if cell_lengths.len() != expected_cell_count {
return Err(invalid_payload(format!(
"measured cell length count {} does not match row count {row_count} and column count {column_count}",
cell_lengths.len()
)));
}
let (row_lengths, payload_len) =
measure_row_lengths(row_count, column_count, &cell_lengths)?;
Ok(Self {
row_count,
column_count,
cell_lengths,
row_lengths,
payload_len,
})
}
pub(crate) const fn row_count(&self) -> usize {
self.row_count
}
pub(crate) const fn column_count(&self) -> usize {
self.column_count
}
pub(crate) const fn payload_len(&self) -> usize {
self.payload_len
}
pub(crate) fn row_ranges(&self, max_payload_bytes: usize) -> Result<Vec<MeasuredRowRange>> {
if max_payload_bytes == 0 {
return Err(invalid_payload(
"direct row range byte limit must be greater than zero",
));
}
let mut ranges = Vec::new();
let mut start = 0usize;
let mut len = 0usize;
let mut bytes = 0usize;
for (row_index, row_len) in self.row_lengths.iter().copied().enumerate() {
let next_bytes = bytes
.checked_add(row_len)
.ok_or_else(|| invalid_payload("measured row range length overflowed usize"))?;
if len > 0 && next_bytes > max_payload_bytes {
ranges.push(MeasuredRowRange { start, len });
start = row_index;
len = 0;
bytes = row_len;
} else {
bytes = next_bytes;
}
len += 1;
}
if len > 0 {
ranges.push(MeasuredRowRange { start, len });
}
Ok(ranges)
}
pub(crate) fn range_payload_len(&self, start_row: usize, row_count: usize) -> Result<usize> {
self.check_range(start_row, row_count)?;
let end_row = start_row
.checked_add(row_count)
.ok_or_else(|| invalid_payload("direct row range end overflowed usize"))?;
self.row_lengths[start_row..end_row]
.iter()
.try_fold(0usize, |total, row_len| {
total
.checked_add(*row_len)
.ok_or_else(|| invalid_payload("measured row range length overflowed usize"))
})
}
pub(super) fn cell_len(&self, row_index: usize, column_index: usize) -> Result<usize> {
self.check_range(row_index, 1)?;
if column_index >= self.column_count {
return Err(invalid_payload(format!(
"direct measured column index {column_index} is outside measured column count {}",
self.column_count
)));
}
let index = row_index
.checked_mul(self.column_count)
.and_then(|base| base.checked_add(column_index))
.ok_or_else(|| invalid_payload("measured cell length index overflowed usize"))?;
self.cell_lengths.get(index).copied().ok_or_else(|| {
invalid_payload(format!(
"measured cell length index {index} is outside measured cell length count {}",
self.cell_lengths.len()
))
})
}
pub(super) fn range_layout(
&self,
start_row: usize,
row_count: usize,
) -> Result<layout::RowLayout> {
self.check_range(start_row, row_count)?;
build_fixed_width_row_range_layout(
start_row,
row_count,
self.column_count,
&self.cell_lengths,
)
}
pub(super) fn check_range(&self, start_row: usize, row_count: usize) -> Result<()> {
let end_row = start_row
.checked_add(row_count)
.ok_or_else(|| invalid_payload("direct row range end overflowed usize"))?;
if end_row > self.row_count {
return Err(invalid_payload(format!(
"direct measured row range {start_row}..{end_row} is outside measured row count {}",
self.row_count
)));
}
Ok(())
}
}
fn measure_row_lengths(
row_count: usize,
column_count: usize,
cell_lengths: &[usize],
) -> Result<(Vec<usize>, usize)> {
let mut row_lengths = Vec::with_capacity(row_count);
let mut payload_len = 0usize;
for row_index in 0..row_count {
let mut row_len = 1usize;
for column_index in 0..column_count {
row_len = row_len
.checked_add(cell_lengths[row_index * column_count + column_index])
.ok_or_else(|| invalid_payload("measured row length overflowed usize"))?;
}
payload_len = payload_len
.checked_add(row_len)
.ok_or_else(|| invalid_payload("measured payload length overflowed usize"))?;
row_lengths.push(row_len);
}
Ok((row_lengths, payload_len))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct MeasuredRowRange {
pub(crate) start: usize,
pub(crate) len: usize,
}
#[cfg(test)]
mod tests {
use super::{MeasuredDirectBatch, MeasuredRowRange};
use crate::{DiagnosticCode, Error};
#[test]
fn measured_direct_batch_builds_payload_lengths() {
let measured = MeasuredDirectBatch::new(2, 3, vec![4, 1, 8, 4, 5, 1]).unwrap();
assert_eq!(measured.row_count(), 2);
assert_eq!(measured.column_count(), 3);
assert_eq!(measured.payload_len(), 25);
assert_eq!(measured.range_payload_len(0, 1).unwrap(), 14);
assert_eq!(measured.range_payload_len(1, 1).unwrap(), 11);
assert_eq!(measured.cell_len(1, 1).unwrap(), 5);
}
#[test]
fn measured_direct_batch_ranges_split_by_payload_byte_limit() {
let measured = MeasuredDirectBatch::new(4, 1, vec![4, 4, 4, 4]).unwrap();
assert_eq!(measured.payload_len(), 20);
assert_eq!(
measured.row_ranges(10).unwrap(),
[
MeasuredRowRange { start: 0, len: 2 },
MeasuredRowRange { start: 2, len: 2 },
]
);
assert_eq!(
measured.row_ranges(4).unwrap(),
[
MeasuredRowRange { start: 0, len: 1 },
MeasuredRowRange { start: 1, len: 1 },
MeasuredRowRange { start: 2, len: 1 },
MeasuredRowRange { start: 3, len: 1 },
]
);
}
#[test]
fn measured_direct_batch_rejects_invalid_cell_length_count() {
let err = MeasuredDirectBatch::new(2, 2, vec![1, 2, 3])
.expect_err("cell length count must match shape");
assert_direct_encoding_invalid_payload(err);
}
#[test]
fn measured_direct_batch_rejects_invalid_ranges() {
let measured = MeasuredDirectBatch::new(2, 1, vec![4, 4]).unwrap();
assert_direct_encoding_invalid_payload(
measured
.range_payload_len(1, 2)
.expect_err("range past measured rows must fail"),
);
assert_direct_encoding_invalid_payload(
measured
.row_ranges(0)
.expect_err("zero byte limit must fail"),
);
}
fn assert_direct_encoding_invalid_payload(err: Error) {
let Error::DirectEncoding { diagnostics } = err else {
panic!("expected direct encoding error");
};
assert_eq!(diagnostics.len(), 1);
assert_eq!(
diagnostics.all()[0].code(),
DiagnosticCode::DirectEncodingInvalidPayload
);
}
}