1use std::any::Any;
2use std::str::FromStr;
3use std::sync::LazyLock;
4
5use datafusion::arrow::datatypes::{DataType, FieldRef};
6use datafusion::common::{ScalarValue, internal_err, not_impl_err, plan_datafusion_err, plan_err};
7use datafusion::error::Result;
8use datafusion::logical_expr::{
9 ColumnarValue, DocSection, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
10 ScalarUDFImpl, Signature, Volatility,
11};
12
13use super::udf_field_from_fields;
14
15static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
16 Documentation::builder(DocSection::default(), "Add one to an int32", "add_one(2)")
17 .with_argument("arg1", "The string representation of the ClickHouse function")
18 .with_argument("arg2", "The string representation of the expected DataType")
19 .build()
20});
21
22pub const CLICKHOUSE_EVAL_UDF_ALIASES: &[&str] = &["clickhouse_eval"];
23
24pub fn clickhouse_eval_udf() -> ScalarUDF { ScalarUDF::from(ClickHouseEval::new()) }
25
26fn get_doc() -> &'static Documentation { &DOCUMENTATION }
27
28#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct ClickHouseEval {
35 signature: Signature,
36 aliases: Vec<String>,
37}
38
39impl Default for ClickHouseEval {
40 fn default() -> Self { Self::new() }
41}
42
43impl ClickHouseEval {
44 pub const ARG_LEN: usize = 2;
45
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::uniform(
49 2,
50 vec![DataType::Utf8, DataType::Utf8View, DataType::LargeUtf8],
51 Volatility::Volatile,
52 ),
53 aliases: CLICKHOUSE_EVAL_UDF_ALIASES.iter().map(ToString::to_string).collect(),
54 }
55 }
56}
57
58impl ScalarUDFImpl for ClickHouseEval {
59 fn as_any(&self) -> &dyn Any { self }
60
61 fn name(&self) -> &'static str { CLICKHOUSE_EVAL_UDF_ALIASES[0] }
62
63 fn aliases(&self) -> &[String] { &self.aliases }
64
65 fn signature(&self) -> &Signature { &self.signature }
66
67 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
73 if arg_types.len() != 2 {
74 return plan_err!(
75 "Expected two string arguments, syntax and datatype, received fields {:?}",
76 arg_types
77 );
78 }
79
80 Ok(arg_types.get(1).cloned().unwrap())
82 }
83
84 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
90 if args.arg_fields.len() != 2 || args.scalar_arguments.len() != 2 {
91 return plan_err!(
92 "Expected two string arguments, syntax and datatype, received fields {:?}",
93 args.arg_fields
94 );
95 }
96
97 let syntax_arg = args
99 .scalar_arguments
100 .first()
101 .unwrap()
102 .ok_or(plan_datafusion_err!("First argument (syntax) missing"))?;
103 let type_arg = args
104 .scalar_arguments
105 .get(1)
106 .unwrap()
107 .ok_or(plan_datafusion_err!("Second argument (data type) missing"))?;
108
109 if let (
110 ScalarValue::Utf8(syntax)
111 | ScalarValue::Utf8View(syntax)
112 | ScalarValue::LargeUtf8(syntax),
113 ScalarValue::Utf8(data_type)
114 | ScalarValue::Utf8View(data_type)
115 | ScalarValue::LargeUtf8(data_type),
116 ) = (syntax_arg, type_arg)
117 {
118 if syntax.is_none() {
120 return internal_err!("Missing syntax argument");
121 }
122
123 let Some(type_str) = data_type else {
125 return internal_err!("Missing data type argument");
126 };
127
128 let data_type = DataType::from_str(type_str)
130 .map_err(|e| plan_datafusion_err!("Invalid type string: {e}"))?;
131 Ok(udf_field_from_fields(self.name(), data_type, args.arg_fields))
132 } else {
133 internal_err!("clickhouse_func expects string arguments")
134 }
135 }
136
137 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
138 not_impl_err!("UDFs are evaluated after data has been fetched.")
139 }
140
141 fn documentation(&self) -> Option<&Documentation> { Some(get_doc()) }
142}
143
144#[cfg(all(test, feature = "test-utils"))]
145mod tests {
146 use std::sync::Arc;
147
148 use datafusion::arrow;
149 use datafusion::arrow::datatypes::*;
150 use datafusion::common::ScalarValue;
151 use datafusion::config::ConfigOptions;
152 use datafusion::logical_expr::{ReturnFieldArgs, ScalarUDFImpl};
153 use datafusion::prelude::SessionContext;
154
155 use super::*;
156
157 #[test]
158 fn test_clickhouse_eval_new() {
159 let func = ClickHouseEval::new();
160 assert_eq!(func.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
161 assert_eq!(func.aliases(), CLICKHOUSE_EVAL_UDF_ALIASES);
162 }
163
164 #[test]
165 fn test_clickhouse_eval_default() {
166 let func = ClickHouseEval::default();
167 assert_eq!(func.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
168 }
169
170 #[test]
171 fn test_clickhouse_func_constants() {
172 assert_eq!(ClickHouseEval::ARG_LEN, 2);
173 }
174
175 #[test]
176 fn test_clickhouse_eval_udf_creation() {
177 let udf = clickhouse_eval_udf();
178 assert_eq!(udf.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
179 }
180
181 #[test]
182 fn test_return_type_valid_args() {
183 let func = ClickHouseEval::new();
184 let arg_types = vec![DataType::Utf8, DataType::Int32];
185 let result = func.return_type(&arg_types);
186 assert!(result.is_ok());
187 assert_eq!(result.unwrap(), DataType::Int32);
188 }
189
190 #[test]
191 fn test_return_type_valid_args_utf8_view() {
192 let func = ClickHouseEval::new();
193 let arg_types = vec![DataType::Utf8View, DataType::Float64];
194 let result = func.return_type(&arg_types);
195 assert!(result.is_ok());
196 assert_eq!(result.unwrap(), DataType::Float64);
197 }
198
199 #[test]
200 fn test_return_type_valid_args_large_utf8() {
201 let func = ClickHouseEval::new();
202 let arg_types = vec![DataType::LargeUtf8, DataType::Boolean];
203 let result = func.return_type(&arg_types);
204 assert!(result.is_ok());
205 assert_eq!(result.unwrap(), DataType::Boolean);
206 }
207
208 #[test]
209 fn test_return_type_wrong_arg_count() {
210 let func = ClickHouseEval::new();
211
212 let arg_types = vec![DataType::Utf8];
214 let result = func.return_type(&arg_types);
215 assert!(result.is_err());
216 assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
217
218 let arg_types = vec![DataType::Utf8, DataType::Int32, DataType::Float64];
220 let result = func.return_type(&arg_types);
221 assert!(result.is_err());
222 assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
223 }
224
225 #[test]
226 fn test_return_field_from_args_valid() {
227 let func = ClickHouseEval::new();
228 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
229 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
230 let scalar = [
231 Some(ScalarValue::Utf8(Some("count()".to_string()))),
232 Some(ScalarValue::Utf8(Some("Int64".to_string()))),
233 ];
234 let args = ReturnFieldArgs {
235 arg_fields: &[field1, field2],
236 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
237 };
238
239 let result = func.return_field_from_args(args);
240 assert!(result.is_ok());
241 let field = result.unwrap();
242 assert_eq!(field.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
243 assert_eq!(field.data_type(), &DataType::Int64);
244 assert!(!field.is_nullable(), "Expect non-nullable - no nullable input fields");
245 }
246
247 #[test]
248 fn test_return_field_from_args_utf8_view() {
249 let func = ClickHouseEval::new();
250 let field1 = Arc::new(Field::new("syntax", DataType::Utf8View, false));
251 let field2 = Arc::new(Field::new("type", DataType::Utf8View, false));
252 let scalar = [
253 Some(ScalarValue::Utf8View(Some("sum(x)".to_string()))),
254 Some(ScalarValue::Utf8View(Some("Float64".to_string()))),
255 ];
256 let args = ReturnFieldArgs {
257 arg_fields: &[field1, field2],
258 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
259 };
260
261 let result = func.return_field_from_args(args);
262 assert!(result.is_ok());
263 let field = result.unwrap();
264 assert_eq!(field.data_type(), &DataType::Float64);
265 }
266
267 #[test]
268 fn test_return_field_from_args_large_utf8() {
269 let func = ClickHouseEval::new();
270 let field1 = Arc::new(Field::new("syntax", DataType::LargeUtf8, false));
271 let field2 = Arc::new(Field::new("type", DataType::LargeUtf8, false));
272
273 let scalar = [
274 Some(ScalarValue::LargeUtf8(Some("avg(y)".to_string()))),
275 Some(ScalarValue::LargeUtf8(Some("Boolean".to_string()))),
276 ];
277 let args = ReturnFieldArgs {
278 arg_fields: &[field1, field2],
279 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
280 };
281
282 let result = func.return_field_from_args(args);
283 assert!(result.is_ok());
284 let field = result.unwrap();
285 assert_eq!(field.data_type(), &DataType::Boolean);
286 }
287
288 #[test]
289 fn test_return_field_from_args_wrong_field_count() {
290 let func = ClickHouseEval::new();
291 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
292 let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string())))];
293 let args = ReturnFieldArgs {
294 arg_fields: &[field1],
295 scalar_arguments: &[scalar[0].as_ref()],
296 };
297
298 let result = func.return_field_from_args(args);
299 assert!(result.is_err());
300 assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
301 }
302
303 #[test]
304 fn test_return_field_from_args_wrong_scalar_count() {
305 let func = ClickHouseEval::new();
306 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
307 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
308 let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string())))];
309 let args = ReturnFieldArgs {
310 arg_fields: &[field1, field2],
311 scalar_arguments: &[scalar[0].as_ref()],
312 };
313
314 let result = func.return_field_from_args(args);
315 assert!(result.is_err());
316 assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
317 }
318
319 #[test]
320 fn test_return_field_from_args_missing_syntax() {
321 let func = ClickHouseEval::new();
322 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
323 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
324 let scalar = [None, Some(ScalarValue::Utf8(Some("Int64".to_string())))];
325 let args = ReturnFieldArgs {
326 arg_fields: &[field1, field2],
327 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
328 };
329
330 let result = func.return_field_from_args(args);
331 assert!(result.is_err());
332 assert!(result.unwrap_err().to_string().contains("First argument (syntax) missing"));
333 }
334
335 #[test]
336 fn test_return_field_from_args_missing_type() {
337 let func = ClickHouseEval::new();
338 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
339 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
340 let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string()))), None];
341 let args = ReturnFieldArgs {
342 arg_fields: &[field1, field2],
343 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
344 };
345
346 let result = func.return_field_from_args(args);
347 assert!(result.is_err());
348 assert!(result.unwrap_err().to_string().contains("Second argument (data type) missing"));
349 }
350
351 #[test]
352 fn test_return_field_from_args_null_syntax() {
353 let func = ClickHouseEval::new();
354 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
355 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
356 let scalar =
357 [Some(ScalarValue::Utf8(None)), Some(ScalarValue::Utf8(Some("Int64".to_string())))];
358 let args = ReturnFieldArgs {
359 arg_fields: &[field1, field2],
360 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
361 };
362
363 let result = func.return_field_from_args(args);
364 assert!(result.is_err());
365 assert!(result.unwrap_err().to_string().contains("Missing syntax argument"));
366 }
367
368 #[test]
369 fn test_return_field_from_args_null_type() {
370 let func = ClickHouseEval::new();
371 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
372 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
373 let scalar =
374 [Some(ScalarValue::Utf8(Some("count()".to_string()))), Some(ScalarValue::Utf8(None))];
375 let args = ReturnFieldArgs {
376 arg_fields: &[field1, field2],
377 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
378 };
379
380 let result = func.return_field_from_args(args);
381 assert!(result.is_err());
382 assert!(result.unwrap_err().to_string().contains("Missing data type argument"));
383 }
384
385 #[test]
386 fn test_return_field_from_args_invalid_type_string() {
387 let func = ClickHouseEval::new();
388 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
389 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
390 let scalar = [
391 Some(ScalarValue::Utf8(Some("count()".to_string()))),
392 Some(ScalarValue::Utf8(Some("InvalidType".to_string()))),
393 ];
394 let args = ReturnFieldArgs {
395 arg_fields: &[field1, field2],
396 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
397 };
398
399 let result = func.return_field_from_args(args);
400 assert!(result.is_err());
401 assert!(result.unwrap_err().to_string().contains("Invalid type string"));
402 }
403
404 #[test]
405 fn test_return_field_from_args_non_string_arguments() {
406 let func = ClickHouseEval::new();
407 let field1 = Arc::new(Field::new("syntax", DataType::Int32, false));
408 let field2 = Arc::new(Field::new("type", DataType::Int32, false));
409 let scalar = [Some(ScalarValue::Int32(Some(42))), Some(ScalarValue::Int32(Some(24)))];
410 let args = ReturnFieldArgs {
411 arg_fields: &[field1, field2],
412 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
413 };
414
415 let result = func.return_field_from_args(args);
416 assert!(result.is_err());
417 assert!(
418 result.unwrap_err().to_string().contains("clickhouse_func expects string arguments")
419 );
420 }
421
422 #[test]
423 fn test_invoke_with_args_not_implemented() {
424 let func = ClickHouseEval::new();
425 let args = ScalarFunctionArgs {
426 args: vec![],
427 arg_fields: vec![],
428 number_rows: 1,
429 return_field: Arc::new(Field::new("", DataType::Int32, false)),
430 config_options: Arc::new(ConfigOptions::default()),
431 };
432 let result = func.invoke_with_args(args);
433 assert!(result.is_err());
434 assert!(
435 result
436 .unwrap_err()
437 .to_string()
438 .contains("UDFs are evaluated after data has been fetched")
439 );
440 }
441
442 #[test]
443 fn test_documentation() {
444 let func = ClickHouseEval::new();
445 let doc = func.documentation();
446 assert!(doc.is_some());
447
448 let documentation = get_doc();
449 assert!(documentation.description.contains("Add one to an int32"));
450 }
451
452 #[test]
453 fn test_as_any() {
454 let func = ClickHouseEval::new();
455 let any_ref = func.as_any();
456 assert!(any_ref.downcast_ref::<ClickHouseEval>().is_some());
457 }
458
459 #[tokio::test]
460 async fn test_clickhouse_udf() -> Result<(), Box<dyn std::error::Error>> {
461 let ctx = SessionContext::new();
462 ctx.register_udf(clickhouse_eval_udf());
463
464 let schema = SchemaRef::new(Schema::new(vec![
465 Field::new("id", DataType::Int32, false),
466 Field::new("names", DataType::Utf8, false),
467 ]));
468
469 let provider =
470 Arc::new(datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![
471 arrow::record_batch::RecordBatch::try_new(schema, vec![
472 Arc::new(arrow::array::Int32Array::from(vec![1])),
473 Arc::new(arrow::array::StringArray::from(vec!["John,Jon,J"])),
474 ])?,
475 ]])?);
476 drop(ctx.register_table("people", provider)?);
477 let sql =
478 "SELECT id, clickhouse_eval('splitByChar('','', names)', 'List(Utf8)') FROM people";
479 let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
480 let results = df.collect().await?;
481 println!("EXPLAIN: {results:?}");
482 Ok(())
483 }
484}