clickhouse_datafusion/udfs/
apply.rs

1// TODO: Remove - Important! explain how `apply` and `lambda` are only used when nested
2//! `ScalarUDFImpl` for [`ClickHouseApplyUDF`]
3//!
4//! Currently this provides little value over using a `ClickHouse` lambda function directly in
5//! [`super::clickhouse::ClickHouseUDF`] since both will be parsed the same. This UDF will be
6//! expanded to allow using it directly similarly to the `clickhouse` function.
7use std::collections::HashMap;
8use std::str::FromStr;
9
10use datafusion::arrow::datatypes::{DataType, FieldRef};
11use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
12use datafusion::common::{Column, not_impl_err, plan_datafusion_err, plan_err};
13use datafusion::error::Result;
14use datafusion::logical_expr::expr::{Placeholder, ScalarFunction};
15use datafusion::logical_expr::{
16    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
17    Volatility,
18};
19use datafusion::prelude::Expr;
20use datafusion::scalar::ScalarValue;
21use datafusion::sql::sqlparser::ast;
22use datafusion::sql::unparser::Unparser;
23
24use super::udf_field_from_fields;
25
26pub const CLICKHOUSE_APPLY_ALIASES: [&str; 7] = [
27    "apply",
28    "lambda",
29    "clickhouse_apply",
30    "clickhouse_lambda",
31    "clickhouse_map",
32    "clickhouse_fmap",
33    "clickhouse_hof",
34];
35
36pub fn clickhouse_apply_udf() -> ScalarUDF { ScalarUDF::new_from_impl(ClickHouseApplyUDF::new()) }
37
38#[derive(Debug)]
39pub struct ClickHouseApplyUDF {
40    signature: Signature,
41    aliases:   Vec<String>,
42}
43
44impl Default for ClickHouseApplyUDF {
45    fn default() -> Self {
46        Self {
47            signature: Signature::variadic_any(Volatility::Immutable),
48            aliases:   CLICKHOUSE_APPLY_ALIASES.iter().map(ToString::to_string).collect(),
49        }
50    }
51}
52
53impl ClickHouseApplyUDF {
54    pub fn new() -> Self { Self::default() }
55}
56
57impl ScalarUDFImpl for ClickHouseApplyUDF {
58    fn as_any(&self) -> &dyn std::any::Any { self }
59
60    fn name(&self) -> &str { CLICKHOUSE_APPLY_ALIASES[0] }
61
62    fn aliases(&self) -> &[String] { &self.aliases }
63
64    fn signature(&self) -> &Signature { &self.signature }
65
66    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67        arg_types
68            .last()
69            .cloned()
70            .ok_or(plan_datafusion_err!("ClickHouseApplyUDF requires at least one argument"))
71    }
72
73    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
74        if let Ok(ret) = super::extract_return_field_from_args(self.name(), &args) {
75            Ok(ret)
76        } else {
77            let data_types =
78                args.arg_fields.iter().map(|f| f.data_type()).cloned().collect::<Vec<_>>();
79            let return_type = self.return_type(&data_types)?;
80            Ok(udf_field_from_fields(self.name(), return_type, args.arg_fields))
81        }
82    }
83
84    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
85        not_impl_err!(
86            "ClickHouseApplyUDF is for planning only - lambda functions are pushed down to \
87             ClickHouse"
88        )
89    }
90
91    /// Set to true to prevent optimizations. There is no way to know what the function will
92    /// produce, so these settings must be conservative.
93    fn short_circuits(&self) -> bool { true }
94}
95
96pub(crate) struct ClickHouseApplyRewriter {
97    pub name:      String,
98    pub body:      Expr,
99    pub param_map: HashMap<Placeholder, Column>,
100}
101
102impl ClickHouseApplyRewriter {
103    pub(crate) fn try_new(expr: &Expr) -> Result<Self> {
104        if !is_clickhouse_lambda(expr) {
105            return plan_err!("Unknown function passed to ClickHouseApplyRewriter");
106        }
107
108        let Expr::ScalarFunction(ScalarFunction { func, args }) = expr else {
109            // Guaranteed by `is_clickhouse_lambda`
110            unreachable!();
111        };
112
113        // Unwrap the aliased function if it's wrapped
114        let (name, mut args) = if CLICKHOUSE_APPLY_ALIASES.contains(&func.name()) {
115            let Some(Expr::ScalarFunction(ScalarFunction { func, args })) = args.first() else {
116                return plan_err!("ClickHouseApplyUDF must be higher order function");
117            };
118            (func.name().to_string(), args.clone())
119        } else {
120            (func.name().to_string(), args.clone())
121        };
122
123        // Attempt to extract data type from the last argument, unused
124        let _data_type = args
125            .pop_if(|expr| matches!(expr, Expr::Literal(_, _)))
126            .map(|expr| match expr.as_literal() {
127                Some(
128                    ScalarValue::Utf8(Some(ret))
129                    | ScalarValue::Utf8View(Some(ret))
130                    | ScalarValue::LargeUtf8(Some(ret)),
131                ) => DataType::from_str(ret.as_str())
132                    .map_err(|e| plan_datafusion_err!("Invalid return type: {e}"))
133                    .map(Some),
134                _ => Ok(None),
135            })
136            .transpose()?
137            .flatten();
138
139        let (param_map, body) = extract_apply_args(args)?;
140        Ok(Self { name, body, param_map })
141    }
142
143    pub(crate) fn rewrite_to_ast(self, unparser: &Unparser<'_>) -> Result<ast::Expr> {
144        let Self { name, body, param_map, .. } = self;
145
146        // Transform the body expression, replacing columns with parameters
147        let transformed_body = body
148            .transform(|expr| {
149                if let Expr::Placeholder(ref placeholder) = expr
150                    && let Some((param_name, _)) =
151                        param_map.iter().find(|(p, _)| p.id == placeholder.id)
152                {
153                    let variable = param_name.id.trim_start_matches('$');
154                    // Use unqualified column which should unparse without quotes
155                    return Ok(Transformed::new(
156                        Expr::Column(Column::new_unqualified(variable)),
157                        true,
158                        TreeNodeRecursion::Jump,
159                    ));
160                }
161                Ok(Transformed::no(expr))
162            })
163            .unwrap()
164            .data;
165
166        // Convert body to SQL
167        let body_sql = unparser.expr_to_sql(&transformed_body)?;
168
169        // Strip all '$' from param names
170        let (mut params, mut columns): (Vec<_>, Vec<_>) = param_map
171            .into_iter()
172            .map(|(p, c)| (p.id.trim_start_matches('$').to_string(), c))
173            .unzip();
174
175        // Create lambda function parameters
176        let lambda_params = if params.len() == 1 {
177            ast::OneOrManyWithParens::One(ast::Ident::new(params.remove(0)))
178        } else {
179            ast::OneOrManyWithParens::Many(params.into_iter().map(ast::Ident::new).collect())
180        };
181
182        let column_params = if columns.len() == 1 {
183            let col = columns.remove(0);
184            vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
185                unparser
186                    .expr_to_sql(&Expr::Column(col.clone()))
187                    .unwrap_or_else(|_| ast::Expr::Identifier(ast::Ident::new(&col.name))),
188            ))]
189        } else {
190            columns
191                .into_iter()
192                .map(|c| {
193                    ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
194                        unparser
195                            .expr_to_sql(&Expr::Column(c.clone()))
196                            .unwrap_or_else(|_| ast::Expr::Identifier(ast::Ident::new(&c.name))),
197                    ))
198                })
199                .collect::<Vec<_>>()
200        };
201
202        // Create the lambda expression
203        let lambda_expr = ast::Expr::Lambda(ast::LambdaFunction {
204            params: lambda_params,
205            body:   Box::new(body_sql),
206        });
207
208        // Now create the higher-order function call with the lambda and original
209        // columns
210        let hof_args: Vec<ast::FunctionArg> = std::iter::once(
211            // First arg is the lambda
212            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(lambda_expr)),
213        )
214        .chain(column_params)
215        .collect();
216
217        Ok(ast::Expr::Function(ast::Function {
218            name:             ast::ObjectName(vec![ast::ObjectNamePart::Identifier(
219                ast::Ident::new(name),
220            )]),
221            args:             ast::FunctionArguments::List(ast::FunctionArgumentList {
222                duplicate_treatment: None,
223                args:                hof_args,
224                clauses:             vec![],
225            }),
226            filter:           None,
227            null_treatment:   None,
228            over:             None,
229            within_group:     vec![],
230            parameters:       ast::FunctionArguments::None,
231            uses_odbc_syntax: false,
232        }))
233    }
234}
235
236pub(crate) fn is_clickhouse_lambda(expr: &Expr) -> bool {
237    let Expr::ScalarFunction(ScalarFunction { func, args }) = expr else {
238        return false;
239    };
240    CLICKHOUSE_APPLY_ALIASES.contains(&func.name())
241        || args.first().is_some_and(|a| matches!(a, Expr::Placeholder(_)))
242}
243
244pub(crate) fn extract_apply_args(
245    mut args: Vec<Expr>,
246) -> Result<(HashMap<Placeholder, Column>, Expr)> {
247    if args.len() < 3 {
248        return plan_err!(
249            "ClickHouseApplyUDF requires at least 3 arguments: placeholders, body, and column \
250             references"
251        );
252    }
253
254    let mut columns = Vec::with_capacity(args.len());
255
256    // Pull out columns and body
257    let body = loop {
258        match args.pop() {
259            Some(Expr::Column(col)) => columns.push(col),
260            Some(e) => break e,
261            None => {
262                return plan_err!("ClickHouseApplyUDF missing body expression");
263            }
264        }
265    };
266
267    // Finally confirm placeholders
268    let placeholders = args
269        .into_iter()
270        .map(
271            |e| if let Expr::Placeholder(p) = e { Ok(p) } else { plan_err!("Invalid placeholder") },
272        )
273        .collect::<Result<Vec<_>>>()?;
274
275    if columns.len() != placeholders.len() {
276        return plan_err!("Number of placeholders and columns must match");
277    }
278
279    let param_map = placeholders.into_iter().zip(columns).collect::<HashMap<_, _>>();
280
281    Ok((param_map, body))
282}
283
284#[cfg(test)]
285mod tests {
286    use std::sync::Arc;
287
288    use datafusion::arrow::datatypes::*;
289    use datafusion::common::ScalarValue;
290    use datafusion::logical_expr::{BinaryExpr, Operator, ReturnFieldArgs, ScalarFunctionArgs};
291    use datafusion::prelude::lit;
292    use datafusion::sql::TableReference;
293
294    use super::*;
295    use crate::udfs::placeholder::{PlaceholderUDF, placeholder_udf_from_placeholder};
296
297    #[test]
298    fn test_apply_udf() {
299        let udf = clickhouse_apply_udf();
300
301        // Ensure short circuits
302        assert!(udf.short_circuits());
303
304        // Test that the return field will return the data type if it is passed in
305        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
306        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
307        let scalar = [
308            Some(ScalarValue::Utf8(Some("count()".to_string()))),
309            Some(ScalarValue::Utf8(Some("Int64".to_string()))),
310        ];
311        let args = ReturnFieldArgs {
312            arg_fields:       &[field1, field2],
313            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
314        };
315
316        let result = udf.return_field_from_args(args);
317        assert!(result.is_ok());
318        let field = result.unwrap();
319        assert_eq!(field.name(), CLICKHOUSE_APPLY_ALIASES[0]);
320        assert_eq!(field.data_type(), &DataType::Int64);
321
322        // Test that invoking will fail
323        let args = ScalarFunctionArgs {
324            args:         vec![],
325            arg_fields:   vec![],
326            number_rows:  1,
327            return_field: Arc::new(Field::new("", DataType::Int32, false)),
328        };
329        let result = udf.invoke_with_args(args);
330        assert!(result.is_err());
331        assert!(result.unwrap_err().to_string().contains("planning only"));
332    }
333
334    #[test]
335    fn test_apply_rewriter() {
336        let placeholder = Placeholder::new("$x".to_string(), None);
337
338        // Ensure `extract_apply_args` works correctly
339        let result = extract_apply_args(vec![Expr::Placeholder(placeholder.clone())]);
340        assert!(result.is_err(), "Apply expects at least 3 args");
341        let result = extract_apply_args(vec![Expr::Column(Column::from_name("test"))]);
342        assert!(result.is_err(), "Apply expects a body arg before columns");
343        let exprs_fail = vec![
344            Expr::Placeholder(placeholder.clone()),
345            lit("1"),
346            Expr::Column(Column::from_name("test1")),
347            Expr::Column(Column::from_name("test2")),
348        ];
349        let result = extract_apply_args(exprs_fail);
350        assert!(result.is_err(), "Placeholder count must match column count");
351
352        let common_args = vec![
353            Expr::Placeholder(placeholder.clone()),
354            Expr::BinaryExpr(BinaryExpr {
355                left:  Box::new(Expr::Placeholder(placeholder)),
356                op:    Operator::Plus,
357                right: Box::new(lit(1)),
358            }),
359            Expr::Column(Column::new(None::<TableReference>, "test_col")),
360            lit("Int64"),
361        ];
362
363        let expr = Expr::ScalarFunction(ScalarFunction {
364            func: Arc::new(clickhouse_apply_udf()),
365            args: common_args.clone(),
366        });
367
368        let result = ClickHouseApplyRewriter::try_new(&expr);
369        assert!(result.is_err(), "Apply/Lambda must be a higher order function");
370
371        // Modify expr to be HOF
372
373        let expr = Expr::ScalarFunction(ScalarFunction {
374            func: Arc::new(clickhouse_apply_udf()),
375            args: vec![Expr::ScalarFunction(ScalarFunction {
376                func: Arc::new(placeholder_udf_from_placeholder(PlaceholderUDF::new("arrayMap"))),
377                args: common_args.clone(),
378            })],
379        });
380
381        let result = ClickHouseApplyRewriter::try_new(&expr);
382        assert!(result.is_ok(), "Apply/Lambda expected to be higher order function");
383    }
384}