use std::hash::Hash;
use std::hash::Hasher;
use std::sync::Arc;
use prost::Message;
use vortex_array::Array;
use vortex_array::ArrayEq;
use vortex_array::ArrayHash;
use vortex_array::ArrayId;
use vortex_array::ArrayParts;
use vortex_array::ArrayRef;
use vortex_array::ArrayView;
use vortex_array::ExecutionCtx;
use vortex_array::ExecutionResult;
use vortex_array::Precision;
use vortex_array::buffer::BufferHandle;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::serde::ArrayChildren;
use vortex_array::validity::Validity;
use vortex_array::vtable::VTable;
use vortex_array::vtable::ValidityVTable;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_ensure_eq;
use vortex_error::vortex_err;
use vortex_error::vortex_panic;
use vortex_session::VortexSession;
use crate::encodings::turboquant::TurboQuantData;
use crate::encodings::turboquant::array::slots::Slot;
use crate::encodings::turboquant::compute::rules::PARENT_KERNELS;
use crate::encodings::turboquant::compute::rules::RULES;
use crate::encodings::turboquant::metadata::TurboQuantMetadata;
use crate::encodings::turboquant::scheme::decompress::execute_decompress;
use crate::vector::AnyVector;
use crate::vector::VectorMatcherMetadata;
#[derive(Clone, Debug)]
pub struct TurboQuant;
impl TurboQuant {
pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant");
pub const MIN_DIMENSION: u32 = 128;
pub const MAX_BIT_WIDTH: u8 = 8;
pub const MAX_CENTROIDS: usize = 1usize << (Self::MAX_BIT_WIDTH as usize);
pub fn validate_dtype(dtype: &DType) -> VortexResult<VectorMatcherMetadata> {
let vector_metadata = dtype
.as_extension_opt()
.and_then(|ext| ext.metadata_opt::<AnyVector>())
.ok_or_else(|| {
vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}")
})?;
let dimensions = vector_metadata.dimensions();
vortex_ensure!(
dimensions >= Self::MIN_DIMENSION,
"TurboQuant requires dimension >= {}, got {dimensions}",
Self::MIN_DIMENSION
);
Ok(vector_metadata)
}
pub fn try_new_array(
dtype: DType,
codes: ArrayRef,
centroids: ArrayRef,
rotation_signs: ArrayRef,
) -> VortexResult<TurboQuantArray> {
TurboQuantData::validate(&dtype, &codes, ¢roids, &rotation_signs)?;
Ok(unsafe { Self::new_array_unchecked(dtype, codes, centroids, rotation_signs) })
}
pub unsafe fn new_array_unchecked(
dtype: DType,
codes: ArrayRef,
centroids: ArrayRef,
rotation_signs: ArrayRef,
) -> TurboQuantArray {
#[cfg(debug_assertions)]
TurboQuantData::validate(&dtype, &codes, ¢roids, &rotation_signs)
.vortex_expect("[DEBUG ASSERTION]: TurboQuantData arrays are invalid");
let len = codes.len();
let dimension = dtype
.as_extension_opt()
.vortex_expect("we validated the dtype")
.metadata_opt::<AnyVector>()
.vortex_expect("we validated that this is a vector")
.dimensions();
let bit_width = if centroids.is_empty() {
0
} else {
#[expect(
clippy::cast_possible_truncation,
reason = "bit_width is guaranteed <= 8"
)]
(centroids.len().trailing_zeros() as u8)
};
#[expect(
clippy::cast_possible_truncation,
reason = "num_rounds fits in u8 by the caller's invariants"
)]
let num_rounds = rotation_signs.len() as u8;
let data = unsafe { TurboQuantData::new_unchecked(dimension, bit_width, num_rounds) };
let parts = ArrayParts::new(TurboQuant, dtype, len, data)
.with_slots(TurboQuantData::make_slots(codes, centroids, rotation_signs));
unsafe { Array::from_parts_unchecked(parts) }
}
}
pub type TurboQuantArray = Array<TurboQuant>;
impl VTable for TurboQuant {
type ArrayData = TurboQuantData;
type OperationsVTable = TurboQuant;
type ValidityVTable = TurboQuant;
fn id(&self) -> ArrayId {
Self::ID
}
fn validate(
&self,
data: &Self::ArrayData,
dtype: &DType,
len: usize,
slots: &[Option<ArrayRef>],
) -> VortexResult<()> {
vortex_ensure_eq!(
slots.len(),
Slot::COUNT,
"TurboQuantArray got incorrect amount of slots",
);
let codes = slots[Slot::Codes as usize]
.as_ref()
.ok_or_else(|| vortex_err!("TurboQuantArray missing codes slot"))?;
let centroids = slots[Slot::Centroids as usize]
.as_ref()
.ok_or_else(|| vortex_err!("TurboQuantArray missing centroids slot"))?;
let rotation_signs = slots[Slot::RotationSigns as usize]
.as_ref()
.ok_or_else(|| vortex_err!("TurboQuantArray missing rotation_signs slot"))?;
vortex_ensure_eq!(
codes.len(),
len,
"TurboQuant codes length does not match outer length",
);
TurboQuantData::validate(dtype, codes, centroids, rotation_signs)?;
vortex_ensure_eq!(data.dimension, Self::validate_dtype(dtype)?.dimensions());
let expected_bit_width = if centroids.is_empty() {
0
} else {
u8::try_from(centroids.len().trailing_zeros())
.map_err(|_| vortex_err!("centroids bit_width does not fit in u8"))?
};
vortex_ensure_eq!(
data.bit_width,
expected_bit_width,
"TurboQuant bit_width does not match centroids slot",
);
let expected_num_rounds = u8::try_from(rotation_signs.len())
.map_err(|_| vortex_err!("rotation_signs num_rounds does not fit in u8"))?;
vortex_ensure_eq!(
data.num_rounds,
expected_num_rounds,
"TurboQuant num_rounds does not match rotation_signs slot",
);
Ok(())
}
fn nbuffers(_array: ArrayView<Self>) -> usize {
0
}
fn buffer(_array: ArrayView<Self>, idx: usize) -> BufferHandle {
vortex_panic!("TurboQuantArray buffer index {idx} out of bounds")
}
fn buffer_name(_array: ArrayView<Self>, _idx: usize) -> Option<String> {
None
}
fn serialize(
array: ArrayView<'_, Self>,
_session: &VortexSession,
) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(
TurboQuantMetadata::new(array.bit_width, array.num_rounds).encode_to_vec(),
))
}
fn deserialize(
&self,
dtype: &DType,
len: usize,
metadata: &[u8],
_buffers: &[BufferHandle],
children: &dyn ArrayChildren,
_session: &VortexSession,
) -> VortexResult<ArrayParts<Self>> {
let metadata = TurboQuantMetadata::decode(metadata)?;
let bit_width = metadata.bit_width()?;
let num_rounds = metadata.num_rounds()?;
vortex_ensure!(
bit_width > 0 || len == 0,
"bit_width == 0 is only valid for empty arrays, got len={len}"
);
vortex_ensure!(
num_rounds > 0 || len == 0,
"num_rounds == 0 is only valid for empty arrays, got len={len}"
);
let vector_metadata = TurboQuant::validate_dtype(dtype)?;
let dimensions = vector_metadata.dimensions();
vortex_ensure!(
!dtype.is_nullable(),
"TurboQuant dtype must be non-nullable during deserialization"
);
let padded_dim = dimensions.next_power_of_two();
let codes_ptype = DType::Primitive(PType::U8, Nullability::NonNullable);
let codes_dtype =
DType::FixedSizeList(Arc::new(codes_ptype), padded_dim, Nullability::NonNullable);
let codes_array = children.get(0, &codes_dtype, len)?;
let num_centroids = if bit_width == 0 {
0 } else {
1usize << bit_width
};
let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
let centroids = children.get(1, ¢roids_dtype, num_centroids)?;
let signs_len = if len == 0 { 0 } else { num_rounds as usize };
let signs_dtype = DType::FixedSizeList(
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
padded_dim,
Nullability::NonNullable,
);
let rotation_signs = children.get(2, &signs_dtype, signs_len)?;
Ok(ArrayParts::new(
TurboQuant,
dtype.clone(),
len,
TurboQuantData {
dimension: dimensions,
bit_width,
num_rounds,
},
)
.with_slots(TurboQuantData::make_slots(
codes_array,
centroids,
rotation_signs,
)))
}
fn slot_name(_array: ArrayView<Self>, idx: usize) -> String {
Slot::from_index(idx).name().to_string()
}
fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
Ok(ExecutionResult::done(execute_decompress(array, ctx)?))
}
fn execute_parent(
array: ArrayView<Self>,
parent: &ArrayRef,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
}
fn reduce_parent(
array: ArrayView<Self>,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
RULES.evaluate(array, parent, child_idx)
}
}
impl ValidityVTable<TurboQuant> for TurboQuant {
fn validity(_array: ArrayView<'_, TurboQuant>) -> VortexResult<Validity> {
Ok(Validity::NonNullable)
}
}
impl ArrayHash for TurboQuantData {
fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
self.dimension.hash(state);
self.bit_width.hash(state);
self.num_rounds.hash(state);
}
}
impl ArrayEq for TurboQuantData {
fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
self.dimension == other.dimension
&& self.bit_width == other.bit_width
&& self.num_rounds == other.num_rounds
}
}