use crate::ast::DistanceMetric;
use crate::datafusion_planner::vector_ops;
use arrow::array::ArrayRef;
use arrow::datatypes::DataType;
use datafusion::logical_expr::{ScalarUDF, Signature, Volatility};
use datafusion::physical_plan::ColumnarValue;
use std::sync::{Arc, LazyLock};
type UdfFunc =
Arc<dyn Fn(&[ColumnarValue]) -> datafusion::error::Result<ColumnarValue> + Send + Sync>;
fn vector_distance_func(
args: &[ColumnarValue],
metric: &DistanceMetric,
) -> datafusion::error::Result<ColumnarValue> {
if args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"vector_distance requires exactly 2 arguments".to_string(),
));
}
match (&args[0], &args[1]) {
(ColumnarValue::Array(left_arr), ColumnarValue::Array(right_arr)) => {
let left_vectors = vector_ops::extract_vectors(left_arr)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let right_vectors = vector_ops::extract_vectors(right_arr)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let distances: Vec<f32> = if right_vectors.len() == 1 {
vector_ops::compute_vector_distances(&left_vectors, &right_vectors[0], metric)
} else if left_vectors.len() == 1 {
vector_ops::compute_vector_distances(&right_vectors, &left_vectors[0], metric)
} else if left_vectors.len() == right_vectors.len() {
left_vectors
.iter()
.zip(right_vectors.iter())
.map(|(l, r)| match metric {
DistanceMetric::L2 => vector_ops::l2_distance(l, r),
DistanceMetric::Cosine => vector_ops::cosine_distance(l, r),
DistanceMetric::Dot => vector_ops::dot_product_distance(l, r),
})
.collect()
} else {
return Err(datafusion::error::DataFusionError::Execution(format!(
"Vector count mismatch: left has {} vectors, right has {}",
left_vectors.len(),
right_vectors.len()
)));
};
let result = Arc::new(arrow::array::Float32Array::from(distances)) as ArrayRef;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Array(left_arr), ColumnarValue::Scalar(right_scalar)) => {
let left_vectors = vector_ops::extract_vectors(left_arr)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let query_vector = vector_ops::extract_single_vector_from_scalar(right_scalar)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let distances =
vector_ops::compute_vector_distances(&left_vectors, &query_vector, metric);
let result = Arc::new(arrow::array::Float32Array::from(distances)) as ArrayRef;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Scalar(left_scalar), ColumnarValue::Array(right_arr)) => {
let right_vectors = vector_ops::extract_vectors(right_arr)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let query_vector = vector_ops::extract_single_vector_from_scalar(left_scalar)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let distances =
vector_ops::compute_vector_distances(&right_vectors, &query_vector, metric);
let result = Arc::new(arrow::array::Float32Array::from(distances)) as ArrayRef;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Scalar(left_scalar), ColumnarValue::Scalar(right_scalar)) => {
let left_vec = vector_ops::extract_single_vector_from_scalar(left_scalar)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let right_vec = vector_ops::extract_single_vector_from_scalar(right_scalar)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let distance = match metric {
DistanceMetric::L2 => vector_ops::l2_distance(&left_vec, &right_vec),
DistanceMetric::Cosine => vector_ops::cosine_distance(&left_vec, &right_vec),
DistanceMetric::Dot => vector_ops::dot_product_distance(&left_vec, &right_vec),
};
Ok(ColumnarValue::Scalar(
datafusion::scalar::ScalarValue::Float32(Some(distance)),
))
}
}
}
static VECTOR_DISTANCE_L2_UDF: LazyLock<Arc<ScalarUDF>> = LazyLock::new(|| {
let func = move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
vector_distance_func(args, &DistanceMetric::L2)
};
Arc::new(ScalarUDF::new_from_impl(VectorDistanceUDF {
name: "vector_distance_l2".to_string(),
func: Arc::new(func),
metric: DistanceMetric::L2,
signature: Signature::any(2, Volatility::Immutable),
}))
});
static VECTOR_DISTANCE_COSINE_UDF: LazyLock<Arc<ScalarUDF>> = LazyLock::new(|| {
let func = move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
vector_distance_func(args, &DistanceMetric::Cosine)
};
Arc::new(ScalarUDF::new_from_impl(VectorDistanceUDF {
name: "vector_distance_cosine".to_string(),
func: Arc::new(func),
metric: DistanceMetric::Cosine,
signature: Signature::any(2, Volatility::Immutable),
}))
});
static VECTOR_DISTANCE_DOT_UDF: LazyLock<Arc<ScalarUDF>> = LazyLock::new(|| {
let func = move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
vector_distance_func(args, &DistanceMetric::Dot)
};
Arc::new(ScalarUDF::new_from_impl(VectorDistanceUDF {
name: "vector_distance_dot".to_string(),
func: Arc::new(func),
metric: DistanceMetric::Dot,
signature: Signature::any(2, Volatility::Immutable),
}))
});
pub(crate) fn create_vector_distance_udf(metric: &DistanceMetric) -> Arc<ScalarUDF> {
match metric {
DistanceMetric::L2 => VECTOR_DISTANCE_L2_UDF.clone(),
DistanceMetric::Cosine => VECTOR_DISTANCE_COSINE_UDF.clone(),
DistanceMetric::Dot => VECTOR_DISTANCE_DOT_UDF.clone(),
}
}
fn vector_similarity_func(
args: &[ColumnarValue],
metric: &DistanceMetric,
) -> datafusion::error::Result<ColumnarValue> {
if args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"vector_similarity requires exactly 2 arguments".to_string(),
));
}
let compute_single_similarity = |l: &[f32], r: &[f32]| match metric {
DistanceMetric::L2 => {
let dist = vector_ops::l2_distance(l, r);
if dist == 0.0 {
1.0
} else {
1.0 / (1.0 + dist)
}
}
DistanceMetric::Cosine => vector_ops::cosine_similarity(l, r),
DistanceMetric::Dot => vector_ops::dot_product_similarity(l, r),
};
match (&args[0], &args[1]) {
(ColumnarValue::Array(left_arr), ColumnarValue::Array(right_arr)) => {
let left_vectors = vector_ops::extract_vectors(left_arr)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let right_vectors = vector_ops::extract_vectors(right_arr)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let similarities: Vec<f32> = if right_vectors.len() == 1 {
vector_ops::compute_vector_similarities(&left_vectors, &right_vectors[0], metric)
} else if left_vectors.len() == 1 {
vector_ops::compute_vector_similarities(&right_vectors, &left_vectors[0], metric)
} else if left_vectors.len() == right_vectors.len() {
left_vectors
.iter()
.zip(right_vectors.iter())
.map(|(l, r)| compute_single_similarity(l, r))
.collect()
} else {
return Err(datafusion::error::DataFusionError::Execution(format!(
"Vector count mismatch: left has {} vectors, right has {}",
left_vectors.len(),
right_vectors.len()
)));
};
let result = Arc::new(arrow::array::Float32Array::from(similarities)) as ArrayRef;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Array(left_arr), ColumnarValue::Scalar(right_scalar)) => {
let left_vectors = vector_ops::extract_vectors(left_arr)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let query_vector = vector_ops::extract_single_vector_from_scalar(right_scalar)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let similarities =
vector_ops::compute_vector_similarities(&left_vectors, &query_vector, metric);
let result = Arc::new(arrow::array::Float32Array::from(similarities)) as ArrayRef;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Scalar(left_scalar), ColumnarValue::Array(right_arr)) => {
let right_vectors = vector_ops::extract_vectors(right_arr)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let query_vector = vector_ops::extract_single_vector_from_scalar(left_scalar)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let similarities =
vector_ops::compute_vector_similarities(&right_vectors, &query_vector, metric);
let result = Arc::new(arrow::array::Float32Array::from(similarities)) as ArrayRef;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Scalar(left_scalar), ColumnarValue::Scalar(right_scalar)) => {
let left_vec = vector_ops::extract_single_vector_from_scalar(left_scalar)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let right_vec = vector_ops::extract_single_vector_from_scalar(right_scalar)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let similarity = compute_single_similarity(&left_vec, &right_vec);
Ok(ColumnarValue::Scalar(
datafusion::scalar::ScalarValue::Float32(Some(similarity)),
))
}
}
}
static VECTOR_SIMILARITY_L2_UDF: LazyLock<Arc<ScalarUDF>> = LazyLock::new(|| {
let func = move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
vector_similarity_func(args, &DistanceMetric::L2)
};
Arc::new(ScalarUDF::new_from_impl(VectorSimilarityUDF {
name: "vector_similarity_l2".to_string(),
func: Arc::new(func),
metric: DistanceMetric::L2,
signature: Signature::any(2, Volatility::Immutable),
}))
});
static VECTOR_SIMILARITY_COSINE_UDF: LazyLock<Arc<ScalarUDF>> = LazyLock::new(|| {
let func = move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
vector_similarity_func(args, &DistanceMetric::Cosine)
};
Arc::new(ScalarUDF::new_from_impl(VectorSimilarityUDF {
name: "vector_similarity_cosine".to_string(),
func: Arc::new(func),
metric: DistanceMetric::Cosine,
signature: Signature::any(2, Volatility::Immutable),
}))
});
static VECTOR_SIMILARITY_DOT_UDF: LazyLock<Arc<ScalarUDF>> = LazyLock::new(|| {
let func = move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
vector_similarity_func(args, &DistanceMetric::Dot)
};
Arc::new(ScalarUDF::new_from_impl(VectorSimilarityUDF {
name: "vector_similarity_dot".to_string(),
func: Arc::new(func),
metric: DistanceMetric::Dot,
signature: Signature::any(2, Volatility::Immutable),
}))
});
pub(crate) fn create_vector_similarity_udf(metric: &DistanceMetric) -> Arc<ScalarUDF> {
match metric {
DistanceMetric::L2 => VECTOR_SIMILARITY_L2_UDF.clone(),
DistanceMetric::Cosine => VECTOR_SIMILARITY_COSINE_UDF.clone(),
DistanceMetric::Dot => VECTOR_SIMILARITY_DOT_UDF.clone(),
}
}
struct VectorDistanceUDF {
name: String,
func: UdfFunc,
metric: DistanceMetric,
signature: Signature,
}
impl std::fmt::Debug for VectorDistanceUDF {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VectorDistanceUDF")
.field("name", &self.name)
.field("metric", &self.metric)
.finish()
}
}
impl datafusion::logical_expr::ScalarUDFImpl for VectorDistanceUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
Ok(DataType::Float32)
}
fn invoke_with_args(
&self,
args: datafusion::logical_expr::ScalarFunctionArgs,
) -> datafusion::error::Result<ColumnarValue> {
(self.func)(&args.args)
}
}
impl PartialEq for VectorDistanceUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.metric == other.metric
}
}
impl Eq for VectorDistanceUDF {}
impl std::hash::Hash for VectorDistanceUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
struct VectorSimilarityUDF {
name: String,
func: UdfFunc,
metric: DistanceMetric,
signature: Signature,
}
impl std::fmt::Debug for VectorSimilarityUDF {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VectorSimilarityUDF")
.field("name", &self.name)
.field("metric", &self.metric)
.finish()
}
}
impl datafusion::logical_expr::ScalarUDFImpl for VectorSimilarityUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
Ok(DataType::Float32)
}
fn invoke_with_args(
&self,
args: datafusion::logical_expr::ScalarFunctionArgs,
) -> datafusion::error::Result<ColumnarValue> {
(self.func)(&args.args)
}
}
impl PartialEq for VectorSimilarityUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.metric == other.metric
}
}
impl Eq for VectorSimilarityUDF {}
impl std::hash::Hash for VectorSimilarityUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Array, FixedSizeListArray, Float32Array};
use arrow::datatypes::{DataType, Field};
use std::sync::Arc;
fn create_vector_array(vectors: Vec<Vec<f32>>) -> ArrayRef {
let dim = vectors[0].len() as i32;
let mut values = Vec::new();
for vec in vectors {
values.extend(vec);
}
let field = Arc::new(Field::new("item", DataType::Float32, true));
let value_array = Arc::new(Float32Array::from(values));
Arc::new(FixedSizeListArray::try_new(field, dim, value_array, None).unwrap())
}
#[test]
fn test_vector_distance_l2_udf() {
let left = create_vector_array(vec![vec![1.0, 0.0, 0.0]]);
let right = create_vector_array(vec![vec![0.0, 1.0, 0.0]]);
let args = vec![ColumnarValue::Array(left), ColumnarValue::Array(right)];
let result = super::vector_distance_func(&args, &DistanceMetric::L2).unwrap();
if let ColumnarValue::Array(arr) = result {
let float_arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(float_arr.len(), 1);
assert!((float_arr.value(0) - 1.414).abs() < 0.01);
} else {
panic!("Expected array result");
}
}
#[test]
fn test_vector_distance_cosine_udf() {
let left = create_vector_array(vec![vec![1.0, 0.0, 0.0]]);
let right = create_vector_array(vec![vec![1.0, 0.0, 0.0]]);
let args = vec![ColumnarValue::Array(left), ColumnarValue::Array(right)];
let result = super::vector_distance_func(&args, &DistanceMetric::Cosine).unwrap();
if let ColumnarValue::Array(arr) = result {
let float_arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(float_arr.len(), 1);
assert_eq!(float_arr.value(0), 0.0);
} else {
panic!("Expected array result");
}
}
#[test]
fn test_vector_similarity_cosine_udf() {
let left = create_vector_array(vec![vec![1.0, 0.0, 0.0]]);
let right = create_vector_array(vec![vec![1.0, 0.0, 0.0]]);
let args = vec![ColumnarValue::Array(left), ColumnarValue::Array(right)];
let result = super::vector_similarity_func(&args, &DistanceMetric::Cosine).unwrap();
if let ColumnarValue::Array(arr) = result {
let float_arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(float_arr.len(), 1);
assert_eq!(float_arr.value(0), 1.0);
} else {
panic!("Expected array result");
}
}
#[test]
fn test_vector_distance_broadcast() {
let left = create_vector_array(vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
]);
let right = create_vector_array(vec![vec![1.0, 0.0, 0.0]]);
let args = vec![ColumnarValue::Array(left), ColumnarValue::Array(right)];
let result = super::vector_distance_func(&args, &DistanceMetric::L2).unwrap();
if let ColumnarValue::Array(arr) = result {
let float_arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(float_arr.len(), 3);
assert_eq!(float_arr.value(0), 0.0); assert!((float_arr.value(1) - 1.414).abs() < 0.01); assert!((float_arr.value(2) - 1.414).abs() < 0.01); } else {
panic!("Expected array result");
}
}
#[test]
fn test_vector_distance_wrong_arg_count() {
let left = create_vector_array(vec![vec![1.0, 0.0, 0.0]]);
let args = vec![ColumnarValue::Array(left)];
let result = super::vector_distance_func(&args, &DistanceMetric::L2);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("requires exactly 2 arguments"));
}
fn create_vector_scalar(vec: Vec<f32>) -> datafusion::scalar::ScalarValue {
use datafusion::scalar::ScalarValue;
let dim = vec.len() as i32;
let field = Arc::new(Field::new("item", DataType::Float32, true));
let values = Arc::new(Float32Array::from(vec));
let list_array = FixedSizeListArray::try_new(field, dim, values, None).unwrap();
ScalarValue::try_from_array(&list_array, 0).unwrap()
}
#[test]
fn test_vector_distance_array_vs_scalar() {
let left = create_vector_array(vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
]);
let right_scalar = create_vector_scalar(vec![1.0, 0.0, 0.0]);
let args = vec![
ColumnarValue::Array(left),
ColumnarValue::Scalar(right_scalar),
];
let result = super::vector_distance_func(&args, &DistanceMetric::L2).unwrap();
if let ColumnarValue::Array(arr) = result {
let float_arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(float_arr.len(), 3);
assert_eq!(float_arr.value(0), 0.0); assert!((float_arr.value(1) - 1.414).abs() < 0.01); assert!((float_arr.value(2) - 1.414).abs() < 0.01); } else {
panic!("Expected array result");
}
}
#[test]
fn test_vector_distance_scalar_vs_array() {
let left_scalar = create_vector_scalar(vec![1.0, 0.0, 0.0]);
let right = create_vector_array(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]]);
let args = vec![
ColumnarValue::Scalar(left_scalar),
ColumnarValue::Array(right),
];
let result = super::vector_distance_func(&args, &DistanceMetric::Cosine).unwrap();
if let ColumnarValue::Array(arr) = result {
let float_arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(float_arr.len(), 2);
assert_eq!(float_arr.value(0), 0.0); assert_eq!(float_arr.value(1), 1.0); } else {
panic!("Expected array result");
}
}
#[test]
fn test_vector_distance_scalar_vs_scalar() {
let left_scalar = create_vector_scalar(vec![1.0, 0.0, 0.0]);
let right_scalar = create_vector_scalar(vec![0.0, 1.0, 0.0]);
let args = vec![
ColumnarValue::Scalar(left_scalar),
ColumnarValue::Scalar(right_scalar),
];
let result = super::vector_distance_func(&args, &DistanceMetric::L2).unwrap();
if let ColumnarValue::Scalar(scalar) = result {
if let datafusion::scalar::ScalarValue::Float32(Some(dist)) = scalar {
assert!((dist - 1.414).abs() < 0.01); } else {
panic!("Expected Float32 scalar");
}
} else {
panic!("Expected scalar result for scalar vs scalar");
}
}
#[test]
fn test_vector_similarity_array_vs_scalar() {
let left = create_vector_array(vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.707, 0.707, 0.0], ]);
let right_scalar = create_vector_scalar(vec![1.0, 0.0, 0.0]);
let args = vec![
ColumnarValue::Array(left),
ColumnarValue::Scalar(right_scalar),
];
let result = super::vector_similarity_func(&args, &DistanceMetric::Cosine).unwrap();
if let ColumnarValue::Array(arr) = result {
let float_arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(float_arr.len(), 3);
assert_eq!(float_arr.value(0), 1.0); assert_eq!(float_arr.value(1), 0.0); assert!((float_arr.value(2) - 0.707).abs() < 0.01); } else {
panic!("Expected array result");
}
}
#[test]
fn test_vector_similarity_scalar_vs_scalar() {
let left_scalar = create_vector_scalar(vec![1.0, 0.0, 0.0]);
let right_scalar = create_vector_scalar(vec![1.0, 0.0, 0.0]);
let args = vec![
ColumnarValue::Scalar(left_scalar),
ColumnarValue::Scalar(right_scalar),
];
let result = super::vector_similarity_func(&args, &DistanceMetric::Cosine).unwrap();
if let ColumnarValue::Scalar(scalar) = result {
if let datafusion::scalar::ScalarValue::Float32(Some(sim)) = scalar {
assert_eq!(sim, 1.0); } else {
panic!("Expected Float32 scalar");
}
} else {
panic!("Expected scalar result for scalar vs scalar");
}
}
#[test]
fn test_vector_distance_dot_product_with_scalar() {
let left = create_vector_array(vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]]);
let right_scalar = create_vector_scalar(vec![1.0, 0.0, 0.0]);
let args = vec![
ColumnarValue::Array(left),
ColumnarValue::Scalar(right_scalar),
];
let result = super::vector_distance_func(&args, &DistanceMetric::Dot).unwrap();
if let ColumnarValue::Array(arr) = result {
let float_arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(float_arr.len(), 2);
assert_eq!(float_arr.value(0), -1.0); assert!((float_arr.value(1) + 0.9).abs() < 0.01); } else {
panic!("Expected array result");
}
}
}