dbx_core/automation/udf/
table.rs1use crate::automation::callable::{Callable, ExecutionContext, Signature, Value};
6use crate::error::DbxResult;
7
8type TableFn = Box<dyn Fn(&ExecutionContext, &[Value]) -> DbxResult<Vec<Vec<Value>>> + Send + Sync>;
10
11pub struct TableUDF {
13 name: String,
14 signature: Signature,
15 func: TableFn,
16}
17
18impl TableUDF {
19 pub fn new<F>(name: impl Into<String>, signature: Signature, func: F) -> Self
21 where
22 F: Fn(&ExecutionContext, &[Value]) -> DbxResult<Vec<Vec<Value>>> + Send + Sync + 'static,
23 {
24 Self {
25 name: name.into(),
26 signature,
27 func: Box::new(func),
28 }
29 }
30}
31
32impl Callable for TableUDF {
33 fn call(&self, ctx: &ExecutionContext, args: &[Value]) -> DbxResult<Value> {
34 let rows = (self.func)(ctx, args)?;
35 Ok(Value::Table(rows))
36 }
37
38 fn name(&self) -> &str {
39 &self.name
40 }
41
42 fn signature(&self) -> &Signature {
43 &self.signature
44 }
45}
46
47impl TableUDF {
48 pub fn execute(&self, ctx: &ExecutionContext, args: &[Value]) -> DbxResult<Vec<Vec<Value>>> {
50 (self.func)(ctx, args)
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use crate::automation::callable::DataType;
58 use crate::engine::Database;
59 use std::sync::Arc;
60
61 #[test]
62 fn test_table_udf_basic() {
63 let table_udf = TableUDF::new(
64 "generate_series",
65 Signature {
66 params: vec![DataType::Int, DataType::Int],
67 return_type: DataType::Int,
68 is_variadic: false,
69 },
70 |_ctx, args| {
71 let start = args[0].as_i64()?;
72 let end = args[1].as_i64()?;
73
74 let mut rows = Vec::new();
75 for i in start..=end {
76 rows.push(vec![Value::Int(i)]);
77 }
78
79 Ok(rows)
80 },
81 );
82
83 let db = Database::open_in_memory().unwrap();
84 let ctx = ExecutionContext::new(Arc::new(db));
85
86 let rows = table_udf
87 .execute(&ctx, &[Value::Int(1), Value::Int(5)])
88 .unwrap();
89
90 assert_eq!(rows.len(), 5);
91 assert_eq!(rows[0][0].as_i64().unwrap(), 1);
92 assert_eq!(rows[4][0].as_i64().unwrap(), 5);
93 }
94
95 #[test]
96 fn test_table_udf_multi_column() {
97 let table_udf = TableUDF::new(
98 "user_data",
99 Signature {
100 params: vec![],
101 return_type: DataType::String,
102 is_variadic: false,
103 },
104 |_ctx, _args| {
105 Ok(vec![
106 vec![Value::Int(1), Value::String("Alice".to_string())],
107 vec![Value::Int(2), Value::String("Bob".to_string())],
108 vec![Value::Int(3), Value::String("Charlie".to_string())],
109 ])
110 },
111 );
112
113 let db = Database::open_in_memory().unwrap();
114 let ctx = ExecutionContext::new(Arc::new(db));
115
116 let rows = table_udf.execute(&ctx, &[]).unwrap();
117
118 assert_eq!(rows.len(), 3);
119 assert_eq!(rows[0].len(), 2);
120 assert_eq!(rows[1][1].as_str().unwrap(), "Bob");
121 }
122
123 #[test]
124 fn test_table_udf_with_filter() {
125 let table_udf = TableUDF::new(
126 "filtered_range",
127 Signature {
128 params: vec![DataType::Int, DataType::Int, DataType::Int],
129 return_type: DataType::Int,
130 is_variadic: false,
131 },
132 |_ctx, args| {
133 let start = args[0].as_i64()?;
134 let end = args[1].as_i64()?;
135 let step = args[2].as_i64()?;
136
137 let mut rows = Vec::new();
138 let mut current = start;
139 while current <= end {
140 rows.push(vec![Value::Int(current)]);
141 current += step;
142 }
143
144 Ok(rows)
145 },
146 );
147
148 let db = Database::open_in_memory().unwrap();
149 let ctx = ExecutionContext::new(Arc::new(db));
150
151 let rows = table_udf
152 .execute(&ctx, &[Value::Int(0), Value::Int(10), Value::Int(2)])
153 .unwrap();
154
155 assert_eq!(rows.len(), 6); assert_eq!(rows[3][0].as_i64().unwrap(), 6);
157 }
158
159 #[test]
160 fn test_table_udf_with_engine() {
161 use crate::automation::ExecutionEngine;
162
163 let engine = ExecutionEngine::new();
164
165 let table_udf = Arc::new(TableUDF::new(
166 "range",
167 Signature {
168 params: vec![DataType::Int],
169 return_type: DataType::Int,
170 is_variadic: false,
171 },
172 |_ctx, args| {
173 let n = args[0].as_i64()?;
174 let mut rows = Vec::new();
175 for i in 0..n {
176 rows.push(vec![Value::Int(i)]);
177 }
178 Ok(rows)
179 },
180 ));
181
182 engine.register(table_udf).unwrap();
183
184 let db = Database::open_in_memory().unwrap();
185 let ctx = ExecutionContext::new(Arc::new(db));
186
187 let result = engine.execute("range", &ctx, &[Value::Int(3)]).unwrap();
189
190 let table = result.as_table().unwrap();
192 assert_eq!(table.len(), 3);
193 assert_eq!(table[0][0].as_i64().unwrap(), 0);
194 }
195}