clickhouse_datafusion/udfs/
apply.rs1use std::collections::HashMap;
8use std::str::FromStr;
9
10use datafusion::arrow::datatypes::{DataType, FieldRef};
11use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
12use datafusion::common::{Column, not_impl_err, plan_datafusion_err, plan_err};
13use datafusion::error::Result;
14use datafusion::logical_expr::expr::{Placeholder, ScalarFunction};
15use datafusion::logical_expr::{
16 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
17 Volatility,
18};
19use datafusion::prelude::Expr;
20use datafusion::scalar::ScalarValue;
21use datafusion::sql::sqlparser::ast;
22use datafusion::sql::unparser::Unparser;
23
24use super::udf_field_from_fields;
25
26pub const CLICKHOUSE_APPLY_ALIASES: [&str; 7] = [
27 "apply",
28 "lambda",
29 "clickhouse_apply",
30 "clickhouse_lambda",
31 "clickhouse_map",
32 "clickhouse_fmap",
33 "clickhouse_hof",
34];
35
36pub fn clickhouse_apply_udf() -> ScalarUDF { ScalarUDF::new_from_impl(ClickHouseApplyUDF::new()) }
37
38#[derive(Debug, PartialEq, Eq, Hash)]
39pub struct ClickHouseApplyUDF {
40 signature: Signature,
41 aliases: Vec<String>,
42}
43
44impl Default for ClickHouseApplyUDF {
45 fn default() -> Self {
46 Self {
47 signature: Signature::variadic_any(Volatility::Immutable),
48 aliases: CLICKHOUSE_APPLY_ALIASES.iter().map(ToString::to_string).collect(),
49 }
50 }
51}
52
53impl ClickHouseApplyUDF {
54 pub fn new() -> Self { Self::default() }
55}
56
57impl ScalarUDFImpl for ClickHouseApplyUDF {
58 fn as_any(&self) -> &dyn std::any::Any { self }
59
60 fn name(&self) -> &str { CLICKHOUSE_APPLY_ALIASES[0] }
61
62 fn aliases(&self) -> &[String] { &self.aliases }
63
64 fn signature(&self) -> &Signature { &self.signature }
65
66 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67 arg_types
68 .last()
69 .cloned()
70 .ok_or(plan_datafusion_err!("ClickHouseApplyUDF requires at least one argument"))
71 }
72
73 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
74 if let Ok(ret) = super::extract_return_field_from_args(self.name(), &args) {
75 Ok(ret)
76 } else {
77 let data_types =
78 args.arg_fields.iter().map(|f| f.data_type()).cloned().collect::<Vec<_>>();
79 let return_type = self.return_type(&data_types)?;
80 Ok(udf_field_from_fields(self.name(), return_type, args.arg_fields))
81 }
82 }
83
84 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
85 not_impl_err!(
86 "ClickHouseApplyUDF is for planning only - lambda functions are pushed down to \
87 ClickHouse"
88 )
89 }
90
91 fn short_circuits(&self) -> bool { true }
94}
95
96pub(crate) struct ClickHouseApplyRewriter {
97 pub name: String,
98 pub body: Expr,
99 pub param_map: HashMap<Placeholder, Column>,
100}
101
102impl ClickHouseApplyRewriter {
103 pub(crate) fn try_new(expr: &Expr) -> Result<Self> {
104 let (name, mut args) = unwrap_clickhouse_lambda(expr)?;
105
106 let _data_type = args
108 .pop_if(|expr| matches!(expr, Expr::Literal(_, _)))
109 .map(|expr| match expr.as_literal() {
110 Some(
111 ScalarValue::Utf8(Some(ret))
112 | ScalarValue::Utf8View(Some(ret))
113 | ScalarValue::LargeUtf8(Some(ret)),
114 ) => DataType::from_str(ret.as_str())
115 .map_err(|e| plan_datafusion_err!("Invalid return type: {e}"))
116 .map(Some),
117 _ => Ok(None),
118 })
119 .transpose()?
120 .flatten();
121
122 let (param_map, body) = extract_apply_args(args)?;
123 Ok(Self { name, body, param_map })
124 }
125
126 pub(crate) fn rewrite_to_ast(self, unparser: &Unparser<'_>) -> Result<ast::Expr> {
127 let Self { name, body, param_map, .. } = self;
128
129 let transformed_body = body
131 .transform(|expr| {
132 if let Expr::Placeholder(ref placeholder) = expr
133 && let Some((param_name, _)) =
134 param_map.iter().find(|(p, _)| p.id == placeholder.id)
135 {
136 let variable = param_name.id.trim_start_matches('$');
137 return Ok(Transformed::new(
139 Expr::Column(Column::new_unqualified(variable)),
140 true,
141 TreeNodeRecursion::Jump,
142 ));
143 }
144 Ok(Transformed::no(expr))
145 })
146 .unwrap()
147 .data;
148
149 let body_sql = unparser.expr_to_sql(&transformed_body)?;
151
152 let (mut params, mut columns): (Vec<_>, Vec<_>) = param_map
154 .into_iter()
155 .map(|(p, c)| (p.id.trim_start_matches('$').to_string(), c))
156 .unzip();
157
158 let lambda_params = if params.len() == 1 {
160 ast::OneOrManyWithParens::One(ast::Ident::new(params.remove(0)))
161 } else {
162 ast::OneOrManyWithParens::Many(params.into_iter().map(ast::Ident::new).collect())
163 };
164
165 let column_params = if columns.len() == 1 {
166 let col = columns.remove(0);
167 vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
168 unparser
169 .expr_to_sql(&Expr::Column(col.clone()))
170 .unwrap_or_else(|_| ast::Expr::Identifier(ast::Ident::new(&col.name))),
171 ))]
172 } else {
173 columns
174 .into_iter()
175 .map(|c| {
176 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
177 unparser
178 .expr_to_sql(&Expr::Column(c.clone()))
179 .unwrap_or_else(|_| ast::Expr::Identifier(ast::Ident::new(&c.name))),
180 ))
181 })
182 .collect::<Vec<_>>()
183 };
184
185 let lambda_expr = ast::Expr::Lambda(ast::LambdaFunction {
187 params: lambda_params,
188 body: Box::new(body_sql),
189 });
190
191 let hof_args: Vec<ast::FunctionArg> = std::iter::once(
194 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(lambda_expr)),
196 )
197 .chain(column_params)
198 .collect();
199
200 Ok(ast::Expr::Function(ast::Function {
201 name: ast::ObjectName(vec![ast::ObjectNamePart::Identifier(
202 ast::Ident::new(name),
203 )]),
204 args: ast::FunctionArguments::List(ast::FunctionArgumentList {
205 duplicate_treatment: None,
206 args: hof_args,
207 clauses: vec![],
208 }),
209 filter: None,
210 null_treatment: None,
211 over: None,
212 within_group: vec![],
213 parameters: ast::FunctionArguments::None,
214 uses_odbc_syntax: false,
215 }))
216 }
217}
218
219pub(crate) fn unwrap_clickhouse_lambda(expr: &Expr) -> Result<(String, Vec<Expr>)> {
220 let inner_expr = if let Expr::Alias(e) = expr { &e.expr } else { expr };
221
222 let Expr::ScalarFunction(ScalarFunction { func, args }) = inner_expr else {
224 return plan_err!("Unknown expression passed to ClickHouseApplyRewriter");
225 };
226
227 Ok(if CLICKHOUSE_APPLY_ALIASES.contains(&func.name()) {
229 let Some(Expr::ScalarFunction(ScalarFunction { func: inner_func, args: inner_args })) =
230 args.first()
231 else {
232 return plan_err!("ClickHouseApplyUDF must be higher order function");
233 };
234
235 (inner_func.name().to_string(), inner_args.clone())
236 } else if args.first().is_some_and(|a| matches!(a, Expr::Placeholder(_))) {
237 (func.name().to_string(), args.clone())
240 } else {
241 return plan_err!("Unknown function passed to ClickHouseApplyRewriter");
242 })
243}
244
245pub(crate) fn extract_apply_args(
246 mut args: Vec<Expr>,
247) -> Result<(HashMap<Placeholder, Column>, Expr)> {
248 if args.len() < 3 {
249 return plan_err!(
250 "ClickHouseApplyUDF requires at least 3 arguments: placeholders, body, and column \
251 references"
252 );
253 }
254
255 let mut columns = Vec::with_capacity(args.len());
256
257 let body = loop {
259 match args.pop() {
260 Some(Expr::Column(col)) => columns.push(col),
261 Some(e) => break e,
262 None => {
263 return plan_err!("ClickHouseApplyUDF missing body expression");
264 }
265 }
266 };
267
268 let placeholders = args
270 .into_iter()
271 .map(
272 |e| if let Expr::Placeholder(p) = e { Ok(p) } else { plan_err!("Invalid placeholder") },
273 )
274 .collect::<Result<Vec<_>>>()?;
275
276 if columns.len() != placeholders.len() {
277 return plan_err!("Number of placeholders and columns must match");
278 }
279
280 let param_map = placeholders.into_iter().zip(columns).collect::<HashMap<_, _>>();
281
282 Ok((param_map, body))
283}
284
285#[cfg(test)]
286mod tests {
287 use std::sync::Arc;
288
289 use datafusion::arrow::datatypes::*;
290 use datafusion::common::ScalarValue;
291 use datafusion::config::ConfigOptions;
292 use datafusion::logical_expr::{BinaryExpr, Operator, ReturnFieldArgs, ScalarFunctionArgs};
293 use datafusion::prelude::lit;
294 use datafusion::sql::TableReference;
295
296 use super::*;
297 use crate::udfs::placeholder::{PlaceholderUDF, placeholder_udf_from_placeholder};
298
299 #[test]
300 fn test_apply_udf() {
301 let udf = clickhouse_apply_udf();
302
303 assert!(udf.short_circuits());
305
306 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
308 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
309 let scalar = [
310 Some(ScalarValue::Utf8(Some("count()".to_string()))),
311 Some(ScalarValue::Utf8(Some("Int64".to_string()))),
312 ];
313 let args = ReturnFieldArgs {
314 arg_fields: &[field1, field2],
315 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
316 };
317
318 let result = udf.return_field_from_args(args);
319 assert!(result.is_ok());
320 let field = result.unwrap();
321 assert_eq!(field.name(), CLICKHOUSE_APPLY_ALIASES[0]);
322 assert_eq!(field.data_type(), &DataType::Int64);
323
324 let args = ScalarFunctionArgs {
326 args: vec![],
327 arg_fields: vec![],
328 number_rows: 1,
329 return_field: Arc::new(Field::new("", DataType::Int32, false)),
330 config_options: Arc::new(ConfigOptions::default()),
331 };
332 let result = udf.invoke_with_args(args);
333 assert!(result.is_err());
334 assert!(result.unwrap_err().to_string().contains("planning only"));
335 }
336
337 #[test]
338 fn test_apply_rewriter() {
339 let placeholder = Placeholder::new_with_field("$x".to_string(), None);
340
341 let result = extract_apply_args(vec![Expr::Placeholder(placeholder.clone())]);
343 assert!(result.is_err(), "Apply expects at least 3 args");
344 let result = extract_apply_args(vec![Expr::Column(Column::from_name("test"))]);
345 assert!(result.is_err(), "Apply expects a body arg before columns");
346 let exprs_fail = vec![
347 Expr::Placeholder(placeholder.clone()),
348 lit("1"),
349 Expr::Column(Column::from_name("test1")),
350 Expr::Column(Column::from_name("test2")),
351 ];
352 let result = extract_apply_args(exprs_fail);
353 assert!(result.is_err(), "Placeholder count must match column count");
354
355 let common_args = vec![
356 Expr::Placeholder(placeholder.clone()),
357 Expr::BinaryExpr(BinaryExpr {
358 left: Box::new(Expr::Placeholder(placeholder)),
359 op: Operator::Plus,
360 right: Box::new(lit(1)),
361 }),
362 Expr::Column(Column::new(None::<TableReference>, "test_col")),
363 lit("Int64"),
364 ];
365
366 let expr = Expr::ScalarFunction(ScalarFunction {
367 func: Arc::new(clickhouse_apply_udf()),
368 args: common_args.clone(),
369 });
370
371 let result = ClickHouseApplyRewriter::try_new(&expr);
372 assert!(result.is_err(), "Apply/Lambda must be a higher order function");
373
374 let expr = Expr::ScalarFunction(ScalarFunction {
377 func: Arc::new(clickhouse_apply_udf()),
378 args: vec![Expr::ScalarFunction(ScalarFunction {
379 func: Arc::new(placeholder_udf_from_placeholder(PlaceholderUDF::new("arrayMap"))),
380 args: common_args.clone(),
381 })],
382 });
383
384 let result = ClickHouseApplyRewriter::try_new(&expr);
385 assert!(result.is_ok(), "Apply/Lambda expected to be higher order function");
386 }
387}