#![cfg(feature = "rhai-runtime")]
use std::sync::Arc;
use arrow_array::{ArrayRef, BinaryArray, LargeBinaryArray};
use arrow_schema::{DataType, Field};
use datafusion::scalar::ScalarValue;
use rhai::{Dynamic, Scope};
use smol_str::SmolStr;
use uni_plugin::errors::FnError;
use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
use uni_plugin::traits::scalar::ArgType;
use crate::dynamic_bridge::array_row_to_dynamic;
use crate::runtime::RhaiPluginRuntime;
#[derive(Debug)]
pub struct RhaiAggregateFn {
runtime: Arc<RhaiPluginRuntime>,
name: SmolStr,
signature: AggSignature,
}
impl RhaiAggregateFn {
#[must_use]
pub fn new(
runtime: Arc<RhaiPluginRuntime>,
name: impl Into<SmolStr>,
signature: AggSignature,
) -> Self {
Self {
runtime,
name: name.into(),
signature,
}
}
}
impl AggregatePluginFn for RhaiAggregateFn {
fn signature(&self) -> &AggSignature {
&self.signature
}
fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
let mut scope = Scope::new();
let init_fn = format!("{}_init", self.name);
let (state, init_error) = match self.runtime.engine.call_fn::<Dynamic>(
&mut scope,
&self.runtime.ast,
&init_fn,
(),
) {
Ok(s) => (s, None),
Err(e) => (
Dynamic::UNIT,
Some(FnError::new(
0x723,
format!("Rhai aggregate `{}` init failed: {e}", self.name),
)),
),
};
Box::new(RhaiAccumulator {
runtime: Arc::clone(&self.runtime),
name: self.name.clone(),
state,
input_types: self.signature.args.clone(),
init_error,
})
}
}
pub struct RhaiAccumulator {
runtime: Arc<RhaiPluginRuntime>,
name: SmolStr,
state: Dynamic,
input_types: Vec<ArgType>,
init_error: Option<FnError>,
}
impl std::fmt::Debug for RhaiAccumulator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RhaiAccumulator")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
impl RhaiAccumulator {
fn check_init(&self) -> Result<(), FnError> {
match &self.init_error {
Some(e) => Err(e.clone()),
None => Ok(()),
}
}
}
impl PluginAccumulator for RhaiAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
self.check_init()?;
let accumulate_fn = format!("{}_accumulate", self.name);
let n = values.first().map(|a| a.len()).unwrap_or(0);
for row in 0..n {
let mut dyn_args: Vec<Dynamic> = Vec::with_capacity(values.len() + 1);
dyn_args.push(self.state.clone());
for (i, arr) in values.iter().enumerate() {
let dt = primitive_datatype(&self.input_types, i)?;
let d = array_row_to_dynamic(arr, row, &dt)
.map_err(|e| FnError::new(0x12, e.to_string()))?;
dyn_args.push(d);
}
let mut scope = Scope::new();
let new_state = self
.runtime
.engine
.call_fn::<Dynamic>(&mut scope, &self.runtime.ast, &accumulate_fn, dyn_args)
.map_err(|e| FnError::new(0x720, format!("Rhai accumulate: {e}")))?;
self.state = new_state;
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError> {
self.check_init()?;
let merge_fn = format!("{}_merge", self.name);
let Some(state_arr) = states.first() else {
return Ok(());
};
let n = state_arr.len();
for row in 0..n {
let bytes = peer_state_bytes(state_arr, row)?;
let peer_state = decode_state(&bytes)?;
let mut scope = Scope::new();
let new_state = self
.runtime
.engine
.call_fn::<Dynamic>(
&mut scope,
&self.runtime.ast,
&merge_fn,
(self.state.clone(), peer_state),
)
.map_err(|e| FnError::new(0x721, format!("Rhai merge: {e}")))?;
self.state = new_state;
}
Ok(())
}
fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
self.check_init()?;
let bytes = encode_state(&self.state)?;
Ok(vec![ScalarValue::LargeBinary(Some(bytes))])
}
fn evaluate(&self) -> Result<ScalarValue, FnError> {
self.check_init()?;
let finalize_fn = format!("{}_finalize", self.name);
let mut scope = Scope::new();
let result = self
.runtime
.engine
.call_fn::<Dynamic>(
&mut scope,
&self.runtime.ast,
&finalize_fn,
(self.state.clone(),),
)
.map_err(|e| FnError::new(0x722, format!("Rhai finalize: {e}")))?;
dynamic_to_scalar_loose(result)
}
fn size(&self) -> usize {
std::mem::size_of::<Self>() + 64
}
}
fn primitive_datatype(args: &[ArgType], i: usize) -> Result<DataType, FnError> {
match args.get(i) {
Some(ArgType::Primitive(dt)) => Ok(dt.clone()),
Some(other) => Err(FnError::new(
0x10,
format!("Rhai aggregate arg {i}: primitives only, got {other:?}"),
)),
None => Err(FnError::new(0x10, format!("missing arg type {i}"))),
}
}
fn peer_state_bytes(arr: &ArrayRef, row: usize) -> Result<Vec<u8>, FnError> {
if arr.is_null(row) {
return Ok(Vec::new());
}
if let Some(a) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
return Ok(a.value(row).to_vec());
}
if let Some(a) = arr.as_any().downcast_ref::<BinaryArray>() {
return Ok(a.value(row).to_vec());
}
Err(FnError::new(
0x12,
format!(
"Rhai aggregate merge: expected Binary/LargeBinary state column, got {:?}",
arr.data_type()
),
))
}
fn encode_state(state: &Dynamic) -> Result<Vec<u8>, FnError> {
serde_json::to_vec(state).map_err(|e| FnError::new(0x13, format!("Rhai state encode: {e}")))
}
fn decode_state(bytes: &[u8]) -> Result<Dynamic, FnError> {
if bytes.is_empty() {
return Ok(Dynamic::UNIT);
}
let v: serde_json::Value = serde_json::from_slice(bytes)
.map_err(|e| FnError::new(0x13, format!("Rhai state decode: {e}")))?;
serde_json_to_dynamic(&v).map_err(|e| FnError::new(0x13, format!("Rhai state value: {e}")))
}
pub fn serde_json_to_dynamic(v: &serde_json::Value) -> Result<Dynamic, String> {
use serde_json::Value as J;
Ok(match v {
J::Null => Dynamic::UNIT,
J::Bool(b) => Dynamic::from(*b),
J::Number(n) => {
if let Some(i) = n.as_i64() {
Dynamic::from(i)
} else if let Some(f) = n.as_f64() {
Dynamic::from(f)
} else {
return Err(format!("unrepresentable number: {n}"));
}
}
J::String(s) => Dynamic::from(s.clone()),
J::Array(arr) => {
let mut out: rhai::Array = Vec::with_capacity(arr.len());
for item in arr {
out.push(serde_json_to_dynamic(item)?);
}
Dynamic::from(out)
}
J::Object(obj) => {
let mut out: rhai::Map = rhai::Map::new();
for (k, v) in obj {
out.insert(k.as_str().into(), serde_json_to_dynamic(v)?);
}
Dynamic::from(out)
}
})
}
fn dynamic_to_scalar_loose(d: Dynamic) -> Result<ScalarValue, FnError> {
if d.is_unit() {
return Ok(ScalarValue::Null);
}
if let Ok(b) = d.as_bool() {
return Ok(ScalarValue::Boolean(Some(b)));
}
if let Ok(i) = d.as_int() {
return Ok(ScalarValue::Int64(Some(i)));
}
if let Ok(f) = d.as_float() {
return Ok(ScalarValue::Float64(Some(f)));
}
if let Ok(s) = d.clone().into_string() {
return Ok(ScalarValue::Utf8(Some(s)));
}
let bytes = serde_json::to_string(&d).map_err(|e| FnError::new(0x13, e.to_string()))?;
Ok(ScalarValue::LargeUtf8(Some(bytes)))
}
#[must_use]
pub fn rhai_state_fields() -> Vec<Field> {
vec![Field::new("rhai_state", DataType::LargeBinary, true)]
}
pub fn build_agg_signature(
args: &[String],
returns: &str,
determinism: &str,
) -> Result<AggSignature, crate::error::RhaiError> {
use crate::wire_translate::{determinism_to_volatility, type_name_to_argtype};
let arg_types: Vec<ArgType> = args
.iter()
.map(|s| type_name_to_argtype(s))
.collect::<Result<_, _>>()?;
let return_type = match returns.trim().to_ascii_lowercase().as_str() {
"map" | "object" | "any" => ArgType::Primitive(DataType::LargeUtf8),
_ => type_name_to_argtype(returns)?,
};
Ok(AggSignature {
args: arg_types,
returns: return_type,
state_fields: rhai_state_fields(),
volatility: determinism_to_volatility(determinism),
supports_partial: true,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::build_engine;
use crate::host_fns::RhaiHostFnRegistry;
use crate::manifest::compile;
use arrow_array::Float64Array;
use datafusion::logical_expr::Volatility;
use uni_plugin::{CapabilitySet, PluginId};
fn build_runtime(script: &str) -> Arc<RhaiPluginRuntime> {
let engine = build_engine(&CapabilitySet::new(), &RhaiHostFnRegistry::new());
let ast = compile(&engine, script).unwrap();
RhaiPluginRuntime::new(PluginId::new("test.agg"), engine, ast)
}
#[test]
fn stats_aggregate_round_trips() {
let script = r#"
fn stats_init() {
#{ n: 0, sum: 0.0, sum_sq: 0.0 }
}
fn stats_accumulate(state, x) {
state.n += 1;
state.sum += x;
state.sum_sq += x * x;
state
}
fn stats_merge(a, b) {
#{ n: a.n + b.n, sum: a.sum + b.sum, sum_sq: a.sum_sq + b.sum_sq }
}
fn stats_finalize(s) {
if s.n == 0 { return (); }
s.sum / s.n
}
"#;
let runtime = build_runtime(script);
let sig = AggSignature {
args: vec![ArgType::Primitive(DataType::Float64)],
returns: ArgType::Primitive(DataType::Float64),
state_fields: rhai_state_fields(),
volatility: Volatility::Immutable,
supports_partial: true,
};
let agg = RhaiAggregateFn::new(runtime, "stats", sig);
let mut acc = agg.create_accumulator();
let xs: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
acc.update_batch(&[xs]).unwrap();
let result = acc.evaluate().unwrap();
match result {
ScalarValue::Float64(Some(v)) => assert!((v - 2.5).abs() < 1e-9),
other => panic!("unexpected result: {other:?}"),
}
}
#[test]
fn state_serializes_and_merges() {
let script = r#"
fn sum_init() { 0.0 }
fn sum_accumulate(state, x) { state + x }
fn sum_merge(a, b) { a + b }
fn sum_finalize(s) { s }
"#;
let runtime = build_runtime(script);
let sig = AggSignature {
args: vec![ArgType::Primitive(DataType::Float64)],
returns: ArgType::Primitive(DataType::Float64),
state_fields: rhai_state_fields(),
volatility: Volatility::Immutable,
supports_partial: true,
};
let agg = RhaiAggregateFn::new(runtime, "sum", sig);
let mut a = agg.create_accumulator();
let xs1: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
a.update_batch(&[xs1]).unwrap();
let state_vec = a.state().unwrap();
let state_bytes = match &state_vec[0] {
ScalarValue::LargeBinary(Some(b)) => b.clone(),
other => panic!("expected LargeBinary, got {other:?}"),
};
let mut b = agg.create_accumulator();
let xs2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 20.0]));
b.update_batch(&[xs2]).unwrap();
let peer_arr: ArrayRef = Arc::new(LargeBinaryArray::from(vec![state_bytes.as_slice()]));
b.merge_batch(&[peer_arr]).unwrap();
let result = b.evaluate().unwrap();
match result {
ScalarValue::Float64(Some(v)) => assert!((v - 36.0).abs() < 1e-9),
other => panic!("unexpected result: {other:?}"),
}
}
}