clickhouse_datafusion/udfs/
eval.rs

1use std::any::Any;
2use std::str::FromStr;
3use std::sync::LazyLock;
4
5use datafusion::arrow::datatypes::{DataType, FieldRef};
6use datafusion::common::{ScalarValue, internal_err, not_impl_err, plan_datafusion_err, plan_err};
7use datafusion::error::Result;
8use datafusion::logical_expr::{
9    ColumnarValue, DocSection, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
10    ScalarUDFImpl, Signature, Volatility,
11};
12
13use super::udf_field_from_fields;
14
15static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
16    Documentation::builder(DocSection::default(), "Add one to an int32", "add_one(2)")
17        .with_argument("arg1", "The string representation of the ClickHouse function")
18        .with_argument("arg2", "The string representation of the expected DataType")
19        .build()
20});
21
22pub const CLICKHOUSE_EVAL_UDF_ALIASES: &[&str] = &["clickhouse_eval"];
23
24pub fn clickhouse_eval_udf() -> ScalarUDF { ScalarUDF::from(ClickHouseEval::new()) }
25
26fn get_doc() -> &'static Documentation { &DOCUMENTATION }
27
28// TODO: Docs - explain how this can be used if the full custom ClickHouseQueryPlanner CANNOT be
29// used. This provides an alternative syntax for specifying clickhouse functions
30//
31/// [`ClickHouseEval`] is an escape hatch to pass syntax that `DataFusion` does not support directly
32/// to `ClickHouse` using the string representation only.
33#[derive(Debug)]
34pub struct ClickHouseEval {
35    signature: Signature,
36    aliases:   Vec<String>,
37}
38
39impl Default for ClickHouseEval {
40    fn default() -> Self { Self::new() }
41}
42
43impl ClickHouseEval {
44    pub const ARG_LEN: usize = 2;
45
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::uniform(
49                2,
50                vec![DataType::Utf8, DataType::Utf8View, DataType::LargeUtf8],
51                Volatility::Volatile,
52            ),
53            aliases:   CLICKHOUSE_EVAL_UDF_ALIASES.iter().map(ToString::to_string).collect(),
54        }
55    }
56}
57
58impl ScalarUDFImpl for ClickHouseEval {
59    fn as_any(&self) -> &dyn Any { self }
60
61    fn name(&self) -> &'static str { CLICKHOUSE_EVAL_UDF_ALIASES[0] }
62
63    fn aliases(&self) -> &[String] { &self.aliases }
64
65    fn signature(&self) -> &Signature { &self.signature }
66
67    /// # Errors
68    /// Returns an error if the arguments are invalid or the data type cannot be parsed.
69    ///
70    /// # Panics
71    /// Unwrap is used but it's guarded by a bounds check.
72    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
73        if arg_types.len() != 2 {
74            return plan_err!(
75                "Expected two string arguments, syntax and datatype, received fields {:?}",
76                arg_types
77            );
78        }
79
80        // Length is confirmed above, ok to unwrap
81        Ok(arg_types.get(1).cloned().unwrap())
82    }
83
84    /// # Errors
85    /// Returns an error if the arguments are invalid or the data type cannot be parsed.
86    ///
87    /// # Panics
88    /// Unwrap is used but it's guarded by a bounds check.
89    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
90        if args.arg_fields.len() != 2 || args.scalar_arguments.len() != 2 {
91            return plan_err!(
92                "Expected two string arguments, syntax and datatype, received fields {:?}",
93                args.arg_fields
94            );
95        }
96
97        // Length is confirmed above, ok to unwrap
98        let syntax_arg = args
99            .scalar_arguments
100            .first()
101            .unwrap()
102            .ok_or(plan_datafusion_err!("First argument (syntax) missing"))?;
103        let type_arg = args
104            .scalar_arguments
105            .get(1)
106            .unwrap()
107            .ok_or(plan_datafusion_err!("Second argument (data type) missing"))?;
108
109        if let (
110            ScalarValue::Utf8(syntax)
111            | ScalarValue::Utf8View(syntax)
112            | ScalarValue::LargeUtf8(syntax),
113            ScalarValue::Utf8(data_type)
114            | ScalarValue::Utf8View(data_type)
115            | ScalarValue::LargeUtf8(data_type),
116        ) = (syntax_arg, type_arg)
117        {
118            // Extract syntax string from first argument
119            if syntax.is_none() {
120                return internal_err!("Missing syntax argument");
121            }
122
123            // Extract type string from second argument
124            let Some(type_str) = data_type else {
125                return internal_err!("Missing data type argument");
126            };
127
128            // Parse type string to DataType
129            let data_type = DataType::from_str(type_str)
130                .map_err(|e| plan_datafusion_err!("Invalid type string: {e}"))?;
131            Ok(udf_field_from_fields(self.name(), data_type, args.arg_fields))
132        } else {
133            internal_err!("clickhouse_func expects string arguments")
134        }
135    }
136
137    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
138        not_impl_err!("UDFs are evaluated after data has been fetched.")
139    }
140
141    fn documentation(&self) -> Option<&Documentation> { Some(get_doc()) }
142}
143
144#[cfg(all(test, feature = "test-utils"))]
145mod tests {
146    use std::sync::Arc;
147
148    use datafusion::arrow;
149    use datafusion::arrow::datatypes::*;
150    use datafusion::common::ScalarValue;
151    use datafusion::logical_expr::{ReturnFieldArgs, ScalarUDFImpl};
152    use datafusion::prelude::SessionContext;
153
154    use super::*;
155
156    #[test]
157    fn test_clickhouse_eval_new() {
158        let func = ClickHouseEval::new();
159        assert_eq!(func.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
160        assert_eq!(func.aliases(), CLICKHOUSE_EVAL_UDF_ALIASES);
161    }
162
163    #[test]
164    fn test_clickhouse_eval_default() {
165        let func = ClickHouseEval::default();
166        assert_eq!(func.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
167    }
168
169    #[test]
170    fn test_clickhouse_func_constants() {
171        assert_eq!(ClickHouseEval::ARG_LEN, 2);
172    }
173
174    #[test]
175    fn test_clickhouse_eval_udf_creation() {
176        let udf = clickhouse_eval_udf();
177        assert_eq!(udf.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
178    }
179
180    #[test]
181    fn test_return_type_valid_args() {
182        let func = ClickHouseEval::new();
183        let arg_types = vec![DataType::Utf8, DataType::Int32];
184        let result = func.return_type(&arg_types);
185        assert!(result.is_ok());
186        assert_eq!(result.unwrap(), DataType::Int32);
187    }
188
189    #[test]
190    fn test_return_type_valid_args_utf8_view() {
191        let func = ClickHouseEval::new();
192        let arg_types = vec![DataType::Utf8View, DataType::Float64];
193        let result = func.return_type(&arg_types);
194        assert!(result.is_ok());
195        assert_eq!(result.unwrap(), DataType::Float64);
196    }
197
198    #[test]
199    fn test_return_type_valid_args_large_utf8() {
200        let func = ClickHouseEval::new();
201        let arg_types = vec![DataType::LargeUtf8, DataType::Boolean];
202        let result = func.return_type(&arg_types);
203        assert!(result.is_ok());
204        assert_eq!(result.unwrap(), DataType::Boolean);
205    }
206
207    #[test]
208    fn test_return_type_wrong_arg_count() {
209        let func = ClickHouseEval::new();
210
211        // Too few arguments
212        let arg_types = vec![DataType::Utf8];
213        let result = func.return_type(&arg_types);
214        assert!(result.is_err());
215        assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
216
217        // Too many arguments
218        let arg_types = vec![DataType::Utf8, DataType::Int32, DataType::Float64];
219        let result = func.return_type(&arg_types);
220        assert!(result.is_err());
221        assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
222    }
223
224    #[test]
225    fn test_return_field_from_args_valid() {
226        let func = ClickHouseEval::new();
227        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
228        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
229        let scalar = [
230            Some(ScalarValue::Utf8(Some("count()".to_string()))),
231            Some(ScalarValue::Utf8(Some("Int64".to_string()))),
232        ];
233        let args = ReturnFieldArgs {
234            arg_fields:       &[field1, field2],
235            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
236        };
237
238        let result = func.return_field_from_args(args);
239        assert!(result.is_ok());
240        let field = result.unwrap();
241        assert_eq!(field.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
242        assert_eq!(field.data_type(), &DataType::Int64);
243        assert!(!field.is_nullable(), "Expect non-nullable - no nullable input fields");
244    }
245
246    #[test]
247    fn test_return_field_from_args_utf8_view() {
248        let func = ClickHouseEval::new();
249        let field1 = Arc::new(Field::new("syntax", DataType::Utf8View, false));
250        let field2 = Arc::new(Field::new("type", DataType::Utf8View, false));
251        let scalar = [
252            Some(ScalarValue::Utf8View(Some("sum(x)".to_string()))),
253            Some(ScalarValue::Utf8View(Some("Float64".to_string()))),
254        ];
255        let args = ReturnFieldArgs {
256            arg_fields:       &[field1, field2],
257            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
258        };
259
260        let result = func.return_field_from_args(args);
261        assert!(result.is_ok());
262        let field = result.unwrap();
263        assert_eq!(field.data_type(), &DataType::Float64);
264    }
265
266    #[test]
267    fn test_return_field_from_args_large_utf8() {
268        let func = ClickHouseEval::new();
269        let field1 = Arc::new(Field::new("syntax", DataType::LargeUtf8, false));
270        let field2 = Arc::new(Field::new("type", DataType::LargeUtf8, false));
271
272        let scalar = [
273            Some(ScalarValue::LargeUtf8(Some("avg(y)".to_string()))),
274            Some(ScalarValue::LargeUtf8(Some("Boolean".to_string()))),
275        ];
276        let args = ReturnFieldArgs {
277            arg_fields:       &[field1, field2],
278            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
279        };
280
281        let result = func.return_field_from_args(args);
282        assert!(result.is_ok());
283        let field = result.unwrap();
284        assert_eq!(field.data_type(), &DataType::Boolean);
285    }
286
287    #[test]
288    fn test_return_field_from_args_wrong_field_count() {
289        let func = ClickHouseEval::new();
290        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
291        let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string())))];
292        let args = ReturnFieldArgs {
293            arg_fields:       &[field1],
294            scalar_arguments: &[scalar[0].as_ref()],
295        };
296
297        let result = func.return_field_from_args(args);
298        assert!(result.is_err());
299        assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
300    }
301
302    #[test]
303    fn test_return_field_from_args_wrong_scalar_count() {
304        let func = ClickHouseEval::new();
305        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
306        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
307        let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string())))];
308        let args = ReturnFieldArgs {
309            arg_fields:       &[field1, field2],
310            scalar_arguments: &[scalar[0].as_ref()],
311        };
312
313        let result = func.return_field_from_args(args);
314        assert!(result.is_err());
315        assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
316    }
317
318    #[test]
319    fn test_return_field_from_args_missing_syntax() {
320        let func = ClickHouseEval::new();
321        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
322        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
323        let scalar = [None, Some(ScalarValue::Utf8(Some("Int64".to_string())))];
324        let args = ReturnFieldArgs {
325            arg_fields:       &[field1, field2],
326            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
327        };
328
329        let result = func.return_field_from_args(args);
330        assert!(result.is_err());
331        assert!(result.unwrap_err().to_string().contains("First argument (syntax) missing"));
332    }
333
334    #[test]
335    fn test_return_field_from_args_missing_type() {
336        let func = ClickHouseEval::new();
337        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
338        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
339        let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string()))), None];
340        let args = ReturnFieldArgs {
341            arg_fields:       &[field1, field2],
342            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
343        };
344
345        let result = func.return_field_from_args(args);
346        assert!(result.is_err());
347        assert!(result.unwrap_err().to_string().contains("Second argument (data type) missing"));
348    }
349
350    #[test]
351    fn test_return_field_from_args_null_syntax() {
352        let func = ClickHouseEval::new();
353        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
354        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
355        let scalar =
356            [Some(ScalarValue::Utf8(None)), Some(ScalarValue::Utf8(Some("Int64".to_string())))];
357        let args = ReturnFieldArgs {
358            arg_fields:       &[field1, field2],
359            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
360        };
361
362        let result = func.return_field_from_args(args);
363        assert!(result.is_err());
364        assert!(result.unwrap_err().to_string().contains("Missing syntax argument"));
365    }
366
367    #[test]
368    fn test_return_field_from_args_null_type() {
369        let func = ClickHouseEval::new();
370        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
371        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
372        let scalar =
373            [Some(ScalarValue::Utf8(Some("count()".to_string()))), Some(ScalarValue::Utf8(None))];
374        let args = ReturnFieldArgs {
375            arg_fields:       &[field1, field2],
376            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
377        };
378
379        let result = func.return_field_from_args(args);
380        assert!(result.is_err());
381        assert!(result.unwrap_err().to_string().contains("Missing data type argument"));
382    }
383
384    #[test]
385    fn test_return_field_from_args_invalid_type_string() {
386        let func = ClickHouseEval::new();
387        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
388        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
389        let scalar = [
390            Some(ScalarValue::Utf8(Some("count()".to_string()))),
391            Some(ScalarValue::Utf8(Some("InvalidType".to_string()))),
392        ];
393        let args = ReturnFieldArgs {
394            arg_fields:       &[field1, field2],
395            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
396        };
397
398        let result = func.return_field_from_args(args);
399        assert!(result.is_err());
400        assert!(result.unwrap_err().to_string().contains("Invalid type string"));
401    }
402
403    #[test]
404    fn test_return_field_from_args_non_string_arguments() {
405        let func = ClickHouseEval::new();
406        let field1 = Arc::new(Field::new("syntax", DataType::Int32, false));
407        let field2 = Arc::new(Field::new("type", DataType::Int32, false));
408        let scalar = [Some(ScalarValue::Int32(Some(42))), Some(ScalarValue::Int32(Some(24)))];
409        let args = ReturnFieldArgs {
410            arg_fields:       &[field1, field2],
411            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
412        };
413
414        let result = func.return_field_from_args(args);
415        assert!(result.is_err());
416        assert!(
417            result.unwrap_err().to_string().contains("clickhouse_func expects string arguments")
418        );
419    }
420
421    #[test]
422    fn test_invoke_with_args_not_implemented() {
423        let func = ClickHouseEval::new();
424        let args = ScalarFunctionArgs {
425            args:         vec![],
426            arg_fields:   vec![],
427            number_rows:  1,
428            return_field: Arc::new(Field::new("", DataType::Int32, false)),
429        };
430        let result = func.invoke_with_args(args);
431        assert!(result.is_err());
432        assert!(
433            result
434                .unwrap_err()
435                .to_string()
436                .contains("UDFs are evaluated after data has been fetched")
437        );
438    }
439
440    #[test]
441    fn test_documentation() {
442        let func = ClickHouseEval::new();
443        let doc = func.documentation();
444        assert!(doc.is_some());
445
446        let documentation = get_doc();
447        assert!(documentation.description.contains("Add one to an int32"));
448    }
449
450    #[test]
451    fn test_as_any() {
452        let func = ClickHouseEval::new();
453        let any_ref = func.as_any();
454        assert!(any_ref.downcast_ref::<ClickHouseEval>().is_some());
455    }
456
457    #[tokio::test]
458    async fn test_clickhouse_udf() -> Result<(), Box<dyn std::error::Error>> {
459        let ctx = SessionContext::new();
460        ctx.register_udf(clickhouse_eval_udf());
461
462        let schema = SchemaRef::new(Schema::new(vec![
463            Field::new("id", DataType::Int32, false),
464            Field::new("names", DataType::Utf8, false),
465        ]));
466
467        let provider =
468            Arc::new(datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![
469                arrow::record_batch::RecordBatch::try_new(schema, vec![
470                    Arc::new(arrow::array::Int32Array::from(vec![1])),
471                    Arc::new(arrow::array::StringArray::from(vec!["John,Jon,J"])),
472                ])?,
473            ]])?);
474        drop(ctx.register_table("people", provider)?);
475        let sql =
476            "SELECT id, clickhouse_eval('splitByChar('','', names)', 'List(Utf8)') FROM people";
477        let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
478        let results = df.collect().await?;
479        println!("EXPLAIN: {results:?}");
480        Ok(())
481    }
482}