1pub mod apply;
6pub mod clickhouse;
7pub mod eval;
8pub mod placeholder;
9
10use std::str::FromStr;
11
12use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
13use datafusion::common::{plan_datafusion_err, plan_err};
14use datafusion::error::Result;
15use datafusion::logical_expr::ReturnFieldArgs;
16use datafusion::prelude::SessionContext;
17use datafusion::scalar::ScalarValue;
18
19pub fn register_clickhouse_functions(ctx: &SessionContext) {
24 ctx.register_udf(eval::clickhouse_eval_udf());
25 ctx.register_udf(clickhouse::clickhouse_udf());
26 ctx.register_udf(apply::clickhouse_apply_udf());
27}
28
29fn extract_return_field_from_args(name: &str, args: &ReturnFieldArgs<'_>) -> Result<FieldRef> {
31 if let Some(Some(
32 ScalarValue::Utf8(Some(return_type_str))
33 | ScalarValue::Utf8View(Some(return_type_str))
34 | ScalarValue::LargeUtf8(Some(return_type_str)),
35 )) = &args.scalar_arguments.last()
36 {
37 let dt = DataType::from_str(return_type_str.as_str())
38 .map_err(|e| plan_datafusion_err!("Invalid return type for {name}: {e}"))?;
39 Ok(udf_field_from_fields(name, dt, args.arg_fields))
40 } else {
41 plan_err!("Expected return type literal in scalar arguments for {name}")
42 }
43}
44
45fn udf_field_from_fields(name: &str, dt: DataType, fields: &[FieldRef]) -> FieldRef {
46 let mut placeholder_nullable = false;
52 if fields.len() >= 3 && fields.first().is_some_and(|f| f.name().starts_with('$')) {
53 let rev_fields = fields.iter().rev();
54 for (pl, col) in fields.iter().zip(rev_fields) {
55 if pl.name().starts_with('$') {
56 placeholder_nullable |= col.is_nullable();
57 } else {
58 break;
59 }
60 }
61 return Field::new(name, dt, placeholder_nullable).into();
62 }
63
64 let nullable = fields.iter().any(|a| {
68 !matches!(
69 a.data_type(),
70 &DataType::List(_) | &DataType::ListView(_) | &DataType::LargeList(_)
71 ) && a.is_nullable()
72 });
73 Field::new(name, dt, nullable).into()
74}
75
76pub mod functions {
77 use datafusion::common::Column;
79 use datafusion::logical_expr::expr::Placeholder;
80 use datafusion::prelude::Expr;
81 use datafusion::scalar::ScalarValue;
82
83 pub fn clickhouse_eval(expr: impl Into<String>, return_type: &str) -> Expr {
85 super::eval::clickhouse_eval_udf().call(vec![
86 Expr::Literal(ScalarValue::Utf8(Some(expr.into())), None),
87 Expr::Literal(ScalarValue::Utf8(Some(return_type.to_string())), None),
88 ])
89 }
90
91 pub fn clickhouse(expr: Expr, return_type: &str) -> Expr {
93 super::clickhouse::clickhouse_udf()
94 .call(vec![expr, Expr::Literal(ScalarValue::Utf8(Some(return_type.to_string())), None)])
95 }
96
97 pub fn apply<C: IntoIterator<Item = Column>>(
100 expr: Expr,
101 columns: C,
102 return_type: &str,
103 ) -> Expr {
104 let (mut args, columns): (Vec<_>, Vec<_>) = columns
105 .into_iter()
106 .enumerate()
107 .map(|(i, c)| {
108 (
109 Expr::Placeholder(Placeholder { id: format!("x{i}"), data_type: None }),
110 Expr::Column(c),
111 )
112 })
113 .unzip();
114 args.push(expr);
115 args.extend(columns);
116 let apply_udf = super::apply::clickhouse_apply_udf().call(args);
117 clickhouse(apply_udf, return_type)
118 }
119
120 pub fn lambda<C: IntoIterator<Item = Column>>(
122 expr: Expr,
123 columns: C,
124 return_type: &str,
125 ) -> Expr {
126 apply(expr, columns, return_type)
127 }
128
129 pub fn clickhouse_apply<C: IntoIterator<Item = Column>>(
131 expr: Expr,
132 columns: C,
133 return_type: &str,
134 ) -> Expr {
135 apply(expr, columns, return_type)
136 }
137
138 pub fn clickhouse_lambda<C: IntoIterator<Item = Column>>(
140 expr: Expr,
141 columns: C,
142 return_type: &str,
143 ) -> Expr {
144 apply(expr, columns, return_type)
145 }
146
147 pub fn clickhouse_map<C: IntoIterator<Item = Column>>(
149 expr: Expr,
150 columns: C,
151 return_type: &str,
152 ) -> Expr {
153 apply(expr, columns, return_type)
154 }
155
156 #[cfg(test)]
157 mod tests {
158 use std::sync::Arc;
159
160 use datafusion::common::ScalarValue;
161 use datafusion::logical_expr::expr::ScalarFunction;
162 use datafusion::prelude::{Expr, lit};
163
164 use super::*;
165 use crate::prelude::clickhouse_eval_udf;
166 use crate::udfs::apply::clickhouse_apply_udf;
167 use crate::udfs::clickhouse::clickhouse_udf;
168 use crate::udfs::functions::clickhouse_eval;
169
170 #[test]
171 fn test_create_simple_udf() {
172 assert_eq!(
173 clickhouse_eval("count(*)", "UInt64"),
174 Expr::ScalarFunction(ScalarFunction {
175 func: Arc::new(clickhouse_eval_udf()),
176 args: vec![
177 Expr::Literal(ScalarValue::Utf8(Some("count(*)".to_string())), None),
178 Expr::Literal(ScalarValue::Utf8(Some("UInt64".to_string())), None),
179 ],
180 })
181 );
182 }
183
184 #[test]
185 fn test_clickhouse_udf() {
186 assert_eq!(
187 clickhouse(
188 Expr::Literal(ScalarValue::Utf8(Some("count(*)".to_string())), None),
189 "UInt64"
190 ),
191 Expr::ScalarFunction(ScalarFunction {
192 func: Arc::new(clickhouse_udf()),
193 args: vec![
194 Expr::Literal(ScalarValue::Utf8(Some("count(*)".to_string())), None),
195 Expr::Literal(ScalarValue::Utf8(Some("UInt64".to_string())), None),
196 ],
197 })
198 );
199 }
200
201 #[test]
202 fn test_clickhouse_apply_udf() {
203 let expr = Expr::Column(Column::from_name("id")) + lit(5);
204 let columns = vec![Column::from_name("id")];
205 let return_type = "UInt64";
206 let apply_expr = apply(expr.clone(), columns.clone(), return_type);
207 assert_eq!(
208 apply_expr,
209 Expr::ScalarFunction(ScalarFunction {
210 func: Arc::new(clickhouse_udf()),
211 args: vec![
212 Expr::ScalarFunction(ScalarFunction {
214 func: Arc::new(clickhouse_apply_udf()),
215 args: vec![
216 Expr::Placeholder(Placeholder {
217 id: "x0".to_string(),
218 data_type: None,
219 }),
220 Expr::Column(Column::from_name("id")) + lit(5),
221 Expr::Column(Column::from_name("id")),
222 ],
223 }),
224 Expr::Literal(ScalarValue::Utf8(Some("UInt64".to_string())), None),
225 ],
226 })
227 );
228
229 let lambda_expr = lambda(expr.clone(), columns.clone(), return_type);
230 let ch_apply_expr = clickhouse_apply(expr.clone(), columns.clone(), return_type);
231 let ch_lambda_expr = clickhouse_lambda(expr.clone(), columns.clone(), return_type);
232 let ch_map_expr = clickhouse_map(expr.clone(), columns.clone(), return_type);
233 assert_eq!(apply_expr, lambda_expr);
234 assert_eq!(apply_expr, ch_apply_expr);
235 assert_eq!(apply_expr, ch_lambda_expr);
236 assert_eq!(apply_expr, ch_map_expr);
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use std::sync::Arc;
244
245 use datafusion::arrow::datatypes::{DataType, Field};
246 use datafusion::common::ScalarValue;
247 use datafusion::logical_expr::ReturnFieldArgs;
248 use datafusion::prelude::SessionContext;
249
250 use super::*;
251
252 #[test]
253 fn test_register_clickhouse_functions() {
254 let ctx = SessionContext::new();
255 register_clickhouse_functions(&ctx);
256
257 let state = ctx.state();
259 let functions = state.scalar_functions();
260 assert!(functions.contains_key("clickhouse_eval"));
261 assert!(functions.contains_key("clickhouse"));
262 assert!(functions.contains_key("apply"));
263 }
264
265 #[test]
266 fn test_extract_return_field_from_args_utf8() {
267 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
268 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
269 let scalar = [
270 Some(ScalarValue::Utf8(Some("count()".to_string()))),
271 Some(ScalarValue::Utf8(Some("Int64".to_string()))),
272 ];
273 let args = ReturnFieldArgs {
274 arg_fields: &[field1, field2],
275 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
276 };
277 let result = extract_return_field_from_args("test_func", &args);
278 assert!(result.is_ok());
279 let field = result.unwrap();
280 assert_eq!(field.name(), "test_func");
281 assert_eq!(field.data_type(), &DataType::Int64);
282 assert!(!field.is_nullable());
283 }
284
285 #[test]
286 fn test_extract_return_field_from_args_utf8_view() {
287 let field1 = Arc::new(Field::new("syntax", DataType::Utf8View, false));
288 let field2 = Arc::new(Field::new("type", DataType::Utf8View, false));
289 let scalar = [
290 Some(ScalarValue::Utf8View(Some("sum(x)".to_string()))),
291 Some(ScalarValue::Utf8View(Some("Float64".to_string()))),
292 ];
293 let args = ReturnFieldArgs {
294 arg_fields: &[field1, field2],
295 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
296 };
297
298 let result = extract_return_field_from_args("test_func", &args);
299 assert!(result.is_ok());
300 let field = result.unwrap();
301 assert_eq!(field.data_type(), &DataType::Float64);
302 }
303
304 #[test]
305 fn test_extract_return_field_from_args_large_utf8() {
306 let field1 = Arc::new(Field::new("syntax", DataType::LargeUtf8, false));
307 let field2 = Arc::new(Field::new("type", DataType::LargeUtf8, false));
308 let scalar = [
309 Some(ScalarValue::LargeUtf8(Some("avg(y)".to_string()))),
310 Some(ScalarValue::LargeUtf8(Some("Boolean".to_string()))),
311 ];
312 let args = ReturnFieldArgs {
313 arg_fields: &[field1, field2],
314 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
315 };
316
317 let result = extract_return_field_from_args("test_func", &args);
318 assert!(result.is_ok());
319 let field = result.unwrap();
320 assert_eq!(field.data_type(), &DataType::Boolean);
321 }
322
323 #[test]
324 fn test_extract_return_field_from_args_invalid_type() {
325 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
326 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
327 let scalar = [
328 Some(ScalarValue::Utf8(Some("count()".to_string()))),
329 Some(ScalarValue::Utf8(Some("InvalidDataType".to_string()))),
330 ];
331 let args = ReturnFieldArgs {
332 arg_fields: &[field1, field2],
333 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
334 };
335
336 let result = extract_return_field_from_args("test_func", &args);
337 assert!(result.is_err());
338 assert!(result.unwrap_err().to_string().contains("Invalid return type"));
339 }
340
341 #[test]
342 fn test_extract_return_field_from_args_no_last_arg() {
343 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
344 let args = ReturnFieldArgs { arg_fields: &[field1], scalar_arguments: &[] };
345
346 let result = extract_return_field_from_args("test_func", &args);
347 assert!(result.is_err());
348 assert!(result.unwrap_err().to_string().contains("Expected return type"));
349 }
350
351 #[test]
352 fn test_extract_return_field_from_args_null_last_arg() {
353 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
354 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
355 let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string()))), None];
356 let args = ReturnFieldArgs {
357 arg_fields: &[field1, field2],
358 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
359 };
360
361 let result = extract_return_field_from_args("test_func", &args);
362 assert!(result.is_err());
363 assert!(result.unwrap_err().to_string().contains("Expected return type"));
364 }
365
366 #[test]
367 fn test_extract_return_field_from_args_non_string_last_arg() {
368 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
369 let field2 = Arc::new(Field::new("type", DataType::Int32, false));
370 let scalar = [
371 Some(ScalarValue::Utf8(Some("count()".to_string()))),
372 Some(ScalarValue::Int32(Some(42))),
373 ];
374 let args = ReturnFieldArgs {
375 arg_fields: &[field1, field2],
376 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
377 };
378
379 let result = extract_return_field_from_args("test_func", &args);
380 assert!(result.is_err());
381 assert!(result.unwrap_err().to_string().contains("Expected return type"));
382 }
383
384 #[test]
385 fn test_extract_return_field_from_args_empty_string_last_arg() {
386 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
387 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
388 let scalar = [
389 Some(ScalarValue::Utf8(Some("count()".to_string()))),
390 Some(ScalarValue::Utf8(Some(String::new()))),
391 ];
392 let args = ReturnFieldArgs {
393 arg_fields: &[field1, field2],
394 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
395 };
396
397 let result = extract_return_field_from_args("test_func", &args);
398 assert!(result.is_err());
399 assert!(result.unwrap_err().to_string().contains("Invalid return type"));
400 }
401
402 #[test]
403 fn test_extract_return_field_from_args_null_string_last_arg() {
404 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
405 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
406 let scalar =
407 [Some(ScalarValue::Utf8(Some("count()".to_string()))), Some(ScalarValue::Utf8(None))];
408 let args = ReturnFieldArgs {
409 arg_fields: &[field1, field2],
410 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
411 };
412
413 let result = extract_return_field_from_args("test_func", &args);
414 assert!(result.is_err());
415 assert!(result.unwrap_err().to_string().contains("Expected return type"));
416 }
417}