dbx_core/automation/udf/
vectorized.rs1use crate::automation::callable::{Callable, DataType, ExecutionContext, Signature, Value};
6use crate::error::{DbxError, DbxResult};
7
8type VectorizedFn = Box<dyn Fn(&[Value]) -> DbxResult<Value> + Send + Sync>;
10
11pub struct VectorizedUDF {
13 name: String,
14 signature: Signature,
15 func: VectorizedFn,
16}
17
18impl VectorizedUDF {
19 pub fn new<F>(name: impl Into<String>, signature: Signature, func: F) -> Self
21 where
22 F: Fn(&[Value]) -> DbxResult<Value> + Send + Sync + 'static,
23 {
24 Self {
25 name: name.into(),
26 signature,
27 func: Box::new(func),
28 }
29 }
30}
31
32impl Callable for VectorizedUDF {
33 fn call(&self, _ctx: &ExecutionContext, args: &[Value]) -> DbxResult<Value> {
34 self.signature.validate_args(args)?;
36
37 for arg in args {
40 if arg.data_type() != DataType::Array {
41 return Err(DbxError::TypeMismatch {
42 expected: "Array".to_string(),
43 actual: format!("{:?}", arg.data_type()),
44 });
45 }
46 }
47
48 let mut expected_len = None;
50 for arg in args {
51 let arr = arg.as_array()?;
52 match expected_len {
53 None => expected_len = Some(arr.len()),
54 Some(len) => {
55 if len != arr.len() {
56 return Err(DbxError::InvalidArguments(
57 "Vectorized UDF arguments must have the same array length".to_string(),
58 ));
59 }
60 }
61 }
62 }
63
64 let result = (self.func)(args)?;
66
67 if result.data_type() != DataType::Array {
69 return Err(DbxError::TypeMismatch {
70 expected: "Array".to_string(),
71 actual: format!("{:?}", result.data_type()),
72 });
73 }
74
75 if let Some(expected) = expected_len {
77 let res_arr = result.as_array()?;
78 if res_arr.len() != expected {
79 return Err(DbxError::InvalidArguments(
80 "Vectorized UDF result array length must match input array lengths".to_string(),
81 ));
82 }
83 }
84
85 Ok(result)
86 }
87
88 fn name(&self) -> &str {
89 &self.name
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::automation::executor::ExecutionEngine;
101 use crate::engine::Database;
102 use std::sync::Arc;
103
104 #[test]
105 fn test_vectorized_udf_basic() {
106 let udf = VectorizedUDF::new(
108 "vec_double",
109 Signature {
110 params: vec![DataType::Array],
111 return_type: DataType::Array,
112 is_variadic: false,
113 },
114 |args| {
115 let input_array = args[0].as_array()?;
116 let mut result = Vec::with_capacity(input_array.len());
117 for val in input_array {
118 let x = val.as_i64()?;
120 result.push(Value::Int(x * 2));
121 }
122 Ok(Value::Array(result))
123 },
124 );
125
126 let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
127
128 let input_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
129 let result = udf.call(&ctx, &[Value::Array(input_vals)]).unwrap();
130
131 let out_arr = result.as_array().unwrap();
132 assert_eq!(out_arr.len(), 3);
133 assert_eq!(out_arr[0].as_i64().unwrap(), 20);
134 assert_eq!(out_arr[1].as_i64().unwrap(), 40);
135 assert_eq!(out_arr[2].as_i64().unwrap(), 60);
136 }
137
138 #[test]
139 fn test_vectorized_udf_multiple_args() {
140 let udf = VectorizedUDF::new(
142 "vec_add",
143 Signature {
144 params: vec![DataType::Array, DataType::Array],
145 return_type: DataType::Array,
146 is_variadic: false,
147 },
148 |args| {
149 let x_arr = args[0].as_array()?;
150 let y_arr = args[1].as_array()?;
151
152 let mut result = Vec::with_capacity(x_arr.len());
153 for i in 0..x_arr.len() {
154 let x = x_arr[i].as_i64()?;
155 let y = y_arr[i].as_i64()?;
156 result.push(Value::Int(x + y));
157 }
158 Ok(Value::Array(result))
159 },
160 );
161
162 let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
163
164 let x_vals = vec![Value::Int(1), Value::Int(2), Value::Int(3)];
165 let y_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
166
167 let result = udf
168 .call(&ctx, &[Value::Array(x_vals), Value::Array(y_vals)])
169 .unwrap();
170
171 let out_arr = result.as_array().unwrap();
172 assert_eq!(out_arr.len(), 3);
173 assert_eq!(out_arr[0].as_i64().unwrap(), 11);
174 assert_eq!(out_arr[1].as_i64().unwrap(), 22);
175 assert_eq!(out_arr[2].as_i64().unwrap(), 33);
176 }
177
178 #[test]
179 fn test_vectorized_udf_mismatch_lengths() {
180 let udf = VectorizedUDF::new(
182 "vec_add",
183 Signature {
184 params: vec![DataType::Array, DataType::Array],
185 return_type: DataType::Array,
186 is_variadic: false,
187 },
188 |_args| {
189 Ok(Value::Array(vec![])) },
191 );
192
193 let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
194
195 let x_vals = vec![Value::Int(1), Value::Int(2)];
196 let y_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
197
198 let result = udf.call(&ctx, &[Value::Array(x_vals), Value::Array(y_vals)]);
200 assert!(result.is_err());
201 match result.unwrap_err() {
202 DbxError::InvalidArguments(msg) => assert!(msg.contains("same array length")),
203 _ => panic!("Expected InvalidArguments error"),
204 }
205 }
206
207 #[test]
208 fn test_vectorized_udf_with_engine() {
209 let engine = ExecutionEngine::new();
210
211 let udf = Arc::new(VectorizedUDF::new(
212 "vec_triple",
213 Signature {
214 params: vec![DataType::Array],
215 return_type: DataType::Array,
216 is_variadic: false,
217 },
218 |args| {
219 let input_array = args[0].as_array()?;
220 let mut result = Vec::with_capacity(input_array.len());
221 for val in input_array {
222 let x = val.as_i64()?;
223 result.push(Value::Int(x * 3));
224 }
225 Ok(Value::Array(result))
226 },
227 ));
228
229 engine.register(udf).unwrap();
230
231 let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
232 let result = engine
233 .execute(
234 "vec_triple",
235 &ctx,
236 &[Value::Array(vec![Value::Int(5), Value::Int(10)])],
237 )
238 .unwrap();
239
240 let out_arr = result.as_array().unwrap();
241 assert_eq!(out_arr[0].as_i64().unwrap(), 15);
242 assert_eq!(out_arr[1].as_i64().unwrap(), 30);
243 }
244}