vortex-tensor 0.68.0

Vortex tensor extension type
Documentation
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! TurboQuant decoding (dequantization) logic.
//!
//! Decompression produces unit-norm vectors. The original magnitudes are restored externally
//! by the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm) ScalarFnArray wrapper.

use num_traits::Float;
use num_traits::FromPrimitive;
use vortex_array::Array;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::Nullability;
use vortex_array::match_each_float_ptype;
use vortex_array::validity::Validity;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;

use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::TurboQuantArrayExt;
use crate::encodings::turboquant::array::rotation::RotationMatrix;
use crate::encodings::turboquant::compute::float_from_f32;
use crate::vector::AnyVector;

/// Decompress a `TurboQuantArray` into a unit-norm [`Vector`] extension array.
///
/// The returned array is an [`ExtensionArray`] with the (non-nullable) Vector dtype wrapping a
/// `FixedSizeListArray` of the original vector element type. Each vector has unit L2 norm; the
/// original magnitudes are restored by the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)
/// ScalarFnArray wrapper.
///
/// [`Vector`]: crate::vector::Vector
pub fn execute_decompress(
    array: Array<TurboQuant>,
    ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
    let dim = array.dimension() as usize;
    let padded_dim = array.padded_dim() as usize;
    let num_rows = array.len();
    let ext_dtype = array.dtype().as_extension().clone();
    let element_ptype = ext_dtype.metadata::<AnyVector>().element_ptype();

    if num_rows == 0 {
        match_each_float_ptype!(element_ptype, |T| {
            let elements = PrimitiveArray::empty::<T>(Nullability::NonNullable);
            let fsl = FixedSizeListArray::try_new(
                elements.into_array(),
                array.dimension(),
                Validity::NonNullable,
                0,
            )?;

            return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array());
        })
    }

    // Read stored centroids (always f32).
    let centroids_prim = array.centroids().clone().execute::<PrimitiveArray>(ctx)?;
    let centroids = centroids_prim.as_slice::<f32>();

    // The rotation signs are stored as a FixedSizeListArray wrapping bitpacked u8 values.
    // We unwrap to the flat elements, then FastLanes SIMD-unpacks the 1-bit values into u8 0/1.
    // These are expanded to u32 XOR masks once (amortized over all rows), enabling branchless
    // XOR-based sign application in the per-row structured-rotation hot loop.
    let num_rounds = array.num_rounds() as usize;
    let signs_fsl = array
        .rotation_signs()
        .clone()
        .execute::<FixedSizeListArray>(ctx)?;
    let signs_prim = signs_fsl
        .elements()
        .clone()
        .execute::<PrimitiveArray>(ctx)?;
    let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::<u8>(), dim, num_rounds)?;

    // Unpack codes from FixedSizeListArray -> flat u8 elements.
    let codes_fsl = array.codes().clone().execute::<FixedSizeListArray>(ctx)?;
    let codes_prim = codes_fsl
        .elements()
        .clone()
        .execute::<PrimitiveArray>(ctx)?;
    let indices = codes_prim.as_slice::<u8>();

    // MSE decode: dequantize (f32) -> inverse rotate (f32) -> cast to T.
    // The rotation and centroid lookup always happen in f32. The final output is cast to the
    // Vector's element type to match the original storage dtype. No norm scaling is applied here;
    // that is handled by the external L2Denorm wrapper.
    match_each_float_ptype!(element_ptype, |T| {
        decompress_typed::<T>(centroids, &rotation, indices, dim, padded_dim, num_rows).and_then(
            |elements| {
                let fsl = FixedSizeListArray::try_new(
                    elements.into_array(),
                    array.dimension(),
                    Validity::NonNullable,
                    num_rows,
                )?;
                Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
            },
        )
    })
}

/// Typed decompress: dequantizes in f32 and produces unit-norm output as `T`.
fn decompress_typed<T: NativePType + Float + FromPrimitive>(
    centroids: &[f32],
    rotation: &RotationMatrix,
    indices: &[u8],
    dim: usize,
    padded_dim: usize,
    num_rows: usize,
) -> VortexResult<PrimitiveArray> {
    let mut output = BufferMut::<T>::with_capacity(num_rows * dim);
    let mut dequantized = vec![0.0f32; padded_dim];
    let mut unrotated = vec![0.0f32; padded_dim];

    for row in 0..num_rows {
        let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim];

        for idx in 0..padded_dim {
            dequantized[idx] = centroids[row_indices[idx] as usize];
        }

        rotation.inverse_rotate(&dequantized, &mut unrotated);

        for idx in 0..dim {
            output.push(float_from_f32::<T>(unrotated[idx]));
        }
    }

    Ok(PrimitiveArray::new::<T>(
        output.freeze(),
        Validity::NonNullable,
    ))
}