use vortex_buffer::BitBufferMut;
use vortex_buffer::BufferMut;
use vortex_buffer::ByteBufferMut;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_panic;
use vortex_mask::Mask;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::VarBinArray;
use crate::arrays::VarBinVTable;
use crate::arrays::dict::TakeExecute;
use crate::dtype::DType;
use crate::dtype::IntegerPType;
use crate::executor::ExecutionCtx;
use crate::match_each_integer_ptype;
use crate::validity::Validity;
impl TakeExecute for VarBinVTable {
fn take(
array: &VarBinArray,
indices: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let offsets = array.offsets().to_array().execute::<PrimitiveArray>(ctx)?;
let data = array.bytes();
let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
let dtype = array
.dtype()
.clone()
.union_nullability(indices.dtype().nullability());
let array_validity = array.validity_mask()?;
let indices_validity = indices.validity_mask()?;
let array = match_each_integer_ptype!(indices.ptype(), |I| {
match offsets.ptype() {
PType::U8 => take::<I, u8, u32>(
dtype,
offsets.as_slice::<u8>(),
data.as_slice(),
indices.as_slice::<I>(),
array_validity,
indices_validity,
),
PType::U16 => take::<I, u16, u32>(
dtype,
offsets.as_slice::<u16>(),
data.as_slice(),
indices.as_slice::<I>(),
array_validity,
indices_validity,
),
PType::U32 => take::<I, u32, u32>(
dtype,
offsets.as_slice::<u32>(),
data.as_slice(),
indices.as_slice::<I>(),
array_validity,
indices_validity,
),
PType::U64 => take::<I, u64, u64>(
dtype,
offsets.as_slice::<u64>(),
data.as_slice(),
indices.as_slice::<I>(),
array_validity,
indices_validity,
),
PType::I8 => take::<I, i8, i32>(
dtype,
offsets.as_slice::<i8>(),
data.as_slice(),
indices.as_slice::<I>(),
array_validity,
indices_validity,
),
PType::I16 => take::<I, i16, i32>(
dtype,
offsets.as_slice::<i16>(),
data.as_slice(),
indices.as_slice::<I>(),
array_validity,
indices_validity,
),
PType::I32 => take::<I, i32, i32>(
dtype,
offsets.as_slice::<i32>(),
data.as_slice(),
indices.as_slice::<I>(),
array_validity,
indices_validity,
),
PType::I64 => take::<I, i64, i64>(
dtype,
offsets.as_slice::<i64>(),
data.as_slice(),
indices.as_slice::<I>(),
array_validity,
indices_validity,
),
_ => unreachable!("invalid PType for offsets"),
}
});
Ok(Some(array?.into_array()))
}
}
fn take<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
dtype: DType,
offsets: &[Offset],
data: &[u8],
indices: &[Index],
validity_mask: Mask,
indices_validity_mask: Mask,
) -> VortexResult<VarBinArray> {
if !validity_mask.all_true() || !indices_validity_mask.all_true() {
return Ok(take_nullable::<Index, Offset, NewOffset>(
dtype,
offsets,
data,
indices,
validity_mask,
indices_validity_mask,
));
}
let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
new_offsets.push(NewOffset::zero());
let mut current_offset = NewOffset::zero();
for &idx in indices {
let idx = idx
.to_usize()
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
let start = offsets[idx];
let stop = offsets[idx + 1];
current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
new_offsets.push(current_offset);
}
let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
for idx in indices {
let idx = idx
.to_usize()
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
let start = offsets[idx]
.to_usize()
.vortex_expect("Failed to cast max offset to usize");
let stop = offsets[idx + 1]
.to_usize()
.vortex_expect("Failed to cast max offset to usize");
new_data.extend_from_slice(&data[start..stop]);
}
let array_validity = Validity::from(dtype.nullability());
unsafe {
Ok(VarBinArray::new_unchecked(
PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
new_data.freeze(),
dtype,
array_validity,
))
}
}
fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPType>(
dtype: DType,
offsets: &[Offset],
data: &[u8],
indices: &[Index],
data_validity: Mask,
indices_validity: Mask,
) -> VarBinArray {
let mut new_offsets = BufferMut::<NewOffset>::with_capacity(indices.len() + 1);
new_offsets.push(NewOffset::zero());
let mut current_offset = NewOffset::zero();
let mut validity_buffer = BitBufferMut::with_capacity(indices.len());
let mut valid_indices = Vec::with_capacity(indices.len());
for (idx, data_idx) in indices.iter().enumerate() {
if !indices_validity.value(idx) {
validity_buffer.append(false);
new_offsets.push(current_offset);
continue;
}
let data_idx_usize = data_idx
.to_usize()
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
if data_validity.value(data_idx_usize) {
validity_buffer.append(true);
let start = offsets[data_idx_usize];
let stop = offsets[data_idx_usize + 1];
current_offset += NewOffset::from(stop - start).vortex_expect("offset type overflow");
new_offsets.push(current_offset);
valid_indices.push(data_idx_usize);
} else {
validity_buffer.append(false);
new_offsets.push(current_offset);
}
}
let mut new_data = ByteBufferMut::with_capacity(current_offset.as_());
for data_idx in valid_indices {
let start = offsets[data_idx]
.to_usize()
.vortex_expect("Failed to cast max offset to usize");
let stop = offsets[data_idx + 1]
.to_usize()
.vortex_expect("Failed to cast max offset to usize");
new_data.extend_from_slice(&data[start..stop]);
}
let array_validity = Validity::from(validity_buffer.freeze());
unsafe {
VarBinArray::new_unchecked(
PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
new_data.freeze(),
dtype,
array_validity,
)
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_buffer::ByteBuffer;
use vortex_buffer::buffer;
use crate::DynArray;
use crate::IntoArray;
use crate::arrays::VarBinArray;
use crate::arrays::VarBinViewArray;
use crate::arrays::varbin::compute::take::PrimitiveArray;
use crate::assert_arrays_eq;
use crate::compute::conformance::take::test_take_conformance;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::validity::Validity;
#[test]
fn test_null_take() {
let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
let idx1: PrimitiveArray = (0..1).collect();
assert_eq!(
arr.take(idx1.into_array()).unwrap().dtype(),
&DType::Utf8(Nullability::NonNullable)
);
let idx2: PrimitiveArray = PrimitiveArray::from_option_iter(vec![Some(0)]);
assert_eq!(
arr.take(idx2.into_array()).unwrap().dtype(),
&DType::Utf8(Nullability::Nullable)
);
}
#[rstest]
#[case(VarBinArray::from_iter(
["hello", "world", "test", "data", "array"].map(Some),
DType::Utf8(Nullability::NonNullable),
))]
#[case(VarBinArray::from_iter(
[Some("hello"), None, Some("test"), Some("data"), None],
DType::Utf8(Nullability::Nullable),
))]
#[case(VarBinArray::from_iter(
[b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
DType::Binary(Nullability::NonNullable),
))]
#[case(VarBinArray::from_iter(["single"].map(Some), DType::Utf8(Nullability::NonNullable)))]
fn test_take_varbin_conformance(#[case] array: VarBinArray) {
test_take_conformance(&array.into_array());
}
#[test]
fn test_take_overflow() {
let scream = std::iter::once("a").cycle().take(128).collect::<String>();
let bytes = ByteBuffer::copy_from(scream.as_bytes());
let offsets = buffer![0u8, 128u8].into_array();
let array = VarBinArray::new(
offsets,
bytes,
DType::Utf8(Nullability::NonNullable),
Validity::NonNullable,
);
let indices = buffer![0u32; 3].into_array();
let taken = array.take(indices.to_array()).unwrap();
let expected = VarBinViewArray::from_iter(
[Some(scream.clone()), Some(scream.clone()), Some(scream)],
DType::Utf8(Nullability::NonNullable),
);
assert_arrays_eq!(expected, taken);
}
}