use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::BoolArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::scalar::Scalar;
use vortex_array::scalar_fn::fns::binary::CompareKernel;
use vortex_array::scalar_fn::fns::operators::CompareOperator;
use vortex_array::scalar_fn::fns::operators::Operator;
use vortex_array::validity::Validity;
use vortex_buffer::BitBuffer;
use vortex_buffer::ByteBuffer;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use crate::FSST;
use crate::FSSTArray;
impl CompareKernel for FSST {
fn compare(
lhs: &FSSTArray,
rhs: &ArrayRef,
operator: CompareOperator,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
match rhs.as_constant() {
Some(constant) => compare_fsst_constant(lhs, &constant, operator, ctx),
_ => Ok(None),
}
}
}
fn compare_fsst_constant(
left: &FSSTArray,
right: &Scalar,
operator: CompareOperator,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let is_rhs_empty = match right.dtype() {
DType::Binary(_) => right
.as_binary()
.is_empty()
.vortex_expect("RHS should not be null"),
DType::Utf8(_) => right
.as_utf8()
.is_empty()
.vortex_expect("RHS should not be null"),
_ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
};
if is_rhs_empty {
let buffer = match operator {
CompareOperator::Gte => BitBuffer::new_set(left.len()),
CompareOperator::Lt => BitBuffer::new_unset(left.len()),
_ => left
.uncompressed_lengths()
.to_array()
.binary(
ConstantArray::new(
Scalar::zero_value(left.uncompressed_lengths().dtype()),
left.uncompressed_lengths().len(),
)
.into_array(),
operator.into(),
)?
.execute(ctx)?,
};
return Ok(Some(
BoolArray::new(
buffer,
Validity::copy_from_array(&left.clone().into_array())?
.union_nullability(right.dtype().nullability()),
)
.into_array(),
));
}
if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
return Ok(None);
}
let compressor = left.compressor();
let encoded_buffer = match left.dtype() {
DType::Utf8(_) => {
let value = right
.as_utf8()
.value()
.vortex_expect("Expected non-null scalar");
ByteBuffer::from(compressor.compress(value.as_bytes()))
}
DType::Binary(_) => {
let value = right
.as_binary()
.value()
.vortex_expect("Expected non-null scalar");
ByteBuffer::from(compressor.compress(value.as_slice()))
}
_ => unreachable!("FSSTArray can only have string or binary data type"),
};
let encoded_scalar = Scalar::binary(
encoded_buffer,
left.dtype().nullability() | right.dtype().nullability(),
);
let rhs = ConstantArray::new(encoded_scalar, left.len());
left.codes()
.clone()
.into_array()
.binary(rhs.into_array(), Operator::from(operator))
.map(Some)
}
#[cfg(test)]
mod tests {
use vortex_array::DynArray;
use vortex_array::IntoArray;
use vortex_array::ToCanonical;
use vortex_array::arrays::BoolArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::VarBinArray;
use vortex_array::assert_arrays_eq;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::scalar::Scalar;
use vortex_array::scalar_fn::fns::operators::Operator;
use crate::fsst_compress;
use crate::fsst_train_compressor;
#[test]
#[cfg_attr(miri, ignore)]
fn test_compare_fsst() {
let lhs = VarBinArray::from_iter(
[
Some("hello"),
None,
Some("world"),
None,
Some("this is a very long string"),
],
DType::Utf8(Nullability::Nullable),
);
let compressor = fsst_train_compressor(&lhs);
let lhs = fsst_compress(lhs, &compressor);
let rhs = ConstantArray::new("world", lhs.len());
let equals = lhs
.clone()
.into_array()
.binary(rhs.clone().into_array(), Operator::Eq)
.unwrap()
.to_bool();
assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
assert_arrays_eq!(
&equals,
&BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
);
let not_equals = lhs
.clone()
.into_array()
.binary(rhs.into_array(), Operator::NotEq)
.unwrap()
.to_bool();
assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
assert_arrays_eq!(
¬_equals,
&BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
);
let null_rhs =
ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
let equals_null = lhs
.clone()
.into_array()
.binary(null_rhs.clone().into_array(), Operator::Eq)
.unwrap();
assert_arrays_eq!(
&equals_null,
&BoolArray::from_iter([None::<bool>, None, None, None, None])
);
let noteq_null = lhs
.into_array()
.binary(null_rhs.into_array(), Operator::NotEq)
.unwrap();
assert_arrays_eq!(
¬eq_null,
&BoolArray::from_iter([None::<bool>, None, None, None, None])
);
}
}