use std::ops::Add;
use num_traits::CheckedAdd;
use num_traits::CheckedSub;
use vortex_array::ArrayRef;
use vortex_array::ArrayView;
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::Primitive;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::Nullability;
use vortex_array::match_each_integer_ptype;
use vortex_array::match_each_native_ptype;
use vortex_array::scalar::PValue;
use vortex_array::validity::Validity;
use vortex_buffer::BufferMut;
use vortex_buffer::trusted_len::TrustedLen;
use vortex_error::VortexResult;
use crate::Sequence;
use crate::SequenceArray;
use crate::SequenceData;
struct SequenceIter<T> {
acc: T,
step: T,
remaining: usize,
}
impl<T: Copy + Add<Output = T>> Iterator for SequenceIter<T> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
if self.remaining == 0 {
return None;
}
let val = self.acc;
self.remaining -= 1;
if self.remaining > 0 {
self.acc = self.acc + self.step;
}
Some(val)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
unsafe impl<T: Copy + Add<Output = T>> TrustedLen for SequenceIter<T> {}
#[inline]
pub fn sequence_decompress(array: &SequenceArray) -> VortexResult<ArrayRef> {
fn decompress_inner<P: NativePType>(
base: P,
multiplier: P,
len: usize,
nullability: Nullability,
) -> PrimitiveArray {
let values = BufferMut::from_trusted_len_iter(SequenceIter {
acc: base,
step: multiplier,
remaining: len,
});
PrimitiveArray::new(values, Validity::from(nullability))
}
let prim = match_each_native_ptype!(array.ptype(), |P| {
let base = array.base().cast::<P>()?;
let multiplier = array.multiplier().cast::<P>()?;
decompress_inner(base, multiplier, array.len(), array.dtype().nullability())
});
Ok(prim.into_array())
}
pub fn sequence_encode(
primitive_array: ArrayView<'_, Primitive>,
) -> VortexResult<Option<ArrayRef>> {
if primitive_array.is_empty() {
return Ok(None);
}
if !primitive_array
.array()
.all_valid(&mut LEGACY_SESSION.create_execution_ctx())?
{
return Ok(None);
}
if primitive_array.ptype().is_float() {
return Ok(None);
}
match_each_integer_ptype!(primitive_array.ptype(), |P| {
encode_primitive_array(
primitive_array.as_slice::<P>(),
primitive_array.dtype().nullability(),
)
})
}
fn encode_primitive_array<P: NativePType + Into<PValue> + CheckedAdd + CheckedSub>(
slice: &[P],
nullability: Nullability,
) -> VortexResult<Option<ArrayRef>> {
if slice.len() == 1 {
return Sequence::try_new_typed(slice[0], P::zero(), nullability, 1)
.map(|a| Some(a.into_array()));
}
let base = slice[0];
let Some(multiplier) = slice[1].checked_sub(&base) else {
return Ok(None);
};
if multiplier == P::zero() {
return Ok(None);
}
if SequenceData::try_last(base.into(), multiplier.into(), P::PTYPE, slice.len()).is_err() {
return Ok(None);
}
slice
.windows(2)
.all(|w| Some(w[1]) == w[0].checked_add(&multiplier))
.then_some(
Sequence::try_new_typed(base, multiplier, nullability, slice.len())
.map(|a| a.into_array()),
)
.transpose()
}
#[cfg(test)]
mod tests {
#[expect(unused_imports)]
use itertools::Itertools;
use vortex_array::ToCanonical;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::assert_arrays_eq;
use crate::sequence_encode;
#[test]
fn test_encode_array_success() {
let primitive_array = PrimitiveArray::from_iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
let encoded = sequence_encode(primitive_array.as_view()).unwrap();
assert!(encoded.is_some());
let decoded = encoded.unwrap().to_primitive();
assert_arrays_eq!(decoded, primitive_array);
}
#[test]
fn test_encode_array_1_success() {
let primitive_array = PrimitiveArray::from_iter([0]);
let encoded = sequence_encode(primitive_array.as_view()).unwrap();
assert!(encoded.is_some());
let decoded = encoded.unwrap().to_primitive();
assert_arrays_eq!(decoded, primitive_array);
}
#[test]
fn test_encode_array_fail() {
let primitive_array = PrimitiveArray::from_iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0]);
let encoded = sequence_encode(primitive_array.as_view()).unwrap();
assert!(encoded.is_none());
}
#[test]
fn test_encode_array_fail_oob() {
let primitive_array = PrimitiveArray::from_iter(vec![100i8; 1000]);
let encoded = sequence_encode(primitive_array.as_view()).unwrap();
assert!(encoded.is_none());
}
#[test]
fn test_encode_all_u8_values() {
let primitive_array = PrimitiveArray::from_iter(0u8..=255);
let encoded = sequence_encode(primitive_array.as_view()).unwrap();
assert!(encoded.is_some());
let decoded = encoded.unwrap().to_primitive();
assert_arrays_eq!(decoded, primitive_array);
}
}