use vortex_array::ArrayRef;
use vortex_array::ArrayView;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::Extension;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::dict::DictArray;
use vortex_array::arrays::extension::ExtensionArrayExt;
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::dtype::Nullability;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use crate::encodings::turboquant::MAX_BIT_WIDTH;
use crate::encodings::turboquant::MIN_DIMENSION;
use crate::encodings::turboquant::centroids::compute_centroid_boundaries;
use crate::encodings::turboquant::centroids::compute_or_get_centroids;
use crate::encodings::turboquant::centroids::find_nearest_centroid;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
use crate::scalar_fns::sorf_transform::SorfMatrix;
use crate::scalar_fns::sorf_transform::SorfOptions;
use crate::scalar_fns::sorf_transform::SorfTransform;
use crate::types::vector::AnyVector;
use crate::types::vector::Vector;
use crate::utils::cast_to_f32;
#[derive(Clone, Debug)]
pub struct TurboQuantConfig {
pub bit_width: u8,
pub seed: u64,
pub num_rounds: u8,
}
impl Default for TurboQuantConfig {
fn default() -> Self {
Self {
bit_width: MAX_BIT_WIDTH,
seed: 42,
num_rounds: 3,
}
}
}
pub fn turboquant_encode(
input: ArrayRef,
config: &TurboQuantConfig,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let l2_denorm = normalize_as_l2_denorm(input, ctx)?;
let normalized = l2_denorm.child_at(0).clone();
let norms = l2_denorm.child_at(1).clone();
let num_rows = l2_denorm.len();
let normalized_ext = normalized
.as_opt::<Extension>()
.vortex_expect("normalize_as_l2_denorm always produces an Extension array child");
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?;
Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
}
pub unsafe fn turboquant_encode_unchecked(
ext: ArrayView<Extension>,
config: &TurboQuantConfig,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let ext_dtype = ext.dtype().clone();
let storage = ext.storage_array();
let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
vortex_ensure!(
config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH,
"bit_width must be 1-{MAX_BIT_WIDTH}, got {}",
config.bit_width
);
let dimension = fsl.list_size();
vortex_ensure!(
dimension >= MIN_DIMENSION,
"TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}",
);
let vector_metadata = ext_dtype.as_extension().metadata::<AnyVector>();
let element_ptype = vector_metadata.element_ptype();
let seed = config.seed;
let num_rows = fsl.len();
if fsl.is_empty() {
let padded_dim = dimension.next_power_of_two();
let empty_codes = PrimitiveArray::empty::<u8>(Nullability::NonNullable);
let empty_centroids = PrimitiveArray::empty::<f32>(Nullability::NonNullable);
let empty_dict =
DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?;
let empty_fsl = FixedSizeListArray::try_new(
empty_dict.into_array(),
padded_dim,
Validity::NonNullable,
0,
)?;
let empty_padded_vector = Vector::try_new_vector_array(empty_fsl.into_array())?;
let sorf_options = SorfOptions {
seed,
num_rounds: config.num_rounds,
dimensions: dimension,
element_ptype,
};
return Ok(
SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(),
);
}
let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
let quantized_fsl =
build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?;
let padded_vector = Vector::try_new_vector_array(quantized_fsl)?;
let sorf_options = SorfOptions {
seed,
num_rounds: config.num_rounds,
dimensions: dimension,
element_ptype,
};
Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array())
}
struct QuantizationResult {
centroids: Buffer<f32>,
all_indices: Buffer<u8>,
padded_dim: usize,
}
fn turboquant_quantize_core(
fsl: &FixedSizeListArray,
seed: u64,
bit_width: u8,
num_rounds: u8,
ctx: &mut ExecutionCtx,
) -> VortexResult<QuantizationResult> {
let dimension = fsl.list_size() as usize;
let num_rows = fsl.len();
let padded_dim = dimension.next_power_of_two();
let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?;
let padded_dim_u32 =
u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
let f32_elements = cast_to_f32(elements_prim)?;
let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?;
let boundaries = compute_centroid_boundaries(¢roids);
let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
let mut padded = vec![0.0f32; padded_dim];
let mut rotated = vec![0.0f32; padded_dim];
let f32_slice = f32_elements.as_slice();
for row in 0..num_rows {
let x = &f32_slice[row * dimension..(row + 1) * dimension];
padded[..dimension].copy_from_slice(x);
padded[dimension..].fill(0.0);
rotation.rotate(&padded, &mut rotated);
for j in 0..padded_dim {
all_indices.push(find_nearest_centroid(rotated[j], &boundaries));
}
}
Ok(QuantizationResult {
centroids,
all_indices: all_indices.freeze(),
padded_dim,
})
}
fn build_quantized_fsl(
num_rows: usize,
all_indices: Buffer<u8>,
centroids: Buffer<f32>,
padded_dim: usize,
) -> VortexResult<ArrayRef> {
let codes = PrimitiveArray::new::<u8>(all_indices, Validity::NonNullable);
let centroids_array = PrimitiveArray::new::<f32>(centroids, Validity::NonNullable);
let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?;
let padded_dim_u32 =
u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
Ok(FixedSizeListArray::try_new(
dict.into_array(),
padded_dim_u32,
Validity::NonNullable,
num_rows,
)?
.into_array())
}