use arrow_schema::{DataType, Field};
use datafusion::logical_expr::Volatility;
use uni_plugin::adapter_common::arrow_types::argtype_to_arrow;
use uni_plugin::capability::SideEffects;
use uni_plugin::traits::aggregate::AggSignature;
use uni_plugin::traits::procedure::{NamedArgType, ProcedureMode, ProcedureSignature};
use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling};
use crate::error::ExtismError;
use crate::exports::{WireArgType, WireFnSignature};
pub fn arrow_name_to_datatype(name: &str) -> Result<DataType, ExtismError> {
uni_plugin::adapter_common::arrow_types::arrow_name_to_datatype(name).ok_or_else(|| {
ExtismError::ManifestInvalid(format!("unsupported arrow primitive name: `{name}`"))
})
}
pub fn wire_arg_to_internal(wire: &WireArgType) -> Result<ArgType, ExtismError> {
Ok(match wire {
WireArgType::Primitive { arrow } => ArgType::Primitive(arrow_name_to_datatype(arrow)?),
WireArgType::CypherValue => ArgType::CypherValue,
WireArgType::Vector { len, element } => ArgType::Vector {
len: *len,
element: arrow_name_to_datatype(element)?,
},
WireArgType::Variadic { inner } => {
ArgType::Variadic(Box::new(wire_arg_to_internal(inner)?))
}
})
}
pub fn wire_volatility_to_internal(s: &str) -> Result<Volatility, ExtismError> {
Ok(match s {
"immutable" => Volatility::Immutable,
"stable" => Volatility::Stable,
"volatile" => Volatility::Volatile,
other => {
return Err(ExtismError::ManifestInvalid(format!(
"unsupported volatility: `{other}`"
)));
}
})
}
pub fn wire_null_handling_to_internal(s: &str) -> Result<NullHandling, ExtismError> {
Ok(match s {
"propagate" => NullHandling::PropagateNulls,
"user_handled" => NullHandling::UserHandled,
other => {
return Err(ExtismError::ManifestInvalid(format!(
"unsupported null_handling: `{other}`"
)));
}
})
}
pub fn wire_fn_sig_to_internal(wire: &WireFnSignature) -> Result<FnSignature, ExtismError> {
let args: Vec<ArgType> = wire
.args
.iter()
.map(wire_arg_to_internal)
.collect::<Result<_, _>>()?;
let returns = wire_arg_to_internal(&wire.returns)?;
let volatility = wire_volatility_to_internal(&wire.volatility)?;
let null_handling = wire_null_handling_to_internal(&wire.null_handling)?;
Ok(FnSignature {
args,
returns,
volatility,
null_handling,
})
}
pub fn wire_state_to_field(wire: &WireArgType) -> Result<Field, ExtismError> {
match wire {
WireArgType::Primitive { arrow } => {
Ok(Field::new("state", arrow_name_to_datatype(arrow)?, true))
}
other => Err(ExtismError::ManifestInvalid(format!(
"aggregate state must be a Primitive Arrow type; got: {other:?}"
))),
}
}
pub fn wire_agg_sig_to_internal(
wire_sig: &WireFnSignature,
wire_state: &WireArgType,
) -> Result<AggSignature, ExtismError> {
let args: Vec<ArgType> = wire_sig
.args
.iter()
.map(wire_arg_to_internal)
.collect::<Result<_, _>>()?;
let returns = wire_arg_to_internal(&wire_sig.returns)?;
let volatility = wire_volatility_to_internal(&wire_sig.volatility)?;
let state_fields = vec![wire_state_to_field(wire_state)?];
Ok(AggSignature {
args,
returns,
state_fields,
volatility,
supports_partial: true,
})
}
pub fn wire_proc_mode_to_internal(s: &str) -> Result<ProcedureMode, ExtismError> {
Ok(match s {
"read" => ProcedureMode::Read,
"write" => ProcedureMode::Write,
"schema" => ProcedureMode::Schema,
"dbms" => ProcedureMode::Dbms,
other => {
return Err(ExtismError::ManifestInvalid(format!(
"unsupported procedure mode: `{other}`"
)));
}
})
}
pub fn wire_proc_sig_to_internal(
args: &[WireArgType],
yields: &[WireArgType],
mode: &str,
) -> Result<ProcedureSignature, ExtismError> {
let named_args: Vec<NamedArgType> = args
.iter()
.enumerate()
.map(|(i, w)| {
let ty = wire_arg_to_internal(w)?;
Ok::<NamedArgType, ExtismError>(NamedArgType {
name: format!("arg{i}").into(),
ty,
default: None,
doc: String::new(),
})
})
.collect::<Result<_, _>>()?;
let yield_fields: Vec<Field> = yields
.iter()
.enumerate()
.map(|(i, w)| {
let ty = wire_arg_to_internal(w)?;
Ok::<Field, ExtismError>(Field::new(format!("yield{i}"), argtype_to_arrow(&ty), true))
})
.collect::<Result<_, _>>()?;
let mode = wire_proc_mode_to_internal(mode)?;
Ok(ProcedureSignature {
args: named_args,
yields: yield_fields,
mode,
side_effects: SideEffects::default(),
retry_contract: None,
batch_input: None,
docs: String::new(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn translates_primitive_names() {
assert_eq!(arrow_name_to_datatype("int64").unwrap(), DataType::Int64);
assert_eq!(
arrow_name_to_datatype("float64").unwrap(),
DataType::Float64
);
assert_eq!(arrow_name_to_datatype("utf8").unwrap(), DataType::Utf8);
}
#[test]
fn rejects_unknown_primitive_name() {
let err = arrow_name_to_datatype("super_int").unwrap_err();
assert!(matches!(err, ExtismError::ManifestInvalid(_)));
}
#[test]
fn translates_argtype_variants() {
let cv = wire_arg_to_internal(&WireArgType::CypherValue).unwrap();
assert!(matches!(cv, ArgType::CypherValue));
let p = wire_arg_to_internal(&WireArgType::Primitive {
arrow: "float64".to_owned(),
})
.unwrap();
assert!(matches!(p, ArgType::Primitive(DataType::Float64)));
let v = wire_arg_to_internal(&WireArgType::Vector {
len: 128,
element: "float32".to_owned(),
})
.unwrap();
match v {
ArgType::Vector { len, element } => {
assert_eq!(len, 128);
assert_eq!(element, DataType::Float32);
}
_ => unreachable!(),
}
let var = wire_arg_to_internal(&WireArgType::Variadic {
inner: Box::new(WireArgType::Primitive {
arrow: "int64".to_owned(),
}),
})
.unwrap();
match var {
ArgType::Variadic(inner) => {
assert!(matches!(*inner, ArgType::Primitive(DataType::Int64)));
}
_ => unreachable!(),
}
}
#[test]
fn translates_volatility() {
assert!(matches!(
wire_volatility_to_internal("immutable").unwrap(),
Volatility::Immutable
));
assert!(matches!(
wire_volatility_to_internal("stable").unwrap(),
Volatility::Stable
));
assert!(matches!(
wire_volatility_to_internal("volatile").unwrap(),
Volatility::Volatile
));
assert!(wire_volatility_to_internal("immortal").is_err());
}
#[test]
fn translates_null_handling() {
assert!(matches!(
wire_null_handling_to_internal("propagate").unwrap(),
NullHandling::PropagateNulls
));
assert!(matches!(
wire_null_handling_to_internal("user_handled").unwrap(),
NullHandling::UserHandled
));
assert!(wire_null_handling_to_internal("zombies").is_err());
}
#[test]
fn translates_full_signature() {
let wire = WireFnSignature {
args: vec![
WireArgType::Primitive {
arrow: "float64".to_owned(),
},
WireArgType::Primitive {
arrow: "float64".to_owned(),
},
],
returns: WireArgType::Primitive {
arrow: "float64".to_owned(),
},
volatility: "immutable".to_owned(),
null_handling: "propagate".to_owned(),
};
let sig = wire_fn_sig_to_internal(&wire).unwrap();
assert_eq!(sig.args.len(), 2);
assert!(matches!(sig.returns, ArgType::Primitive(DataType::Float64)));
assert!(matches!(sig.volatility, Volatility::Immutable));
assert!(matches!(sig.null_handling, NullHandling::PropagateNulls));
}
}