#![allow(dead_code)]
use crate::expressions::{Expression, Scalar};
use crate::schema::{DataType, PrimitiveType};
use crate::{DeltaResult, Error};
pub(crate) fn parse_sql(sql: &str, data_type: &DataType) -> DeltaResult<Expression> {
let trimmed = sql.trim();
if trimmed.is_empty() {
return Err(Error::generic("empty SQL literal"));
}
if trimmed.eq_ignore_ascii_case("null") {
return Ok(Expression::literal(Scalar::Null(data_type.clone())));
}
parse_literal(trimmed, data_type, sql)
}
fn parse_literal(trimmed: &str, data_type: &DataType, sql: &str) -> DeltaResult<Expression> {
let DataType::Primitive(primitive) = data_type else {
return Err(Error::generic(format!(
"SQL literal parsing only supports primitive types, got {data_type:?}"
)));
};
let scalar = match primitive {
PrimitiveType::Binary => parse_binary_literal(trimmed)?,
PrimitiveType::String => parse_string_literal(trimmed)?,
PrimitiveType::Date => parse_date_literal(trimmed, sql)?,
PrimitiveType::Timestamp => parse_timestamp_ltz_literal(trimmed, sql)?,
PrimitiveType::TimestampNtz => parse_timestamp_ntz_literal(trimmed, sql)?,
PrimitiveType::Float | PrimitiveType::Double => {
parse_double_or_float(primitive, trimmed, sql)?
}
_ => primitive.parse_scalar(trimmed)?,
};
Ok(Expression::literal(scalar))
}
fn parse_string_literal(trimmed: &str) -> DeltaResult<Scalar> {
Ok(Scalar::String(unquote_string(trimmed)?))
}
fn parse_binary_literal(trimmed: &str) -> DeltaResult<Scalar> {
Ok(Scalar::Binary(decode_binary_literal(trimmed)?))
}
fn parse_date_literal(trimmed: &str, sql: &str) -> DeltaResult<Scalar> {
let raw = unwrap_quoted_body(trimmed, &["DATE"], &PrimitiveType::Date, sql)?;
PrimitiveType::Date.parse_scalar(&raw)
}
fn parse_timestamp_ntz_literal(trimmed: &str, sql: &str) -> DeltaResult<Scalar> {
let raw = unwrap_quoted_body(
trimmed,
&["TIMESTAMP_NTZ"],
&PrimitiveType::TimestampNtz,
sql,
)?;
PrimitiveType::TimestampNtz.parse_scalar(&raw)
}
fn parse_timestamp_ltz_literal(trimmed: &str, sql: &str) -> DeltaResult<Scalar> {
let raw = unwrap_quoted_body(
trimmed,
&["TIMESTAMP", "TIMESTAMP_LTZ"],
&PrimitiveType::Timestamp,
sql,
)?;
require_utc_z_suffix(&raw, sql)?;
PrimitiveType::Timestamp.parse_scalar(&raw)
}
fn unwrap_quoted_body(
trimmed: &str,
keywords: &[&str],
primitive: &PrimitiveType,
sql: &str,
) -> DeltaResult<String> {
let body = strip_typed_prefix_and_unquote(trimmed, keywords)?;
let body = body.trim();
if body.is_empty() {
return Err(Error::generic(format!(
"empty {primitive:?} literal: {sql}"
)));
}
Ok(body.to_string())
}
fn require_utc_z_suffix(raw: &str, sql: &str) -> DeltaResult<()> {
if raw.contains(['t', 'z']) {
return Err(Error::generic(
"TIMESTAMP literal must use uppercase 'T' and or 'Z'",
));
}
if raw.ends_with('Z') {
return Ok(());
}
let has_offset = raw
.split_once(['T', ' '])
.is_some_and(|(_, time)| time.contains(['+', '-']));
Err(if has_offset {
Error::generic(format!(
"TIMESTAMP literal with an explicit offset is not yet supported; use 'Z' (UTC): {sql}"
))
} else {
Error::generic(
"zoneless TIMESTAMP literal is not yet supported; use an explicit 'Z' (UTC) suffix",
)
})
}
fn parse_double_or_float(primitive: &PrimitiveType, raw: &str, sql: &str) -> DeltaResult<Scalar> {
let has_exponent = raw.contains(['e', 'E']);
if !has_exponent && exceeds_decimal_precision(raw) {
return Err(Error::generic(format!(
"numeric literal exceeds maximum DECIMAL precision 38: {sql}"
)));
}
let scalar = if *primitive == PrimitiveType::Float && has_exponent {
let value: f64 = raw
.parse()
.map_err(|_| Error::generic(format!("invalid FLOAT literal: {sql}")))?;
Scalar::Float(value as f32)
} else {
primitive.parse_scalar(raw)?
};
let normalize_neg_zero = !has_exponent;
let non_finite_error = || Error::generic("non-finite float literals are not supported");
Ok(match scalar {
Scalar::Float(f) if !f.is_finite() => return Err(non_finite_error()),
Scalar::Double(d) if !d.is_finite() => return Err(non_finite_error()),
Scalar::Float(f) if normalize_neg_zero => Scalar::Float(f + 0.0),
Scalar::Double(d) if normalize_neg_zero => Scalar::Double(d + 0.0),
other => other,
})
}
fn exceeds_decimal_precision(raw: &str) -> bool {
let unsigned = raw.strip_prefix(['+', '-']).unwrap_or(raw);
let scale = match unsigned.split_once('.') {
Some((_, frac)) => frac.chars().filter(|c| c.is_ascii_digit()).count(),
None => 0,
};
let significant = unsigned
.chars()
.filter(|c| c.is_ascii_digit())
.skip_while(|&c| c == '0')
.count();
significant.max(scale) > 38
}
fn unquote_string(input: &str) -> DeltaResult<String> {
let body = input.strip_prefix('\'').ok_or_else(|| {
Error::generic(format!("expected a single-quoted SQL string, got: {input}"))
})?;
let mut out = String::with_capacity(body.len());
let mut chars = body.chars();
while let Some(c) = chars.next() {
if c == '\\' {
return Err(Error::generic(format!(
"backslash escapes in SQL string literals are not yet supported: {input}"
)));
}
if c != '\'' {
out.push(c);
continue;
}
match chars.next() {
None => return Ok(out),
Some('\'') => out.push('\''),
Some(_) => {
return Err(Error::generic(format!(
"unexpected characters after closing quote in SQL string literal: {input}"
)))
}
}
}
Err(Error::generic(format!(
"unterminated SQL string literal: {input}"
)))
}
fn strip_typed_prefix_and_unquote(input: &str, keywords: &[&str]) -> DeltaResult<String> {
let body = keywords.iter().find_map(|kw| {
let prefix = input.get(..kw.len())?;
let rest = &input[kw.len()..];
let is_token = rest.starts_with('\'') || rest.starts_with(char::is_whitespace);
(prefix.eq_ignore_ascii_case(kw) && is_token).then(|| rest.trim_start())
});
unquote_string(body.unwrap_or(input))
}
fn decode_binary_literal(input: &str) -> DeltaResult<Vec<u8>> {
let err = || {
Error::generic(format!(
"expected a SQL binary literal like X'..', got: {input}"
))
};
let hex = input
.strip_prefix(['x', 'X'])
.and_then(|rest| rest.strip_prefix('\''))
.and_then(|rest| rest.strip_suffix('\''))
.ok_or_else(err)?;
if !hex.len().is_multiple_of(2) {
return Err(Error::generic(format!(
"binary literal must contain an even number of hex digits: {input}"
)));
}
hex.as_bytes()
.chunks_exact(2)
.map(|pair| {
let hi = (pair[0] as char)
.to_digit(16)
.ok_or_else(|| Error::generic(format!("invalid hex digit in {input}")))?;
let lo = (pair[1] as char)
.to_digit(16)
.ok_or_else(|| Error::generic(format!("invalid hex digit in {input}")))?;
Ok((hi << 4 | lo) as u8)
})
.collect()
}
#[cfg(test)]
mod tests {
use chrono::{DateTime, NaiveDate, NaiveDateTime, TimeZone, Utc};
use rstest::rstest;
use super::*;
use crate::expressions::{DecimalData, Expression};
use crate::schema::{ArrayType, DataType, DecimalType, MapType, StructField};
fn date_days(year: i32, month: u32, day: u32) -> i32 {
let nd = NaiveDate::from_ymd_opt(year, month, day)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap();
Utc.from_utc_datetime(&nd)
.signed_duration_since(DateTime::UNIX_EPOCH)
.num_days() as i32
}
fn ts_micros(s: &str) -> i64 {
let ndt = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").unwrap();
Utc.from_utc_datetime(&ndt)
.signed_duration_since(DateTime::UNIX_EPOCH)
.num_microseconds()
.unwrap()
}
fn decimal_type(precision: u8, scale: u8) -> DataType {
DataType::Primitive(PrimitiveType::Decimal(
DecimalType::try_new(precision, scale).unwrap(),
))
}
#[rstest]
#[case("42", DataType::INTEGER, Scalar::Integer(42))]
#[case(" -7 ", DataType::INTEGER, Scalar::Integer(-7))]
#[case("+5", DataType::INTEGER, Scalar::Integer(5))]
#[case("127", DataType::BYTE, Scalar::Byte(127))]
#[case("-32768", DataType::SHORT, Scalar::Short(i16::MIN))]
#[case("9223372036854775807", DataType::LONG, Scalar::Long(i64::MAX))]
#[case("2.5", DataType::DOUBLE, Scalar::Double(2.5))]
#[case("0.5", DataType::FLOAT, Scalar::Float(0.5))]
#[case("1.5E3", DataType::DOUBLE, Scalar::Double(1500.0))]
#[case("1.5E3", DataType::FLOAT, Scalar::Float(1500.0))]
#[case("TRUE", DataType::BOOLEAN, Scalar::Boolean(true))]
#[case("false", DataType::BOOLEAN, Scalar::Boolean(false))]
#[case("'hello'", DataType::STRING, Scalar::String("hello".into()))]
#[case("' hi '", DataType::STRING, Scalar::String(" hi ".into()))] #[case("''", DataType::STRING, Scalar::String(String::new()))]
#[case("'it''s'", DataType::STRING, Scalar::String("it's".into()))]
#[case("'a''b''c'", DataType::STRING, Scalar::String("a'b'c".into()))]
#[case("'''hello'", DataType::STRING, Scalar::String("'hello".into()))]
#[case("'hello'''", DataType::STRING, Scalar::String("hello'".into()))]
#[case("'''bad'''", DataType::STRING, Scalar::String("'bad'".into()))]
#[case(
"1.23",
decimal_type(5, 2),
Scalar::Decimal(DecimalData::try_new(123, DecimalType::try_new(5, 2).unwrap()).unwrap()),
)]
#[case(
"-12345.67",
decimal_type(10, 2),
Scalar::Decimal(
DecimalData::try_new(-1234567, DecimalType::try_new(10, 2).unwrap()).unwrap()
),
)]
fn parses_basic_literals(#[case] sql: &str, #[case] ty: DataType, #[case] expected: Scalar) {
let got = parse_sql(sql, &ty).unwrap();
assert_eq!(got, Expression::literal(expected));
}
#[rstest]
#[case("'2024-01-01'", date_days(2024, 1, 1))]
#[case("DATE '2024-01-01'", date_days(2024, 1, 1))]
#[case("DATE'2024-01-01'", date_days(2024, 1, 1))] #[case("date '1970-01-02'", date_days(1970, 1, 2))]
#[case("' 2024-01-01 '", date_days(2024, 1, 1))] #[case("DATE ' 2024-01-01 '", date_days(2024, 1, 1))]
fn parses_date_literals(#[case] sql: &str, #[case] expected_days: i32) {
let got = parse_sql(sql, &DataType::DATE).unwrap();
assert_eq!(got, Expression::literal(Scalar::Date(expected_days)));
}
#[rstest]
#[case("'2024-01-01 12:34:56'", "2024-01-01 12:34:56")]
#[case("TIMESTAMP_NTZ '2024-01-01 12:34:56'", "2024-01-01 12:34:56")]
#[case("TIMESTAMP_NTZ'2024-01-01 12:34:56'", "2024-01-01 12:34:56")] #[case("timestamp_ntz '2024-01-01 12:34:56.789'", "2024-01-01 12:34:56.789")]
#[case("' 2024-01-01 12:34:56 '", "2024-01-01 12:34:56")] fn parses_zoneless_timestamp_ntz_literals(#[case] sql: &str, #[case] equivalent: &str) {
let got = parse_sql(sql, &DataType::TIMESTAMP_NTZ).unwrap();
assert_eq!(
got,
Expression::literal(Scalar::TimestampNtz(ts_micros(equivalent)))
);
}
#[rstest]
#[case("'2024-01-01 12:34:56'")] #[case("TIMESTAMP '2024-01-01 12:34:56.789'")] #[case("TIMESTAMP'2024-01-01 12:34:56'")] #[case("TIMESTAMP_LTZ '2024-01-01 12:34:56'")] #[case("' 2024-01-01 12:34:56 '")] #[case("'2024-06-15T14:30:00+05:00'")] #[case("TIMESTAMP '2024-06-15T14:30:00-05:00'")] #[case("'2024-06-15T14:30:00+00:00'")] fn rejects_zoneless_and_offset_timestamp_ltz(#[case] sql: &str) {
let result = parse_sql(sql, &DataType::TIMESTAMP);
assert!(
result.is_err(),
"expected error for zoneless/offset TIMESTAMP {sql:?}, got {result:?}"
);
}
#[rstest]
#[case("'2024-01-01 12:34:56'", "zoneless")]
#[case("TIMESTAMP '2024-01-01 12:34:56'", "zoneless")]
#[case("'2024-06-15T14:30:00+05:00'", "offset")]
#[case("'2024-06-15T14:30:00+00:00'", "offset")] fn timestamp_ltz_rejection_distinguishes_zoneless_from_offset(
#[case] sql: &str,
#[case] needle: &str,
) {
let err = parse_sql(sql, &DataType::TIMESTAMP)
.unwrap_err()
.to_string();
assert!(
err.contains(needle),
"{sql:?} message missing {needle:?}: {err}"
);
}
#[rstest]
#[case("'2024-01-01t12:00:00Z'")] #[case("'2024-01-01T12:00:00z'")] #[case("'2024-01-01t12:00:00z'")] #[case("TIMESTAMP '2024-01-01t12:00:00Z'")] #[case("TIMESTAMP_LTZ '2024-01-01T12:00:00z'")] fn rejects_lowercase_timestamp_separator_and_zone(#[case] sql: &str) {
let err = parse_sql(sql, &DataType::TIMESTAMP)
.unwrap_err()
.to_string();
assert!(
err.contains("uppercase"),
"{sql:?} should be rejected for lowercase 't'/'z': {err}"
);
}
#[rstest]
#[case("TIMESTAMP_NTZ '1970-01-01T00:00:00Z'", DataType::TIMESTAMP)] #[case("TIMESTAMP '2024-01-01 12:34:56'", DataType::TIMESTAMP_NTZ)] #[case("TIMESTAMP_LTZ '1970-01-01T00:00:00Z'", DataType::TIMESTAMP_NTZ)] fn rejects_mismatched_timestamp_keyword(#[case] sql: &str, #[case] ty: DataType) {
let result = parse_sql(sql, &ty);
assert!(
result.is_err(),
"expected error for mismatched timestamp keyword {sql:?} as {ty:?}, got {result:?}"
);
}
#[rstest]
#[case("'1970-01-01T00:00:00.123Z'", "1970-01-01 00:00:00.123")]
#[case("'2024-06-15T14:30:00Z'", "2024-06-15 14:30:00")]
#[case("TIMESTAMP '2024-06-15T14:30:00.456Z'", "2024-06-15 14:30:00.456")]
#[case("TIMESTAMP_LTZ '2024-06-15T14:30:00Z'", "2024-06-15 14:30:00")] #[case("TIMESTAMP_LTZ'1970-01-01T00:00:00.123Z'", "1970-01-01 00:00:00.123")] fn iso_8601_form_accepted_only_for_timestamp(#[case] sql: &str, #[case] equivalent: &str) {
let got = parse_sql(sql, &DataType::TIMESTAMP).unwrap();
assert_eq!(
got,
Expression::literal(Scalar::Timestamp(ts_micros(equivalent)))
);
parse_sql(sql, &DataType::TIMESTAMP_NTZ).unwrap_err();
}
#[rstest]
#[case("X''", vec![])]
#[case("X'00'", vec![0x00])]
#[case("X'DeAdBeEf'", vec![0xde, 0xad, 0xbe, 0xef])]
#[case("x'01ff'", vec![0x01, 0xff])]
fn parses_binary_literals(#[case] sql: &str, #[case] expected: Vec<u8>) {
let got = parse_sql(sql, &DataType::BINARY).unwrap();
assert_eq!(got, Expression::literal(Scalar::Binary(expected)));
}
#[rstest]
#[case(DataType::INTEGER)]
#[case(DataType::STRING)]
#[case(DataType::BOOLEAN)]
#[case(DataType::DATE)]
#[case(DataType::BINARY)]
fn null_is_accepted_for_any_primitive(#[case] ty: DataType) {
let got = parse_sql("NULL", &ty).unwrap();
assert_eq!(got, Expression::literal(Scalar::Null(ty.clone())));
let got_lower = parse_sql(" null ", &ty).unwrap();
assert_eq!(got_lower, Expression::literal(Scalar::Null(ty)));
}
#[rstest]
#[case("1L", DataType::LONG)]
#[case("1.23BD", decimal_type(5, 2))]
#[case("1.5F", DataType::FLOAT)]
#[case("CAST('2024-01-01' AS DATE)", DataType::DATE)]
#[case("CAST(NULL AS INT)", DataType::INTEGER)]
#[case("current_date()", DataType::DATE)]
#[case("current_timestamp()", DataType::TIMESTAMP)]
#[case("now()", DataType::TIMESTAMP)]
#[case("1 + 1", DataType::INTEGER)]
#[case("concat('a', 'b')", DataType::STRING)]
#[case("0", decimal_type(10, 2))] #[case("1.2", decimal_type(5, 2))]
fn currently_unsupported_valid_sql(#[case] sql: &str, #[case] ty: DataType) {
let result = parse_sql(sql, &ty);
assert!(
result.is_err(),
"expected error for currently-unsupported SQL {sql:?} as {ty:?}, got {result:?}"
);
}
#[rstest]
#[case("", DataType::INTEGER)]
#[case(" ", DataType::INTEGER)]
#[case("'42'", DataType::INTEGER)] #[case("+", DataType::INTEGER)] #[case("UnknownFn()", DataType::INTEGER)] #[case("42", DataType::STRING)] #[case("foo", DataType::STRING)] #[case("'unterminated", DataType::STRING)]
#[case("'bad'quote'", DataType::STRING)] #[case("'''", DataType::STRING)] #[case("'''''", DataType::STRING)] #[case("'ab''", DataType::STRING)] #[case("nope", DataType::BOOLEAN)]
#[case("'TRUE'", DataType::BOOLEAN)] #[case("'2024-13-01'", DataType::DATE)] #[case("not-a-date", DataType::DATE)]
#[case("''", DataType::DATE)] #[case("DATE ''", DataType::DATE)]
#[case("' '", DataType::DATE)] #[case("DATE ' '", DataType::DATE)]
#[case("DATEX '2024-01-01'", DataType::DATE)] #[case("DATEX'2024-01-01'", DataType::DATE)] #[case("''", DataType::TIMESTAMP)]
#[case("' '", DataType::TIMESTAMP)]
#[case("TIMESTAMP ''", DataType::TIMESTAMP)]
#[case("TIMESTAMPX '2024-01-01 12:34:56'", DataType::TIMESTAMP)] #[case("TIMESTAMP_NTZ ''", DataType::TIMESTAMP_NTZ)]
#[case("timestamp_ntza'2024-01-01 12:34:56'", DataType::TIMESTAMP_NTZ)] #[case(" now() ", DataType::TIMESTAMP)] #[case("X'0'", DataType::BINARY)] #[case("X'gg'", DataType::BINARY)] #[case("'deadbeef'", DataType::BINARY)] #[case("128", DataType::BYTE)] #[case("2147483648", DataType::INTEGER)] fn rejects_invalid_input(#[case] sql: &str, #[case] ty: DataType) {
let result = parse_sql(sql, &ty);
assert!(
result.is_err(),
"expected error for {sql:?} as {ty:?}, got {result:?}"
);
}
#[rstest]
fn rejects_bare_non_finite_floats(
#[values(
"NaN",
"nan",
"Infinity",
"infinity",
"inf",
"-inf",
"+inf",
"-Infinity",
"1e999", // overflows to infinity
"-1e999"
)]
sql: &str,
#[values(DataType::FLOAT, DataType::DOUBLE)] ty: DataType,
) {
let result = parse_sql(sql, &ty);
assert!(
result.is_err(),
"expected error for bare non-finite literal {sql:?} as {ty:?}, got {result:?}"
);
}
#[rstest]
#[case("-0.0")]
#[case("-0")]
#[case("-0.00")]
fn normalizes_negative_zero_to_positive(#[case] sql: &str) {
let Expression::Literal(Scalar::Double(d)) = parse_sql(sql, &DataType::DOUBLE).unwrap()
else {
panic!("expected a Double literal for {sql:?}");
};
assert!(
d == 0.0 && d.is_sign_positive(),
"DOUBLE {sql:?} kept the sign: {d}"
);
let Expression::Literal(Scalar::Float(f)) = parse_sql(sql, &DataType::FLOAT).unwrap()
else {
panic!("expected a Float literal for {sql:?}");
};
assert!(
f == 0.0 && f.is_sign_positive(),
"FLOAT {sql:?} kept the sign: {f}"
);
}
#[rstest]
#[case("-0.0E0")]
#[case("-0E0")]
#[case("-0.0e10")]
fn preserves_negative_zero_with_exponent(#[case] sql: &str) {
let Expression::Literal(Scalar::Double(d)) = parse_sql(sql, &DataType::DOUBLE).unwrap()
else {
panic!("expected a Double literal for {sql:?}");
};
assert!(
d == 0.0 && d.is_sign_negative(),
"DOUBLE {sql:?} should keep the negative sign: {d}"
);
let Expression::Literal(Scalar::Float(f)) = parse_sql(sql, &DataType::FLOAT).unwrap()
else {
panic!("expected a Float literal for {sql:?}");
};
assert!(
f == 0.0 && f.is_sign_negative(),
"FLOAT {sql:?} should keep the negative sign: {f}"
);
}
#[test]
fn float_exponent_literal_double_rounds_to_match_spark() {
let got = parse_sql("7.038531E-26", &DataType::FLOAT).unwrap();
let spark = "7.038531E-26".parse::<f64>().unwrap() as f32;
assert_eq!(got, Expression::literal(Scalar::Float(spark)));
assert_eq!(spark.to_bits(), 0x15ae_43fe);
assert_ne!(
spark.to_bits(),
"7.038531E-26".parse::<f32>().unwrap().to_bits()
);
}
#[rstest]
#[case(DataType::DOUBLE)]
#[case(DataType::FLOAT)]
fn rejects_numeric_literal_over_decimal_precision_38(#[case] ty: DataType) {
let thirty_nine_fraction = format!("0.{}", "1".repeat(39)); let forty_digit_integer = "1".repeat(40); let scale_39_one_sig = format!("0.{}1", "0".repeat(38)); for sql in [
&thirty_nine_fraction,
&forty_digit_integer,
&scale_39_one_sig,
] {
let err = parse_sql(sql, &ty).unwrap_err().to_string();
assert!(
err.contains("precision"),
"{sql:?} as {ty:?} not rejected: {err}"
);
}
}
#[rstest]
#[case(DataType::DOUBLE)]
#[case(DataType::FLOAT)]
fn accepts_numeric_literal_at_decimal_precision_boundary(#[case] ty: DataType) {
let thirty_eight_fraction = format!("0.{}", "1".repeat(38)); let thirty_eight_integer = "1".repeat(38); for sql in [
thirty_eight_fraction.as_str(),
thirty_eight_integer.as_str(),
".10101010101", "1.230", ] {
parse_sql(sql, &ty).unwrap_or_else(|e| panic!("{sql:?} as {ty:?} rejected: {e}"));
}
}
#[rstest]
#[case(r"'a\nb'")] #[case(r"'c:\temp'")] #[case(r"'\\'")] fn rejects_backslash_in_string_literal(#[case] sql: &str) {
let result = parse_sql(sql, &DataType::STRING);
assert!(
result.is_err(),
"expected error for backslash in string literal {sql:?}, got {result:?}"
);
}
#[rstest]
#[case(r#""foo""#)]
#[case(r#""it's""#)] fn rejects_double_quoted_string(#[case] sql: &str) {
let result = parse_sql(sql, &DataType::STRING);
assert!(
result.is_err(),
"expected error for double-quoted string {sql:?}, got {result:?}"
);
}
fn struct_ty() -> DataType {
DataType::try_struct_type([StructField::nullable("a", DataType::INTEGER)]).unwrap()
}
fn array_ty() -> DataType {
DataType::Array(Box::new(ArrayType::new(DataType::INTEGER, true)))
}
fn map_ty() -> DataType {
DataType::Map(Box::new(MapType::new(
DataType::STRING,
DataType::INTEGER,
true,
)))
}
#[rstest]
#[case::struct_target(struct_ty())]
#[case::array_target(array_ty())]
#[case::map_target(map_ty())]
fn rejects_non_primitive_target(#[case] ty: DataType) {
assert!(parse_sql("'foo'", &ty).is_err());
}
#[rstest]
#[case::struct_target(struct_ty())]
#[case::array_target(array_ty())]
#[case::map_target(map_ty())]
fn null_is_accepted_for_non_primitive_target(#[case] ty: DataType) {
let got = parse_sql("NULL", &ty).unwrap();
assert_eq!(got, Expression::literal(Scalar::Null(ty)));
}
}