1use std::sync::Arc;
13
14use arrow_array::ArrayRef;
15use arrow_array::builder::{BooleanBuilder, Float64Builder, Int64Builder, StringBuilder};
16use arrow_schema::DataType;
17use datafusion::logical_expr::{ColumnarValue, Volatility};
18use uni_common::Value;
19use uni_cypher::ast::Expr;
20use uni_plugin::FnError;
21use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
22
23use crate::decode::{array_value_at, eval_err_to_fn, stringify};
24use crate::eval::eval_expr;
25
26pub struct DeclaredScalarFn {
33 body: Expr,
34 arg_names: Vec<String>,
35 signature: FnSignature,
36}
37
38impl std::fmt::Debug for DeclaredScalarFn {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("DeclaredScalarFn")
41 .field("arg_names", &self.arg_names)
42 .field("return_type", &self.signature.returns)
43 .finish_non_exhaustive()
44 }
45}
46
47impl DeclaredScalarFn {
48 #[must_use]
54 pub fn new(body: Expr, arg_names: Vec<String>, signature: FnSignature) -> Self {
55 Self {
56 body,
57 arg_names,
58 signature,
59 }
60 }
61
62 #[must_use]
65 pub fn build_signature(returns: DataType, args: &[(String, DataType)]) -> FnSignature {
66 FnSignature {
67 args: args
68 .iter()
69 .map(|(_, t)| ArgType::Primitive(t.clone()))
70 .collect(),
71 returns: ArgType::Primitive(returns),
72 volatility: Volatility::Volatile,
73 null_handling: NullHandling::UserHandled,
74 }
75 }
76}
77
78impl ScalarPluginFn for DeclaredScalarFn {
79 fn signature(&self) -> &FnSignature {
80 &self.signature
81 }
82
83 fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError> {
84 if args.len() != self.arg_names.len() {
85 return Err(FnError::new(
86 FnError::CODE_TYPE_COERCION,
87 format!(
88 "declared scalar fn expected {} args, got {}",
89 self.arg_names.len(),
90 args.len()
91 ),
92 ));
93 }
94 let row_count = rows.max(1);
95 let columns: Vec<ArrayRef> = args
96 .iter()
97 .map(|cv| columnar_to_array(cv, row_count))
98 .collect::<Result<_, _>>()?;
99
100 let return_dt = match &self.signature.returns {
101 ArgType::Primitive(dt) => dt.clone(),
102 other => {
103 return Err(FnError::new(
104 FnError::CODE_TYPE_COERCION,
105 format!("declared fn return type not supported: {other:?}"),
106 ));
107 }
108 };
109
110 let out = build_output(&return_dt, row_count, |row| {
111 let mut bindings = std::collections::HashMap::with_capacity(columns.len());
112 for (i, col) in columns.iter().enumerate() {
113 bindings.insert(self.arg_names[i].clone(), array_value_at(col, row)?);
114 }
115 eval_expr(&self.body, &bindings).map_err(eval_err_to_fn)
116 })?;
117
118 Ok(ColumnarValue::Array(out))
119 }
120}
121
122fn columnar_to_array(cv: &ColumnarValue, rows: usize) -> Result<ArrayRef, FnError> {
123 match cv {
124 ColumnarValue::Array(a) => Ok(Arc::clone(a)),
125 ColumnarValue::Scalar(s) => s
126 .to_array_of_size(rows)
127 .map_err(|e| FnError::new(FnError::CODE_TYPE_COERCION, format!("scalar→array: {e}"))),
128 }
129}
130
131fn build_output(
132 dt: &DataType,
133 rows: usize,
134 mut row_value: impl FnMut(usize) -> Result<Value, FnError>,
135) -> Result<ArrayRef, FnError> {
136 match dt {
137 DataType::Utf8 => {
138 let mut b = StringBuilder::with_capacity(rows, rows * 8);
139 for row in 0..rows {
140 match row_value(row)? {
141 Value::Null => b.append_null(),
142 Value::String(s) => b.append_value(s),
143 other => b.append_value(stringify(&other)),
144 }
145 }
146 Ok(Arc::new(b.finish()))
147 }
148 DataType::Int64 => {
149 let mut b = Int64Builder::with_capacity(rows);
150 for row in 0..rows {
151 match row_value(row)? {
152 Value::Null => b.append_null(),
153 Value::Int(i) => b.append_value(i),
154 Value::Float(f) => b.append_value(f as i64),
155 other => {
156 return Err(FnError::new(
157 FnError::CODE_TYPE_COERCION,
158 format!("expected Int64, got {other:?}"),
159 ));
160 }
161 }
162 }
163 Ok(Arc::new(b.finish()))
164 }
165 DataType::Float64 => {
166 let mut b = Float64Builder::with_capacity(rows);
167 for row in 0..rows {
168 match row_value(row)? {
169 Value::Null => b.append_null(),
170 Value::Int(i) => b.append_value(i as f64),
171 Value::Float(f) => b.append_value(f),
172 other => {
173 return Err(FnError::new(
174 FnError::CODE_TYPE_COERCION,
175 format!("expected Float64, got {other:?}"),
176 ));
177 }
178 }
179 }
180 Ok(Arc::new(b.finish()))
181 }
182 DataType::Boolean => {
183 let mut b = BooleanBuilder::with_capacity(rows);
184 for row in 0..rows {
185 match row_value(row)? {
186 Value::Null => b.append_null(),
187 Value::Bool(v) => b.append_value(v),
188 other => {
189 return Err(FnError::new(
190 FnError::CODE_TYPE_COERCION,
191 format!("expected Boolean, got {other:?}"),
192 ));
193 }
194 }
195 }
196 Ok(Arc::new(b.finish()))
197 }
198 other => Err(FnError::new(
199 FnError::CODE_TYPE_COERCION,
200 format!("declared fn return type {other:?} not supported"),
201 )),
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use arrow_array::{Array, StringArray};
209 use datafusion::scalar::ScalarValue;
210 use uni_cypher::parse_expression;
211
212 fn fn_string(body: &str, arg_names: &[&str]) -> DeclaredScalarFn {
213 let body = parse_expression(body).unwrap();
214 let arg_names: Vec<String> = arg_names.iter().map(|s| (*s).to_owned()).collect();
215 let sig_args: Vec<(String, DataType)> = arg_names
216 .iter()
217 .map(|n| (n.clone(), DataType::Utf8))
218 .collect();
219 let sig = DeclaredScalarFn::build_signature(DataType::Utf8, &sig_args);
220 DeclaredScalarFn::new(body, arg_names, sig)
221 }
222
223 #[test]
224 fn invoke_string_concat_via_scalars() {
225 let f = fn_string("$first + ' ' + $last", &["first", "last"]);
226 let args = vec![
227 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Ada".to_owned()))),
228 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Lovelace".to_owned()))),
229 ];
230 let out = f.invoke(&args, 1).unwrap();
231 let arr = match out {
232 ColumnarValue::Array(a) => a,
233 ColumnarValue::Scalar(_) => panic!("expected array"),
234 };
235 let s = arr.as_any().downcast_ref::<StringArray>().unwrap();
236 assert_eq!(s.value(0), "Ada Lovelace");
237 }
238
239 #[test]
240 fn invoke_arity_mismatch() {
241 let f = fn_string("$first + ' ' + $last", &["first", "last"]);
242 let args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
243 "a".to_owned(),
244 )))];
245 let err = f.invoke(&args, 1).unwrap_err();
246 assert_eq!(err.code, FnError::CODE_TYPE_COERCION);
247 }
248}