use std::any::Any;
use std::sync::Arc;
use datafusion::arrow::array::{Array, ArrayRef, Float64Array};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::Result as DfResult;
use datafusion::logical_expr::{
ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use nodedb_query::DEFAULT_RRF_K;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct RrfScore {
signature: Signature,
}
impl RrfScore {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]),
TypeSignature::Exact(vec![
DataType::Float64,
DataType::Float64,
DataType::Float64,
]),
TypeSignature::Variadic(vec![DataType::Float64]),
],
Volatility::Immutable,
),
}
}
}
impl Default for RrfScore {
fn default() -> Self {
Self::new()
}
}
impl ScalarUDFImpl for RrfScore {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"rrf_score"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DfResult<DataType> {
Ok(DataType::Float64)
}
fn invoke_with_args(
&self,
args: datafusion::logical_expr::ScalarFunctionArgs,
) -> DfResult<ColumnarValue> {
let num_rows = args.number_rows;
let arrays: Vec<ArrayRef> = args
.args
.iter()
.map(|a| match a {
ColumnarValue::Array(arr) => Ok(Arc::clone(arr)),
ColumnarValue::Scalar(s) => s.to_array(),
})
.collect::<DfResult<Vec<_>>>()?;
if arrays.is_empty() {
return Ok(ColumnarValue::Array(Arc::new(Float64Array::from(vec![
0.0f64;
num_rows
]))));
}
let len = arrays[0].len();
let mut scores = vec![0.0f64; len];
for arr in &arrays {
let rank_arr = arr.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
datafusion::error::DataFusionError::Internal(
"rrf_score: expected Float64 array".into(),
)
})?;
for (i, score) in scores.iter_mut().enumerate().take(len) {
if !rank_arr.is_null(i) {
let rank = rank_arr.value(i);
*score += 1.0 / (DEFAULT_RRF_K + rank.abs() + 1.0);
}
}
}
Ok(ColumnarValue::Array(Arc::new(Float64Array::from(scores))))
}
}