Skip to main content

dbx_core/engine/
udf_api.rs

1//! UDF API for Database
2//!
3//! UDF 등록 및 호출 API
4
5use crate::automation::ScalarUDF;
6use crate::automation::callable::{ExecutionContext, Signature, Value};
7use crate::engine::Database;
8use crate::error::DbxResult;
9use std::sync::Arc;
10
11impl Database {
12    /// Scalar UDF 등록
13    ///
14    /// # 예제
15    ///
16    /// ```rust
17    /// use dbx_core::Database;
18    /// use dbx_core::automation::callable::{DataType, Signature, Value};
19    ///
20    /// # fn main() -> dbx_core::DbxResult<()> {
21    /// let db = Database::open_in_memory()?;
22    ///
23    /// // UDF: x * 2
24    /// db.register_scalar_udf(
25    ///     "double",
26    ///     Signature {
27    ///         params: vec![DataType::Int],
28    ///         return_type: DataType::Int,
29    ///         is_variadic: false,
30    ///     },
31    ///     |args| {
32    ///         let x = args[0].as_i64()?;
33    ///         Ok(Value::Int(x * 2))
34    ///     },
35    /// )?;
36    ///
37    /// // UDF 호출
38    /// let result = db.call_udf("double", &[Value::Int(21)])?;
39    /// assert_eq!(result.as_i64()?, 42);
40    /// # Ok(())
41    /// # }
42    /// ```
43    pub fn register_scalar_udf<F>(
44        &self,
45        name: impl Into<String>,
46        signature: Signature,
47        func: F,
48    ) -> DbxResult<()>
49    where
50        F: Fn(&[Value]) -> DbxResult<Value> + Send + Sync + 'static,
51    {
52        let udf = Arc::new(ScalarUDF::new(name, signature, func));
53        self.automation_engine.register(udf)
54    }
55
56    /// UDF 호출
57    ///
58    /// # 예제
59    ///
60    /// ```rust
61    /// use dbx_core::Database;
62    /// use dbx_core::automation::callable::{DataType, Signature, Value};
63    ///
64    /// # fn main() -> dbx_core::DbxResult<()> {
65    /// let db = Database::open_in_memory()?;
66    ///
67    /// db.register_scalar_udf(
68    ///     "add",
69    ///     Signature {
70    ///         params: vec![DataType::Int, DataType::Int],
71    ///         return_type: DataType::Int,
72    ///         is_variadic: false,
73    ///     },
74    ///     |args| {
75    ///         let x = args[0].as_i64()?;
76    ///         let y = args[1].as_i64()?;
77    ///         Ok(Value::Int(x + y))
78    ///     },
79    /// )?;
80    ///
81    /// let result = db.call_udf("add", &[Value::Int(10), Value::Int(32)])?;
82    /// assert_eq!(result.as_i64()?, 42);
83    /// # Ok(())
84    /// # }
85    /// ```
86    pub fn call_udf(&self, name: &str, args: &[Value]) -> DbxResult<Value> {
87        // Use a temporary in-memory DB for ExecutionContext
88        // Note: Ideally ExecutionContext would accept Option<Arc<Database>>
89        // to avoid this allocation, but that requires a broader refactor.
90        // For now, this is sufficient as UDFs rarely need the DB context.
91        let temp_db = Arc::new(Database::open_in_memory()?);
92        let ctx = ExecutionContext::new(temp_db);
93        self.automation_engine.execute(name, &ctx, args)
94    }
95
96    /// 등록된 UDF 목록 조회
97    pub fn list_udfs(&self) -> DbxResult<Vec<String>> {
98        self.automation_engine.list()
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::automation::callable::{DataType, Signature, Value};
106
107    #[test]
108    fn test_register_and_call_udf() {
109        let db = Database::open_in_memory().unwrap();
110
111        // UDF 등록
112        db.register_scalar_udf(
113            "triple",
114            Signature {
115                params: vec![DataType::Int],
116                return_type: DataType::Int,
117                is_variadic: false,
118            },
119            |args| {
120                let x = args[0].as_i64()?;
121                Ok(Value::Int(x * 3))
122            },
123        )
124        .unwrap();
125
126        // UDF 호출
127        let result = db.call_udf("triple", &[Value::Int(14)]).unwrap();
128        assert_eq!(result.as_i64().unwrap(), 42);
129    }
130
131    #[test]
132    fn test_multiple_udfs() {
133        let db = Database::open_in_memory().unwrap();
134
135        // UDF 1: double
136        db.register_scalar_udf(
137            "double",
138            Signature {
139                params: vec![DataType::Int],
140                return_type: DataType::Int,
141                is_variadic: false,
142            },
143            |args| {
144                let x = args[0].as_i64()?;
145                Ok(Value::Int(x * 2))
146            },
147        )
148        .unwrap();
149
150        // UDF 2: add
151        db.register_scalar_udf(
152            "add",
153            Signature {
154                params: vec![DataType::Int, DataType::Int],
155                return_type: DataType::Int,
156                is_variadic: false,
157            },
158            |args| {
159                let x = args[0].as_i64()?;
160                let y = args[1].as_i64()?;
161                Ok(Value::Int(x + y))
162            },
163        )
164        .unwrap();
165
166        // 호출
167        let r1 = db.call_udf("double", &[Value::Int(21)]).unwrap();
168        let r2 = db
169            .call_udf("add", &[Value::Int(10), Value::Int(32)])
170            .unwrap();
171
172        assert_eq!(r1.as_i64().unwrap(), 42);
173        assert_eq!(r2.as_i64().unwrap(), 42);
174    }
175
176    #[test]
177    fn test_list_udfs() {
178        let db = Database::open_in_memory().unwrap();
179
180        db.register_scalar_udf(
181            "func1",
182            Signature {
183                params: vec![DataType::Int],
184                return_type: DataType::Int,
185                is_variadic: false,
186            },
187            |args| Ok(args[0].clone()),
188        )
189        .unwrap();
190
191        db.register_scalar_udf(
192            "func2",
193            Signature {
194                params: vec![DataType::String],
195                return_type: DataType::String,
196                is_variadic: false,
197            },
198            |args| Ok(args[0].clone()),
199        )
200        .unwrap();
201
202        let udfs = db.list_udfs().unwrap();
203        assert_eq!(udfs.len(), 2);
204        assert!(udfs.contains(&"func1".to_string()));
205        assert!(udfs.contains(&"func2".to_string()));
206    }
207}