Skip to main content

dbx_core/automation/udf/
vectorized.rs

1//! Vectorized UDF implementation
2//!
3//! Vectorized UDF: 다수 값(콜럼/배열) → 다수 값(콜럼/배열) 변환 함수
4
5use crate::automation::callable::{Callable, DataType, ExecutionContext, Signature, Value};
6use crate::error::{DbxError, DbxResult};
7
8/// Type alias for vectorized UDF function
9type VectorizedFn = Box<dyn Fn(&[Value]) -> DbxResult<Value> + Send + Sync>;
10
11/// Vectorized UDF (배열 단위 처리)
12pub struct VectorizedUDF {
13    name: String,
14    signature: Signature,
15    func: VectorizedFn,
16}
17
18impl VectorizedUDF {
19    /// 새 Vectorized UDF 생성
20    pub fn new<F>(name: impl Into<String>, signature: Signature, func: F) -> Self
21    where
22        F: Fn(&[Value]) -> DbxResult<Value> + Send + Sync + 'static,
23    {
24        Self {
25            name: name.into(),
26            signature,
27            func: Box::new(func),
28        }
29    }
30}
31
32impl Callable for VectorizedUDF {
33    fn call(&self, _ctx: &ExecutionContext, args: &[Value]) -> DbxResult<Value> {
34        // 1. Signature 인자 개수 및 Array 타입 검증
35        self.signature.validate_args(args)?;
36
37        // 2. 추가 검증: 입력된 모든 인자가 Array인지 재확인 (방어적 프로그래밍)
38        // (Signature.params에 DataType::Array로 명시되어 있다면 validate_args에서 이미 검증됨)
39        for arg in args {
40            if arg.data_type() != DataType::Array {
41                return Err(DbxError::TypeMismatch {
42                    expected: "Array".to_string(),
43                    actual: format!("{:?}", arg.data_type()),
44                });
45            }
46        }
47
48        // 3. 배열들의 길이가 모두 일치하는지 검증
49        let mut expected_len = None;
50        for arg in args {
51            let arr = arg.as_array()?;
52            match expected_len {
53                None => expected_len = Some(arr.len()),
54                Some(len) => {
55                    if len != arr.len() {
56                        return Err(DbxError::InvalidArguments(
57                            "Vectorized UDF arguments must have the same array length".to_string(),
58                        ));
59                    }
60                }
61            }
62        }
63
64        // 4. 함수 실행
65        let result = (self.func)(args)?;
66
67        // 5. 결과 타입 검증 (반드시 Array를 반환해야 함)
68        if result.data_type() != DataType::Array {
69            return Err(DbxError::TypeMismatch {
70                expected: "Array".to_string(),
71                actual: format!("{:?}", result.data_type()),
72            });
73        }
74
75        // 6. 반환 배열 길이 검증 (입력 길이와 같아야 함, 단 입력이 없었던 경우는 예외)
76        if let Some(expected) = expected_len {
77            let res_arr = result.as_array()?;
78            if res_arr.len() != expected {
79                return Err(DbxError::InvalidArguments(
80                    "Vectorized UDF result array length must match input array lengths".to_string(),
81                ));
82            }
83        }
84
85        Ok(result)
86    }
87
88    fn name(&self) -> &str {
89        &self.name
90    }
91
92    fn signature(&self) -> &Signature {
93        &self.signature
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::automation::executor::ExecutionEngine;
101    use crate::engine::Database;
102    use std::sync::Arc;
103
104    #[test]
105    fn test_vectorized_udf_basic() {
106        // UDF: 배열의 각 요소에 대해 x * 2를 수행
107        let udf = VectorizedUDF::new(
108            "vec_double",
109            Signature {
110                params: vec![DataType::Array],
111                return_type: DataType::Array,
112                is_variadic: false,
113            },
114            |args| {
115                let input_array = args[0].as_array()?;
116                let mut result = Vec::with_capacity(input_array.len());
117                for val in input_array {
118                    // 내부 요소가 Int라고 가정
119                    let x = val.as_i64()?;
120                    result.push(Value::Int(x * 2));
121                }
122                Ok(Value::Array(result))
123            },
124        );
125
126        let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
127
128        let input_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
129        let result = udf.call(&ctx, &[Value::Array(input_vals)]).unwrap();
130
131        let out_arr = result.as_array().unwrap();
132        assert_eq!(out_arr.len(), 3);
133        assert_eq!(out_arr[0].as_i64().unwrap(), 20);
134        assert_eq!(out_arr[1].as_i64().unwrap(), 40);
135        assert_eq!(out_arr[2].as_i64().unwrap(), 60);
136    }
137
138    #[test]
139    fn test_vectorized_udf_multiple_args() {
140        // UDF: 배열 콜럼 2개를 받아 x + y를 수행
141        let udf = VectorizedUDF::new(
142            "vec_add",
143            Signature {
144                params: vec![DataType::Array, DataType::Array],
145                return_type: DataType::Array,
146                is_variadic: false,
147            },
148            |args| {
149                let x_arr = args[0].as_array()?;
150                let y_arr = args[1].as_array()?;
151
152                let mut result = Vec::with_capacity(x_arr.len());
153                for i in 0..x_arr.len() {
154                    let x = x_arr[i].as_i64()?;
155                    let y = y_arr[i].as_i64()?;
156                    result.push(Value::Int(x + y));
157                }
158                Ok(Value::Array(result))
159            },
160        );
161
162        let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
163
164        let x_vals = vec![Value::Int(1), Value::Int(2), Value::Int(3)];
165        let y_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
166
167        let result = udf
168            .call(&ctx, &[Value::Array(x_vals), Value::Array(y_vals)])
169            .unwrap();
170
171        let out_arr = result.as_array().unwrap();
172        assert_eq!(out_arr.len(), 3);
173        assert_eq!(out_arr[0].as_i64().unwrap(), 11);
174        assert_eq!(out_arr[1].as_i64().unwrap(), 22);
175        assert_eq!(out_arr[2].as_i64().unwrap(), 33);
176    }
177
178    #[test]
179    fn test_vectorized_udf_mismatch_lengths() {
180        // UDF: 길이가 다른 두 콜럼 전달 시 에러를 반환해야 함
181        let udf = VectorizedUDF::new(
182            "vec_add",
183            Signature {
184                params: vec![DataType::Array, DataType::Array],
185                return_type: DataType::Array,
186                is_variadic: false,
187            },
188            |_args| {
189                Ok(Value::Array(vec![])) // 실제 실행 전 프레임워크 레벨에서 걸러져야 함 
190            },
191        );
192
193        let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
194
195        let x_vals = vec![Value::Int(1), Value::Int(2)];
196        let y_vals = vec![Value::Int(10), Value::Int(20), Value::Int(30)];
197
198        // 길이가 다르므로 InvalidArguments 에러가 나야 함
199        let result = udf.call(&ctx, &[Value::Array(x_vals), Value::Array(y_vals)]);
200        assert!(result.is_err());
201        match result.unwrap_err() {
202            DbxError::InvalidArguments(msg) => assert!(msg.contains("same array length")),
203            _ => panic!("Expected InvalidArguments error"),
204        }
205    }
206
207    #[test]
208    fn test_vectorized_udf_with_engine() {
209        let engine = ExecutionEngine::new();
210
211        let udf = Arc::new(VectorizedUDF::new(
212            "vec_triple",
213            Signature {
214                params: vec![DataType::Array],
215                return_type: DataType::Array,
216                is_variadic: false,
217            },
218            |args| {
219                let input_array = args[0].as_array()?;
220                let mut result = Vec::with_capacity(input_array.len());
221                for val in input_array {
222                    let x = val.as_i64()?;
223                    result.push(Value::Int(x * 3));
224                }
225                Ok(Value::Array(result))
226            },
227        ));
228
229        engine.register(udf).unwrap();
230
231        let ctx = ExecutionContext::new(Arc::new(Database::open_in_memory().unwrap()));
232        let result = engine
233            .execute(
234                "vec_triple",
235                &ctx,
236                &[Value::Array(vec![Value::Int(5), Value::Int(10)])],
237            )
238            .unwrap();
239
240        let out_arr = result.as_array().unwrap();
241        assert_eq!(out_arr[0].as_i64().unwrap(), 15);
242        assert_eq!(out_arr[1].as_i64().unwrap(), 30);
243    }
244}