use crate::automation::callable::{Callable, ExecutionContext, Signature, Value};
use crate::error::DbxResult;
type ScalarFn = Box<dyn Fn(&[Value]) -> DbxResult<Value> + Send + Sync>;
pub struct ScalarUDF {
name: String,
signature: Signature,
func: ScalarFn,
}
impl ScalarUDF {
pub fn new<F>(name: impl Into<String>, signature: Signature, func: F) -> Self
where
F: Fn(&[Value]) -> DbxResult<Value> + Send + Sync + 'static,
{
Self {
name: name.into(),
signature,
func: Box::new(func),
}
}
}
impl Callable for ScalarUDF {
fn call(&self, _ctx: &ExecutionContext, args: &[Value]) -> DbxResult<Value> {
self.signature.validate_args(args)?;
(self.func)(args)
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::automation::callable::DataType;
use crate::automation::executor::ExecutionEngine;
use crate::engine::Database;
use std::sync::Arc;
#[test]
fn test_scalar_udf_basic() {
let udf = ScalarUDF::new(
"double",
Signature {
params: vec![DataType::Int],
return_type: DataType::Int,
is_variadic: false,
},
|args| {
let x = args[0].as_i64()?;
Ok(Value::Int(x * 2))
},
);
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let result = udf.call(&ctx, &[Value::Int(21)]).unwrap();
assert_eq!(result.as_i64().unwrap(), 42);
}
#[test]
fn test_scalar_udf_string() {
let udf = ScalarUDF::new(
"upper",
Signature {
params: vec![DataType::String],
return_type: DataType::String,
is_variadic: false,
},
|args| {
let s = args[0].as_str()?;
Ok(Value::String(s.to_uppercase()))
},
);
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let result = udf
.call(&ctx, &[Value::String("hello".to_string())])
.unwrap();
assert_eq!(result.as_str().unwrap(), "HELLO");
}
#[test]
fn test_scalar_udf_multiple_args() {
let udf = ScalarUDF::new(
"add",
Signature {
params: vec![DataType::Int, DataType::Int],
return_type: DataType::Int,
is_variadic: false,
},
|args| {
let x = args[0].as_i64()?;
let y = args[1].as_i64()?;
Ok(Value::Int(x + y))
},
);
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let result = udf.call(&ctx, &[Value::Int(10), Value::Int(32)]).unwrap();
assert_eq!(result.as_i64().unwrap(), 42);
}
#[test]
fn test_scalar_udf_type_validation() {
let udf = ScalarUDF::new(
"double",
Signature {
params: vec![DataType::Int],
return_type: DataType::Int,
is_variadic: false,
},
|args| {
let x = args[0].as_i64()?;
Ok(Value::Int(x * 2))
},
);
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let result = udf.call(&ctx, &[Value::String("hello".to_string())]);
assert!(result.is_err());
}
#[test]
fn test_scalar_udf_with_engine() {
let engine = ExecutionEngine::new();
let udf = Arc::new(ScalarUDF::new(
"triple",
Signature {
params: vec![DataType::Int],
return_type: DataType::Int,
is_variadic: false,
},
|args| {
let x = args[0].as_i64()?;
Ok(Value::Int(x * 3))
},
));
engine.register(udf).unwrap();
let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
let result = engine.execute("triple", &ctx, &[Value::Int(14)]).unwrap();
assert_eq!(result.as_i64().unwrap(), 42);
}
}