use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::Nullability;
use vortex_array::scalar::PValue;
use vortex_array::scalar::Scalar;
use vortex_array::scalar_fn::fns::operators::Operator;
use vortex_error::VortexResult;
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::types::vector::Vector;
pub fn build_similarity_search_tree<T: NativePType + Into<PValue>>(
data: ArrayRef,
query: &[T],
threshold: T,
) -> VortexResult<ArrayRef> {
let num_rows = data.len();
let query_vec = Vector::constant_array(query, num_rows)?;
let cosine = CosineSimilarity::try_new_array(data, query_vec)?.into_array();
let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable);
let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array();
cosine.binary(threshold_array, Operator::Gt)
}
#[cfg(test)]
mod tests {
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::BoolArray;
use vortex_array::arrays::bool::BoolArrayExt;
use vortex_error::VortexResult;
use super::build_similarity_search_tree;
use crate::tests::SESSION;
use crate::utils::test_helpers::vector_array;
#[test]
fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> {
let data = vector_array(
3,
&[
1.0f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ],
)?;
let query = [1.0f32, 0.0, 0.0];
let tree = build_similarity_search_tree(data, &query, 0.5)?;
let mut ctx = SESSION.create_execution_ctx();
let result: BoolArray = tree.execute(&mut ctx)?;
let bits = result.to_bit_buffer();
assert_eq!(bits.len(), 4);
assert!(bits.value(0));
assert!(!bits.value(1));
assert!(!bits.value(2));
assert!(bits.value(3));
Ok(())
}
}