1use std::panic::{AssertUnwindSafe, catch_unwind};
14use std::sync::Arc;
15
16use arrow::datatypes::{DataType, Schema};
17use arrow::record_batch::RecordBatch;
18use regex::Regex;
19use std::sync::LazyLock;
20
21static CREATE_FUNCTION_RE: LazyLock<Option<Regex>> = LazyLock::new(|| {
22 Regex::new(
23 r"(?is)^\s*CREATE\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+(\w+)\s*\(([^)]*)\)\s*RETURNS\s+TABLE\s*\(([^)]*)\)(?:\s+LANGUAGE\s+(\w+))?(?:\s+AS\s+'((?:[^']|'')*)')?\s*;?\s*$",
24 )
25 .ok()
26});
27
28use krishiv_plan::udf::{ScalarValue, TableUdf, UdfError};
29
30#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct ColumnDef {
37 pub name: String,
38 pub data_type: DataType,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct FunctionArgDef {
44 pub name: String,
45 pub data_type: DataType,
46}
47
48#[derive(Debug, Clone)]
50pub struct CreateFunctionDdl {
51 pub function_name: String,
53 pub arguments: Vec<FunctionArgDef>,
55 pub return_columns: Vec<ColumnDef>,
57 pub language: Option<String>,
59 pub body: Option<String>,
61}
62
63pub fn is_create_function_returns_table(sql: &str) -> bool {
72 let upper = sql.trim().to_ascii_uppercase();
73 (upper.starts_with("CREATE FUNCTION") || upper.starts_with("CREATE OR REPLACE FUNCTION"))
74 && upper.contains("RETURNS TABLE")
75}
76
77pub fn parse_create_function(sql: &str) -> Result<CreateFunctionDdl, String> {
82 let caps = CREATE_FUNCTION_RE
91 .as_ref()
92 .ok_or_else(|| "CREATE FUNCTION regex failed to compile".to_string())?
93 .captures(sql)
94 .ok_or_else(|| "SQL does not match CREATE FUNCTION … RETURNS TABLE pattern".to_string())?;
95
96 let function_name = caps
97 .get(1)
98 .map(|m| m.as_str().to_string())
99 .ok_or("could not extract function name")?;
100
101 let arg_list = caps.get(2).map(|m| m.as_str()).unwrap_or("");
102 let arguments = parse_argument_list(arg_list)?;
103
104 let col_list = caps.get(3).map(|m| m.as_str()).unwrap_or("");
105 let return_columns = parse_column_list(col_list)?;
106
107 let language = caps.get(4).map(|m| m.as_str().to_ascii_lowercase());
108 let body = caps.get(5).map(|m| m.as_str().replace("''", "'"));
109
110 Ok(CreateFunctionDdl {
111 function_name,
112 arguments,
113 return_columns,
114 language,
115 body,
116 })
117}
118
119fn parse_argument_list(list: &str) -> Result<Vec<FunctionArgDef>, String> {
120 parse_named_type_list(list, "argument")?
121 .into_iter()
122 .map(|(name, data_type)| Ok(FunctionArgDef { name, data_type }))
123 .collect()
124}
125
126fn parse_column_list(list: &str) -> Result<Vec<ColumnDef>, String> {
129 parse_named_type_list(list, "column")?
130 .into_iter()
131 .map(|(name, data_type)| Ok(ColumnDef { name, data_type }))
132 .collect()
133}
134
135fn parse_named_type_list(list: &str, item_kind: &str) -> Result<Vec<(String, DataType)>, String> {
136 let list = list.trim();
137 if list.is_empty() {
138 return Ok(Vec::new());
139 }
140 let mut parsed = Vec::new();
141 let mut names = std::collections::HashSet::new();
142 for item in list.split(',') {
143 let parts: Vec<&str> = item.split_whitespace().collect();
144 if parts.len() < 2 {
145 return Err(format!("invalid {item_kind} definition: '{item}'"));
146 }
147 let name = parts.first().copied().unwrap_or("").to_string();
148 if !names.insert(name.to_ascii_lowercase()) {
149 return Err(format!("duplicate {item_kind} name '{name}'"));
150 }
151 let type_str = parts.get(1..).unwrap_or(&[]).join(" ");
152 let data_type = sql_type_to_arrow(&type_str)?;
153 parsed.push((name, data_type));
154 }
155 Ok(parsed)
156}
157
158fn sql_type_to_arrow(type_str: &str) -> Result<DataType, String> {
163 match type_str.trim().to_ascii_uppercase().as_str() {
164 "BOOLEAN" | "BOOL" => Ok(DataType::Boolean),
165 "TINYINT" | "INT8" => Ok(DataType::Int8),
166 "SMALLINT" | "INT16" => Ok(DataType::Int16),
167 "INT" | "INTEGER" | "INT32" => Ok(DataType::Int32),
168 "BIGINT" | "INT64" | "LONG" => Ok(DataType::Int64),
169 "FLOAT" | "FLOAT32" | "REAL" => Ok(DataType::Float32),
170 "DOUBLE" | "FLOAT64" | "DOUBLE PRECISION" => Ok(DataType::Float64),
171 "TEXT" | "VARCHAR" | "STRING" | "CHARACTER VARYING" => Ok(DataType::Utf8),
172 "BYTEA" | "BYTES" | "BINARY" | "BLOB" => Ok(DataType::Binary),
173 "DATE" => Ok(DataType::Date32),
174 "TIMESTAMP" | "DATETIME" => Ok(DataType::Timestamp(
175 arrow::datatypes::TimeUnit::Microsecond,
176 None,
177 )),
178 _ => Err(format!(
179 "unsupported SQL type '{type_str}' in CREATE FUNCTION DDL"
180 )),
181 }
182}
183
184pub type UdtfBodyFn = Arc<dyn Fn(&[ScalarValue]) -> Result<RecordBatch, UdfError> + Send + Sync>;
190
191#[derive(Clone)]
193pub struct ClosureTableUdf {
194 pub(crate) name: String,
195 pub(crate) schema: Schema,
196 body_fn: UdtfBodyFn,
197}
198
199impl std::fmt::Debug for ClosureTableUdf {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 f.debug_struct("ClosureTableUdf")
202 .field("name", &self.name)
203 .field("schema", &self.schema)
204 .finish()
205 }
206}
207
208impl ClosureTableUdf {
209 pub fn try_new(
211 name: impl Into<String>,
212 schema: Schema,
213 body_fn: UdtfBodyFn,
214 ) -> Result<Self, UdfError> {
215 let name = name.into();
216 validate_udtf_definition(&name, &schema)?;
217 Ok(Self {
218 name,
219 schema,
220 body_fn,
221 })
222 }
223}
224
225impl TableUdf for ClosureTableUdf {
226 fn name(&self) -> &str {
227 &self.name
228 }
229
230 fn output_schema(&self) -> &Schema {
231 &self.schema
232 }
233
234 fn call(&self, args: &[ScalarValue]) -> Result<RecordBatch, UdfError> {
235 let batch =
236 catch_unwind(AssertUnwindSafe(|| (self.body_fn)(args))).map_err(|payload| {
237 let message = payload
238 .downcast_ref::<&str>()
239 .copied()
240 .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
241 .unwrap_or("unknown panic");
242 UdfError::Panic(format!("UDTF '{}': {message}", self.name))
243 })??;
244 if !schema_contract_matches(batch.schema().as_ref(), &self.schema) {
245 return Err(UdfError::Execution {
246 message: format!(
247 "UDTF '{}' returned schema {:?}, expected {:?}",
248 self.name,
249 batch.schema(),
250 self.schema
251 ),
252 });
253 }
254 Ok(batch)
255 }
256}
257
258fn validate_udtf_definition(name: &str, schema: &Schema) -> Result<(), UdfError> {
259 if name.trim().is_empty() {
260 return Err(UdfError::InvalidArgument {
261 message: String::from("UDTF name must not be empty"),
262 });
263 }
264 if schema.fields().is_empty() {
265 return Err(UdfError::InvalidArgument {
266 message: format!("UDTF '{name}' must declare at least one output column"),
267 });
268 }
269 let mut names = std::collections::HashSet::with_capacity(schema.fields().len());
270 for field in schema.fields() {
271 if field.name().trim().is_empty() {
272 return Err(UdfError::InvalidArgument {
273 message: format!("UDTF '{name}' contains an empty output column name"),
274 });
275 }
276 if !names.insert(field.name()) {
277 return Err(UdfError::InvalidArgument {
278 message: format!(
279 "UDTF '{name}' contains duplicate output column '{}'",
280 field.name()
281 ),
282 });
283 }
284 }
285 Ok(())
286}
287
288fn schema_contract_matches(actual: &Schema, expected: &Schema) -> bool {
289 actual.fields().len() == expected.fields().len()
290 && actual
291 .fields()
292 .iter()
293 .zip(expected.fields())
294 .all(|(actual, expected)| {
295 actual.name() == expected.name() && actual.data_type() == expected.data_type()
296 })
297}
298
299#[derive(Clone)]
305pub struct SqlBodyTableUdf {
306 pub(crate) name: String,
307 pub(crate) schema: Schema,
308 body_sql: String,
309 argument_count: usize,
310 ctx: Arc<datafusion::prelude::SessionContext>,
311}
312
313impl std::fmt::Debug for SqlBodyTableUdf {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("SqlBodyTableUdf")
316 .field("name", &self.name)
317 .field("body_sql", &self.body_sql)
318 .finish()
319 }
320}
321
322impl SqlBodyTableUdf {
323 pub fn try_new(
324 name: impl Into<String>,
325 schema: Schema,
326 body_sql: impl Into<String>,
327 argument_count: usize,
328 ctx: Arc<datafusion::prelude::SessionContext>,
329 ) -> Result<Self, UdfError> {
330 let name = name.into();
331 validate_udtf_definition(&name, &schema)?;
332 let body_sql = body_sql.into();
333 if body_sql.trim().is_empty() {
334 return Err(UdfError::InvalidArgument {
335 message: format!("SQL UDTF '{name}' body must not be empty"),
336 });
337 }
338 let placeholder_args = vec![ScalarValue::Null; argument_count];
339 bind_sql_body_args(&body_sql, &placeholder_args)?;
340 Ok(Self {
341 name,
342 schema,
343 body_sql,
344 argument_count,
345 ctx,
346 })
347 }
348}
349
350impl TableUdf for SqlBodyTableUdf {
351 fn name(&self) -> &str {
352 &self.name
353 }
354
355 fn output_schema(&self) -> &Schema {
356 &self.schema
357 }
358
359 fn call(&self, args: &[ScalarValue]) -> Result<RecordBatch, UdfError> {
360 if args.len() != self.argument_count {
361 return Err(UdfError::InvalidArgument {
362 message: format!(
363 "UDTF '{}' expects {} arguments, got {}",
364 self.name,
365 self.argument_count,
366 args.len()
367 ),
368 });
369 }
370
371 let ctx = Arc::clone(&self.ctx);
374 let sql = bind_sql_body_args(&self.body_sql, args)?;
375 let schema = Arc::new(self.schema.clone());
376 let handle =
377 tokio::runtime::Handle::try_current().map_err(|error| UdfError::Execution {
378 message: format!(
379 "SQL UDTF '{}' requires an active Tokio runtime: {error}",
380 self.name
381 ),
382 })?;
383 if !matches!(
384 handle.runtime_flavor(),
385 tokio::runtime::RuntimeFlavor::MultiThread
386 ) {
387 return Err(UdfError::Execution {
388 message: format!(
389 "SQL UDTF '{}' requires a multi-thread Tokio runtime",
390 self.name
391 ),
392 });
393 }
394 catch_unwind(AssertUnwindSafe(|| {
395 tokio::task::block_in_place(|| {
396 handle.block_on(async {
397 let df = ctx.sql(&sql).await.map_err(|e| UdfError::Execution {
398 message: e.to_string(),
399 })?;
400 let batches = df.collect().await.map_err(|e| UdfError::Execution {
401 message: e.to_string(),
402 })?;
403 if batches.is_empty() {
404 return Ok(RecordBatch::new_empty(schema));
405 }
406 let batch = arrow::compute::concat_batches(
407 &batches
408 .first()
409 .ok_or_else(|| UdfError::Execution {
410 message: "empty batch list".into(),
411 })?
412 .schema(),
413 &batches,
414 )
415 .map_err(|e| UdfError::Arrow(e.to_string()))?;
416 if !schema_contract_matches(batch.schema().as_ref(), schema.as_ref()) {
417 return Err(UdfError::Execution {
418 message: format!(
419 "SQL UDTF '{}' returned schema {:?}, expected {:?}",
420 self.name,
421 batch.schema(),
422 schema
423 ),
424 });
425 }
426 Ok(batch)
427 })
428 })
429 }))
430 .map_err(|payload| {
431 let message = payload
432 .downcast_ref::<&str>()
433 .copied()
434 .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
435 .unwrap_or("unknown panic");
436 UdfError::Panic(format!("SQL UDTF '{}': {message}", self.name))
437 })?
438 }
439}
440
441fn bind_sql_body_args(sql: &str, args: &[ScalarValue]) -> Result<String, UdfError> {
442 let bytes = sql.as_bytes();
443 let mut output = String::with_capacity(sql.len());
444 let mut index = 0;
445
446 while index < bytes.len() {
447 let Some(&byte) = bytes.get(index) else {
448 break;
449 };
450 match byte {
451 b'\'' | b'"' | b'`' => {
452 index = copy_quoted_segment(sql, index, byte, &mut output)?;
453 }
454 b'-' if bytes.get(index + 1) == Some(&b'-') => {
455 let end = sql[index..]
456 .find('\n')
457 .map_or(bytes.len(), |offset| index + offset + 1);
458 output.push_str(&sql[index..end]);
459 index = end;
460 }
461 b'/' if bytes.get(index + 1) == Some(&b'*') => {
462 index = copy_block_comment(sql, index, &mut output)?;
463 }
464 b'$' => {
465 if let Some((delimiter, end)) = dollar_quote_delimiter(sql, index) {
466 let body_start = end;
467 let close_offset = sql[body_start..].find(delimiter).ok_or_else(|| {
468 UdfError::InvalidArgument {
469 message: "unterminated dollar-quoted SQL body".to_owned(),
470 }
471 })?;
472 let segment_end = body_start + close_offset + delimiter.len();
473 output.push_str(&sql[index..segment_end]);
474 index = segment_end;
475 continue;
476 }
477
478 let digit_start = index + 1;
479 let mut end = digit_start;
480 while bytes.get(end).is_some_and(u8::is_ascii_digit) {
481 end += 1;
482 }
483 if end == digit_start {
484 output.push('$');
485 index += 1;
486 continue;
487 }
488
489 let placeholder = sql[digit_start..end].parse::<usize>().map_err(|error| {
490 UdfError::InvalidArgument {
491 message: format!(
492 "invalid SQL UDTF placeholder '{}': {error}",
493 &sql[index..end]
494 ),
495 }
496 })?;
497 if placeholder == 0 {
498 return Err(UdfError::InvalidArgument {
499 message: "SQL UDTF placeholders are 1-based; $0 is invalid".to_owned(),
500 });
501 }
502 let value = args.get(placeholder - 1).ok_or_else(|| UdfError::InvalidArgument {
503 message: format!(
504 "SQL UDTF placeholder ${placeholder} has no matching argument; got {} arguments",
505 args.len()
506 ),
507 })?;
508 output.push_str(&scalar_to_sql_literal(value)?);
509 index = end;
510 }
511 _ => {
512 let ch = sql[index..]
513 .chars()
514 .next()
515 .ok_or_else(|| UdfError::InvalidArgument {
516 message: "unexpected end of SQL string".to_owned(),
517 })?;
518 output.push(ch);
519 index += ch.len_utf8();
520 }
521 }
522 }
523
524 Ok(output)
525}
526
527fn copy_quoted_segment(
528 sql: &str,
529 start: usize,
530 quote: u8,
531 output: &mut String,
532) -> Result<usize, UdfError> {
533 let bytes = sql.as_bytes();
534 let mut index = start + 1;
535 while index < bytes.len() {
536 let Some(&b) = bytes.get(index) else {
537 break;
538 };
539 if b == quote {
540 index += 1;
541 if bytes.get(index) == Some("e) {
542 index += 1;
543 continue;
544 }
545 output.push_str(&sql[start..index]);
546 return Ok(index);
547 }
548 let ch = sql[index..]
549 .chars()
550 .next()
551 .ok_or_else(|| UdfError::InvalidArgument {
552 message: "unexpected end of SQL string".to_owned(),
553 })?;
554 index += ch.len_utf8();
555 }
556 Err(UdfError::InvalidArgument {
557 message: "unterminated quoted SQL segment".to_owned(),
558 })
559}
560
561fn copy_block_comment(sql: &str, start: usize, output: &mut String) -> Result<usize, UdfError> {
562 let bytes = sql.as_bytes();
563 let mut index = start + 2;
564 let mut depth = 1usize;
565 while index < bytes.len() {
566 if bytes.get(index) == Some(&b'/') && bytes.get(index + 1) == Some(&b'*') {
567 depth += 1;
568 index += 2;
569 } else if bytes.get(index) == Some(&b'*') && bytes.get(index + 1) == Some(&b'/') {
570 depth -= 1;
571 index += 2;
572 if depth == 0 {
573 output.push_str(&sql[start..index]);
574 return Ok(index);
575 }
576 } else {
577 let ch = sql[index..]
578 .chars()
579 .next()
580 .ok_or_else(|| UdfError::InvalidArgument {
581 message: "unexpected end of SQL string".to_owned(),
582 })?;
583 index += ch.len_utf8();
584 }
585 }
586 Err(UdfError::InvalidArgument {
587 message: "unterminated SQL block comment".to_owned(),
588 })
589}
590
591fn dollar_quote_delimiter(sql: &str, start: usize) -> Option<(&str, usize)> {
592 let bytes = sql.as_bytes();
593 if bytes.get(start) != Some(&b'$') {
594 return None;
595 }
596 let mut index = start + 1;
597 if bytes.get(index) == Some(&b'$') {
598 return Some((&sql[start..=index], index + 1));
599 }
600 let first = *bytes.get(index)?;
601 if !first.is_ascii_alphabetic() && first != b'_' {
602 return None;
603 }
604 index += 1;
605 while bytes
606 .get(index)
607 .is_some_and(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
608 {
609 index += 1;
610 }
611 if bytes.get(index) == Some(&b'$') {
612 Some((&sql[start..=index], index + 1))
613 } else {
614 None
615 }
616}
617
618fn scalar_to_sql_literal(value: &ScalarValue) -> Result<String, UdfError> {
619 match value {
620 ScalarValue::Null => Ok("NULL".to_owned()),
621 ScalarValue::Int64(value) => Ok(value.to_string()),
622 ScalarValue::Float64(value) if value.is_finite() => Ok(value.to_string()),
623 ScalarValue::Float64(value) => Err(UdfError::InvalidArgument {
624 message: format!("non-finite floating-point UDTF argument {value} is not supported"),
625 }),
626 ScalarValue::Utf8(value) => Ok(format!("'{}'", value.replace('\'', "''"))),
627 ScalarValue::Boolean(value) => Ok(if *value { "TRUE" } else { "FALSE" }.to_owned()),
628 ScalarValue::Bytes(_) => Err(UdfError::InvalidArgument {
629 message: "binary UDTF arguments are not supported in SQL bodies".to_owned(),
630 }),
631 }
632}
633
634#[cfg(test)]
639#[allow(clippy::unwrap_used, clippy::expect_used)]
640mod tests {
641 use super::*;
642 use arrow::array::{ArrayRef, Int64Array};
643 use arrow::datatypes::{DataType, Field};
644
645 const BASIC_DDL: &str = "
646 CREATE FUNCTION my_udtf(arg1 INT)
647 RETURNS TABLE (col1 TEXT, col2 BIGINT)
648 LANGUAGE RUST
649 AS 'fn my_udtf(arg1: i64) -> Vec<Row> { vec![] }'
650 ";
651
652 #[test]
653 fn detects_create_function_returns_table() {
654 assert!(is_create_function_returns_table(BASIC_DDL));
655 assert!(is_create_function_returns_table(
657 "CREATE OR REPLACE FUNCTION g(x INT) RETURNS TABLE (v TEXT)"
658 ));
659 assert!(!is_create_function_returns_table("SELECT 1"));
661 assert!(!is_create_function_returns_table(
663 "CREATE FUNCTION f(x INT) RETURNS INT LANGUAGE SQL AS 'SELECT x'"
664 ));
665 }
666
667 #[test]
668 fn parses_function_name() {
669 let ddl = parse_create_function(BASIC_DDL).expect("should parse");
670 assert_eq!(ddl.function_name, "my_udtf");
671 }
672
673 #[test]
674 fn parses_typed_arguments() {
675 let ddl = parse_create_function(
676 "CREATE FUNCTION typed_args(count BIGINT, label TEXT, enabled BOOLEAN) \
677 RETURNS TABLE (value TEXT) LANGUAGE SQL AS 'SELECT $2 AS value'",
678 )
679 .expect("should parse");
680 assert_eq!(
681 ddl.arguments,
682 vec![
683 FunctionArgDef {
684 name: "count".to_owned(),
685 data_type: DataType::Int64,
686 },
687 FunctionArgDef {
688 name: "label".to_owned(),
689 data_type: DataType::Utf8,
690 },
691 FunctionArgDef {
692 name: "enabled".to_owned(),
693 data_type: DataType::Boolean,
694 },
695 ]
696 );
697 }
698
699 #[test]
700 fn parses_return_columns() {
701 let ddl = parse_create_function(BASIC_DDL).expect("should parse");
702 assert_eq!(ddl.return_columns.len(), 2);
703 assert_eq!(ddl.return_columns[0].name, "col1");
704 assert_eq!(ddl.return_columns[0].data_type, DataType::Utf8);
705 assert_eq!(ddl.return_columns[1].name, "col2");
706 assert_eq!(ddl.return_columns[1].data_type, DataType::Int64);
707 }
708
709 #[test]
710 fn parses_language_and_body() {
711 let ddl = parse_create_function(BASIC_DDL).expect("should parse");
712 assert_eq!(ddl.language.as_deref(), Some("rust"));
713 assert!(ddl.body.is_some());
714 }
715
716 #[test]
717 fn parses_without_language_and_body() {
718 let sql = "CREATE FUNCTION simple(x INT) RETURNS TABLE (val BIGINT)";
719 let ddl = parse_create_function(sql).expect("should parse");
720 assert_eq!(ddl.function_name, "simple");
721 assert_eq!(ddl.return_columns.len(), 1);
722 assert_eq!(ddl.language, None);
723 assert_eq!(ddl.body, None);
724 }
725
726 #[test]
727 fn parses_or_replace_variant() {
728 let sql = "CREATE OR REPLACE FUNCTION f(x INT) RETURNS TABLE (a TEXT, b INT)";
729 let ddl = parse_create_function(sql).expect("should parse");
730 assert_eq!(ddl.function_name, "f");
731 assert_eq!(ddl.return_columns.len(), 2);
732 }
733
734 #[test]
735 fn parser_rejects_trailing_unparsed_sql() {
736 let error = parse_create_function(&format!("{BASIC_DDL} SELECT 1"))
737 .expect_err("trailing SQL must not be ignored");
738 assert!(error.contains("does not match"));
739 }
740
741 #[test]
742 fn parser_rejects_duplicate_argument_and_output_names() {
743 let duplicate_arg = parse_create_function(
744 "CREATE FUNCTION f(value INT, VALUE BIGINT) \
745 RETURNS TABLE (result BIGINT) LANGUAGE SQL AS 'SELECT 1 AS result'",
746 )
747 .expect_err("argument names are case-insensitively unique");
748 assert!(duplicate_arg.contains("duplicate argument"));
749
750 let duplicate_output = parse_create_function(
751 "CREATE FUNCTION f() RETURNS TABLE (value INT, VALUE BIGINT) \
752 LANGUAGE SQL AS 'SELECT 1 AS value, 2 AS VALUE'",
753 )
754 .expect_err("output names are case-insensitively unique");
755 assert!(duplicate_output.contains("duplicate column"));
756 }
757
758 #[test]
759 fn closure_table_udf_executes_and_validates_output_schema() {
760 let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]);
761 let udf = ClosureTableUdf::try_new(
762 "values",
763 schema.clone(),
764 Arc::new({
765 let schema = Arc::new(schema);
766 move |_| {
767 RecordBatch::try_new(
768 Arc::clone(&schema),
769 vec![Arc::new(Int64Array::from(vec![1_i64, 2])) as ArrayRef],
770 )
771 .map_err(UdfError::from)
772 }
773 }),
774 )
775 .unwrap();
776
777 let batch = udf.call(&[]).unwrap();
778 assert_eq!(batch.num_rows(), 2);
779
780 let wrong_schema = ClosureTableUdf::try_new(
781 "wrong",
782 Schema::new(vec![Field::new("expected", DataType::Int64, false)]),
783 Arc::new(|_| {
784 RecordBatch::try_new(
785 Arc::new(Schema::new(vec![Field::new(
786 "actual",
787 DataType::Int64,
788 false,
789 )])),
790 vec![Arc::new(Int64Array::from(vec![1_i64])) as ArrayRef],
791 )
792 .map_err(UdfError::from)
793 }),
794 )
795 .unwrap();
796 assert!(matches!(
797 wrong_schema.call(&[]),
798 Err(UdfError::Execution { .. })
799 ));
800 }
801
802 #[test]
803 fn closure_table_udf_contains_panics() {
804 let udf = ClosureTableUdf::try_new(
805 "panic_udtf",
806 Schema::new(vec![Field::new("value", DataType::Int64, false)]),
807 Arc::new(|_| -> Result<RecordBatch, UdfError> { panic!("boom") }),
808 )
809 .unwrap();
810
811 assert!(matches!(udf.call(&[]), Err(UdfError::Panic(_))));
812 }
813
814 #[test]
815 fn sql_body_udtf_without_runtime_returns_typed_error() {
816 let udf = SqlBodyTableUdf::try_new(
817 "runtime_required",
818 Schema::new(vec![Field::new("value", DataType::Int64, false)]),
819 "SELECT 1 AS value",
820 0,
821 Arc::new(datafusion::prelude::SessionContext::new()),
822 )
823 .unwrap();
824
825 let error = udf
826 .call(&[])
827 .expect_err("missing Tokio runtime must not panic");
828 assert!(matches!(error, UdfError::Execution { .. }));
829 }
830
831 #[test]
832 fn sql_body_binding_replaces_only_unquoted_placeholders() {
833 let sql = "SELECT $1 AS n, '$1' AS literal, \"$2\" AS quoted, /* $2 */ $2 AS text";
834 let bound = bind_sql_body_args(
835 sql,
836 &[
837 ScalarValue::Int64(42),
838 ScalarValue::Utf8("O'Reilly".to_owned()),
839 ],
840 )
841 .expect("binding should succeed");
842 assert_eq!(
843 bound,
844 "SELECT 42 AS n, '$1' AS literal, \"$2\" AS quoted, /* $2 */ 'O''Reilly' AS text"
845 );
846 }
847
848 #[test]
849 fn sql_body_binding_preserves_comments_and_dollar_quoted_segments() {
850 let sql = "SELECT $$body $1$$ AS body, -- $1\n$1 AS value";
851 let bound =
852 bind_sql_body_args(sql, &[ScalarValue::Boolean(true)]).expect("binding should succeed");
853 assert_eq!(bound, "SELECT $$body $1$$ AS body, -- $1\nTRUE AS value");
854 }
855
856 #[test]
857 fn sql_body_binding_rejects_invalid_placeholders_and_values() {
858 let zero = bind_sql_body_args("SELECT $0", &[ScalarValue::Int64(1)])
859 .expect_err("$0 must be rejected");
860 assert!(zero.to_string().contains("1-based"));
861
862 let missing = bind_sql_body_args("SELECT $2", &[ScalarValue::Int64(1)])
863 .expect_err("missing arguments must be rejected");
864 assert!(missing.to_string().contains("no matching argument"));
865
866 let binary = bind_sql_body_args("SELECT $1", &[ScalarValue::Bytes(vec![1, 2])])
867 .expect_err("binary SQL literals must be rejected");
868 assert!(binary.to_string().contains("binary"));
869 }
870
871 #[test]
872 fn rejects_non_matching_sql() {
873 let result = parse_create_function("SELECT 1");
874 assert!(result.is_err());
875 }
876
877 #[test]
878 fn all_supported_types_map() {
879 let ddl = parse_create_function(
880 "CREATE FUNCTION typed(x INT) RETURNS TABLE (
881 a BOOLEAN,
882 b TINYINT,
883 c SMALLINT,
884 d INT,
885 e BIGINT,
886 f FLOAT,
887 g DOUBLE,
888 h TEXT,
889 i BYTEA,
890 j DATE,
891 k TIMESTAMP
892 )",
893 )
894 .expect("should parse");
895 assert_eq!(ddl.return_columns[0].data_type, DataType::Boolean);
896 assert_eq!(ddl.return_columns[1].data_type, DataType::Int8);
897 assert_eq!(ddl.return_columns[2].data_type, DataType::Int16);
898 assert_eq!(ddl.return_columns[3].data_type, DataType::Int32);
899 assert_eq!(ddl.return_columns[4].data_type, DataType::Int64);
900 assert_eq!(ddl.return_columns[5].data_type, DataType::Float32);
901 assert_eq!(ddl.return_columns[6].data_type, DataType::Float64);
902 assert_eq!(ddl.return_columns[7].data_type, DataType::Utf8);
903 assert_eq!(ddl.return_columns[8].data_type, DataType::Binary);
904 assert_eq!(ddl.return_columns[9].data_type, DataType::Date32);
905 assert_eq!(
906 ddl.return_columns[10].data_type,
907 DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None)
908 );
909 }
910}