Skip to main content

krishiv_sql/
create_function_ddl.rs

1//! Pre-processor for `CREATE FUNCTION … RETURNS TABLE` DDL.
2//!
3//! DataFusion does not natively understand the Krishiv-extended
4//! `CREATE FUNCTION … RETURNS TABLE (col TYPE, …) LANGUAGE … AS '…'` syntax.
5//! This module intercepts such statements before they reach DataFusion and
6//! registers a [`TableUdf`][krishiv_plan::udf::TableUdf] backed by either:
7//!
8//! * A SQL body (`LANGUAGE sql AS '…'`) — executed via the session context.
9//! * A runtime-provided Rust closure — registered via `SqlEngine::register_table_udf_fn`.
10//!
11//! Unsupported DDL languages are rejected before any registry mutation.
12
13use std::panic::{AssertUnwindSafe, catch_unwind};
14use std::sync::Arc;
15
16use arrow::datatypes::{DataType, Schema};
17use arrow::record_batch::RecordBatch;
18use regex::Regex;
19use std::sync::LazyLock;
20
21static CREATE_FUNCTION_RE: LazyLock<Option<Regex>> = LazyLock::new(|| {
22    Regex::new(
23        r"(?is)^\s*CREATE\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+(\w+)\s*\(([^)]*)\)\s*RETURNS\s+TABLE\s*\(([^)]*)\)(?:\s+LANGUAGE\s+(\w+))?(?:\s+AS\s+'((?:[^']|'')*)')?\s*;?\s*$",
24    )
25    .ok()
26});
27
28use krishiv_plan::udf::{ScalarValue, TableUdf, UdfError};
29
30// ────────────────────────────────────────────────────────────────────────────
31// Parsed CREATE FUNCTION descriptor
32// ────────────────────────────────────────────────────────────────────────────
33
34/// A column definition extracted from the `RETURNS TABLE (…)` clause.
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct ColumnDef {
37    pub name: String,
38    pub data_type: DataType,
39}
40
41/// A typed function argument extracted from the function signature.
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct FunctionArgDef {
44    pub name: String,
45    pub data_type: DataType,
46}
47
48/// Parsed descriptor produced by [`parse_create_function`].
49#[derive(Debug, Clone)]
50pub struct CreateFunctionDdl {
51    /// Function name as written in the SQL statement.
52    pub function_name: String,
53    /// Typed arguments declared in the function signature.
54    pub arguments: Vec<FunctionArgDef>,
55    /// Output columns declared in `RETURNS TABLE (…)`.
56    pub return_columns: Vec<ColumnDef>,
57    /// Language string (e.g. `RUST`, `PYTHON`), lower-cased.
58    pub language: Option<String>,
59    /// Raw function body from the `AS '…'` clause, if any.
60    pub body: Option<String>,
61}
62
63// ────────────────────────────────────────────────────────────────────────────
64// Parsing
65// ────────────────────────────────────────────────────────────────────────────
66
67/// Return `true` if `sql` looks like a `CREATE FUNCTION … RETURNS TABLE …`
68/// statement (case-insensitive, leading/trailing whitespace allowed).
69///
70/// Handles both `CREATE FUNCTION` and `CREATE OR REPLACE FUNCTION`.
71pub fn is_create_function_returns_table(sql: &str) -> bool {
72    let upper = sql.trim().to_ascii_uppercase();
73    (upper.starts_with("CREATE FUNCTION") || upper.starts_with("CREATE OR REPLACE FUNCTION"))
74        && upper.contains("RETURNS TABLE")
75}
76
77/// Parse a `CREATE FUNCTION … RETURNS TABLE (…)` statement and return a
78/// [`CreateFunctionDdl`] descriptor.
79///
80/// Returns an error string if the statement cannot be recognised.
81pub fn parse_create_function(sql: &str) -> Result<CreateFunctionDdl, String> {
82    // Regex: capture function name, RETURNS TABLE column list, optional
83    // LANGUAGE clause, and optional AS body.
84    //
85    // Pattern (case-insensitive):
86    //   CREATE [OR REPLACE] FUNCTION  <name> ( <args> )
87    //   RETURNS TABLE ( <col_defs> )
88    //   [LANGUAGE <lang>]
89    //   [AS '<body>']
90    let caps = CREATE_FUNCTION_RE
91        .as_ref()
92        .ok_or_else(|| "CREATE FUNCTION regex failed to compile".to_string())?
93        .captures(sql)
94        .ok_or_else(|| "SQL does not match CREATE FUNCTION … RETURNS TABLE pattern".to_string())?;
95
96    let function_name = caps
97        .get(1)
98        .map(|m| m.as_str().to_string())
99        .ok_or("could not extract function name")?;
100
101    let arg_list = caps.get(2).map(|m| m.as_str()).unwrap_or("");
102    let arguments = parse_argument_list(arg_list)?;
103
104    let col_list = caps.get(3).map(|m| m.as_str()).unwrap_or("");
105    let return_columns = parse_column_list(col_list)?;
106
107    let language = caps.get(4).map(|m| m.as_str().to_ascii_lowercase());
108    let body = caps.get(5).map(|m| m.as_str().replace("''", "'"));
109
110    Ok(CreateFunctionDdl {
111        function_name,
112        arguments,
113        return_columns,
114        language,
115        body,
116    })
117}
118
119fn parse_argument_list(list: &str) -> Result<Vec<FunctionArgDef>, String> {
120    parse_named_type_list(list, "argument")?
121        .into_iter()
122        .map(|(name, data_type)| Ok(FunctionArgDef { name, data_type }))
123        .collect()
124}
125
126/// Parse a comma-separated `name TYPE` column list as it appears inside
127/// `RETURNS TABLE (…)`.
128fn parse_column_list(list: &str) -> Result<Vec<ColumnDef>, String> {
129    parse_named_type_list(list, "column")?
130        .into_iter()
131        .map(|(name, data_type)| Ok(ColumnDef { name, data_type }))
132        .collect()
133}
134
135fn parse_named_type_list(list: &str, item_kind: &str) -> Result<Vec<(String, DataType)>, String> {
136    let list = list.trim();
137    if list.is_empty() {
138        return Ok(Vec::new());
139    }
140    let mut parsed = Vec::new();
141    let mut names = std::collections::HashSet::new();
142    for item in list.split(',') {
143        let parts: Vec<&str> = item.split_whitespace().collect();
144        if parts.len() < 2 {
145            return Err(format!("invalid {item_kind} definition: '{item}'"));
146        }
147        let name = parts.first().copied().unwrap_or("").to_string();
148        if !names.insert(name.to_ascii_lowercase()) {
149            return Err(format!("duplicate {item_kind} name '{name}'"));
150        }
151        let type_str = parts.get(1..).unwrap_or(&[]).join(" ");
152        let data_type = sql_type_to_arrow(&type_str)?;
153        parsed.push((name, data_type));
154    }
155    Ok(parsed)
156}
157
158/// Map a SQL type keyword (as used in DDL) to an Arrow [`DataType`].
159///
160/// Only the types commonly seen in `RETURNS TABLE` declarations are mapped.
161/// Unknown types fall back to `DataType::Utf8`.
162fn sql_type_to_arrow(type_str: &str) -> Result<DataType, String> {
163    match type_str.trim().to_ascii_uppercase().as_str() {
164        "BOOLEAN" | "BOOL" => Ok(DataType::Boolean),
165        "TINYINT" | "INT8" => Ok(DataType::Int8),
166        "SMALLINT" | "INT16" => Ok(DataType::Int16),
167        "INT" | "INTEGER" | "INT32" => Ok(DataType::Int32),
168        "BIGINT" | "INT64" | "LONG" => Ok(DataType::Int64),
169        "FLOAT" | "FLOAT32" | "REAL" => Ok(DataType::Float32),
170        "DOUBLE" | "FLOAT64" | "DOUBLE PRECISION" => Ok(DataType::Float64),
171        "TEXT" | "VARCHAR" | "STRING" | "CHARACTER VARYING" => Ok(DataType::Utf8),
172        "BYTEA" | "BYTES" | "BINARY" | "BLOB" => Ok(DataType::Binary),
173        "DATE" => Ok(DataType::Date32),
174        "TIMESTAMP" | "DATETIME" => Ok(DataType::Timestamp(
175            arrow::datatypes::TimeUnit::Microsecond,
176            None,
177        )),
178        _ => Err(format!(
179            "unsupported SQL type '{type_str}' in CREATE FUNCTION DDL"
180        )),
181    }
182}
183
184// ────────────────────────────────────────────────────────────────────────────
185// UDTF implementations
186// ────────────────────────────────────────────────────────────────────────────
187
188/// Body-function type alias for runtime-registered UDTFs.
189pub type UdtfBodyFn = Arc<dyn Fn(&[ScalarValue]) -> Result<RecordBatch, UdfError> + Send + Sync>;
190
191/// A [`TableUdf`] backed by a runtime-provided Rust closure.
192#[derive(Clone)]
193pub struct ClosureTableUdf {
194    pub(crate) name: String,
195    pub(crate) schema: Schema,
196    body_fn: UdtfBodyFn,
197}
198
199impl std::fmt::Debug for ClosureTableUdf {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        f.debug_struct("ClosureTableUdf")
202            .field("name", &self.name)
203            .field("schema", &self.schema)
204            .finish()
205    }
206}
207
208impl ClosureTableUdf {
209    /// Create a closure-backed UDTF with an explicit output schema.
210    pub fn try_new(
211        name: impl Into<String>,
212        schema: Schema,
213        body_fn: UdtfBodyFn,
214    ) -> Result<Self, UdfError> {
215        let name = name.into();
216        validate_udtf_definition(&name, &schema)?;
217        Ok(Self {
218            name,
219            schema,
220            body_fn,
221        })
222    }
223}
224
225impl TableUdf for ClosureTableUdf {
226    fn name(&self) -> &str {
227        &self.name
228    }
229
230    fn output_schema(&self) -> &Schema {
231        &self.schema
232    }
233
234    fn call(&self, args: &[ScalarValue]) -> Result<RecordBatch, UdfError> {
235        let batch =
236            catch_unwind(AssertUnwindSafe(|| (self.body_fn)(args))).map_err(|payload| {
237                let message = payload
238                    .downcast_ref::<&str>()
239                    .copied()
240                    .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
241                    .unwrap_or("unknown panic");
242                UdfError::Panic(format!("UDTF '{}': {message}", self.name))
243            })??;
244        if !schema_contract_matches(batch.schema().as_ref(), &self.schema) {
245            return Err(UdfError::Execution {
246                message: format!(
247                    "UDTF '{}' returned schema {:?}, expected {:?}",
248                    self.name,
249                    batch.schema(),
250                    self.schema
251                ),
252            });
253        }
254        Ok(batch)
255    }
256}
257
258fn validate_udtf_definition(name: &str, schema: &Schema) -> Result<(), UdfError> {
259    if name.trim().is_empty() {
260        return Err(UdfError::InvalidArgument {
261            message: String::from("UDTF name must not be empty"),
262        });
263    }
264    if schema.fields().is_empty() {
265        return Err(UdfError::InvalidArgument {
266            message: format!("UDTF '{name}' must declare at least one output column"),
267        });
268    }
269    let mut names = std::collections::HashSet::with_capacity(schema.fields().len());
270    for field in schema.fields() {
271        if field.name().trim().is_empty() {
272            return Err(UdfError::InvalidArgument {
273                message: format!("UDTF '{name}' contains an empty output column name"),
274            });
275        }
276        if !names.insert(field.name()) {
277            return Err(UdfError::InvalidArgument {
278                message: format!(
279                    "UDTF '{name}' contains duplicate output column '{}'",
280                    field.name()
281                ),
282            });
283        }
284    }
285    Ok(())
286}
287
288fn schema_contract_matches(actual: &Schema, expected: &Schema) -> bool {
289    actual.fields().len() == expected.fields().len()
290        && actual
291            .fields()
292            .iter()
293            .zip(expected.fields())
294            .all(|(actual, expected)| {
295                actual.name() == expected.name() && actual.data_type() == expected.data_type()
296            })
297}
298
299/// A [`TableUdf`] whose body is a SQL query executed via a DataFusion session.
300///
301/// Created by `SqlEngine` when `CREATE FUNCTION … LANGUAGE sql AS '…'` is
302/// processed.  Uses `block_in_place` so the sync `TableFunctionImpl::call()`
303/// can safely block on async SQL execution without deadlocking the runtime.
304#[derive(Clone)]
305pub struct SqlBodyTableUdf {
306    pub(crate) name: String,
307    pub(crate) schema: Schema,
308    body_sql: String,
309    argument_count: usize,
310    ctx: Arc<datafusion::prelude::SessionContext>,
311}
312
313impl std::fmt::Debug for SqlBodyTableUdf {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("SqlBodyTableUdf")
316            .field("name", &self.name)
317            .field("body_sql", &self.body_sql)
318            .finish()
319    }
320}
321
322impl SqlBodyTableUdf {
323    pub fn try_new(
324        name: impl Into<String>,
325        schema: Schema,
326        body_sql: impl Into<String>,
327        argument_count: usize,
328        ctx: Arc<datafusion::prelude::SessionContext>,
329    ) -> Result<Self, UdfError> {
330        let name = name.into();
331        validate_udtf_definition(&name, &schema)?;
332        let body_sql = body_sql.into();
333        if body_sql.trim().is_empty() {
334            return Err(UdfError::InvalidArgument {
335                message: format!("SQL UDTF '{name}' body must not be empty"),
336            });
337        }
338        let placeholder_args = vec![ScalarValue::Null; argument_count];
339        bind_sql_body_args(&body_sql, &placeholder_args)?;
340        Ok(Self {
341            name,
342            schema,
343            body_sql,
344            argument_count,
345            ctx,
346        })
347    }
348}
349
350impl TableUdf for SqlBodyTableUdf {
351    fn name(&self) -> &str {
352        &self.name
353    }
354
355    fn output_schema(&self) -> &Schema {
356        &self.schema
357    }
358
359    fn call(&self, args: &[ScalarValue]) -> Result<RecordBatch, UdfError> {
360        if args.len() != self.argument_count {
361            return Err(UdfError::InvalidArgument {
362                message: format!(
363                    "UDTF '{}' expects {} arguments, got {}",
364                    self.name,
365                    self.argument_count,
366                    args.len()
367                ),
368            });
369        }
370
371        // Execute the SQL body synchronously using block_in_place so this
372        // sync call-site can safely await without deadlocking the executor.
373        let ctx = Arc::clone(&self.ctx);
374        let sql = bind_sql_body_args(&self.body_sql, args)?;
375        let schema = Arc::new(self.schema.clone());
376        let handle =
377            tokio::runtime::Handle::try_current().map_err(|error| UdfError::Execution {
378                message: format!(
379                    "SQL UDTF '{}' requires an active Tokio runtime: {error}",
380                    self.name
381                ),
382            })?;
383        if !matches!(
384            handle.runtime_flavor(),
385            tokio::runtime::RuntimeFlavor::MultiThread
386        ) {
387            return Err(UdfError::Execution {
388                message: format!(
389                    "SQL UDTF '{}' requires a multi-thread Tokio runtime",
390                    self.name
391                ),
392            });
393        }
394        catch_unwind(AssertUnwindSafe(|| {
395            tokio::task::block_in_place(|| {
396                handle.block_on(async {
397                    let df = ctx.sql(&sql).await.map_err(|e| UdfError::Execution {
398                        message: e.to_string(),
399                    })?;
400                    let batches = df.collect().await.map_err(|e| UdfError::Execution {
401                        message: e.to_string(),
402                    })?;
403                    if batches.is_empty() {
404                        return Ok(RecordBatch::new_empty(schema));
405                    }
406                    let batch = arrow::compute::concat_batches(
407                        &batches
408                            .first()
409                            .ok_or_else(|| UdfError::Execution {
410                                message: "empty batch list".into(),
411                            })?
412                            .schema(),
413                        &batches,
414                    )
415                    .map_err(|e| UdfError::Arrow(e.to_string()))?;
416                    if !schema_contract_matches(batch.schema().as_ref(), schema.as_ref()) {
417                        return Err(UdfError::Execution {
418                            message: format!(
419                                "SQL UDTF '{}' returned schema {:?}, expected {:?}",
420                                self.name,
421                                batch.schema(),
422                                schema
423                            ),
424                        });
425                    }
426                    Ok(batch)
427                })
428            })
429        }))
430        .map_err(|payload| {
431            let message = payload
432                .downcast_ref::<&str>()
433                .copied()
434                .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
435                .unwrap_or("unknown panic");
436            UdfError::Panic(format!("SQL UDTF '{}': {message}", self.name))
437        })?
438    }
439}
440
441fn bind_sql_body_args(sql: &str, args: &[ScalarValue]) -> Result<String, UdfError> {
442    let bytes = sql.as_bytes();
443    let mut output = String::with_capacity(sql.len());
444    let mut index = 0;
445
446    while index < bytes.len() {
447        let Some(&byte) = bytes.get(index) else {
448            break;
449        };
450        match byte {
451            b'\'' | b'"' | b'`' => {
452                index = copy_quoted_segment(sql, index, byte, &mut output)?;
453            }
454            b'-' if bytes.get(index + 1) == Some(&b'-') => {
455                let end = sql[index..]
456                    .find('\n')
457                    .map_or(bytes.len(), |offset| index + offset + 1);
458                output.push_str(&sql[index..end]);
459                index = end;
460            }
461            b'/' if bytes.get(index + 1) == Some(&b'*') => {
462                index = copy_block_comment(sql, index, &mut output)?;
463            }
464            b'$' => {
465                if let Some((delimiter, end)) = dollar_quote_delimiter(sql, index) {
466                    let body_start = end;
467                    let close_offset = sql[body_start..].find(delimiter).ok_or_else(|| {
468                        UdfError::InvalidArgument {
469                            message: "unterminated dollar-quoted SQL body".to_owned(),
470                        }
471                    })?;
472                    let segment_end = body_start + close_offset + delimiter.len();
473                    output.push_str(&sql[index..segment_end]);
474                    index = segment_end;
475                    continue;
476                }
477
478                let digit_start = index + 1;
479                let mut end = digit_start;
480                while bytes.get(end).is_some_and(u8::is_ascii_digit) {
481                    end += 1;
482                }
483                if end == digit_start {
484                    output.push('$');
485                    index += 1;
486                    continue;
487                }
488
489                let placeholder = sql[digit_start..end].parse::<usize>().map_err(|error| {
490                    UdfError::InvalidArgument {
491                        message: format!(
492                            "invalid SQL UDTF placeholder '{}': {error}",
493                            &sql[index..end]
494                        ),
495                    }
496                })?;
497                if placeholder == 0 {
498                    return Err(UdfError::InvalidArgument {
499                        message: "SQL UDTF placeholders are 1-based; $0 is invalid".to_owned(),
500                    });
501                }
502                let value = args.get(placeholder - 1).ok_or_else(|| UdfError::InvalidArgument {
503                    message: format!(
504                        "SQL UDTF placeholder ${placeholder} has no matching argument; got {} arguments",
505                        args.len()
506                    ),
507                })?;
508                output.push_str(&scalar_to_sql_literal(value)?);
509                index = end;
510            }
511            _ => {
512                let ch = sql[index..]
513                    .chars()
514                    .next()
515                    .ok_or_else(|| UdfError::InvalidArgument {
516                        message: "unexpected end of SQL string".to_owned(),
517                    })?;
518                output.push(ch);
519                index += ch.len_utf8();
520            }
521        }
522    }
523
524    Ok(output)
525}
526
527fn copy_quoted_segment(
528    sql: &str,
529    start: usize,
530    quote: u8,
531    output: &mut String,
532) -> Result<usize, UdfError> {
533    let bytes = sql.as_bytes();
534    let mut index = start + 1;
535    while index < bytes.len() {
536        let Some(&b) = bytes.get(index) else {
537            break;
538        };
539        if b == quote {
540            index += 1;
541            if bytes.get(index) == Some(&quote) {
542                index += 1;
543                continue;
544            }
545            output.push_str(&sql[start..index]);
546            return Ok(index);
547        }
548        let ch = sql[index..]
549            .chars()
550            .next()
551            .ok_or_else(|| UdfError::InvalidArgument {
552                message: "unexpected end of SQL string".to_owned(),
553            })?;
554        index += ch.len_utf8();
555    }
556    Err(UdfError::InvalidArgument {
557        message: "unterminated quoted SQL segment".to_owned(),
558    })
559}
560
561fn copy_block_comment(sql: &str, start: usize, output: &mut String) -> Result<usize, UdfError> {
562    let bytes = sql.as_bytes();
563    let mut index = start + 2;
564    let mut depth = 1usize;
565    while index < bytes.len() {
566        if bytes.get(index) == Some(&b'/') && bytes.get(index + 1) == Some(&b'*') {
567            depth += 1;
568            index += 2;
569        } else if bytes.get(index) == Some(&b'*') && bytes.get(index + 1) == Some(&b'/') {
570            depth -= 1;
571            index += 2;
572            if depth == 0 {
573                output.push_str(&sql[start..index]);
574                return Ok(index);
575            }
576        } else {
577            let ch = sql[index..]
578                .chars()
579                .next()
580                .ok_or_else(|| UdfError::InvalidArgument {
581                    message: "unexpected end of SQL string".to_owned(),
582                })?;
583            index += ch.len_utf8();
584        }
585    }
586    Err(UdfError::InvalidArgument {
587        message: "unterminated SQL block comment".to_owned(),
588    })
589}
590
591fn dollar_quote_delimiter(sql: &str, start: usize) -> Option<(&str, usize)> {
592    let bytes = sql.as_bytes();
593    if bytes.get(start) != Some(&b'$') {
594        return None;
595    }
596    let mut index = start + 1;
597    if bytes.get(index) == Some(&b'$') {
598        return Some((&sql[start..=index], index + 1));
599    }
600    let first = *bytes.get(index)?;
601    if !first.is_ascii_alphabetic() && first != b'_' {
602        return None;
603    }
604    index += 1;
605    while bytes
606        .get(index)
607        .is_some_and(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
608    {
609        index += 1;
610    }
611    if bytes.get(index) == Some(&b'$') {
612        Some((&sql[start..=index], index + 1))
613    } else {
614        None
615    }
616}
617
618fn scalar_to_sql_literal(value: &ScalarValue) -> Result<String, UdfError> {
619    match value {
620        ScalarValue::Null => Ok("NULL".to_owned()),
621        ScalarValue::Int64(value) => Ok(value.to_string()),
622        ScalarValue::Float64(value) if value.is_finite() => Ok(value.to_string()),
623        ScalarValue::Float64(value) => Err(UdfError::InvalidArgument {
624            message: format!("non-finite floating-point UDTF argument {value} is not supported"),
625        }),
626        ScalarValue::Utf8(value) => Ok(format!("'{}'", value.replace('\'', "''"))),
627        ScalarValue::Boolean(value) => Ok(if *value { "TRUE" } else { "FALSE" }.to_owned()),
628        ScalarValue::Bytes(_) => Err(UdfError::InvalidArgument {
629            message: "binary UDTF arguments are not supported in SQL bodies".to_owned(),
630        }),
631    }
632}
633
634// ────────────────────────────────────────────────────────────────────────────
635// Tests
636// ────────────────────────────────────────────────────────────────────────────
637
638#[cfg(test)]
639#[allow(clippy::unwrap_used, clippy::expect_used)]
640mod tests {
641    use super::*;
642    use arrow::array::{ArrayRef, Int64Array};
643    use arrow::datatypes::{DataType, Field};
644
645    const BASIC_DDL: &str = "
646        CREATE FUNCTION my_udtf(arg1 INT)
647        RETURNS TABLE (col1 TEXT, col2 BIGINT)
648        LANGUAGE RUST
649        AS 'fn my_udtf(arg1: i64) -> Vec<Row> { vec![] }'
650    ";
651
652    #[test]
653    fn detects_create_function_returns_table() {
654        assert!(is_create_function_returns_table(BASIC_DDL));
655        // CREATE OR REPLACE variant
656        assert!(is_create_function_returns_table(
657            "CREATE OR REPLACE FUNCTION g(x INT) RETURNS TABLE (v TEXT)"
658        ));
659        // Non-matching: plain SELECT
660        assert!(!is_create_function_returns_table("SELECT 1"));
661        // Non-matching: RETURNS scalar, not TABLE
662        assert!(!is_create_function_returns_table(
663            "CREATE FUNCTION f(x INT) RETURNS INT LANGUAGE SQL AS 'SELECT x'"
664        ));
665    }
666
667    #[test]
668    fn parses_function_name() {
669        let ddl = parse_create_function(BASIC_DDL).expect("should parse");
670        assert_eq!(ddl.function_name, "my_udtf");
671    }
672
673    #[test]
674    fn parses_typed_arguments() {
675        let ddl = parse_create_function(
676            "CREATE FUNCTION typed_args(count BIGINT, label TEXT, enabled BOOLEAN) \
677             RETURNS TABLE (value TEXT) LANGUAGE SQL AS 'SELECT $2 AS value'",
678        )
679        .expect("should parse");
680        assert_eq!(
681            ddl.arguments,
682            vec![
683                FunctionArgDef {
684                    name: "count".to_owned(),
685                    data_type: DataType::Int64,
686                },
687                FunctionArgDef {
688                    name: "label".to_owned(),
689                    data_type: DataType::Utf8,
690                },
691                FunctionArgDef {
692                    name: "enabled".to_owned(),
693                    data_type: DataType::Boolean,
694                },
695            ]
696        );
697    }
698
699    #[test]
700    fn parses_return_columns() {
701        let ddl = parse_create_function(BASIC_DDL).expect("should parse");
702        assert_eq!(ddl.return_columns.len(), 2);
703        assert_eq!(ddl.return_columns[0].name, "col1");
704        assert_eq!(ddl.return_columns[0].data_type, DataType::Utf8);
705        assert_eq!(ddl.return_columns[1].name, "col2");
706        assert_eq!(ddl.return_columns[1].data_type, DataType::Int64);
707    }
708
709    #[test]
710    fn parses_language_and_body() {
711        let ddl = parse_create_function(BASIC_DDL).expect("should parse");
712        assert_eq!(ddl.language.as_deref(), Some("rust"));
713        assert!(ddl.body.is_some());
714    }
715
716    #[test]
717    fn parses_without_language_and_body() {
718        let sql = "CREATE FUNCTION simple(x INT) RETURNS TABLE (val BIGINT)";
719        let ddl = parse_create_function(sql).expect("should parse");
720        assert_eq!(ddl.function_name, "simple");
721        assert_eq!(ddl.return_columns.len(), 1);
722        assert_eq!(ddl.language, None);
723        assert_eq!(ddl.body, None);
724    }
725
726    #[test]
727    fn parses_or_replace_variant() {
728        let sql = "CREATE OR REPLACE FUNCTION f(x INT) RETURNS TABLE (a TEXT, b INT)";
729        let ddl = parse_create_function(sql).expect("should parse");
730        assert_eq!(ddl.function_name, "f");
731        assert_eq!(ddl.return_columns.len(), 2);
732    }
733
734    #[test]
735    fn parser_rejects_trailing_unparsed_sql() {
736        let error = parse_create_function(&format!("{BASIC_DDL} SELECT 1"))
737            .expect_err("trailing SQL must not be ignored");
738        assert!(error.contains("does not match"));
739    }
740
741    #[test]
742    fn parser_rejects_duplicate_argument_and_output_names() {
743        let duplicate_arg = parse_create_function(
744            "CREATE FUNCTION f(value INT, VALUE BIGINT) \
745             RETURNS TABLE (result BIGINT) LANGUAGE SQL AS 'SELECT 1 AS result'",
746        )
747        .expect_err("argument names are case-insensitively unique");
748        assert!(duplicate_arg.contains("duplicate argument"));
749
750        let duplicate_output = parse_create_function(
751            "CREATE FUNCTION f() RETURNS TABLE (value INT, VALUE BIGINT) \
752             LANGUAGE SQL AS 'SELECT 1 AS value, 2 AS VALUE'",
753        )
754        .expect_err("output names are case-insensitively unique");
755        assert!(duplicate_output.contains("duplicate column"));
756    }
757
758    #[test]
759    fn closure_table_udf_executes_and_validates_output_schema() {
760        let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]);
761        let udf = ClosureTableUdf::try_new(
762            "values",
763            schema.clone(),
764            Arc::new({
765                let schema = Arc::new(schema);
766                move |_| {
767                    RecordBatch::try_new(
768                        Arc::clone(&schema),
769                        vec![Arc::new(Int64Array::from(vec![1_i64, 2])) as ArrayRef],
770                    )
771                    .map_err(UdfError::from)
772                }
773            }),
774        )
775        .unwrap();
776
777        let batch = udf.call(&[]).unwrap();
778        assert_eq!(batch.num_rows(), 2);
779
780        let wrong_schema = ClosureTableUdf::try_new(
781            "wrong",
782            Schema::new(vec![Field::new("expected", DataType::Int64, false)]),
783            Arc::new(|_| {
784                RecordBatch::try_new(
785                    Arc::new(Schema::new(vec![Field::new(
786                        "actual",
787                        DataType::Int64,
788                        false,
789                    )])),
790                    vec![Arc::new(Int64Array::from(vec![1_i64])) as ArrayRef],
791                )
792                .map_err(UdfError::from)
793            }),
794        )
795        .unwrap();
796        assert!(matches!(
797            wrong_schema.call(&[]),
798            Err(UdfError::Execution { .. })
799        ));
800    }
801
802    #[test]
803    fn closure_table_udf_contains_panics() {
804        let udf = ClosureTableUdf::try_new(
805            "panic_udtf",
806            Schema::new(vec![Field::new("value", DataType::Int64, false)]),
807            Arc::new(|_| -> Result<RecordBatch, UdfError> { panic!("boom") }),
808        )
809        .unwrap();
810
811        assert!(matches!(udf.call(&[]), Err(UdfError::Panic(_))));
812    }
813
814    #[test]
815    fn sql_body_udtf_without_runtime_returns_typed_error() {
816        let udf = SqlBodyTableUdf::try_new(
817            "runtime_required",
818            Schema::new(vec![Field::new("value", DataType::Int64, false)]),
819            "SELECT 1 AS value",
820            0,
821            Arc::new(datafusion::prelude::SessionContext::new()),
822        )
823        .unwrap();
824
825        let error = udf
826            .call(&[])
827            .expect_err("missing Tokio runtime must not panic");
828        assert!(matches!(error, UdfError::Execution { .. }));
829    }
830
831    #[test]
832    fn sql_body_binding_replaces_only_unquoted_placeholders() {
833        let sql = "SELECT $1 AS n, '$1' AS literal, \"$2\" AS quoted, /* $2 */ $2 AS text";
834        let bound = bind_sql_body_args(
835            sql,
836            &[
837                ScalarValue::Int64(42),
838                ScalarValue::Utf8("O'Reilly".to_owned()),
839            ],
840        )
841        .expect("binding should succeed");
842        assert_eq!(
843            bound,
844            "SELECT 42 AS n, '$1' AS literal, \"$2\" AS quoted, /* $2 */ 'O''Reilly' AS text"
845        );
846    }
847
848    #[test]
849    fn sql_body_binding_preserves_comments_and_dollar_quoted_segments() {
850        let sql = "SELECT $$body $1$$ AS body, -- $1\n$1 AS value";
851        let bound =
852            bind_sql_body_args(sql, &[ScalarValue::Boolean(true)]).expect("binding should succeed");
853        assert_eq!(bound, "SELECT $$body $1$$ AS body, -- $1\nTRUE AS value");
854    }
855
856    #[test]
857    fn sql_body_binding_rejects_invalid_placeholders_and_values() {
858        let zero = bind_sql_body_args("SELECT $0", &[ScalarValue::Int64(1)])
859            .expect_err("$0 must be rejected");
860        assert!(zero.to_string().contains("1-based"));
861
862        let missing = bind_sql_body_args("SELECT $2", &[ScalarValue::Int64(1)])
863            .expect_err("missing arguments must be rejected");
864        assert!(missing.to_string().contains("no matching argument"));
865
866        let binary = bind_sql_body_args("SELECT $1", &[ScalarValue::Bytes(vec![1, 2])])
867            .expect_err("binary SQL literals must be rejected");
868        assert!(binary.to_string().contains("binary"));
869    }
870
871    #[test]
872    fn rejects_non_matching_sql() {
873        let result = parse_create_function("SELECT 1");
874        assert!(result.is_err());
875    }
876
877    #[test]
878    fn all_supported_types_map() {
879        let ddl = parse_create_function(
880            "CREATE FUNCTION typed(x INT) RETURNS TABLE (
881                a BOOLEAN,
882                b TINYINT,
883                c SMALLINT,
884                d INT,
885                e BIGINT,
886                f FLOAT,
887                g DOUBLE,
888                h TEXT,
889                i BYTEA,
890                j DATE,
891                k TIMESTAMP
892            )",
893        )
894        .expect("should parse");
895        assert_eq!(ddl.return_columns[0].data_type, DataType::Boolean);
896        assert_eq!(ddl.return_columns[1].data_type, DataType::Int8);
897        assert_eq!(ddl.return_columns[2].data_type, DataType::Int16);
898        assert_eq!(ddl.return_columns[3].data_type, DataType::Int32);
899        assert_eq!(ddl.return_columns[4].data_type, DataType::Int64);
900        assert_eq!(ddl.return_columns[5].data_type, DataType::Float32);
901        assert_eq!(ddl.return_columns[6].data_type, DataType::Float64);
902        assert_eq!(ddl.return_columns[7].data_type, DataType::Utf8);
903        assert_eq!(ddl.return_columns[8].data_type, DataType::Binary);
904        assert_eq!(ddl.return_columns[9].data_type, DataType::Date32);
905        assert_eq!(
906            ddl.return_columns[10].data_type,
907            DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None)
908        );
909    }
910}