use std::sync::Arc;
use arrow_schema::DataType;
use datafusion::logical_expr::{ColumnarValue, Volatility};
use crate::errors::FnError;
pub trait ScalarPluginFn: Send + Sync {
fn signature(&self) -> &FnSignature;
fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError>;
}
#[derive(Clone, Debug)]
pub struct FnSignature {
pub args: Vec<ArgType>,
pub returns: ArgType,
pub volatility: Volatility,
pub null_handling: NullHandling,
}
impl FnSignature {
#[must_use]
pub fn new(args: Vec<ArgType>, returns: ArgType, volatility: Volatility) -> Self {
Self {
args,
returns,
volatility,
null_handling: NullHandling::PropagateNulls,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NullHandling {
PropagateNulls,
UserHandled,
}
#[derive(Clone, Debug)]
pub enum ArgType {
Primitive(DataType),
CypherValue,
Vector {
len: usize,
element: DataType,
},
Variadic(Box<ArgType>),
}
pub struct RowFn<F> {
signature: FnSignature,
#[allow(
dead_code,
reason = "row evaluation is wired by uni-query host adapter; field held for downstream extraction"
)]
inner: Arc<F>,
}
impl<F> std::fmt::Debug for RowFn<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RowFn")
.field("signature", &self.signature)
.finish_non_exhaustive()
}
}
impl<F> RowFn<F> {
#[must_use]
pub fn new(signature: FnSignature, f: F) -> Self {
Self {
signature,
inner: Arc::new(f),
}
}
}
impl<F> ScalarPluginFn for RowFn<F>
where
F: Send + Sync + 'static,
{
fn signature(&self) -> &FnSignature {
&self.signature
}
fn invoke(&self, _args: &[ColumnarValue], _rows: usize) -> Result<ColumnarValue, FnError> {
Err(FnError::new(
0xDEAD,
"RowFn::invoke must be intercepted by the host adapter; \
see uni-query::custom_functions::register_row_fn",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn signature_constructor() {
let sig = FnSignature::new(
vec![ArgType::Primitive(DataType::Float64)],
ArgType::Primitive(DataType::Float64),
Volatility::Immutable,
);
assert_eq!(sig.args.len(), 1);
assert_eq!(sig.null_handling, NullHandling::PropagateNulls);
}
#[test]
fn arg_type_variants_round_trip_in_debug() {
let t = ArgType::Vector {
len: 384,
element: DataType::Float32,
};
let s = format!("{t:?}");
assert!(s.contains("Vector"));
assert!(s.contains("384"));
}
}