clickhouse_datafusion/
udfs.rs

1//! Various UDFs providing `DataFusion`'s sql parsing with some `ClickHouse` specific functionality.
2//!
3//! [`self::eval::ClickHouseEval`] is a sort of 'escape-hatch' to allow passing syntax directly
4//! to `ClickHouse` as SQL.
5pub mod apply;
6pub mod clickhouse;
7pub mod eval;
8pub mod placeholder;
9
10use std::str::FromStr;
11
12use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
13use datafusion::common::{plan_datafusion_err, plan_err};
14use datafusion::error::Result;
15use datafusion::logical_expr::ReturnFieldArgs;
16use datafusion::prelude::SessionContext;
17use datafusion::scalar::ScalarValue;
18
19// TODO: Docs - explain how this registers the best-effort UDF that can be used when the full
20// `ClickHouseQueryPlanner` is not available.
21//
22/// Registers `ClickHouse`-specific UDFs with the provided [`SessionContext`].
23pub fn register_clickhouse_functions(ctx: &SessionContext) {
24    ctx.register_udf(eval::clickhouse_eval_udf());
25    ctx.register_udf(clickhouse::clickhouse_udf());
26    ctx.register_udf(apply::clickhouse_apply_udf());
27}
28
29/// Helper function to extract return [`DataType`] from second UDF arg
30fn extract_return_field_from_args(name: &str, args: &ReturnFieldArgs<'_>) -> Result<FieldRef> {
31    if let Some(Some(
32        ScalarValue::Utf8(Some(return_type_str))
33        | ScalarValue::Utf8View(Some(return_type_str))
34        | ScalarValue::LargeUtf8(Some(return_type_str)),
35    )) = &args.scalar_arguments.last()
36    {
37        let dt = DataType::from_str(return_type_str.as_str())
38            .map_err(|e| plan_datafusion_err!("Invalid return type for {name}: {e}"))?;
39        Ok(udf_field_from_fields(name, dt, args.arg_fields))
40    } else {
41        plan_err!("Expected return type literal in scalar arguments for {name}")
42    }
43}
44
45fn udf_field_from_fields(name: &str, dt: DataType, fields: &[FieldRef]) -> FieldRef {
46    // Apply/lambda will be indicated by placeholder fields. These are flagged as nullable by
47    // DataFusion during expr creation. But, the return field should only be nullable if any of the
48    // columns the placeholders refer to are nullable.
49    //
50    // Apply functions have at least 3 args.
51    let mut placeholder_nullable = false;
52    if fields.len() >= 3 && fields.first().is_some_and(|f| f.name().starts_with('$')) {
53        let rev_fields = fields.iter().rev();
54        for (pl, col) in fields.iter().zip(rev_fields) {
55            if pl.name().starts_with('$') {
56                placeholder_nullable |= col.is_nullable();
57            } else {
58                break;
59            }
60        }
61        return Field::new(name, dt, placeholder_nullable).into();
62    }
63
64    // Otherwise determine from the provided args
65
66    // Treat array returns differently. ClickHouse doesn't support nullable arrays.
67    let nullable = fields.iter().any(|a| {
68        !matches!(
69            a.data_type(),
70            &DataType::List(_) | &DataType::ListView(_) | &DataType::LargeList(_)
71        ) && a.is_nullable()
72    });
73    Field::new(name, dt, nullable).into()
74}
75
76pub mod functions {
77    //! List of functions to create the various `ClickHouse`-specific UDFs.
78    use datafusion::common::Column;
79    use datafusion::logical_expr::expr::Placeholder;
80    use datafusion::prelude::Expr;
81    use datafusion::scalar::ScalarValue;
82
83    /// Create a `ClickHouse` UDF to be 'evaluated' on the `ClickHouse` server.
84    pub fn clickhouse_eval(expr: impl Into<String>, return_type: &str) -> Expr {
85        super::eval::clickhouse_eval_udf().call(vec![
86            Expr::Literal(ScalarValue::Utf8(Some(expr.into())), None),
87            Expr::Literal(ScalarValue::Utf8(Some(return_type.to_string())), None),
88        ])
89    }
90
91    /// Create a `ClickHouse` UDF that will be executed on the `ClickHouse` server.
92    pub fn clickhouse(expr: Expr, return_type: &str) -> Expr {
93        super::clickhouse::clickhouse_udf()
94            .call(vec![expr, Expr::Literal(ScalarValue::Utf8(Some(return_type.to_string())), None)])
95    }
96
97    /// Create a `ClickHouse` Higher Order Function UDF that will be executed on the `ClickHouse`
98    /// server.
99    pub fn apply<C: IntoIterator<Item = Column>>(
100        expr: Expr,
101        columns: C,
102        return_type: &str,
103    ) -> Expr {
104        let (mut args, columns): (Vec<_>, Vec<_>) = columns
105            .into_iter()
106            .enumerate()
107            .map(|(i, c)| {
108                (
109                    Expr::Placeholder(Placeholder { id: format!("x{i}"), data_type: None }),
110                    Expr::Column(c),
111                )
112            })
113            .unzip();
114        args.push(expr);
115        args.extend(columns);
116        let apply_udf = super::apply::clickhouse_apply_udf().call(args);
117        clickhouse(apply_udf, return_type)
118    }
119
120    /// Alias for [`self::apply`]
121    pub fn lambda<C: IntoIterator<Item = Column>>(
122        expr: Expr,
123        columns: C,
124        return_type: &str,
125    ) -> Expr {
126        apply(expr, columns, return_type)
127    }
128
129    /// Alias for [`self::apply`]
130    pub fn clickhouse_apply<C: IntoIterator<Item = Column>>(
131        expr: Expr,
132        columns: C,
133        return_type: &str,
134    ) -> Expr {
135        apply(expr, columns, return_type)
136    }
137
138    /// Alias for [`self::apply`]
139    pub fn clickhouse_lambda<C: IntoIterator<Item = Column>>(
140        expr: Expr,
141        columns: C,
142        return_type: &str,
143    ) -> Expr {
144        apply(expr, columns, return_type)
145    }
146
147    /// Alias for [`self::apply`]
148    pub fn clickhouse_map<C: IntoIterator<Item = Column>>(
149        expr: Expr,
150        columns: C,
151        return_type: &str,
152    ) -> Expr {
153        apply(expr, columns, return_type)
154    }
155
156    #[cfg(test)]
157    mod tests {
158        use std::sync::Arc;
159
160        use datafusion::common::ScalarValue;
161        use datafusion::logical_expr::expr::ScalarFunction;
162        use datafusion::prelude::{Expr, lit};
163
164        use super::*;
165        use crate::prelude::clickhouse_eval_udf;
166        use crate::udfs::apply::clickhouse_apply_udf;
167        use crate::udfs::clickhouse::clickhouse_udf;
168        use crate::udfs::functions::clickhouse_eval;
169
170        #[test]
171        fn test_create_simple_udf() {
172            assert_eq!(
173                clickhouse_eval("count(*)", "UInt64"),
174                Expr::ScalarFunction(ScalarFunction {
175                    func: Arc::new(clickhouse_eval_udf()),
176                    args: vec![
177                        Expr::Literal(ScalarValue::Utf8(Some("count(*)".to_string())), None),
178                        Expr::Literal(ScalarValue::Utf8(Some("UInt64".to_string())), None),
179                    ],
180                })
181            );
182        }
183
184        #[test]
185        fn test_clickhouse_udf() {
186            assert_eq!(
187                clickhouse(
188                    Expr::Literal(ScalarValue::Utf8(Some("count(*)".to_string())), None),
189                    "UInt64"
190                ),
191                Expr::ScalarFunction(ScalarFunction {
192                    func: Arc::new(clickhouse_udf()),
193                    args: vec![
194                        Expr::Literal(ScalarValue::Utf8(Some("count(*)".to_string())), None),
195                        Expr::Literal(ScalarValue::Utf8(Some("UInt64".to_string())), None),
196                    ],
197                })
198            );
199        }
200
201        #[test]
202        fn test_clickhouse_apply_udf() {
203            let expr = Expr::Column(Column::from_name("id")) + lit(5);
204            let columns = vec![Column::from_name("id")];
205            let return_type = "UInt64";
206            let apply_expr = apply(expr.clone(), columns.clone(), return_type);
207            assert_eq!(
208                apply_expr,
209                Expr::ScalarFunction(ScalarFunction {
210                    func: Arc::new(clickhouse_udf()),
211                    args: vec![
212                        // apply is a HOF
213                        Expr::ScalarFunction(ScalarFunction {
214                            func: Arc::new(clickhouse_apply_udf()),
215                            args: vec![
216                                Expr::Placeholder(Placeholder {
217                                    id:        "x0".to_string(),
218                                    data_type: None,
219                                }),
220                                Expr::Column(Column::from_name("id")) + lit(5),
221                                Expr::Column(Column::from_name("id")),
222                            ],
223                        }),
224                        Expr::Literal(ScalarValue::Utf8(Some("UInt64".to_string())), None),
225                    ],
226                })
227            );
228
229            let lambda_expr = lambda(expr.clone(), columns.clone(), return_type);
230            let ch_apply_expr = clickhouse_apply(expr.clone(), columns.clone(), return_type);
231            let ch_lambda_expr = clickhouse_lambda(expr.clone(), columns.clone(), return_type);
232            let ch_map_expr = clickhouse_map(expr.clone(), columns.clone(), return_type);
233            assert_eq!(apply_expr, lambda_expr);
234            assert_eq!(apply_expr, ch_apply_expr);
235            assert_eq!(apply_expr, ch_lambda_expr);
236            assert_eq!(apply_expr, ch_map_expr);
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use std::sync::Arc;
244
245    use datafusion::arrow::datatypes::{DataType, Field};
246    use datafusion::common::ScalarValue;
247    use datafusion::logical_expr::ReturnFieldArgs;
248    use datafusion::prelude::SessionContext;
249
250    use super::*;
251
252    #[test]
253    fn test_register_clickhouse_functions() {
254        let ctx = SessionContext::new();
255        register_clickhouse_functions(&ctx);
256
257        // Check that the clickhouse function was registered
258        let state = ctx.state();
259        let functions = state.scalar_functions();
260        assert!(functions.contains_key("clickhouse_eval"));
261        assert!(functions.contains_key("clickhouse"));
262        assert!(functions.contains_key("apply"));
263    }
264
265    #[test]
266    fn test_extract_return_field_from_args_utf8() {
267        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
268        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
269        let scalar = [
270            Some(ScalarValue::Utf8(Some("count()".to_string()))),
271            Some(ScalarValue::Utf8(Some("Int64".to_string()))),
272        ];
273        let args = ReturnFieldArgs {
274            arg_fields:       &[field1, field2],
275            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
276        };
277        let result = extract_return_field_from_args("test_func", &args);
278        assert!(result.is_ok());
279        let field = result.unwrap();
280        assert_eq!(field.name(), "test_func");
281        assert_eq!(field.data_type(), &DataType::Int64);
282        assert!(!field.is_nullable());
283    }
284
285    #[test]
286    fn test_extract_return_field_from_args_utf8_view() {
287        let field1 = Arc::new(Field::new("syntax", DataType::Utf8View, false));
288        let field2 = Arc::new(Field::new("type", DataType::Utf8View, false));
289        let scalar = [
290            Some(ScalarValue::Utf8View(Some("sum(x)".to_string()))),
291            Some(ScalarValue::Utf8View(Some("Float64".to_string()))),
292        ];
293        let args = ReturnFieldArgs {
294            arg_fields:       &[field1, field2],
295            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
296        };
297
298        let result = extract_return_field_from_args("test_func", &args);
299        assert!(result.is_ok());
300        let field = result.unwrap();
301        assert_eq!(field.data_type(), &DataType::Float64);
302    }
303
304    #[test]
305    fn test_extract_return_field_from_args_large_utf8() {
306        let field1 = Arc::new(Field::new("syntax", DataType::LargeUtf8, false));
307        let field2 = Arc::new(Field::new("type", DataType::LargeUtf8, false));
308        let scalar = [
309            Some(ScalarValue::LargeUtf8(Some("avg(y)".to_string()))),
310            Some(ScalarValue::LargeUtf8(Some("Boolean".to_string()))),
311        ];
312        let args = ReturnFieldArgs {
313            arg_fields:       &[field1, field2],
314            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
315        };
316
317        let result = extract_return_field_from_args("test_func", &args);
318        assert!(result.is_ok());
319        let field = result.unwrap();
320        assert_eq!(field.data_type(), &DataType::Boolean);
321    }
322
323    #[test]
324    fn test_extract_return_field_from_args_invalid_type() {
325        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
326        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
327        let scalar = [
328            Some(ScalarValue::Utf8(Some("count()".to_string()))),
329            Some(ScalarValue::Utf8(Some("InvalidDataType".to_string()))),
330        ];
331        let args = ReturnFieldArgs {
332            arg_fields:       &[field1, field2],
333            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
334        };
335
336        let result = extract_return_field_from_args("test_func", &args);
337        assert!(result.is_err());
338        assert!(result.unwrap_err().to_string().contains("Invalid return type"));
339    }
340
341    #[test]
342    fn test_extract_return_field_from_args_no_last_arg() {
343        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
344        let args = ReturnFieldArgs { arg_fields: &[field1], scalar_arguments: &[] };
345
346        let result = extract_return_field_from_args("test_func", &args);
347        assert!(result.is_err());
348        assert!(result.unwrap_err().to_string().contains("Expected return type"));
349    }
350
351    #[test]
352    fn test_extract_return_field_from_args_null_last_arg() {
353        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
354        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
355        let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string()))), None];
356        let args = ReturnFieldArgs {
357            arg_fields:       &[field1, field2],
358            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
359        };
360
361        let result = extract_return_field_from_args("test_func", &args);
362        assert!(result.is_err());
363        assert!(result.unwrap_err().to_string().contains("Expected return type"));
364    }
365
366    #[test]
367    fn test_extract_return_field_from_args_non_string_last_arg() {
368        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
369        let field2 = Arc::new(Field::new("type", DataType::Int32, false));
370        let scalar = [
371            Some(ScalarValue::Utf8(Some("count()".to_string()))),
372            Some(ScalarValue::Int32(Some(42))),
373        ];
374        let args = ReturnFieldArgs {
375            arg_fields:       &[field1, field2],
376            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
377        };
378
379        let result = extract_return_field_from_args("test_func", &args);
380        assert!(result.is_err());
381        assert!(result.unwrap_err().to_string().contains("Expected return type"));
382    }
383
384    #[test]
385    fn test_extract_return_field_from_args_empty_string_last_arg() {
386        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
387        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
388        let scalar = [
389            Some(ScalarValue::Utf8(Some("count()".to_string()))),
390            Some(ScalarValue::Utf8(Some(String::new()))),
391        ];
392        let args = ReturnFieldArgs {
393            arg_fields:       &[field1, field2],
394            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
395        };
396
397        let result = extract_return_field_from_args("test_func", &args);
398        assert!(result.is_err());
399        assert!(result.unwrap_err().to_string().contains("Invalid return type"));
400    }
401
402    #[test]
403    fn test_extract_return_field_from_args_null_string_last_arg() {
404        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
405        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
406        let scalar =
407            [Some(ScalarValue::Utf8(Some("count()".to_string()))), Some(ScalarValue::Utf8(None))];
408        let args = ReturnFieldArgs {
409            arg_fields:       &[field1, field2],
410            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
411        };
412
413        let result = extract_return_field_from_args("test_func", &args);
414        assert!(result.is_err());
415        assert!(result.unwrap_err().to_string().contains("Expected return type"));
416    }
417}