use crate::automation::callable::{Callable, DataType, ExecutionContext, Signature, Value};
use crate::error::{DbxError, DbxResult};
type VectorizedFn = Box<dyn Fn(&[Value]) -> DbxResult<Value> + Send + Sync>;
pub struct VectorizedUDF {
name: String,
signature: Signature,
func: VectorizedFn,
}
impl VectorizedUDF {
pub fn new<F>(name: impl Into<String>, signature: Signature, func: F) -> Self
where
F: Fn(&[Value]) -> DbxResult<Value> + Send + Sync + 'static,
{
Self {
name: name.into(),
signature,
func: Box::new(func),
}
}
}
impl Callable for VectorizedUDF {
fn call(&self, _ctx: &ExecutionContext, args: &[Value]) -> DbxResult<Value> {
self.signature.validate_args(args)?;
for arg in args {
if arg.data_type() != DataType::Array {
return Err(DbxError::TypeMismatch {
expected: "Array".to_string(),
actual: format!("{:?}", arg.data_type()),
});
}
}
let mut expected_len = None;
for arg in args {
let arr = arg.as_array()?;
match expected_len {
None => expected_len = Some(arr.len()),
Some(len) => {
if len != arr.len() {
return Err(DbxError::InvalidArguments(
"Vectorized UDF arguments must have the same array length".to_string(),
));
}
}
}
}
let result = (self.func)(args)?;
if result.data_type() != DataType::Array {
return Err(DbxError::TypeMismatch {
expected: "Array".to_string(),
actual: format!("{:?}", result.data_type()),
});
}
if let Some(expected) = expected_len {
let res_arr = result.as_array()?;
if res_arr.len() != expected {
return Err(DbxError::InvalidArguments(
"Vectorized UDF result array length must match input array lengths".to_string(),
));
}
}
Ok(result)
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::automation::executor::ExecutionEngine;
use crate::engine::Database;
use std::sync::Arc;
#[test]
fn test_vectorized_udf_basic() {
let udf = VectorizedUDF::new(
"vec_double",
Signature {
params: vec![DataType::Array],
return_type: DataType::Array,
is_variadic: false,
},
|args| {
let input_array = args[0].as_array()?;
let mut result = Vec::with_capacity(input_array.len());
for val in input_array {
let x = val.as_i64()?;
result.push(Value::Int(x * 2));
}
Ok(Value::Array(result))
},
);
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let input_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
let result = udf.call(&ctx, &[Value::Array(input_vals)]).unwrap();
let out_arr = result.as_array().unwrap();
assert_eq!(out_arr.len(), 3);
assert_eq!(out_arr[0].as_i64().unwrap(), 20);
assert_eq!(out_arr[1].as_i64().unwrap(), 40);
assert_eq!(out_arr[2].as_i64().unwrap(), 60);
}
#[test]
fn test_vectorized_udf_multiple_args() {
let udf = VectorizedUDF::new(
"vec_add",
Signature {
params: vec![DataType::Array, DataType::Array],
return_type: DataType::Array,
is_variadic: false,
},
|args| {
let x_arr = args[0].as_array()?;
let y_arr = args[1].as_array()?;
let mut result = Vec::with_capacity(x_arr.len());
for i in 0..x_arr.len() {
let x = x_arr[i].as_i64()?;
let y = y_arr[i].as_i64()?;
result.push(Value::Int(x + y));
}
Ok(Value::Array(result))
},
);
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let x_vals = vec![Value::Int(1), Value::Int(2), Value::Int(3)];
let y_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
let result = udf
.call(&ctx, &[Value::Array(x_vals), Value::Array(y_vals)])
.unwrap();
let out_arr = result.as_array().unwrap();
assert_eq!(out_arr.len(), 3);
assert_eq!(out_arr[0].as_i64().unwrap(), 11);
assert_eq!(out_arr[1].as_i64().unwrap(), 22);
assert_eq!(out_arr[2].as_i64().unwrap(), 33);
}
#[test]
fn test_vectorized_udf_mismatch_lengths() {
let udf = VectorizedUDF::new(
"vec_add",
Signature {
params: vec![DataType::Array, DataType::Array],
return_type: DataType::Array,
is_variadic: false,
},
|_args| {
Ok(Value::Array(vec![])) },
);
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let x_vals = vec![Value::Int(1), Value::Int(2)];
let y_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
let result = udf.call(&ctx, &[Value::Array(x_vals), Value::Array(y_vals)]);
assert!(result.is_err());
match result.unwrap_err() {
DbxError::InvalidArguments(msg) => assert!(msg.contains("same array length")),
_ => panic!("Expected InvalidArguments error"),
}
}
#[test]
fn test_vectorized_udf_with_engine() {
let engine = ExecutionEngine::new();
let udf = Arc::new(VectorizedUDF::new(
"vec_triple",
Signature {
params: vec![DataType::Array],
return_type: DataType::Array,
is_variadic: false,
},
|args| {
let input_array = args[0].as_array()?;
let mut result = Vec::with_capacity(input_array.len());
for val in input_array {
let x = val.as_i64()?;
result.push(Value::Int(x * 3));
}
Ok(Value::Array(result))
},
));
engine.register(udf).unwrap();
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let result = engine
.execute(
"vec_triple",
&ctx,
&[Value::Array(vec![Value::Int(5), Value::Int(10)])],
)
.unwrap();
let out_arr = result.as_array().unwrap();
assert_eq!(out_arr[0].as_i64().unwrap(), 15);
assert_eq!(out_arr[1].as_i64().unwrap(), 30);
}
}