use std::sync::Arc;
use datafusion::arrow::array::{Array, ArrayRef, Float32Array, Float64Array, RecordBatch};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result as DfResult;
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility};
use tracing::{debug, warn};
use crate::model::fao::FaoOperator;
#[derive(Debug)]
pub struct FaoScalarUdf {
name: String,
operator: Arc<dyn FaoOperator>,
signature: Signature,
}
impl FaoScalarUdf {
pub fn new(operator: Arc<dyn FaoOperator>) -> Self {
let name = operator.function_id().to_string();
let signature = Signature::new(TypeSignature::VariadicAny, Volatility::Stable);
Self {
name,
operator,
signature,
}
}
}
impl PartialEq for FaoScalarUdf {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for FaoScalarUdf {}
impl std::hash::Hash for FaoScalarUdf {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
impl ScalarUDFImpl for FaoScalarUdf {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DfResult<DataType> {
Ok(DataType::Float64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
let args = &args.args;
if args.is_empty() {
return Err(datafusion_common::DataFusionError::Plan(format!(
"FAO UDF '{}' requires at least one argument",
self.name
)));
}
let arrays: Vec<ArrayRef> = args
.iter()
.map(|cv| match cv {
ColumnarValue::Array(arr) => Ok(arr.clone()),
ColumnarValue::Scalar(s) => {
let n = args
.iter()
.find_map(|a| match a {
ColumnarValue::Array(arr) => Some(arr.len()),
_ => None,
})
.unwrap_or(1);
s.to_array_of_size(n)
}
})
.collect::<DfResult<Vec<_>>>()?;
let num_rows = arrays[0].len();
let input_fields: Vec<Field> = arrays
.iter()
.enumerate()
.map(|(i, _)| Field::new(format!("feature_{i}"), DataType::Float32, false))
.collect();
let input_schema = Arc::new(Schema::new(input_fields));
let f32_columns: Vec<ArrayRef> = arrays
.iter()
.map(|arr: &ArrayRef| -> DfResult<ArrayRef> {
match arr.data_type() {
DataType::Float32 => Ok(arr.clone()),
DataType::Float64 => {
let f64_arr =
arr.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
datafusion_common::DataFusionError::Internal(
"expected Float64Array".into(),
)
})?;
let f32_vals: Vec<f32> =
f64_arr.values().iter().map(|v| *v as f32).collect();
Ok(Arc::new(Float32Array::from(f32_vals)) as ArrayRef)
}
DataType::Int64 => {
let i64_arr = arr
.as_any()
.downcast_ref::<datafusion::arrow::array::Int64Array>()
.ok_or_else(|| {
datafusion_common::DataFusionError::Internal(
"expected Int64Array".into(),
)
})?;
let f32_vals: Vec<f32> =
i64_arr.values().iter().map(|v| *v as f32).collect();
Ok(Arc::new(Float32Array::from(f32_vals)) as ArrayRef)
}
DataType::Int32 => {
let i32_arr = arr
.as_any()
.downcast_ref::<datafusion::arrow::array::Int32Array>()
.ok_or_else(|| {
datafusion_common::DataFusionError::Internal(
"expected Int32Array".into(),
)
})?;
let f32_vals: Vec<f32> =
i32_arr.values().iter().map(|v| *v as f32).collect();
Ok(Arc::new(Float32Array::from(f32_vals)) as ArrayRef)
}
other => Err(datafusion_common::DataFusionError::Plan(format!(
"FAO UDF '{}': unsupported input type {:?} — expected numeric",
self.name, other
))),
}
})
.collect::<DfResult<Vec<_>>>()?;
let input_batch = RecordBatch::try_new(input_schema, f32_columns).map_err(|e| {
datafusion_common::DataFusionError::Internal(format!(
"failed to build input batch for FAO '{}': {e}",
self.name
))
})?;
debug!(
fao = %self.name,
rows = num_rows,
"invoking FAO UDF inline"
);
let operator = Arc::clone(&self.operator);
let output_batch = tokio::task::block_in_place(move || {
tokio::runtime::Handle::current()
.block_on(async move { operator.execute(input_batch).await })
})
.map_err(|e| datafusion_common::DataFusionError::External(Box::new(e)))?;
if output_batch.num_columns() == 0 {
return Err(datafusion_common::DataFusionError::Internal(format!(
"FAO '{}' returned no columns",
self.name
)));
}
let score_col = output_batch.column(0).clone();
let result: ArrayRef = match score_col.data_type() {
DataType::Float64 => score_col,
DataType::Float32 => {
let f32_arr = score_col.as_any().downcast_ref::<Float32Array>().unwrap();
let f64_vals: Vec<f64> = f32_arr.values().iter().map(|v| *v as f64).collect();
Arc::new(Float64Array::from(f64_vals))
}
other => {
warn!(
fao = %self.name,
output_type = ?other,
"unexpected FAO output type, passing through"
);
score_col
}
};
Ok(ColumnarValue::Array(result))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn udf_creation() {
use crate::model::fao::FaoRef;
let udf_name = "test_model";
let sig = Signature::new(TypeSignature::VariadicAny, Volatility::Stable);
assert_eq!(sig.type_signature, TypeSignature::VariadicAny);
let fao_ref = FaoRef {
function_id: udf_name.to_string(),
version: "1.0.0".to_string(),
model_id: "mock".to_string(),
est_latency_ms: 1.0,
est_accuracy: 0.95,
};
assert_eq!(fao_ref.function_id, udf_name);
}
}