use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Formatter;
use std::hash::Hash;
use std::hash::Hasher;
use std::sync::Arc;
use itertools::Itertools as _;
use prost::Message as _;
use vortex_array::Array;
use vortex_array::ArrayEq;
use vortex_array::ArrayHash;
use vortex_array::ArrayId;
use vortex_array::ArrayParts;
use vortex_array::ArrayRef;
use vortex_array::ArrayView;
use vortex_array::Canonical;
use vortex_array::ExecutionCtx;
use vortex_array::ExecutionResult;
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::Precision;
use vortex_array::ToCanonical;
use vortex_array::VortexSessionExecute;
use vortex_array::accessor::ArrayAccessor;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::VarBinViewArray;
use vortex_array::arrays::varbinview::build_views::BinaryView;
use vortex_array::arrays::varbinview::build_views::MAX_BUFFER_LEN;
use vortex_array::buffer::BufferHandle;
use vortex_array::dtype::DType;
use vortex_array::scalar::Scalar;
use vortex_array::serde::ArrayChildren;
use vortex_array::validity::Validity;
use vortex_array::vtable::OperationsVTable;
use vortex_array::vtable::VTable;
use vortex_array::vtable::ValidityVTable;
use vortex_array::vtable::child_to_validity;
use vortex_array::vtable::validity_to_child;
use vortex_buffer::Alignment;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_buffer::ByteBuffer;
use vortex_buffer::ByteBufferMut;
use vortex_error::VortexError;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_error::vortex_panic;
use vortex_mask::AllOr;
use vortex_session::VortexSession;
use vortex_session::registry::CachedId;
use crate::ZstdFrameMetadata;
use crate::ZstdMetadata;
const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
type ViewLen = u32;
pub type ZstdArray = Array<Zstd>;
impl ArrayHash for ZstdData {
fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
match &self.dictionary {
Some(dict) => {
true.hash(state);
dict.array_hash(state, precision);
}
None => {
false.hash(state);
}
}
for frame in &self.frames {
frame.array_hash(state, precision);
}
self.unsliced_n_rows.hash(state);
self.slice_start.hash(state);
self.slice_stop.hash(state);
}
}
impl ArrayEq for ZstdData {
fn array_eq(&self, other: &Self, precision: Precision) -> bool {
if !match (&self.dictionary, &other.dictionary) {
(Some(d1), Some(d2)) => d1.array_eq(d2, precision),
(None, None) => true,
_ => false,
} {
return false;
}
if self.frames.len() != other.frames.len() {
return false;
}
for (a, b) in self.frames.iter().zip(&other.frames) {
if !a.array_eq(b, precision) {
return false;
}
}
self.unsliced_n_rows == other.unsliced_n_rows
&& self.slice_start == other.slice_start
&& self.slice_stop == other.slice_stop
}
}
impl VTable for Zstd {
type ArrayData = ZstdData;
type OperationsVTable = Self;
type ValidityVTable = Self;
fn id(&self) -> ArrayId {
static ID: CachedId = CachedId::new("vortex.zstd");
*ID
}
fn validate(
&self,
data: &Self::ArrayData,
dtype: &DType,
len: usize,
slots: &[Option<ArrayRef>],
) -> VortexResult<()> {
let validity = child_to_validity(&slots[0], dtype.nullability());
data.validate(dtype, len, &validity)
}
fn nbuffers(array: ArrayView<'_, Self>) -> usize {
array.dictionary.is_some() as usize + array.frames.len()
}
fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
if let Some(dict) = &array.dictionary {
if idx == 0 {
return BufferHandle::new_host(dict.clone());
}
BufferHandle::new_host(array.frames[idx - 1].clone())
} else {
BufferHandle::new_host(array.frames[idx].clone())
}
}
fn buffer_name(array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
if array.dictionary.is_some() {
if idx == 0 {
Some("dictionary".to_string())
} else {
Some(format!("frame_{}", idx - 1))
}
} else {
Some(format!("frame_{idx}"))
}
}
fn serialize(
array: ArrayView<'_, Self>,
_session: &VortexSession,
) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(array.metadata.clone().encode_to_vec()))
}
fn deserialize(
&self,
dtype: &DType,
len: usize,
metadata: &[u8],
buffers: &[BufferHandle],
children: &dyn ArrayChildren,
_session: &VortexSession,
) -> VortexResult<ArrayParts<Self>> {
let metadata = ZstdMetadata::decode(metadata)?;
let validity = if children.is_empty() {
Validity::from(dtype.nullability())
} else if children.len() == 1 {
let validity = children.get(0, &Validity::DTYPE, len)?;
Validity::Array(validity)
} else {
vortex_bail!("ZstdArray expected 0 or 1 child, got {}", children.len());
};
let (dictionary_buffer, compressed_buffers) = if metadata.dictionary_size == 0 {
(
None,
buffers
.iter()
.map(|b| b.clone().try_to_host_sync())
.collect::<VortexResult<Vec<_>>>()?,
)
} else {
(
Some(buffers[0].clone().try_to_host_sync()?),
buffers[1..]
.iter()
.map(|b| b.clone().try_to_host_sync())
.collect::<VortexResult<Vec<_>>>()?,
)
};
let slots = vec![validity_to_child(&validity, len)];
let data = ZstdData::new(dictionary_buffer, compressed_buffers, metadata, len);
Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
}
fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
SLOT_NAMES[idx].to_string()
}
fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
let unsliced_validity =
child_to_validity(&array.as_ref().slots()[0], array.dtype().nullability());
array
.data()
.decompress(array.dtype(), &unsliced_validity, ctx)?
.execute::<ArrayRef>(ctx)
.map(ExecutionResult::done)
}
fn reduce_parent(
array: ArrayView<'_, Self>,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
crate::rules::RULES.evaluate(array, parent, child_idx)
}
}
#[derive(Clone, Debug)]
pub struct Zstd;
impl Zstd {
pub fn try_new(dtype: DType, data: ZstdData, validity: Validity) -> VortexResult<ZstdArray> {
let len = data.len();
data.validate(&dtype, len, &validity)?;
let slots = vec![validity_to_child(&validity, data.unsliced_n_rows())];
Ok(unsafe {
Array::from_parts_unchecked(ArrayParts::new(Zstd, dtype, len, data).with_slots(slots))
})
}
pub fn from_var_bin_view_without_dict(
vbv: &VarBinViewArray,
level: i32,
values_per_frame: usize,
) -> VortexResult<ZstdArray> {
let validity = vbv.validity()?;
Self::try_new(
vbv.dtype().clone(),
ZstdData::from_var_bin_view_without_dict(vbv, level, values_per_frame)?,
validity,
)
}
pub fn from_primitive(
parray: &PrimitiveArray,
level: i32,
values_per_frame: usize,
) -> VortexResult<ZstdArray> {
let validity = parray.validity()?;
Self::try_new(
parray.dtype().clone(),
ZstdData::from_primitive(parray, level, values_per_frame)?,
validity,
)
}
pub fn from_var_bin_view(
vbv: &VarBinViewArray,
level: i32,
values_per_frame: usize,
) -> VortexResult<ZstdArray> {
let validity = vbv.validity()?;
Self::try_new(
vbv.dtype().clone(),
ZstdData::from_var_bin_view(vbv, level, values_per_frame)?,
validity,
)
}
pub fn decompress(array: &ZstdArray, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
let unsliced_validity =
child_to_validity(&array.as_ref().slots()[0], array.dtype().nullability());
array
.data()
.decompress(array.dtype(), &unsliced_validity, ctx)
}
}
pub(super) const NUM_SLOTS: usize = 1;
pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
#[derive(Clone, Debug)]
pub struct ZstdData {
pub(crate) dictionary: Option<ByteBuffer>,
pub(crate) frames: Vec<ByteBuffer>,
pub(crate) metadata: ZstdMetadata,
unsliced_n_rows: usize,
slice_start: usize,
slice_stop: usize,
}
impl Display for ZstdData {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"nrows: {}, slice: {}..{}",
self.unsliced_n_rows, self.slice_start, self.slice_stop
)
}
}
pub struct ZstdDataParts {
pub dictionary: Option<ByteBuffer>,
pub frames: Vec<ByteBuffer>,
pub metadata: ZstdMetadata,
pub validity: Validity,
pub n_rows: usize,
pub slice_start: usize,
pub slice_stop: usize,
}
#[derive(Debug)]
struct Frames {
dictionary: Option<ByteBuffer>,
frames: Vec<ByteBuffer>,
frame_metas: Vec<ZstdFrameMetadata>,
}
fn choose_max_dict_size(uncompressed_size: usize) -> usize {
(uncompressed_size / 100).clamp(256, 100 * 1024)
}
fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult<PrimitiveArray> {
let mask = parray.as_ref().validity()?.to_mask(
parray.as_ref().len(),
&mut LEGACY_SESSION.create_execution_ctx(),
)?;
Ok(parray.filter(mask)?.to_primitive())
}
fn collect_valid_vbv(vbv: &VarBinViewArray) -> VortexResult<(ByteBuffer, Vec<usize>)> {
let mask = vbv.as_ref().validity()?.to_mask(
vbv.as_ref().len(),
&mut LEGACY_SESSION.create_execution_ctx(),
)?;
let buffer_and_value_byte_indices = match mask.bit_buffer() {
AllOr::None => (Buffer::empty(), Vec::new()),
_ => {
let mut buffer = BufferMut::with_capacity(
usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer")
+ mask.true_count() * size_of::<ViewLen>(),
);
let mut value_byte_indices = Vec::new();
vbv.with_iterator(|iterator| {
for value in iterator.flatten() {
value_byte_indices.push(buffer.len());
buffer
.extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter());
buffer.extend_from_slice(value);
}
Ok::<_, VortexError>(())
})?;
(buffer.freeze(), value_byte_indices)
}
};
Ok(buffer_and_value_byte_indices)
}
pub fn reconstruct_views(
buffer: &ByteBuffer,
max_buffer_len: usize,
) -> (Vec<ByteBuffer>, Buffer<BinaryView>) {
let mut views = BufferMut::<BinaryView>::empty();
let mut buffers = Vec::new();
let mut segment_start: usize = 0;
let mut offset = 0;
while offset < buffer.len() {
let str_len = ViewLen::from_le_bytes(
buffer
.get(offset..offset + size_of::<ViewLen>())
.vortex_expect("corrupted zstd length")
.try_into()
.ok()
.vortex_expect("must fit ViewLen size"),
) as usize;
let value_data_offset = offset + size_of::<ViewLen>();
let local_offset = value_data_offset - segment_start;
if local_offset + str_len > max_buffer_len && offset > segment_start {
buffers.push(buffer.slice(segment_start..offset));
segment_start = offset;
}
let local_offset = u32::try_from(value_data_offset - segment_start)
.vortex_expect("local offset within segment must fit in u32");
let buf_index = u32::try_from(buffers.len()).vortex_expect("buffer index must fit in u32");
let value = &buffer[value_data_offset..value_data_offset + str_len];
views.push(BinaryView::make_view(value, buf_index, local_offset));
offset = value_data_offset + str_len;
}
if segment_start < buffer.len() {
buffers.push(buffer.slice(segment_start..buffer.len()));
}
(buffers, views.freeze())
}
impl ZstdData {
pub fn new(
dictionary: Option<ByteBuffer>,
frames: Vec<ByteBuffer>,
metadata: ZstdMetadata,
n_rows: usize,
) -> Self {
Self {
dictionary,
frames,
metadata,
unsliced_n_rows: n_rows,
slice_start: 0,
slice_stop: n_rows,
}
}
pub fn validate(&self, dtype: &DType, len: usize, validity: &Validity) -> VortexResult<()> {
vortex_ensure!(
matches!(
dtype,
DType::Primitive(..) | DType::Binary(_) | DType::Utf8(_)
),
"Unsupported dtype for Zstd array: {dtype}"
);
vortex_ensure!(
self.slice_start <= self.slice_stop,
"Invalid slice range {}..{}",
self.slice_start,
self.slice_stop
);
vortex_ensure!(
self.slice_stop <= self.unsliced_n_rows,
"Slice stop {} exceeds unsliced row count {}",
self.slice_stop,
self.unsliced_n_rows
);
vortex_ensure!(
self.slice_stop - self.slice_start == len,
"Slice length {} does not match array length {}",
self.slice_stop - self.slice_start,
len
);
if let Some(validity_len) = validity.maybe_len() {
vortex_ensure!(
validity_len == self.unsliced_n_rows,
"Validity length {} does not match unsliced row count {}",
validity_len,
self.unsliced_n_rows
);
}
match &self.dictionary {
Some(dictionary) => vortex_ensure!(
usize::try_from(self.metadata.dictionary_size)? == dictionary.len(),
"Dictionary size metadata {} does not match buffer size {}",
self.metadata.dictionary_size,
dictionary.len()
),
None => vortex_ensure!(
self.metadata.dictionary_size == 0,
"Dictionary metadata present without dictionary buffer"
),
}
vortex_ensure!(
self.frames.len() == self.metadata.frames.len(),
"Frame count {} does not match metadata frame count {}",
self.frames.len(),
self.metadata.frames.len()
);
Ok(())
}
pub(crate) fn with_slice(&self, start: usize, stop: usize) -> Self {
let new_start = self.slice_start + start;
let new_stop = self.slice_start + stop;
assert!(
new_start <= self.slice_stop,
"new slice start {new_start} exceeds end {}",
self.slice_stop
);
assert!(
new_stop <= self.slice_stop,
"new slice stop {new_stop} exceeds end {}",
self.slice_stop
);
Self {
slice_start: new_start,
slice_stop: new_stop,
..self.clone()
}
}
fn compress_values(
value_bytes: &ByteBuffer,
frame_byte_starts: &[usize],
level: i32,
values_per_frame: usize,
n_values: usize,
use_dictionary: bool,
) -> VortexResult<Frames> {
let n_frames = frame_byte_starts.len();
let mut sample_sizes = Vec::with_capacity(n_frames);
for i in 0..n_frames {
let frame_byte_end = frame_byte_starts
.get(i + 1)
.copied()
.unwrap_or(value_bytes.len());
sample_sizes.push(frame_byte_end - frame_byte_starts[i]);
}
debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
let (dictionary, mut compressor) = if !use_dictionary
|| sample_sizes.len() < MIN_SAMPLES_FOR_DICTIONARY
{
(None, zstd::bulk::Compressor::new(level)?)
} else {
let max_dict_size = choose_max_dict_size(value_bytes.len());
let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
.map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
let compressor = zstd::bulk::Compressor::with_dictionary(level, &dict)?;
(Some(ByteBuffer::from(dict)), compressor)
};
let mut frame_metas = vec![];
let mut frames = vec![];
for i in 0..n_frames {
let frame_byte_end = frame_byte_starts
.get(i + 1)
.copied()
.unwrap_or(value_bytes.len());
let uncompressed = &value_bytes.slice(frame_byte_starts[i]..frame_byte_end);
let compressed = compressor
.compress(uncompressed)
.map_err(|err| VortexError::from(err).with_context("while compressing"))?;
frame_metas.push(ZstdFrameMetadata {
uncompressed_size: uncompressed.len() as u64,
n_values: values_per_frame.min(n_values - i * values_per_frame) as u64,
});
frames.push(ByteBuffer::from(compressed));
}
Ok(Frames {
dictionary,
frames,
frame_metas,
})
}
pub fn from_primitive(
parray: &PrimitiveArray,
level: i32,
values_per_frame: usize,
) -> VortexResult<Self> {
Self::from_primitive_impl(parray, level, values_per_frame, true)
}
pub fn from_primitive_without_dict(
parray: &PrimitiveArray,
level: i32,
values_per_frame: usize,
) -> VortexResult<Self> {
Self::from_primitive_impl(parray, level, values_per_frame, false)
}
fn from_primitive_impl(
parray: &PrimitiveArray,
level: i32,
values_per_frame: usize,
use_dictionary: bool,
) -> VortexResult<Self> {
let byte_width = parray.ptype().byte_width();
let values = collect_valid_primitive(parray)?;
let n_values = values.len();
let values_per_frame = if values_per_frame > 0 {
values_per_frame
} else {
n_values
};
let value_bytes = values.buffer_handle().try_to_host_sync()?;
let alignment = *value_bytes.alignment();
let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
let frame_byte_starts = (0..n_values * byte_width)
.step_by(step_width)
.collect::<Vec<_>>();
let Frames {
dictionary,
frames,
frame_metas,
} = Self::compress_values(
&value_bytes,
&frame_byte_starts,
level,
values_per_frame,
n_values,
use_dictionary,
)?;
let metadata = ZstdMetadata {
dictionary_size: dictionary
.as_ref()
.map_or(0, |dict| dict.len())
.try_into()?,
frames: frame_metas,
};
Ok(ZstdData::new(dictionary, frames, metadata, parray.len()))
}
pub fn from_var_bin_view(
vbv: &VarBinViewArray,
level: i32,
values_per_frame: usize,
) -> VortexResult<Self> {
Self::from_var_bin_view_impl(vbv, level, values_per_frame, true)
}
pub fn from_var_bin_view_without_dict(
vbv: &VarBinViewArray,
level: i32,
values_per_frame: usize,
) -> VortexResult<Self> {
Self::from_var_bin_view_impl(vbv, level, values_per_frame, false)
}
fn from_var_bin_view_impl(
vbv: &VarBinViewArray,
level: i32,
values_per_frame: usize,
use_dictionary: bool,
) -> VortexResult<Self> {
let (value_bytes, value_byte_indices) = collect_valid_vbv(vbv)?;
let n_values = value_byte_indices.len();
let values_per_frame = if values_per_frame > 0 {
values_per_frame
} else {
n_values
};
let frame_byte_starts = (0..n_values)
.step_by(values_per_frame)
.map(|i| value_byte_indices[i])
.collect::<Vec<_>>();
let Frames {
dictionary,
frames,
frame_metas,
} = Self::compress_values(
&value_bytes,
&frame_byte_starts,
level,
values_per_frame,
n_values,
use_dictionary,
)?;
let metadata = ZstdMetadata {
dictionary_size: dictionary
.as_ref()
.map_or(0, |dict| dict.len())
.try_into()?,
frames: frame_metas,
};
Ok(ZstdData::new(dictionary, frames, metadata, vbv.len()))
}
pub fn from_canonical(
canonical: &Canonical,
level: i32,
values_per_frame: usize,
) -> VortexResult<Option<Self>> {
match canonical {
Canonical::Primitive(parray) => Ok(Some(ZstdData::from_primitive(
parray,
level,
values_per_frame,
)?)),
Canonical::VarBinView(vbv) => Ok(Some(ZstdData::from_var_bin_view(
vbv,
level,
values_per_frame,
)?)),
_ => Ok(None),
}
}
pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
Self::from_canonical(&array.to_canonical()?, level, values_per_frame)?
.ok_or_else(|| vortex_err!("Zstd can only encode Primitive and VarBinView arrays"))
}
fn byte_width(dtype: &DType) -> usize {
if dtype.is_primitive() {
dtype.as_ptype().byte_width()
} else {
1
}
}
fn decompress(
&self,
dtype: &DType,
unsliced_validity: &Validity,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let byte_width = Self::byte_width(dtype);
let slice_n_rows = self.slice_stop - self.slice_start;
let slice_value_indices = unsliced_validity
.execute_mask(self.unsliced_n_rows, ctx)?
.valid_counts_for_indices(&[self.slice_start, self.slice_stop]);
let slice_value_idx_start = slice_value_indices[0];
let slice_value_idx_stop = slice_value_indices[1];
let mut frames_to_decompress = vec![];
let mut value_idx_start = 0;
let mut uncompressed_size_to_decompress = 0;
let mut n_skipped_values = 0;
for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
if value_idx_start >= slice_value_idx_stop {
break;
}
let frame_uncompressed_size = usize::try_from(frame_meta.uncompressed_size)
.vortex_expect("Uncompressed size must fit in usize");
let frame_n_values = if frame_meta.n_values == 0 {
frame_uncompressed_size / byte_width
} else {
usize::try_from(frame_meta.n_values).vortex_expect("frame size must fit usize")
};
let value_idx_stop = value_idx_start + frame_n_values;
if value_idx_stop > slice_value_idx_start {
frames_to_decompress.push(frame);
uncompressed_size_to_decompress += frame_uncompressed_size;
} else {
n_skipped_values += frame_n_values;
}
value_idx_start = value_idx_stop;
}
let mut decompressor = if let Some(dictionary) = &self.dictionary {
zstd::bulk::Decompressor::with_dictionary(dictionary)?
} else {
zstd::bulk::Decompressor::new()?
};
let mut decompressed = ByteBufferMut::with_capacity_aligned(
uncompressed_size_to_decompress,
Alignment::new(byte_width),
);
unsafe {
decompressed.set_len(uncompressed_size_to_decompress);
}
let mut uncompressed_start = 0;
for frame in frames_to_decompress {
let uncompressed_written = decompressor
.decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
uncompressed_start += uncompressed_written;
}
if uncompressed_start != uncompressed_size_to_decompress {
vortex_panic!(
"Zstd metadata or frames were corrupt; expected {} bytes but decompressed {}",
uncompressed_size_to_decompress,
uncompressed_start
);
}
let decompressed = decompressed.freeze();
let mut slice_validity = unsliced_validity.slice(self.slice_start..self.slice_stop)?;
if !dtype.is_nullable() && !matches!(slice_validity, Validity::NonNullable) {
assert!(
matches!(slice_validity, Validity::AllValid),
"ZSTD array expects to be non-nullable but there are nulls after decompression"
);
slice_validity = Validity::NonNullable;
} else if dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) {
slice_validity = Validity::AllValid;
}
match dtype {
DType::Primitive(..) => {
let slice_values_buffer = decompressed.slice(
(slice_value_idx_start - n_skipped_values) * byte_width
..(slice_value_idx_stop - n_skipped_values) * byte_width,
);
let primitive = PrimitiveArray::from_values_byte_buffer(
slice_values_buffer,
dtype.as_ptype(),
slice_validity,
slice_n_rows,
);
Ok(primitive.into_array())
}
DType::Binary(_) | DType::Utf8(_) => {
match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() {
AllOr::All => {
let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
let valid_views = all_views.slice(
slice_value_idx_start - n_skipped_values
..slice_value_idx_stop - n_skipped_values,
);
Ok(unsafe {
VarBinViewArray::new_unchecked(
valid_views,
Arc::from(buffers),
dtype.clone(),
slice_validity,
)
}
.into_array())
}
AllOr::None => Ok(ConstantArray::new(
Scalar::null(dtype.clone()),
slice_n_rows,
)
.into_array()),
AllOr::Some(valid_indices) => {
let (buffers, all_views) = reconstruct_views(&decompressed, MAX_BUFFER_LEN);
let valid_views = all_views.slice(
slice_value_idx_start - n_skipped_values
..slice_value_idx_stop - n_skipped_values,
);
let mut views = BufferMut::<BinaryView>::zeroed(slice_n_rows);
for (view, index) in valid_views.into_iter().zip_eq(valid_indices) {
views[*index] = view
}
Ok(unsafe {
VarBinViewArray::new_unchecked(
views.freeze(),
Arc::from(buffers),
dtype.clone(),
slice_validity,
)
}
.into_array())
}
}
}
_ => vortex_panic!("Unsupported dtype for Zstd array: {}", dtype),
}
}
#[inline]
pub fn len(&self) -> usize {
self.slice_stop - self.slice_start
}
#[inline]
pub fn is_empty(&self) -> bool {
self.slice_stop == self.slice_start
}
pub fn into_parts(self, validity: Validity) -> ZstdDataParts {
ZstdDataParts {
dictionary: self.dictionary,
frames: self.frames,
metadata: self.metadata,
validity,
n_rows: self.unsliced_n_rows,
slice_start: self.slice_start,
slice_stop: self.slice_stop,
}
}
pub(crate) fn slice_start(&self) -> usize {
self.slice_start
}
pub(crate) fn slice_stop(&self) -> usize {
self.slice_stop
}
pub(crate) fn unsliced_n_rows(&self) -> usize {
self.unsliced_n_rows
}
}
impl ValidityVTable<Zstd> for Zstd {
fn validity(array: ArrayView<'_, Zstd>) -> VortexResult<Validity> {
let unsliced_validity = child_to_validity(&array.slots()[0], array.dtype().nullability());
unsliced_validity.slice(array.slice_start()..array.slice_stop())
}
}
impl OperationsVTable<Zstd> for Zstd {
fn scalar_at(
array: ArrayView<'_, Zstd>,
index: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Scalar> {
let unsliced_validity = child_to_validity(&array.slots()[0], array.dtype().nullability());
let sliced = array.data().with_slice(index, index + 1);
sliced
.decompress(array.dtype(), &unsliced_validity, ctx)?
.execute_scalar(0, ctx)
}
}
#[cfg(test)]
#[expect(clippy::cast_possible_truncation)]
mod tests {
use vortex_buffer::ByteBuffer;
use super::reconstruct_views;
use crate::array::BinaryView;
fn make_interleaved(strings: &[&[u8]]) -> ByteBuffer {
let mut buf = Vec::new();
for s in strings {
let len = s.len() as u32;
buf.extend_from_slice(&len.to_le_bytes());
buf.extend_from_slice(s);
}
ByteBuffer::copy_from(buf.as_slice())
}
#[test]
fn test_reconstruct_views_no_split() {
let strings: &[&[u8]] = &[b"hello", b"world"];
let buf = make_interleaved(strings);
let (buffers, views) = reconstruct_views(&buf, 1024);
assert_eq!(buffers.len(), 1);
assert_eq!(views.len(), 2);
assert_eq!(views[0], BinaryView::make_view(b"hello", 0, 4));
assert_eq!(views[1], BinaryView::make_view(b"world", 0, 13));
}
#[test]
fn test_reconstruct_views_split_across_segments() {
let strings: &[&[u8]] = &[b"aaaaaaaaaaaaa", b"bbbbbbbbbbbbb"];
let buf = make_interleaved(strings);
let (buffers, views) = reconstruct_views(&buf, 20);
assert_eq!(buffers.len(), 2);
assert_eq!(views.len(), 2);
assert_eq!(views[0], BinaryView::make_view(b"aaaaaaaaaaaaa", 0, 4));
assert_eq!(views[1], BinaryView::make_view(b"bbbbbbbbbbbbb", 1, 4));
}
}