use num_traits::Float;
use num_traits::FromPrimitive;
use vortex_array::Array;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::Nullability;
use vortex_array::match_each_float_ptype;
use vortex_array::validity::Validity;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::TurboQuantArrayExt;
use crate::encodings::turboquant::array::rotation::RotationMatrix;
use crate::encodings::turboquant::compute::float_from_f32;
use crate::vector::AnyVector;
pub fn execute_decompress(
array: Array<TurboQuant>,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let dim = array.dimension() as usize;
let padded_dim = array.padded_dim() as usize;
let num_rows = array.len();
let ext_dtype = array.dtype().as_extension().clone();
let element_ptype = ext_dtype.metadata::<AnyVector>().element_ptype();
if num_rows == 0 {
match_each_float_ptype!(element_ptype, |T| {
let elements = PrimitiveArray::empty::<T>(Nullability::NonNullable);
let fsl = FixedSizeListArray::try_new(
elements.into_array(),
array.dimension(),
Validity::NonNullable,
0,
)?;
return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array());
})
}
let centroids_prim = array.centroids().clone().execute::<PrimitiveArray>(ctx)?;
let centroids = centroids_prim.as_slice::<f32>();
let num_rounds = array.num_rounds() as usize;
let signs_fsl = array
.rotation_signs()
.clone()
.execute::<FixedSizeListArray>(ctx)?;
let signs_prim = signs_fsl
.elements()
.clone()
.execute::<PrimitiveArray>(ctx)?;
let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::<u8>(), dim, num_rounds)?;
let codes_fsl = array.codes().clone().execute::<FixedSizeListArray>(ctx)?;
let codes_prim = codes_fsl
.elements()
.clone()
.execute::<PrimitiveArray>(ctx)?;
let indices = codes_prim.as_slice::<u8>();
match_each_float_ptype!(element_ptype, |T| {
decompress_typed::<T>(centroids, &rotation, indices, dim, padded_dim, num_rows).and_then(
|elements| {
let fsl = FixedSizeListArray::try_new(
elements.into_array(),
array.dimension(),
Validity::NonNullable,
num_rows,
)?;
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
},
)
})
}
fn decompress_typed<T: NativePType + Float + FromPrimitive>(
centroids: &[f32],
rotation: &RotationMatrix,
indices: &[u8],
dim: usize,
padded_dim: usize,
num_rows: usize,
) -> VortexResult<PrimitiveArray> {
let mut output = BufferMut::<T>::with_capacity(num_rows * dim);
let mut dequantized = vec![0.0f32; padded_dim];
let mut unrotated = vec![0.0f32; padded_dim];
for row in 0..num_rows {
let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim];
for idx in 0..padded_dim {
dequantized[idx] = centroids[row_indices[idx] as usize];
}
rotation.inverse_rotate(&dequantized, &mut unrotated);
for idx in 0..dim {
output.push(float_from_f32::<T>(unrotated[idx]));
}
}
Ok(PrimitiveArray::new::<T>(
output.freeze(),
Validity::NonNullable,
))
}