clickhouse_datafusion/udfs/
apply.rs

1// TODO: Remove - Important! explain how `apply` and `lambda` are only used when nested
2//! `ScalarUDFImpl` for [`ClickHouseApplyUDF`]
3//!
4//! Currently this provides little value over using a `ClickHouse` lambda function directly in
5//! [`super::clickhouse::ClickHouseUDF`] since both will be parsed the same. This UDF will be
6//! expanded to allow using it directly similarly to the `clickhouse` function.
7use std::collections::HashMap;
8use std::str::FromStr;
9
10use datafusion::arrow::datatypes::{DataType, FieldRef};
11use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
12use datafusion::common::{Column, not_impl_err, plan_datafusion_err, plan_err};
13use datafusion::error::Result;
14use datafusion::logical_expr::expr::{Placeholder, ScalarFunction};
15use datafusion::logical_expr::{
16    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
17    Volatility,
18};
19use datafusion::prelude::Expr;
20use datafusion::scalar::ScalarValue;
21use datafusion::sql::sqlparser::ast;
22use datafusion::sql::unparser::Unparser;
23
24use super::udf_field_from_fields;
25
26pub const CLICKHOUSE_APPLY_ALIASES: [&str; 7] = [
27    "apply",
28    "lambda",
29    "clickhouse_apply",
30    "clickhouse_lambda",
31    "clickhouse_map",
32    "clickhouse_fmap",
33    "clickhouse_hof",
34];
35
36pub fn clickhouse_apply_udf() -> ScalarUDF { ScalarUDF::new_from_impl(ClickHouseApplyUDF::new()) }
37
38#[derive(Debug, PartialEq, Eq, Hash)]
39pub struct ClickHouseApplyUDF {
40    signature: Signature,
41    aliases:   Vec<String>,
42}
43
44impl Default for ClickHouseApplyUDF {
45    fn default() -> Self {
46        Self {
47            signature: Signature::variadic_any(Volatility::Immutable),
48            aliases:   CLICKHOUSE_APPLY_ALIASES.iter().map(ToString::to_string).collect(),
49        }
50    }
51}
52
53impl ClickHouseApplyUDF {
54    pub fn new() -> Self { Self::default() }
55}
56
57impl ScalarUDFImpl for ClickHouseApplyUDF {
58    fn as_any(&self) -> &dyn std::any::Any { self }
59
60    fn name(&self) -> &str { CLICKHOUSE_APPLY_ALIASES[0] }
61
62    fn aliases(&self) -> &[String] { &self.aliases }
63
64    fn signature(&self) -> &Signature { &self.signature }
65
66    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67        arg_types
68            .last()
69            .cloned()
70            .ok_or(plan_datafusion_err!("ClickHouseApplyUDF requires at least one argument"))
71    }
72
73    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
74        if let Ok(ret) = super::extract_return_field_from_args(self.name(), &args) {
75            Ok(ret)
76        } else {
77            let data_types =
78                args.arg_fields.iter().map(|f| f.data_type()).cloned().collect::<Vec<_>>();
79            let return_type = self.return_type(&data_types)?;
80            Ok(udf_field_from_fields(self.name(), return_type, args.arg_fields))
81        }
82    }
83
84    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
85        not_impl_err!(
86            "ClickHouseApplyUDF is for planning only - lambda functions are pushed down to \
87             ClickHouse"
88        )
89    }
90
91    /// Set to true to prevent optimizations. There is no way to know what the function will
92    /// produce, so these settings must be conservative.
93    fn short_circuits(&self) -> bool { true }
94}
95
96pub(crate) struct ClickHouseApplyRewriter {
97    pub name:      String,
98    pub body:      Expr,
99    pub param_map: HashMap<Placeholder, Column>,
100}
101
102impl ClickHouseApplyRewriter {
103    pub(crate) fn try_new(expr: &Expr) -> Result<Self> {
104        let (name, mut args) = unwrap_clickhouse_lambda(expr)?;
105
106        // Attempt to extract data type from the last argument, unused
107        let _data_type = args
108            .pop_if(|expr| matches!(expr, Expr::Literal(_, _)))
109            .map(|expr| match expr.as_literal() {
110                Some(
111                    ScalarValue::Utf8(Some(ret))
112                    | ScalarValue::Utf8View(Some(ret))
113                    | ScalarValue::LargeUtf8(Some(ret)),
114                ) => DataType::from_str(ret.as_str())
115                    .map_err(|e| plan_datafusion_err!("Invalid return type: {e}"))
116                    .map(Some),
117                _ => Ok(None),
118            })
119            .transpose()?
120            .flatten();
121
122        let (param_map, body) = extract_apply_args(args)?;
123        Ok(Self { name, body, param_map })
124    }
125
126    pub(crate) fn rewrite_to_ast(self, unparser: &Unparser<'_>) -> Result<ast::Expr> {
127        let Self { name, body, param_map, .. } = self;
128
129        // Transform the body expression, replacing columns with parameters
130        let transformed_body = body
131            .transform(|expr| {
132                if let Expr::Placeholder(ref placeholder) = expr
133                    && let Some((param_name, _)) =
134                        param_map.iter().find(|(p, _)| p.id == placeholder.id)
135                {
136                    let variable = param_name.id.trim_start_matches('$');
137                    // Use unqualified column which should unparse without quotes
138                    return Ok(Transformed::new(
139                        Expr::Column(Column::new_unqualified(variable)),
140                        true,
141                        TreeNodeRecursion::Jump,
142                    ));
143                }
144                Ok(Transformed::no(expr))
145            })
146            .unwrap()
147            .data;
148
149        // Convert body to SQL
150        let body_sql = unparser.expr_to_sql(&transformed_body)?;
151
152        // Strip all '$' from param names
153        let (mut params, mut columns): (Vec<_>, Vec<_>) = param_map
154            .into_iter()
155            .map(|(p, c)| (p.id.trim_start_matches('$').to_string(), c))
156            .unzip();
157
158        // Create lambda function parameters
159        let lambda_params = if params.len() == 1 {
160            ast::OneOrManyWithParens::One(ast::Ident::new(params.remove(0)))
161        } else {
162            ast::OneOrManyWithParens::Many(params.into_iter().map(ast::Ident::new).collect())
163        };
164
165        let column_params = if columns.len() == 1 {
166            let col = columns.remove(0);
167            vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
168                unparser
169                    .expr_to_sql(&Expr::Column(col.clone()))
170                    .unwrap_or_else(|_| ast::Expr::Identifier(ast::Ident::new(&col.name))),
171            ))]
172        } else {
173            columns
174                .into_iter()
175                .map(|c| {
176                    ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
177                        unparser
178                            .expr_to_sql(&Expr::Column(c.clone()))
179                            .unwrap_or_else(|_| ast::Expr::Identifier(ast::Ident::new(&c.name))),
180                    ))
181                })
182                .collect::<Vec<_>>()
183        };
184
185        // Create the lambda expression
186        let lambda_expr = ast::Expr::Lambda(ast::LambdaFunction {
187            params: lambda_params,
188            body:   Box::new(body_sql),
189        });
190
191        // Now create the higher-order function call with the lambda and original
192        // columns
193        let hof_args: Vec<ast::FunctionArg> = std::iter::once(
194            // First arg is the lambda
195            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(lambda_expr)),
196        )
197        .chain(column_params)
198        .collect();
199
200        Ok(ast::Expr::Function(ast::Function {
201            name:             ast::ObjectName(vec![ast::ObjectNamePart::Identifier(
202                ast::Ident::new(name),
203            )]),
204            args:             ast::FunctionArguments::List(ast::FunctionArgumentList {
205                duplicate_treatment: None,
206                args:                hof_args,
207                clauses:             vec![],
208            }),
209            filter:           None,
210            null_treatment:   None,
211            over:             None,
212            within_group:     vec![],
213            parameters:       ast::FunctionArguments::None,
214            uses_odbc_syntax: false,
215        }))
216    }
217}
218
219pub(crate) fn unwrap_clickhouse_lambda(expr: &Expr) -> Result<(String, Vec<Expr>)> {
220    let inner_expr = if let Expr::Alias(e) = expr { &e.expr } else { expr };
221
222    // Must be a scalar function
223    let Expr::ScalarFunction(ScalarFunction { func, args }) = inner_expr else {
224        return plan_err!("Unknown expression passed to ClickHouseApplyRewriter");
225    };
226
227    // May be nested within an apply alias or may be a function with placeholder directly
228    Ok(if CLICKHOUSE_APPLY_ALIASES.contains(&func.name()) {
229        let Some(Expr::ScalarFunction(ScalarFunction { func: inner_func, args: inner_args })) =
230            args.first()
231        else {
232            return plan_err!("ClickHouseApplyUDF must be higher order function");
233        };
234
235        (inner_func.name().to_string(), inner_args.clone())
236    } else if args.first().is_some_and(|a| matches!(a, Expr::Placeholder(_))) {
237        // Unwrap Alias if present (DataFusion 51+ wraps UDF calls in Alias when
238        // called via an alias name like `lambda` instead of primary name `apply`)
239        (func.name().to_string(), args.clone())
240    } else {
241        return plan_err!("Unknown function passed to ClickHouseApplyRewriter");
242    })
243}
244
245pub(crate) fn extract_apply_args(
246    mut args: Vec<Expr>,
247) -> Result<(HashMap<Placeholder, Column>, Expr)> {
248    if args.len() < 3 {
249        return plan_err!(
250            "ClickHouseApplyUDF requires at least 3 arguments: placeholders, body, and column \
251             references"
252        );
253    }
254
255    let mut columns = Vec::with_capacity(args.len());
256
257    // Pull out columns and body
258    let body = loop {
259        match args.pop() {
260            Some(Expr::Column(col)) => columns.push(col),
261            Some(e) => break e,
262            None => {
263                return plan_err!("ClickHouseApplyUDF missing body expression");
264            }
265        }
266    };
267
268    // Finally confirm placeholders
269    let placeholders = args
270        .into_iter()
271        .map(
272            |e| if let Expr::Placeholder(p) = e { Ok(p) } else { plan_err!("Invalid placeholder") },
273        )
274        .collect::<Result<Vec<_>>>()?;
275
276    if columns.len() != placeholders.len() {
277        return plan_err!("Number of placeholders and columns must match");
278    }
279
280    let param_map = placeholders.into_iter().zip(columns).collect::<HashMap<_, _>>();
281
282    Ok((param_map, body))
283}
284
285#[cfg(test)]
286mod tests {
287    use std::sync::Arc;
288
289    use datafusion::arrow::datatypes::*;
290    use datafusion::common::ScalarValue;
291    use datafusion::config::ConfigOptions;
292    use datafusion::logical_expr::{BinaryExpr, Operator, ReturnFieldArgs, ScalarFunctionArgs};
293    use datafusion::prelude::lit;
294    use datafusion::sql::TableReference;
295
296    use super::*;
297    use crate::udfs::placeholder::{PlaceholderUDF, placeholder_udf_from_placeholder};
298
299    #[test]
300    fn test_apply_udf() {
301        let udf = clickhouse_apply_udf();
302
303        // Ensure short circuits
304        assert!(udf.short_circuits());
305
306        // Test that the return field will return the data type if it is passed in
307        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
308        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
309        let scalar = [
310            Some(ScalarValue::Utf8(Some("count()".to_string()))),
311            Some(ScalarValue::Utf8(Some("Int64".to_string()))),
312        ];
313        let args = ReturnFieldArgs {
314            arg_fields:       &[field1, field2],
315            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
316        };
317
318        let result = udf.return_field_from_args(args);
319        assert!(result.is_ok());
320        let field = result.unwrap();
321        assert_eq!(field.name(), CLICKHOUSE_APPLY_ALIASES[0]);
322        assert_eq!(field.data_type(), &DataType::Int64);
323
324        // Test that invoking will fail
325        let args = ScalarFunctionArgs {
326            args:           vec![],
327            arg_fields:     vec![],
328            number_rows:    1,
329            return_field:   Arc::new(Field::new("", DataType::Int32, false)),
330            config_options: Arc::new(ConfigOptions::default()),
331        };
332        let result = udf.invoke_with_args(args);
333        assert!(result.is_err());
334        assert!(result.unwrap_err().to_string().contains("planning only"));
335    }
336
337    #[test]
338    fn test_apply_rewriter() {
339        let placeholder = Placeholder::new_with_field("$x".to_string(), None);
340
341        // Ensure `extract_apply_args` works correctly
342        let result = extract_apply_args(vec![Expr::Placeholder(placeholder.clone())]);
343        assert!(result.is_err(), "Apply expects at least 3 args");
344        let result = extract_apply_args(vec![Expr::Column(Column::from_name("test"))]);
345        assert!(result.is_err(), "Apply expects a body arg before columns");
346        let exprs_fail = vec![
347            Expr::Placeholder(placeholder.clone()),
348            lit("1"),
349            Expr::Column(Column::from_name("test1")),
350            Expr::Column(Column::from_name("test2")),
351        ];
352        let result = extract_apply_args(exprs_fail);
353        assert!(result.is_err(), "Placeholder count must match column count");
354
355        let common_args = vec![
356            Expr::Placeholder(placeholder.clone()),
357            Expr::BinaryExpr(BinaryExpr {
358                left:  Box::new(Expr::Placeholder(placeholder)),
359                op:    Operator::Plus,
360                right: Box::new(lit(1)),
361            }),
362            Expr::Column(Column::new(None::<TableReference>, "test_col")),
363            lit("Int64"),
364        ];
365
366        let expr = Expr::ScalarFunction(ScalarFunction {
367            func: Arc::new(clickhouse_apply_udf()),
368            args: common_args.clone(),
369        });
370
371        let result = ClickHouseApplyRewriter::try_new(&expr);
372        assert!(result.is_err(), "Apply/Lambda must be a higher order function");
373
374        // Modify expr to be HOF
375
376        let expr = Expr::ScalarFunction(ScalarFunction {
377            func: Arc::new(clickhouse_apply_udf()),
378            args: vec![Expr::ScalarFunction(ScalarFunction {
379                func: Arc::new(placeholder_udf_from_placeholder(PlaceholderUDF::new("arrayMap"))),
380                args: common_args.clone(),
381            })],
382        });
383
384        let result = ClickHouseApplyRewriter::try_new(&expr);
385        assert!(result.is_ok(), "Apply/Lambda expected to be higher order function");
386    }
387}