#![cfg(feature = "rhai-runtime")]
use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::{DataType, Schema, SchemaRef};
use datafusion::execution::SendableRecordBatchStream;
use datafusion::logical_expr::ColumnarValue;
use rhai::{Dynamic, Map, Scope};
use smol_str::SmolStr;
use uni_plugin::adapter_common::batch_builder::batch_into_stream;
use uni_plugin::errors::FnError;
use uni_plugin::traits::procedure::{ProcedureContext, ProcedurePlugin, ProcedureSignature};
use crate::dynamic_bridge::{OutBuilder, scalar_to_dynamic};
use crate::runtime::RhaiPluginRuntime;
#[derive(Debug)]
pub struct RhaiProcedure {
runtime: Arc<RhaiPluginRuntime>,
name: SmolStr,
signature: ProcedureSignature,
}
impl RhaiProcedure {
#[must_use]
pub fn new(
runtime: Arc<RhaiPluginRuntime>,
name: impl Into<SmolStr>,
signature: ProcedureSignature,
) -> Self {
Self {
runtime,
name: name.into(),
signature,
}
}
}
impl ProcedurePlugin for RhaiProcedure {
fn signature(&self) -> &ProcedureSignature {
&self.signature
}
fn invoke(
&self,
_ctx: ProcedureContext<'_>,
args: &[ColumnarValue],
) -> Result<SendableRecordBatchStream, FnError> {
let mut dyn_args: Vec<Dynamic> = Vec::with_capacity(args.len());
for (i, arg) in args.iter().enumerate() {
match arg {
ColumnarValue::Scalar(s) => {
let d = scalar_to_dynamic(s)
.map_err(|e| FnError::new(0x12, format!("procedure arg {i}: {e}")))?;
dyn_args.push(d);
}
ColumnarValue::Array(_) => {
return Err(FnError::new(
0x10,
format!("procedure arg {i} must be a scalar"),
));
}
}
}
let mut scope = Scope::new();
let result: Dynamic = self
.runtime
.engine
.call_fn(&mut scope, &self.runtime.ast, self.name.as_str(), dyn_args)
.map_err(|e| FnError::new(0x730, format!("Rhai procedure `{}`: {e}", self.name)))?;
let yield_schema = Arc::new(Schema::new(self.signature.yields.clone()));
let batch = dynamic_to_record_batch(result, &yield_schema)?;
Ok(batch_into_stream(batch))
}
}
fn dynamic_to_record_batch(d: Dynamic, schema: &SchemaRef) -> Result<RecordBatch, FnError> {
let rows: rhai::Array = d.try_cast().ok_or_else(|| {
FnError::new(
0x12,
String::from("Rhai procedure must return an array of row maps"),
)
})?;
let row_count = rows.len();
let mut builders: Vec<OutBuilder> = schema
.fields()
.iter()
.map(|f| OutBuilder::new(f.data_type(), row_count))
.collect::<Result<_, _>>()
.map_err(|e| FnError::new(0x11, e.to_string()))?;
for (i, row) in rows.into_iter().enumerate() {
let m: Map = row
.try_cast()
.ok_or_else(|| FnError::new(0x12, format!("procedure row {i} must be a map")))?;
for (field_idx, field) in schema.fields().iter().enumerate() {
let key = field.name();
let value = m.get(key.as_str()).cloned().unwrap_or(Dynamic::UNIT);
let value = coerce_for(field.data_type(), value)?;
builders[field_idx]
.push(value)
.map_err(|e| FnError::new(0x14, e.to_string()))?;
}
}
let columns: Vec<ArrayRef> = builders.into_iter().map(|b| b.finish()).collect();
RecordBatch::try_new(schema.clone(), columns)
.map_err(|e| FnError::new(0x15, format!("procedure batch: {e}")))
}
fn coerce_for(target: &DataType, value: Dynamic) -> Result<Dynamic, FnError> {
if value.is_unit() {
return Ok(value);
}
match target {
DataType::Float64 => match value.as_int() {
Ok(i) => Ok(Dynamic::from(i as f64)),
Err(_) => Ok(value),
},
DataType::Int64 => match value.as_float() {
Ok(f) => Ok(Dynamic::from(f as i64)),
Err(_) => Ok(value),
},
_ => Ok(value),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::build_engine;
use crate::host_fns::RhaiHostFnRegistry;
use crate::manifest::compile;
use arrow_schema::Field;
use futures::StreamExt;
use uni_plugin::capability::SideEffects;
use uni_plugin::traits::procedure::ProcedureMode;
use uni_plugin::{CapabilitySet, PluginId};
fn build_runtime(script: &str) -> Arc<RhaiPluginRuntime> {
let engine = build_engine(&CapabilitySet::new(), &RhaiHostFnRegistry::new());
let ast = compile(&engine, script).unwrap();
RhaiPluginRuntime::new(PluginId::new("test.proc"), engine, ast)
}
#[tokio::test]
async fn procedure_emits_rows() {
let script = r#"
fn rows() {
[
#{ id: 1, name: "alice" },
#{ id: 2, name: "bob" },
#{ id: 3, name: "carol" },
]
}
"#;
let runtime = build_runtime(script);
let sig = ProcedureSignature {
args: vec![],
yields: vec![
Field::new("id", DataType::Int64, true),
Field::new("name", DataType::Utf8, true),
],
mode: ProcedureMode::Read,
side_effects: SideEffects::ReadOnly,
retry_contract: None,
batch_input: None,
docs: String::new(),
};
let proc = RhaiProcedure::new(runtime, "rows", sig);
let mut stream = proc.invoke(ProcedureContext::new(), &[]).unwrap();
let batch = stream.next().await.unwrap().unwrap();
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 2);
}
}