clickhouse_datafusion/udfs/
apply.rs1use std::collections::HashMap;
8use std::str::FromStr;
9
10use datafusion::arrow::datatypes::{DataType, FieldRef};
11use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
12use datafusion::common::{Column, not_impl_err, plan_datafusion_err, plan_err};
13use datafusion::error::Result;
14use datafusion::logical_expr::expr::{Placeholder, ScalarFunction};
15use datafusion::logical_expr::{
16 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
17 Volatility,
18};
19use datafusion::prelude::Expr;
20use datafusion::scalar::ScalarValue;
21use datafusion::sql::sqlparser::ast;
22use datafusion::sql::unparser::Unparser;
23
24use super::udf_field_from_fields;
25
26pub const CLICKHOUSE_APPLY_ALIASES: [&str; 7] = [
27 "apply",
28 "lambda",
29 "clickhouse_apply",
30 "clickhouse_lambda",
31 "clickhouse_map",
32 "clickhouse_fmap",
33 "clickhouse_hof",
34];
35
36pub fn clickhouse_apply_udf() -> ScalarUDF { ScalarUDF::new_from_impl(ClickHouseApplyUDF::new()) }
37
38#[derive(Debug)]
39pub struct ClickHouseApplyUDF {
40 signature: Signature,
41 aliases: Vec<String>,
42}
43
44impl Default for ClickHouseApplyUDF {
45 fn default() -> Self {
46 Self {
47 signature: Signature::variadic_any(Volatility::Immutable),
48 aliases: CLICKHOUSE_APPLY_ALIASES.iter().map(ToString::to_string).collect(),
49 }
50 }
51}
52
53impl ClickHouseApplyUDF {
54 pub fn new() -> Self { Self::default() }
55}
56
57impl ScalarUDFImpl for ClickHouseApplyUDF {
58 fn as_any(&self) -> &dyn std::any::Any { self }
59
60 fn name(&self) -> &str { CLICKHOUSE_APPLY_ALIASES[0] }
61
62 fn aliases(&self) -> &[String] { &self.aliases }
63
64 fn signature(&self) -> &Signature { &self.signature }
65
66 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67 arg_types
68 .last()
69 .cloned()
70 .ok_or(plan_datafusion_err!("ClickHouseApplyUDF requires at least one argument"))
71 }
72
73 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
74 if let Ok(ret) = super::extract_return_field_from_args(self.name(), &args) {
75 Ok(ret)
76 } else {
77 let data_types =
78 args.arg_fields.iter().map(|f| f.data_type()).cloned().collect::<Vec<_>>();
79 let return_type = self.return_type(&data_types)?;
80 Ok(udf_field_from_fields(self.name(), return_type, args.arg_fields))
81 }
82 }
83
84 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
85 not_impl_err!(
86 "ClickHouseApplyUDF is for planning only - lambda functions are pushed down to \
87 ClickHouse"
88 )
89 }
90
91 fn short_circuits(&self) -> bool { true }
94}
95
96pub(crate) struct ClickHouseApplyRewriter {
97 pub name: String,
98 pub body: Expr,
99 pub param_map: HashMap<Placeholder, Column>,
100}
101
102impl ClickHouseApplyRewriter {
103 pub(crate) fn try_new(expr: &Expr) -> Result<Self> {
104 if !is_clickhouse_lambda(expr) {
105 return plan_err!("Unknown function passed to ClickHouseApplyRewriter");
106 }
107
108 let Expr::ScalarFunction(ScalarFunction { func, args }) = expr else {
109 unreachable!();
111 };
112
113 let (name, mut args) = if CLICKHOUSE_APPLY_ALIASES.contains(&func.name()) {
115 let Some(Expr::ScalarFunction(ScalarFunction { func, args })) = args.first() else {
116 return plan_err!("ClickHouseApplyUDF must be higher order function");
117 };
118 (func.name().to_string(), args.clone())
119 } else {
120 (func.name().to_string(), args.clone())
121 };
122
123 let _data_type = args
125 .pop_if(|expr| matches!(expr, Expr::Literal(_, _)))
126 .map(|expr| match expr.as_literal() {
127 Some(
128 ScalarValue::Utf8(Some(ret))
129 | ScalarValue::Utf8View(Some(ret))
130 | ScalarValue::LargeUtf8(Some(ret)),
131 ) => DataType::from_str(ret.as_str())
132 .map_err(|e| plan_datafusion_err!("Invalid return type: {e}"))
133 .map(Some),
134 _ => Ok(None),
135 })
136 .transpose()?
137 .flatten();
138
139 let (param_map, body) = extract_apply_args(args)?;
140 Ok(Self { name, body, param_map })
141 }
142
143 pub(crate) fn rewrite_to_ast(self, unparser: &Unparser<'_>) -> Result<ast::Expr> {
144 let Self { name, body, param_map, .. } = self;
145
146 let transformed_body = body
148 .transform(|expr| {
149 if let Expr::Placeholder(ref placeholder) = expr
150 && let Some((param_name, _)) =
151 param_map.iter().find(|(p, _)| p.id == placeholder.id)
152 {
153 let variable = param_name.id.trim_start_matches('$');
154 return Ok(Transformed::new(
156 Expr::Column(Column::new_unqualified(variable)),
157 true,
158 TreeNodeRecursion::Jump,
159 ));
160 }
161 Ok(Transformed::no(expr))
162 })
163 .unwrap()
164 .data;
165
166 let body_sql = unparser.expr_to_sql(&transformed_body)?;
168
169 let (mut params, mut columns): (Vec<_>, Vec<_>) = param_map
171 .into_iter()
172 .map(|(p, c)| (p.id.trim_start_matches('$').to_string(), c))
173 .unzip();
174
175 let lambda_params = if params.len() == 1 {
177 ast::OneOrManyWithParens::One(ast::Ident::new(params.remove(0)))
178 } else {
179 ast::OneOrManyWithParens::Many(params.into_iter().map(ast::Ident::new).collect())
180 };
181
182 let column_params = if columns.len() == 1 {
183 let col = columns.remove(0);
184 vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
185 unparser
186 .expr_to_sql(&Expr::Column(col.clone()))
187 .unwrap_or_else(|_| ast::Expr::Identifier(ast::Ident::new(&col.name))),
188 ))]
189 } else {
190 columns
191 .into_iter()
192 .map(|c| {
193 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
194 unparser
195 .expr_to_sql(&Expr::Column(c.clone()))
196 .unwrap_or_else(|_| ast::Expr::Identifier(ast::Ident::new(&c.name))),
197 ))
198 })
199 .collect::<Vec<_>>()
200 };
201
202 let lambda_expr = ast::Expr::Lambda(ast::LambdaFunction {
204 params: lambda_params,
205 body: Box::new(body_sql),
206 });
207
208 let hof_args: Vec<ast::FunctionArg> = std::iter::once(
211 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(lambda_expr)),
213 )
214 .chain(column_params)
215 .collect();
216
217 Ok(ast::Expr::Function(ast::Function {
218 name: ast::ObjectName(vec![ast::ObjectNamePart::Identifier(
219 ast::Ident::new(name),
220 )]),
221 args: ast::FunctionArguments::List(ast::FunctionArgumentList {
222 duplicate_treatment: None,
223 args: hof_args,
224 clauses: vec![],
225 }),
226 filter: None,
227 null_treatment: None,
228 over: None,
229 within_group: vec![],
230 parameters: ast::FunctionArguments::None,
231 uses_odbc_syntax: false,
232 }))
233 }
234}
235
236pub(crate) fn is_clickhouse_lambda(expr: &Expr) -> bool {
237 let Expr::ScalarFunction(ScalarFunction { func, args }) = expr else {
238 return false;
239 };
240 CLICKHOUSE_APPLY_ALIASES.contains(&func.name())
241 || args.first().is_some_and(|a| matches!(a, Expr::Placeholder(_)))
242}
243
244pub(crate) fn extract_apply_args(
245 mut args: Vec<Expr>,
246) -> Result<(HashMap<Placeholder, Column>, Expr)> {
247 if args.len() < 3 {
248 return plan_err!(
249 "ClickHouseApplyUDF requires at least 3 arguments: placeholders, body, and column \
250 references"
251 );
252 }
253
254 let mut columns = Vec::with_capacity(args.len());
255
256 let body = loop {
258 match args.pop() {
259 Some(Expr::Column(col)) => columns.push(col),
260 Some(e) => break e,
261 None => {
262 return plan_err!("ClickHouseApplyUDF missing body expression");
263 }
264 }
265 };
266
267 let placeholders = args
269 .into_iter()
270 .map(
271 |e| if let Expr::Placeholder(p) = e { Ok(p) } else { plan_err!("Invalid placeholder") },
272 )
273 .collect::<Result<Vec<_>>>()?;
274
275 if columns.len() != placeholders.len() {
276 return plan_err!("Number of placeholders and columns must match");
277 }
278
279 let param_map = placeholders.into_iter().zip(columns).collect::<HashMap<_, _>>();
280
281 Ok((param_map, body))
282}
283
284#[cfg(test)]
285mod tests {
286 use std::sync::Arc;
287
288 use datafusion::arrow::datatypes::*;
289 use datafusion::common::ScalarValue;
290 use datafusion::logical_expr::{BinaryExpr, Operator, ReturnFieldArgs, ScalarFunctionArgs};
291 use datafusion::prelude::lit;
292 use datafusion::sql::TableReference;
293
294 use super::*;
295 use crate::udfs::placeholder::{PlaceholderUDF, placeholder_udf_from_placeholder};
296
297 #[test]
298 fn test_apply_udf() {
299 let udf = clickhouse_apply_udf();
300
301 assert!(udf.short_circuits());
303
304 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
306 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
307 let scalar = [
308 Some(ScalarValue::Utf8(Some("count()".to_string()))),
309 Some(ScalarValue::Utf8(Some("Int64".to_string()))),
310 ];
311 let args = ReturnFieldArgs {
312 arg_fields: &[field1, field2],
313 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
314 };
315
316 let result = udf.return_field_from_args(args);
317 assert!(result.is_ok());
318 let field = result.unwrap();
319 assert_eq!(field.name(), CLICKHOUSE_APPLY_ALIASES[0]);
320 assert_eq!(field.data_type(), &DataType::Int64);
321
322 let args = ScalarFunctionArgs {
324 args: vec![],
325 arg_fields: vec![],
326 number_rows: 1,
327 return_field: Arc::new(Field::new("", DataType::Int32, false)),
328 };
329 let result = udf.invoke_with_args(args);
330 assert!(result.is_err());
331 assert!(result.unwrap_err().to_string().contains("planning only"));
332 }
333
334 #[test]
335 fn test_apply_rewriter() {
336 let placeholder = Placeholder::new("$x".to_string(), None);
337
338 let result = extract_apply_args(vec![Expr::Placeholder(placeholder.clone())]);
340 assert!(result.is_err(), "Apply expects at least 3 args");
341 let result = extract_apply_args(vec![Expr::Column(Column::from_name("test"))]);
342 assert!(result.is_err(), "Apply expects a body arg before columns");
343 let exprs_fail = vec![
344 Expr::Placeholder(placeholder.clone()),
345 lit("1"),
346 Expr::Column(Column::from_name("test1")),
347 Expr::Column(Column::from_name("test2")),
348 ];
349 let result = extract_apply_args(exprs_fail);
350 assert!(result.is_err(), "Placeholder count must match column count");
351
352 let common_args = vec![
353 Expr::Placeholder(placeholder.clone()),
354 Expr::BinaryExpr(BinaryExpr {
355 left: Box::new(Expr::Placeholder(placeholder)),
356 op: Operator::Plus,
357 right: Box::new(lit(1)),
358 }),
359 Expr::Column(Column::new(None::<TableReference>, "test_col")),
360 lit("Int64"),
361 ];
362
363 let expr = Expr::ScalarFunction(ScalarFunction {
364 func: Arc::new(clickhouse_apply_udf()),
365 args: common_args.clone(),
366 });
367
368 let result = ClickHouseApplyRewriter::try_new(&expr);
369 assert!(result.is_err(), "Apply/Lambda must be a higher order function");
370
371 let expr = Expr::ScalarFunction(ScalarFunction {
374 func: Arc::new(clickhouse_apply_udf()),
375 args: vec![Expr::ScalarFunction(ScalarFunction {
376 func: Arc::new(placeholder_udf_from_placeholder(PlaceholderUDF::new("arrayMap"))),
377 args: common_args.clone(),
378 })],
379 });
380
381 let result = ClickHouseApplyRewriter::try_new(&expr);
382 assert!(result.is_ok(), "Apply/Lambda expected to be higher order function");
383 }
384}