use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::Extension;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_array::scalar::PValue;
use vortex_array::scalar::Scalar;
use vortex_array::scalar_fn::fns::operators::Operator;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use crate::encodings::turboquant::TurboQuantConfig;
use crate::encodings::turboquant::turboquant_encode_unchecked;
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
use crate::vector::Vector;
pub fn compress_turboquant(data: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
let l2_denorm = normalize_as_l2_denorm(data, ctx)?;
let normalized = l2_denorm.child_at(0).clone();
let norms = l2_denorm.child_at(1).clone();
let num_rows = l2_denorm.len();
let Some(normalized_ext) = normalized.as_opt::<Extension>() else {
vortex_bail!("normalize_as_l2_denorm must produce an Extension array child");
};
let config = TurboQuantConfig::default();
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, ctx) }?;
Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
}
pub fn build_constant_query_vector<T: NativePType + Into<PValue>>(
query: &[T],
num_rows: usize,
) -> VortexResult<ArrayRef> {
let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable);
let children: Vec<Scalar> = query
.iter()
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
let storage_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
let storage = ConstantArray::new(storage_scalar, num_rows).into_array();
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}
pub fn build_similarity_search_tree<T: NativePType + Into<PValue>>(
data: ArrayRef,
query: &[T],
threshold: T,
) -> VortexResult<ArrayRef> {
let num_rows = data.len();
let query_vec = build_constant_query_vector(query, num_rows)?;
let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array();
let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable);
let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array();
cosine.binary(threshold_array, Operator::Gt)
}
#[cfg(test)]
mod tests {
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::BoolArray;
use vortex_array::arrays::Extension;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::bool::BoolArrayExt;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_array::session::ArraySession;
use vortex_array::validity::Validity;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use vortex_session::VortexSession;
use super::build_constant_query_vector;
use super::build_similarity_search_tree;
use super::compress_turboquant;
use crate::vector::Vector;
fn vector_array(dim: u32, values: &[f32]) -> VortexResult<ArrayRef> {
let dim_usize = dim as usize;
assert_eq!(values.len() % dim_usize, 0);
let num_rows = values.len() / dim_usize;
let mut buf = BufferMut::<f32>::with_capacity(values.len());
for &v in values {
buf.push(v);
}
let elements = PrimitiveArray::new::<f32>(buf.freeze(), Validity::NonNullable);
let fsl = FixedSizeListArray::try_new(
elements.into_array(),
dim,
Validity::NonNullable,
num_rows,
)?;
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
}
fn test_session() -> VortexSession {
VortexSession::empty().with::<ArraySession>()
}
#[test]
fn constant_query_vector_has_vector_extension_dtype() -> VortexResult<()> {
let query = vec![1.0f32, 0.0, 0.0, 0.0];
let rhs = build_constant_query_vector(&query, 5)?;
assert_eq!(rhs.len(), 5);
assert!(rhs.as_opt::<Extension>().is_some());
Ok(())
}
#[test]
fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> {
let data = vector_array(
3,
&[
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ],
)?;
let query = [1.0f32, 0.0, 0.0];
let tree = build_similarity_search_tree(data, &query, 0.5)?;
let mut ctx = test_session().create_execution_ctx();
let result: BoolArray = tree.execute(&mut ctx)?;
let bits = result.to_bit_buffer();
assert_eq!(bits.len(), 4);
assert!(bits.value(0));
assert!(!bits.value(1));
assert!(!bits.value(2));
assert!(bits.value(3));
Ok(())
}
#[test]
fn turboquant_roundtrip_preserves_ranking() -> VortexResult<()> {
const DIM: u32 = 128;
const NUM_ROWS: usize = 6;
let mut values = Vec::<f32>::with_capacity(NUM_ROWS * DIM as usize);
let query: Vec<f32> = (0..DIM as usize)
.map(|i| ((i as f32) * 0.017).sin())
.collect();
values.extend_from_slice(&query);
for (i, q) in query.iter().enumerate() {
values.push(q + 0.05 * ((i as f32) * 0.03).cos());
}
for row in 2..NUM_ROWS {
for i in 0..DIM as usize {
values.push(((row as f32 * 1.3 + i as f32) * 0.07).sin());
}
}
let data = vector_array(DIM, &values)?;
let mut ctx = test_session().create_execution_ctx();
let compressed = compress_turboquant(data, &mut ctx)?;
assert_eq!(compressed.len(), NUM_ROWS);
let tree = build_similarity_search_tree(compressed, &query, 0.95)?;
let result: BoolArray = tree.execute(&mut ctx)?;
let bits = result.to_bit_buffer();
assert_eq!(bits.len(), NUM_ROWS);
assert!(
bits.value(0),
"row 0 (identical to query) must match at threshold 0.95 even after TurboQuant"
);
Ok(())
}
}