clickhouse_datafusion/udfs/
eval.rs

1use std::any::Any;
2use std::str::FromStr;
3use std::sync::LazyLock;
4
5use datafusion::arrow::datatypes::{DataType, FieldRef};
6use datafusion::common::{ScalarValue, internal_err, not_impl_err, plan_datafusion_err, plan_err};
7use datafusion::error::Result;
8use datafusion::logical_expr::{
9    ColumnarValue, DocSection, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
10    ScalarUDFImpl, Signature, Volatility,
11};
12
13use super::udf_field_from_fields;
14
15static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
16    Documentation::builder(DocSection::default(), "Add one to an int32", "add_one(2)")
17        .with_argument("arg1", "The string representation of the ClickHouse function")
18        .with_argument("arg2", "The string representation of the expected DataType")
19        .build()
20});
21
22pub const CLICKHOUSE_EVAL_UDF_ALIASES: &[&str] = &["clickhouse_eval"];
23
24pub fn clickhouse_eval_udf() -> ScalarUDF { ScalarUDF::from(ClickHouseEval::new()) }
25
26fn get_doc() -> &'static Documentation { &DOCUMENTATION }
27
28// TODO: Docs - explain how this can be used if the full custom ClickHouseQueryPlanner CANNOT be
29// used. This provides an alternative syntax for specifying clickhouse functions
30//
31/// [`ClickHouseEval`] is an escape hatch to pass syntax that `DataFusion` does not support directly
32/// to `ClickHouse` using the string representation only.
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct ClickHouseEval {
35    signature: Signature,
36    aliases:   Vec<String>,
37}
38
39impl Default for ClickHouseEval {
40    fn default() -> Self { Self::new() }
41}
42
43impl ClickHouseEval {
44    pub const ARG_LEN: usize = 2;
45
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::uniform(
49                2,
50                vec![DataType::Utf8, DataType::Utf8View, DataType::LargeUtf8],
51                Volatility::Volatile,
52            ),
53            aliases:   CLICKHOUSE_EVAL_UDF_ALIASES.iter().map(ToString::to_string).collect(),
54        }
55    }
56}
57
58impl ScalarUDFImpl for ClickHouseEval {
59    fn as_any(&self) -> &dyn Any { self }
60
61    fn name(&self) -> &'static str { CLICKHOUSE_EVAL_UDF_ALIASES[0] }
62
63    fn aliases(&self) -> &[String] { &self.aliases }
64
65    fn signature(&self) -> &Signature { &self.signature }
66
67    /// # Errors
68    /// Returns an error if the arguments are invalid or the data type cannot be parsed.
69    ///
70    /// # Panics
71    /// Unwrap is used but it's guarded by a bounds check.
72    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
73        if arg_types.len() != 2 {
74            return plan_err!(
75                "Expected two string arguments, syntax and datatype, received fields {:?}",
76                arg_types
77            );
78        }
79
80        // Length is confirmed above, ok to unwrap
81        Ok(arg_types.get(1).cloned().unwrap())
82    }
83
84    /// # Errors
85    /// Returns an error if the arguments are invalid or the data type cannot be parsed.
86    ///
87    /// # Panics
88    /// Unwrap is used but it's guarded by a bounds check.
89    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
90        if args.arg_fields.len() != 2 || args.scalar_arguments.len() != 2 {
91            return plan_err!(
92                "Expected two string arguments, syntax and datatype, received fields {:?}",
93                args.arg_fields
94            );
95        }
96
97        // Length is confirmed above, ok to unwrap
98        let syntax_arg = args
99            .scalar_arguments
100            .first()
101            .unwrap()
102            .ok_or(plan_datafusion_err!("First argument (syntax) missing"))?;
103        let type_arg = args
104            .scalar_arguments
105            .get(1)
106            .unwrap()
107            .ok_or(plan_datafusion_err!("Second argument (data type) missing"))?;
108
109        if let (
110            ScalarValue::Utf8(syntax)
111            | ScalarValue::Utf8View(syntax)
112            | ScalarValue::LargeUtf8(syntax),
113            ScalarValue::Utf8(data_type)
114            | ScalarValue::Utf8View(data_type)
115            | ScalarValue::LargeUtf8(data_type),
116        ) = (syntax_arg, type_arg)
117        {
118            // Extract syntax string from first argument
119            if syntax.is_none() {
120                return internal_err!("Missing syntax argument");
121            }
122
123            // Extract type string from second argument
124            let Some(type_str) = data_type else {
125                return internal_err!("Missing data type argument");
126            };
127
128            // Parse type string to DataType
129            let data_type = DataType::from_str(type_str)
130                .map_err(|e| plan_datafusion_err!("Invalid type string: {e}"))?;
131            Ok(udf_field_from_fields(self.name(), data_type, args.arg_fields))
132        } else {
133            internal_err!("clickhouse_func expects string arguments")
134        }
135    }
136
137    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
138        not_impl_err!("UDFs are evaluated after data has been fetched.")
139    }
140
141    fn documentation(&self) -> Option<&Documentation> { Some(get_doc()) }
142}
143
144#[cfg(all(test, feature = "test-utils"))]
145mod tests {
146    use std::sync::Arc;
147
148    use datafusion::arrow;
149    use datafusion::arrow::datatypes::*;
150    use datafusion::common::ScalarValue;
151    use datafusion::config::ConfigOptions;
152    use datafusion::logical_expr::{ReturnFieldArgs, ScalarUDFImpl};
153    use datafusion::prelude::SessionContext;
154
155    use super::*;
156
157    #[test]
158    fn test_clickhouse_eval_new() {
159        let func = ClickHouseEval::new();
160        assert_eq!(func.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
161        assert_eq!(func.aliases(), CLICKHOUSE_EVAL_UDF_ALIASES);
162    }
163
164    #[test]
165    fn test_clickhouse_eval_default() {
166        let func = ClickHouseEval::default();
167        assert_eq!(func.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
168    }
169
170    #[test]
171    fn test_clickhouse_func_constants() {
172        assert_eq!(ClickHouseEval::ARG_LEN, 2);
173    }
174
175    #[test]
176    fn test_clickhouse_eval_udf_creation() {
177        let udf = clickhouse_eval_udf();
178        assert_eq!(udf.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
179    }
180
181    #[test]
182    fn test_return_type_valid_args() {
183        let func = ClickHouseEval::new();
184        let arg_types = vec![DataType::Utf8, DataType::Int32];
185        let result = func.return_type(&arg_types);
186        assert!(result.is_ok());
187        assert_eq!(result.unwrap(), DataType::Int32);
188    }
189
190    #[test]
191    fn test_return_type_valid_args_utf8_view() {
192        let func = ClickHouseEval::new();
193        let arg_types = vec![DataType::Utf8View, DataType::Float64];
194        let result = func.return_type(&arg_types);
195        assert!(result.is_ok());
196        assert_eq!(result.unwrap(), DataType::Float64);
197    }
198
199    #[test]
200    fn test_return_type_valid_args_large_utf8() {
201        let func = ClickHouseEval::new();
202        let arg_types = vec![DataType::LargeUtf8, DataType::Boolean];
203        let result = func.return_type(&arg_types);
204        assert!(result.is_ok());
205        assert_eq!(result.unwrap(), DataType::Boolean);
206    }
207
208    #[test]
209    fn test_return_type_wrong_arg_count() {
210        let func = ClickHouseEval::new();
211
212        // Too few arguments
213        let arg_types = vec![DataType::Utf8];
214        let result = func.return_type(&arg_types);
215        assert!(result.is_err());
216        assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
217
218        // Too many arguments
219        let arg_types = vec![DataType::Utf8, DataType::Int32, DataType::Float64];
220        let result = func.return_type(&arg_types);
221        assert!(result.is_err());
222        assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
223    }
224
225    #[test]
226    fn test_return_field_from_args_valid() {
227        let func = ClickHouseEval::new();
228        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
229        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
230        let scalar = [
231            Some(ScalarValue::Utf8(Some("count()".to_string()))),
232            Some(ScalarValue::Utf8(Some("Int64".to_string()))),
233        ];
234        let args = ReturnFieldArgs {
235            arg_fields:       &[field1, field2],
236            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
237        };
238
239        let result = func.return_field_from_args(args);
240        assert!(result.is_ok());
241        let field = result.unwrap();
242        assert_eq!(field.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
243        assert_eq!(field.data_type(), &DataType::Int64);
244        assert!(!field.is_nullable(), "Expect non-nullable - no nullable input fields");
245    }
246
247    #[test]
248    fn test_return_field_from_args_utf8_view() {
249        let func = ClickHouseEval::new();
250        let field1 = Arc::new(Field::new("syntax", DataType::Utf8View, false));
251        let field2 = Arc::new(Field::new("type", DataType::Utf8View, false));
252        let scalar = [
253            Some(ScalarValue::Utf8View(Some("sum(x)".to_string()))),
254            Some(ScalarValue::Utf8View(Some("Float64".to_string()))),
255        ];
256        let args = ReturnFieldArgs {
257            arg_fields:       &[field1, field2],
258            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
259        };
260
261        let result = func.return_field_from_args(args);
262        assert!(result.is_ok());
263        let field = result.unwrap();
264        assert_eq!(field.data_type(), &DataType::Float64);
265    }
266
267    #[test]
268    fn test_return_field_from_args_large_utf8() {
269        let func = ClickHouseEval::new();
270        let field1 = Arc::new(Field::new("syntax", DataType::LargeUtf8, false));
271        let field2 = Arc::new(Field::new("type", DataType::LargeUtf8, false));
272
273        let scalar = [
274            Some(ScalarValue::LargeUtf8(Some("avg(y)".to_string()))),
275            Some(ScalarValue::LargeUtf8(Some("Boolean".to_string()))),
276        ];
277        let args = ReturnFieldArgs {
278            arg_fields:       &[field1, field2],
279            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
280        };
281
282        let result = func.return_field_from_args(args);
283        assert!(result.is_ok());
284        let field = result.unwrap();
285        assert_eq!(field.data_type(), &DataType::Boolean);
286    }
287
288    #[test]
289    fn test_return_field_from_args_wrong_field_count() {
290        let func = ClickHouseEval::new();
291        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
292        let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string())))];
293        let args = ReturnFieldArgs {
294            arg_fields:       &[field1],
295            scalar_arguments: &[scalar[0].as_ref()],
296        };
297
298        let result = func.return_field_from_args(args);
299        assert!(result.is_err());
300        assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
301    }
302
303    #[test]
304    fn test_return_field_from_args_wrong_scalar_count() {
305        let func = ClickHouseEval::new();
306        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
307        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
308        let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string())))];
309        let args = ReturnFieldArgs {
310            arg_fields:       &[field1, field2],
311            scalar_arguments: &[scalar[0].as_ref()],
312        };
313
314        let result = func.return_field_from_args(args);
315        assert!(result.is_err());
316        assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
317    }
318
319    #[test]
320    fn test_return_field_from_args_missing_syntax() {
321        let func = ClickHouseEval::new();
322        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
323        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
324        let scalar = [None, Some(ScalarValue::Utf8(Some("Int64".to_string())))];
325        let args = ReturnFieldArgs {
326            arg_fields:       &[field1, field2],
327            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
328        };
329
330        let result = func.return_field_from_args(args);
331        assert!(result.is_err());
332        assert!(result.unwrap_err().to_string().contains("First argument (syntax) missing"));
333    }
334
335    #[test]
336    fn test_return_field_from_args_missing_type() {
337        let func = ClickHouseEval::new();
338        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
339        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
340        let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string()))), None];
341        let args = ReturnFieldArgs {
342            arg_fields:       &[field1, field2],
343            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
344        };
345
346        let result = func.return_field_from_args(args);
347        assert!(result.is_err());
348        assert!(result.unwrap_err().to_string().contains("Second argument (data type) missing"));
349    }
350
351    #[test]
352    fn test_return_field_from_args_null_syntax() {
353        let func = ClickHouseEval::new();
354        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
355        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
356        let scalar =
357            [Some(ScalarValue::Utf8(None)), Some(ScalarValue::Utf8(Some("Int64".to_string())))];
358        let args = ReturnFieldArgs {
359            arg_fields:       &[field1, field2],
360            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
361        };
362
363        let result = func.return_field_from_args(args);
364        assert!(result.is_err());
365        assert!(result.unwrap_err().to_string().contains("Missing syntax argument"));
366    }
367
368    #[test]
369    fn test_return_field_from_args_null_type() {
370        let func = ClickHouseEval::new();
371        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
372        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
373        let scalar =
374            [Some(ScalarValue::Utf8(Some("count()".to_string()))), Some(ScalarValue::Utf8(None))];
375        let args = ReturnFieldArgs {
376            arg_fields:       &[field1, field2],
377            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
378        };
379
380        let result = func.return_field_from_args(args);
381        assert!(result.is_err());
382        assert!(result.unwrap_err().to_string().contains("Missing data type argument"));
383    }
384
385    #[test]
386    fn test_return_field_from_args_invalid_type_string() {
387        let func = ClickHouseEval::new();
388        let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
389        let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
390        let scalar = [
391            Some(ScalarValue::Utf8(Some("count()".to_string()))),
392            Some(ScalarValue::Utf8(Some("InvalidType".to_string()))),
393        ];
394        let args = ReturnFieldArgs {
395            arg_fields:       &[field1, field2],
396            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
397        };
398
399        let result = func.return_field_from_args(args);
400        assert!(result.is_err());
401        assert!(result.unwrap_err().to_string().contains("Invalid type string"));
402    }
403
404    #[test]
405    fn test_return_field_from_args_non_string_arguments() {
406        let func = ClickHouseEval::new();
407        let field1 = Arc::new(Field::new("syntax", DataType::Int32, false));
408        let field2 = Arc::new(Field::new("type", DataType::Int32, false));
409        let scalar = [Some(ScalarValue::Int32(Some(42))), Some(ScalarValue::Int32(Some(24)))];
410        let args = ReturnFieldArgs {
411            arg_fields:       &[field1, field2],
412            scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
413        };
414
415        let result = func.return_field_from_args(args);
416        assert!(result.is_err());
417        assert!(
418            result.unwrap_err().to_string().contains("clickhouse_func expects string arguments")
419        );
420    }
421
422    #[test]
423    fn test_invoke_with_args_not_implemented() {
424        let func = ClickHouseEval::new();
425        let args = ScalarFunctionArgs {
426            args:           vec![],
427            arg_fields:     vec![],
428            number_rows:    1,
429            return_field:   Arc::new(Field::new("", DataType::Int32, false)),
430            config_options: Arc::new(ConfigOptions::default()),
431        };
432        let result = func.invoke_with_args(args);
433        assert!(result.is_err());
434        assert!(
435            result
436                .unwrap_err()
437                .to_string()
438                .contains("UDFs are evaluated after data has been fetched")
439        );
440    }
441
442    #[test]
443    fn test_documentation() {
444        let func = ClickHouseEval::new();
445        let doc = func.documentation();
446        assert!(doc.is_some());
447
448        let documentation = get_doc();
449        assert!(documentation.description.contains("Add one to an int32"));
450    }
451
452    #[test]
453    fn test_as_any() {
454        let func = ClickHouseEval::new();
455        let any_ref = func.as_any();
456        assert!(any_ref.downcast_ref::<ClickHouseEval>().is_some());
457    }
458
459    #[tokio::test]
460    async fn test_clickhouse_udf() -> Result<(), Box<dyn std::error::Error>> {
461        let ctx = SessionContext::new();
462        ctx.register_udf(clickhouse_eval_udf());
463
464        let schema = SchemaRef::new(Schema::new(vec![
465            Field::new("id", DataType::Int32, false),
466            Field::new("names", DataType::Utf8, false),
467        ]));
468
469        let provider =
470            Arc::new(datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![
471                arrow::record_batch::RecordBatch::try_new(schema, vec![
472                    Arc::new(arrow::array::Int32Array::from(vec![1])),
473                    Arc::new(arrow::array::StringArray::from(vec!["John,Jon,J"])),
474                ])?,
475            ]])?);
476        drop(ctx.register_table("people", provider)?);
477        let sql =
478            "SELECT id, clickhouse_eval('splitByChar('','', names)', 'List(Utf8)') FROM people";
479        let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
480        let results = df.collect().await?;
481        println!("EXPLAIN: {results:?}");
482        Ok(())
483    }
484}