1use std::any::Any;
2use std::str::FromStr;
3use std::sync::LazyLock;
4
5use datafusion::arrow::datatypes::{DataType, FieldRef};
6use datafusion::common::{ScalarValue, internal_err, not_impl_err, plan_datafusion_err, plan_err};
7use datafusion::error::Result;
8use datafusion::logical_expr::{
9 ColumnarValue, DocSection, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
10 ScalarUDFImpl, Signature, Volatility,
11};
12
13use super::udf_field_from_fields;
14
15static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
16 Documentation::builder(DocSection::default(), "Add one to an int32", "add_one(2)")
17 .with_argument("arg1", "The string representation of the ClickHouse function")
18 .with_argument("arg2", "The string representation of the expected DataType")
19 .build()
20});
21
22pub const CLICKHOUSE_EVAL_UDF_ALIASES: &[&str] = &["clickhouse_eval"];
23
24pub fn clickhouse_eval_udf() -> ScalarUDF { ScalarUDF::from(ClickHouseEval::new()) }
25
26fn get_doc() -> &'static Documentation { &DOCUMENTATION }
27
28#[derive(Debug)]
34pub struct ClickHouseEval {
35 signature: Signature,
36 aliases: Vec<String>,
37}
38
39impl Default for ClickHouseEval {
40 fn default() -> Self { Self::new() }
41}
42
43impl ClickHouseEval {
44 pub const ARG_LEN: usize = 2;
45
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::uniform(
49 2,
50 vec![DataType::Utf8, DataType::Utf8View, DataType::LargeUtf8],
51 Volatility::Volatile,
52 ),
53 aliases: CLICKHOUSE_EVAL_UDF_ALIASES.iter().map(ToString::to_string).collect(),
54 }
55 }
56}
57
58impl ScalarUDFImpl for ClickHouseEval {
59 fn as_any(&self) -> &dyn Any { self }
60
61 fn name(&self) -> &'static str { CLICKHOUSE_EVAL_UDF_ALIASES[0] }
62
63 fn aliases(&self) -> &[String] { &self.aliases }
64
65 fn signature(&self) -> &Signature { &self.signature }
66
67 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
73 if arg_types.len() != 2 {
74 return plan_err!(
75 "Expected two string arguments, syntax and datatype, received fields {:?}",
76 arg_types
77 );
78 }
79
80 Ok(arg_types.get(1).cloned().unwrap())
82 }
83
84 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
90 if args.arg_fields.len() != 2 || args.scalar_arguments.len() != 2 {
91 return plan_err!(
92 "Expected two string arguments, syntax and datatype, received fields {:?}",
93 args.arg_fields
94 );
95 }
96
97 let syntax_arg = args
99 .scalar_arguments
100 .first()
101 .unwrap()
102 .ok_or(plan_datafusion_err!("First argument (syntax) missing"))?;
103 let type_arg = args
104 .scalar_arguments
105 .get(1)
106 .unwrap()
107 .ok_or(plan_datafusion_err!("Second argument (data type) missing"))?;
108
109 if let (
110 ScalarValue::Utf8(syntax)
111 | ScalarValue::Utf8View(syntax)
112 | ScalarValue::LargeUtf8(syntax),
113 ScalarValue::Utf8(data_type)
114 | ScalarValue::Utf8View(data_type)
115 | ScalarValue::LargeUtf8(data_type),
116 ) = (syntax_arg, type_arg)
117 {
118 if syntax.is_none() {
120 return internal_err!("Missing syntax argument");
121 }
122
123 let Some(type_str) = data_type else {
125 return internal_err!("Missing data type argument");
126 };
127
128 let data_type = DataType::from_str(type_str)
130 .map_err(|e| plan_datafusion_err!("Invalid type string: {e}"))?;
131 Ok(udf_field_from_fields(self.name(), data_type, args.arg_fields))
132 } else {
133 internal_err!("clickhouse_func expects string arguments")
134 }
135 }
136
137 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
138 not_impl_err!("UDFs are evaluated after data has been fetched.")
139 }
140
141 fn documentation(&self) -> Option<&Documentation> { Some(get_doc()) }
142}
143
144#[cfg(all(test, feature = "test-utils"))]
145mod tests {
146 use std::sync::Arc;
147
148 use datafusion::arrow;
149 use datafusion::arrow::datatypes::*;
150 use datafusion::common::ScalarValue;
151 use datafusion::logical_expr::{ReturnFieldArgs, ScalarUDFImpl};
152 use datafusion::prelude::SessionContext;
153
154 use super::*;
155
156 #[test]
157 fn test_clickhouse_eval_new() {
158 let func = ClickHouseEval::new();
159 assert_eq!(func.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
160 assert_eq!(func.aliases(), CLICKHOUSE_EVAL_UDF_ALIASES);
161 }
162
163 #[test]
164 fn test_clickhouse_eval_default() {
165 let func = ClickHouseEval::default();
166 assert_eq!(func.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
167 }
168
169 #[test]
170 fn test_clickhouse_func_constants() {
171 assert_eq!(ClickHouseEval::ARG_LEN, 2);
172 }
173
174 #[test]
175 fn test_clickhouse_eval_udf_creation() {
176 let udf = clickhouse_eval_udf();
177 assert_eq!(udf.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
178 }
179
180 #[test]
181 fn test_return_type_valid_args() {
182 let func = ClickHouseEval::new();
183 let arg_types = vec![DataType::Utf8, DataType::Int32];
184 let result = func.return_type(&arg_types);
185 assert!(result.is_ok());
186 assert_eq!(result.unwrap(), DataType::Int32);
187 }
188
189 #[test]
190 fn test_return_type_valid_args_utf8_view() {
191 let func = ClickHouseEval::new();
192 let arg_types = vec![DataType::Utf8View, DataType::Float64];
193 let result = func.return_type(&arg_types);
194 assert!(result.is_ok());
195 assert_eq!(result.unwrap(), DataType::Float64);
196 }
197
198 #[test]
199 fn test_return_type_valid_args_large_utf8() {
200 let func = ClickHouseEval::new();
201 let arg_types = vec![DataType::LargeUtf8, DataType::Boolean];
202 let result = func.return_type(&arg_types);
203 assert!(result.is_ok());
204 assert_eq!(result.unwrap(), DataType::Boolean);
205 }
206
207 #[test]
208 fn test_return_type_wrong_arg_count() {
209 let func = ClickHouseEval::new();
210
211 let arg_types = vec![DataType::Utf8];
213 let result = func.return_type(&arg_types);
214 assert!(result.is_err());
215 assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
216
217 let arg_types = vec![DataType::Utf8, DataType::Int32, DataType::Float64];
219 let result = func.return_type(&arg_types);
220 assert!(result.is_err());
221 assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
222 }
223
224 #[test]
225 fn test_return_field_from_args_valid() {
226 let func = ClickHouseEval::new();
227 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
228 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
229 let scalar = [
230 Some(ScalarValue::Utf8(Some("count()".to_string()))),
231 Some(ScalarValue::Utf8(Some("Int64".to_string()))),
232 ];
233 let args = ReturnFieldArgs {
234 arg_fields: &[field1, field2],
235 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
236 };
237
238 let result = func.return_field_from_args(args);
239 assert!(result.is_ok());
240 let field = result.unwrap();
241 assert_eq!(field.name(), CLICKHOUSE_EVAL_UDF_ALIASES[0]);
242 assert_eq!(field.data_type(), &DataType::Int64);
243 assert!(!field.is_nullable(), "Expect non-nullable - no nullable input fields");
244 }
245
246 #[test]
247 fn test_return_field_from_args_utf8_view() {
248 let func = ClickHouseEval::new();
249 let field1 = Arc::new(Field::new("syntax", DataType::Utf8View, false));
250 let field2 = Arc::new(Field::new("type", DataType::Utf8View, false));
251 let scalar = [
252 Some(ScalarValue::Utf8View(Some("sum(x)".to_string()))),
253 Some(ScalarValue::Utf8View(Some("Float64".to_string()))),
254 ];
255 let args = ReturnFieldArgs {
256 arg_fields: &[field1, field2],
257 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
258 };
259
260 let result = func.return_field_from_args(args);
261 assert!(result.is_ok());
262 let field = result.unwrap();
263 assert_eq!(field.data_type(), &DataType::Float64);
264 }
265
266 #[test]
267 fn test_return_field_from_args_large_utf8() {
268 let func = ClickHouseEval::new();
269 let field1 = Arc::new(Field::new("syntax", DataType::LargeUtf8, false));
270 let field2 = Arc::new(Field::new("type", DataType::LargeUtf8, false));
271
272 let scalar = [
273 Some(ScalarValue::LargeUtf8(Some("avg(y)".to_string()))),
274 Some(ScalarValue::LargeUtf8(Some("Boolean".to_string()))),
275 ];
276 let args = ReturnFieldArgs {
277 arg_fields: &[field1, field2],
278 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
279 };
280
281 let result = func.return_field_from_args(args);
282 assert!(result.is_ok());
283 let field = result.unwrap();
284 assert_eq!(field.data_type(), &DataType::Boolean);
285 }
286
287 #[test]
288 fn test_return_field_from_args_wrong_field_count() {
289 let func = ClickHouseEval::new();
290 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
291 let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string())))];
292 let args = ReturnFieldArgs {
293 arg_fields: &[field1],
294 scalar_arguments: &[scalar[0].as_ref()],
295 };
296
297 let result = func.return_field_from_args(args);
298 assert!(result.is_err());
299 assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
300 }
301
302 #[test]
303 fn test_return_field_from_args_wrong_scalar_count() {
304 let func = ClickHouseEval::new();
305 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
306 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
307 let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string())))];
308 let args = ReturnFieldArgs {
309 arg_fields: &[field1, field2],
310 scalar_arguments: &[scalar[0].as_ref()],
311 };
312
313 let result = func.return_field_from_args(args);
314 assert!(result.is_err());
315 assert!(result.unwrap_err().to_string().contains("Expected two string arguments"));
316 }
317
318 #[test]
319 fn test_return_field_from_args_missing_syntax() {
320 let func = ClickHouseEval::new();
321 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
322 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
323 let scalar = [None, Some(ScalarValue::Utf8(Some("Int64".to_string())))];
324 let args = ReturnFieldArgs {
325 arg_fields: &[field1, field2],
326 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
327 };
328
329 let result = func.return_field_from_args(args);
330 assert!(result.is_err());
331 assert!(result.unwrap_err().to_string().contains("First argument (syntax) missing"));
332 }
333
334 #[test]
335 fn test_return_field_from_args_missing_type() {
336 let func = ClickHouseEval::new();
337 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
338 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
339 let scalar = [Some(ScalarValue::Utf8(Some("count()".to_string()))), None];
340 let args = ReturnFieldArgs {
341 arg_fields: &[field1, field2],
342 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
343 };
344
345 let result = func.return_field_from_args(args);
346 assert!(result.is_err());
347 assert!(result.unwrap_err().to_string().contains("Second argument (data type) missing"));
348 }
349
350 #[test]
351 fn test_return_field_from_args_null_syntax() {
352 let func = ClickHouseEval::new();
353 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
354 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
355 let scalar =
356 [Some(ScalarValue::Utf8(None)), Some(ScalarValue::Utf8(Some("Int64".to_string())))];
357 let args = ReturnFieldArgs {
358 arg_fields: &[field1, field2],
359 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
360 };
361
362 let result = func.return_field_from_args(args);
363 assert!(result.is_err());
364 assert!(result.unwrap_err().to_string().contains("Missing syntax argument"));
365 }
366
367 #[test]
368 fn test_return_field_from_args_null_type() {
369 let func = ClickHouseEval::new();
370 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
371 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
372 let scalar =
373 [Some(ScalarValue::Utf8(Some("count()".to_string()))), Some(ScalarValue::Utf8(None))];
374 let args = ReturnFieldArgs {
375 arg_fields: &[field1, field2],
376 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
377 };
378
379 let result = func.return_field_from_args(args);
380 assert!(result.is_err());
381 assert!(result.unwrap_err().to_string().contains("Missing data type argument"));
382 }
383
384 #[test]
385 fn test_return_field_from_args_invalid_type_string() {
386 let func = ClickHouseEval::new();
387 let field1 = Arc::new(Field::new("syntax", DataType::Utf8, false));
388 let field2 = Arc::new(Field::new("type", DataType::Utf8, false));
389 let scalar = [
390 Some(ScalarValue::Utf8(Some("count()".to_string()))),
391 Some(ScalarValue::Utf8(Some("InvalidType".to_string()))),
392 ];
393 let args = ReturnFieldArgs {
394 arg_fields: &[field1, field2],
395 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
396 };
397
398 let result = func.return_field_from_args(args);
399 assert!(result.is_err());
400 assert!(result.unwrap_err().to_string().contains("Invalid type string"));
401 }
402
403 #[test]
404 fn test_return_field_from_args_non_string_arguments() {
405 let func = ClickHouseEval::new();
406 let field1 = Arc::new(Field::new("syntax", DataType::Int32, false));
407 let field2 = Arc::new(Field::new("type", DataType::Int32, false));
408 let scalar = [Some(ScalarValue::Int32(Some(42))), Some(ScalarValue::Int32(Some(24)))];
409 let args = ReturnFieldArgs {
410 arg_fields: &[field1, field2],
411 scalar_arguments: &[scalar[0].as_ref(), scalar[1].as_ref()],
412 };
413
414 let result = func.return_field_from_args(args);
415 assert!(result.is_err());
416 assert!(
417 result.unwrap_err().to_string().contains("clickhouse_func expects string arguments")
418 );
419 }
420
421 #[test]
422 fn test_invoke_with_args_not_implemented() {
423 let func = ClickHouseEval::new();
424 let args = ScalarFunctionArgs {
425 args: vec![],
426 arg_fields: vec![],
427 number_rows: 1,
428 return_field: Arc::new(Field::new("", DataType::Int32, false)),
429 };
430 let result = func.invoke_with_args(args);
431 assert!(result.is_err());
432 assert!(
433 result
434 .unwrap_err()
435 .to_string()
436 .contains("UDFs are evaluated after data has been fetched")
437 );
438 }
439
440 #[test]
441 fn test_documentation() {
442 let func = ClickHouseEval::new();
443 let doc = func.documentation();
444 assert!(doc.is_some());
445
446 let documentation = get_doc();
447 assert!(documentation.description.contains("Add one to an int32"));
448 }
449
450 #[test]
451 fn test_as_any() {
452 let func = ClickHouseEval::new();
453 let any_ref = func.as_any();
454 assert!(any_ref.downcast_ref::<ClickHouseEval>().is_some());
455 }
456
457 #[tokio::test]
458 async fn test_clickhouse_udf() -> Result<(), Box<dyn std::error::Error>> {
459 let ctx = SessionContext::new();
460 ctx.register_udf(clickhouse_eval_udf());
461
462 let schema = SchemaRef::new(Schema::new(vec![
463 Field::new("id", DataType::Int32, false),
464 Field::new("names", DataType::Utf8, false),
465 ]));
466
467 let provider =
468 Arc::new(datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![
469 arrow::record_batch::RecordBatch::try_new(schema, vec![
470 Arc::new(arrow::array::Int32Array::from(vec![1])),
471 Arc::new(arrow::array::StringArray::from(vec!["John,Jon,J"])),
472 ])?,
473 ]])?);
474 drop(ctx.register_table("people", provider)?);
475 let sql =
476 "SELECT id, clickhouse_eval('splitByChar('','', names)', 'List(Utf8)') FROM people";
477 let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
478 let results = df.collect().await?;
479 println!("EXPLAIN: {results:?}");
480 Ok(())
481 }
482}