uni_plugin_extism/
adapter.rs1use std::sync::Arc;
20
21use arrow::array::RecordBatch;
22use arrow_schema::{Field, Schema, SchemaRef};
23use datafusion::logical_expr::ColumnarValue;
24use uni_plugin::QName;
25use uni_plugin::errors::FnError;
26use uni_plugin::traits::scalar::{FnSignature, ScalarPluginFn};
27
28use crate::adapter_common::{acquire, extism_err_to_fn_err, sanitize_qname};
29use crate::ipc::{decode_batch, encode_batch};
30use crate::pool::ExtismInstancePool;
31
32pub(crate) fn scalar_export_name(qname: &QName) -> String {
41 format!("invoke_{}", sanitize_qname(qname))
42}
43
44pub struct ExtismScalarFn {
52 pool: Arc<ExtismInstancePool<extism::Plugin>>,
53 qname: QName,
54 export_name: String,
55 sig: FnSignature,
56}
57
58impl std::fmt::Debug for ExtismScalarFn {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("ExtismScalarFn")
61 .field("qname", &self.qname)
62 .field("export_name", &self.export_name)
63 .field("signature", &self.sig)
64 .finish_non_exhaustive()
65 }
66}
67
68impl ExtismScalarFn {
69 #[must_use]
71 pub fn new(
72 pool: Arc<ExtismInstancePool<extism::Plugin>>,
73 qname: QName,
74 sig: FnSignature,
75 ) -> Self {
76 let export_name = scalar_export_name(&qname);
77 Self {
78 pool,
79 qname,
80 export_name,
81 sig,
82 }
83 }
84
85 fn args_to_batch(&self, args: &[ColumnarValue], rows: usize) -> Result<RecordBatch, FnError> {
90 let arrays: Vec<arrow::array::ArrayRef> = args
91 .iter()
92 .map(|c| {
93 c.clone().into_array(rows).map_err(|e| {
94 FnError::new(
95 FnError::CODE_TYPE_COERCION,
96 format!("ColumnarValue::into_array: {e}"),
97 )
98 })
99 })
100 .collect::<Result<_, _>>()?;
101 let fields: Vec<Field> = arrays
102 .iter()
103 .enumerate()
104 .map(|(i, a)| Field::new(format!("arg{i}"), a.data_type().clone(), true))
105 .collect();
106 let schema: SchemaRef = Arc::new(Schema::new(fields));
107 RecordBatch::try_new(schema, arrays).map_err(|e| {
108 FnError::new(
109 FnError::CODE_TYPE_COERCION,
110 format!("RecordBatch assembly: {e}"),
111 )
112 })
113 }
114}
115
116impl ScalarPluginFn for ExtismScalarFn {
117 fn signature(&self) -> &FnSignature {
118 &self.sig
119 }
120
121 fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError> {
122 let batch = self.args_to_batch(args, rows)?;
123 let bytes = encode_batch(&batch).map_err(extism_err_to_fn_err)?;
124
125 let mut leased = acquire(&self.pool)?;
126 let out_bytes: Vec<u8> = {
127 let plugin = leased.get_mut();
128 let out: &[u8] = plugin
129 .call(&self.export_name, bytes.as_slice())
130 .map_err(|e| {
131 FnError::new(
132 FnError::CODE_UNEXPECTED_NULL,
133 format!("extism call `{}` failed: {e}", self.export_name),
134 )
135 })?;
136 out.to_vec()
138 };
139 drop(leased);
140
141 let out_batch = decode_batch(&out_bytes)
142 .map_err(extism_err_to_fn_err)?
143 .ok_or_else(|| {
144 FnError::new(
145 FnError::CODE_UNEXPECTED_NULL,
146 format!("plugin `{}` returned an empty IPC stream", self.export_name),
147 )
148 })?;
149
150 if out_batch.num_columns() != 1 {
151 return Err(FnError::new(
152 FnError::CODE_TYPE_COERCION,
153 format!(
154 "plugin `{}` returned {} columns; scalar fns must return exactly 1",
155 self.export_name,
156 out_batch.num_columns()
157 ),
158 ));
159 }
160 Ok(ColumnarValue::Array(out_batch.column(0).clone()))
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn scalar_export_name_format() {
170 let q = QName::parse("geo.haversine").expect("valid");
171 assert_eq!(scalar_export_name(&q), "invoke_geo_haversine");
172 }
173}