feldera_types/
program_schema.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::cmp::Ordering;
3use std::collections::BTreeMap;
4use std::fmt::Display;
5use std::hash::{Hash, Hasher};
6use utoipa::ToSchema;
7
8#[cfg(feature = "testing")]
9use proptest::{collection::vec, prelude::any};
10
11/// Returns canonical form of a SQL identifier:
12///
13/// - If id is _not_ quoted, then it is interpreted as a case-insensitive
14///   identifier and is converted to the lowercase representation
15/// - If id _is_ quoted, then it is a case-sensitive identifier and is returned
16///   as is, without quotes. No other processing is done on the inner string,
17///   e.g., un-escaping quotes.
18pub fn canonical_identifier(id: &str) -> String {
19    if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
20        id[1..id.len() - 1].to_string()
21    } else {
22        id.to_lowercase()
23    }
24}
25
26/// An SQL identifier.
27///
28/// This struct is used to represent SQL identifiers in a canonical form.
29/// We store table names or field names as identifiers in the schema.
30#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
31#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
32pub struct SqlIdentifier {
33    #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
34    name: String,
35    pub case_sensitive: bool,
36}
37
38impl SqlIdentifier {
39    pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
40        Self {
41            name: name.as_ref().to_string(),
42            case_sensitive,
43        }
44    }
45
46    /// Return the name of the identifier in canonical form.
47    /// The result is the true case-sensitive identifying name of the table,
48    /// and can be used for example to detect duplicate table names.
49    ///
50    /// Example return values for this function:
51    /// - `CREATE TABLE t1` -> `t1`
52    /// - `CREATE TABLE T1` -> `t1`
53    /// - `CREATE TABLE "t1"` -> `t1`
54    /// - `CREATE TABLE "T1"` -> `T1`
55    pub fn name(&self) -> String {
56        if self.case_sensitive {
57            self.name.clone()
58        } else {
59            self.name.to_lowercase()
60        }
61    }
62
63    /// Return the name of the identifier as it appeared originally in SQL.
64    /// This method should only be used for log or error messages as it is what
65    /// the user originally wrote, however it should not be used for identification
66    /// or disambiguation (use `name()` for that instead).
67    ///
68    /// Example return values for this function:
69    /// - `CREATE TABLE t1` -> `t1`
70    /// - `CREATE TABLE T1` -> `T1`
71    /// - `CREATE TABLE "t1"` -> `"t1"`
72    /// - `CREATE TABLE "T1"` -> `"T1"`
73    pub fn sql_name(&self) -> String {
74        if self.case_sensitive {
75            format!("\"{}\"", self.name)
76        } else {
77            self.name.clone()
78        }
79    }
80}
81
82impl Hash for SqlIdentifier {
83    fn hash<H: Hasher>(&self, state: &mut H) {
84        self.name().hash(state);
85    }
86}
87
88impl PartialEq for SqlIdentifier {
89    fn eq(&self, other: &Self) -> bool {
90        match (self.case_sensitive, other.case_sensitive) {
91            (true, true) => self.name == other.name,
92            (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
93            (true, false) => self.name == other.name,
94            (false, true) => self.name == other.name,
95        }
96    }
97}
98
99impl Ord for SqlIdentifier {
100    fn cmp(&self, other: &Self) -> Ordering {
101        self.name().cmp(&other.name())
102    }
103}
104
105impl PartialOrd for SqlIdentifier {
106    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
107        Some(self.cmp(other))
108    }
109}
110
111impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
112    fn eq(&self, other: &S) -> bool {
113        self == &SqlIdentifier::from(other.as_ref())
114    }
115}
116
117impl Eq for SqlIdentifier {}
118
119impl<S: AsRef<str>> From<S> for SqlIdentifier {
120    fn from(name: S) -> Self {
121        if name.as_ref().starts_with('"')
122            && name.as_ref().ends_with('"')
123            && name.as_ref().len() >= 2
124        {
125            Self {
126                name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
127                case_sensitive: true,
128            }
129        } else {
130            Self::new(name, false)
131        }
132    }
133}
134
135impl From<SqlIdentifier> for String {
136    fn from(id: SqlIdentifier) -> String {
137        id.name()
138    }
139}
140
141impl From<&SqlIdentifier> for String {
142    fn from(id: &SqlIdentifier) -> String {
143        id.name()
144    }
145}
146
147impl Display for SqlIdentifier {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        write!(f, "{}", self.name())
150    }
151}
152
153/// A struct containing the tables (inputs) and views for a program.
154///
155/// Parse from the JSON data-type of the DDL generated by the SQL compiler.
156#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
157#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
158pub struct ProgramSchema {
159    #[cfg_attr(
160        feature = "testing",
161        proptest(strategy = "vec(any::<Relation>(), 0..2)")
162    )]
163    pub inputs: Vec<Relation>,
164    #[cfg_attr(
165        feature = "testing",
166        proptest(strategy = "vec(any::<Relation>(), 0..2)")
167    )]
168    pub outputs: Vec<Relation>,
169}
170
171#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
172#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
173pub struct SourcePosition {
174    pub start_line_number: usize,
175    pub start_column: usize,
176    pub end_line_number: usize,
177    pub end_column: usize,
178}
179
180#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
181#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
182pub struct PropertyValue {
183    pub value: String,
184    pub key_position: SourcePosition,
185    pub value_position: SourcePosition,
186}
187
188/// A SQL table or view. It has a name and a list of fields.
189///
190/// Matches the Calcite JSON format.
191#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
192#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
193pub struct Relation {
194    // This field should only be accessed via the `name()` method.
195    #[serde(flatten)]
196    pub name: SqlIdentifier,
197    #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
198    pub fields: Vec<Field>,
199    #[serde(default)]
200    pub materialized: bool,
201    #[serde(default)]
202    pub properties: BTreeMap<String, PropertyValue>,
203}
204
205impl Relation {
206    pub fn empty() -> Self {
207        Self {
208            name: SqlIdentifier::from("".to_string()),
209            fields: Vec::new(),
210            materialized: false,
211            properties: BTreeMap::new(),
212        }
213    }
214
215    pub fn new(
216        name: SqlIdentifier,
217        fields: Vec<Field>,
218        materialized: bool,
219        properties: BTreeMap<String, PropertyValue>,
220    ) -> Self {
221        Self {
222            name,
223            fields,
224            materialized,
225            properties,
226        }
227    }
228
229    /// Lookup field by name.
230    pub fn field(&self, name: &str) -> Option<&Field> {
231        let name = canonical_identifier(name);
232        self.fields.iter().find(|f| f.name == name)
233    }
234}
235
236/// A SQL field.
237///
238/// Matches the SQL compiler JSON format.
239#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
240#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
241pub struct Field {
242    #[serde(flatten)]
243    pub name: SqlIdentifier,
244    pub columntype: ColumnType,
245    pub lateness: Option<String>,
246    pub default: Option<String>,
247    pub watermark: Option<String>,
248}
249
250impl Field {
251    pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
252        Self {
253            name,
254            columntype,
255            lateness: None,
256            default: None,
257            watermark: None,
258        }
259    }
260
261    pub fn with_lateness(mut self, lateness: &str) -> Self {
262        self.lateness = Some(lateness.to_string());
263        self
264    }
265}
266
267/// Thanks to the brain-dead Calcite schema, if we are deserializing a field, the type options
268/// end up inside the Field struct.
269///
270/// This helper struct is used to deserialize the Field struct.
271impl<'de> Deserialize<'de> for Field {
272    fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
273    where
274        D: Deserializer<'de>,
275    {
276        const fn default_is_struct() -> Option<SqlType> {
277            Some(SqlType::Struct)
278        }
279
280        #[derive(Debug, Clone, Deserialize)]
281        struct FieldHelper {
282            name: Option<String>,
283            #[serde(default)]
284            case_sensitive: bool,
285            columntype: Option<ColumnType>,
286            #[serde(rename = "type")]
287            #[serde(default = "default_is_struct")]
288            typ: Option<SqlType>,
289            nullable: Option<bool>,
290            precision: Option<i64>,
291            scale: Option<i64>,
292            component: Option<Box<ColumnType>>,
293            fields: Option<serde_json::Value>,
294            key: Option<Box<ColumnType>>,
295            value: Option<Box<ColumnType>>,
296            default: Option<String>,
297            lateness: Option<String>,
298            watermark: Option<String>,
299        }
300
301        fn helper_to_field(helper: FieldHelper) -> Field {
302            let columntype = if let Some(ctype) = helper.columntype {
303                ctype
304            } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
305                let fields = fields
306                    .into_iter()
307                    .map(|field| {
308                        let field: FieldHelper = serde_json::from_value(field).unwrap();
309                        helper_to_field(field)
310                    })
311                    .collect::<Vec<Field>>();
312
313                ColumnType {
314                    typ: helper.typ.unwrap_or(SqlType::Null),
315                    nullable: helper.nullable.unwrap_or(false),
316                    precision: helper.precision,
317                    scale: helper.scale,
318                    component: helper.component,
319                    fields: Some(fields),
320                    key: None,
321                    value: None,
322                }
323            } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
324                serde_json::from_value(serde_json::Value::Object(obj))
325                    .expect("Failed to deserialize object")
326            } else {
327                ColumnType {
328                    typ: helper.typ.unwrap_or(SqlType::Null),
329                    nullable: helper.nullable.unwrap_or(false),
330                    precision: helper.precision,
331                    scale: helper.scale,
332                    component: helper.component,
333                    fields: None,
334                    key: helper.key,
335                    value: helper.value,
336                }
337            };
338
339            Field {
340                name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
341                columntype,
342                default: helper.default,
343                lateness: helper.lateness,
344                watermark: helper.watermark,
345            }
346        }
347
348        let helper = FieldHelper::deserialize(deserializer)?;
349        Ok(helper_to_field(helper))
350    }
351}
352
353/// The specified units for SQL Interval types.
354///
355/// `INTERVAL 1 DAY`, `INTERVAL 1 DAY TO HOUR`, `INTERVAL 1 DAY TO MINUTE`,
356/// would yield `Day`, `DayToHour`, `DayToMinute`, as the `IntervalUnit` respectively.
357#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
358#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
359pub enum IntervalUnit {
360    /// Unit for `INTERVAL ... DAY`.
361    Day,
362    /// Unit for `INTERVAL ... DAY TO HOUR`.
363    DayToHour,
364    /// Unit for `INTERVAL ... DAY TO MINUTE`.
365    DayToMinute,
366    /// Unit for `INTERVAL ... DAY TO SECOND`.
367    DayToSecond,
368    /// Unit for `INTERVAL ... HOUR`.
369    Hour,
370    /// Unit for `INTERVAL ... HOUR TO MINUTE`.
371    HourToMinute,
372    /// Unit for `INTERVAL ... HOUR TO SECOND`.
373    HourToSecond,
374    /// Unit for `INTERVAL ... MINUTE`.
375    Minute,
376    /// Unit for `INTERVAL ... MINUTE TO SECOND`.
377    MinuteToSecond,
378    /// Unit for `INTERVAL ... MONTH`.
379    Month,
380    /// Unit for `INTERVAL ... SECOND`.
381    Second,
382    /// Unit for `INTERVAL ... YEAR`.
383    Year,
384    /// Unit for `INTERVAL ... YEAR TO MONTH`.
385    YearToMonth,
386}
387
388/// The available SQL types as specified in `CREATE` statements.
389#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
390#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
391pub enum SqlType {
392    /// SQL `BOOLEAN` type.
393    Boolean,
394    /// SQL `TINYINT` type.
395    TinyInt,
396    /// SQL `SMALLINT` or `INT2` type.
397    SmallInt,
398    /// SQL `INTEGER`, `INT`, `SIGNED`, `INT4` type.
399    Int,
400    /// SQL `BIGINT` or `INT64` type.
401    BigInt,
402    /// SQL `REAL` or `FLOAT4` or `FLOAT32` type.
403    Real,
404    /// SQL `DOUBLE` or `FLOAT8` or `FLOAT64` type.
405    Double,
406    /// SQL `DECIMAL` or `DEC` or `NUMERIC` type.
407    Decimal,
408    /// SQL `CHAR(n)` or `CHARACTER(n)` type.
409    Char,
410    /// SQL `VARCHAR`, `CHARACTER VARYING`, `TEXT`, or `STRING` type.
411    Varchar,
412    /// SQL `BINARY(n)` type.
413    Binary,
414    /// SQL `VARBINARY` or `BYTEA` type.
415    Varbinary,
416    /// SQL `TIME` type.
417    Time,
418    /// SQL `DATE` type.
419    Date,
420    /// SQL `TIMESTAMP` type.
421    Timestamp,
422    /// SQL `INTERVAL ... X` type where `X` is a unit.
423    Interval(IntervalUnit),
424    /// SQL `ARRAY` type.
425    Array,
426    /// A complex SQL struct type (`CREATE TYPE x ...`).
427    Struct,
428    /// SQL `MAP` type.
429    Map,
430    /// SQL `NULL` type.
431    Null,
432    /// SQL `UUID` type.
433    Uuid,
434    /// SQL `VARIANT` type.
435    Variant,
436}
437
438impl Display for SqlType {
439    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440        f.write_str(&serde_json::to_string(self).unwrap())
441    }
442}
443
444impl<'de> Deserialize<'de> for SqlType {
445    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
446    where
447        D: Deserializer<'de>,
448    {
449        let value: String = Deserialize::deserialize(deserializer)?;
450        match value.to_lowercase().as_str() {
451            "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
452            "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
453            "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
454            "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
455            "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
456            "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
457            "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
458            "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
459            "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
460            "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
461            "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
462            "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
463            "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
464            "boolean" => Ok(SqlType::Boolean),
465            "tinyint" => Ok(SqlType::TinyInt),
466            "smallint" => Ok(SqlType::SmallInt),
467            "integer" => Ok(SqlType::Int),
468            "bigint" => Ok(SqlType::BigInt),
469            "real" => Ok(SqlType::Real),
470            "double" => Ok(SqlType::Double),
471            "decimal" => Ok(SqlType::Decimal),
472            "char" => Ok(SqlType::Char),
473            "varchar" => Ok(SqlType::Varchar),
474            "binary" => Ok(SqlType::Binary),
475            "varbinary" => Ok(SqlType::Varbinary),
476            "variant" => Ok(SqlType::Variant),
477            "time" => Ok(SqlType::Time),
478            "date" => Ok(SqlType::Date),
479            "timestamp" => Ok(SqlType::Timestamp),
480            "array" => Ok(SqlType::Array),
481            "struct" => Ok(SqlType::Struct),
482            "map" => Ok(SqlType::Map),
483            "null" => Ok(SqlType::Null),
484            "uuid" => Ok(SqlType::Uuid),
485            _ => Err(serde::de::Error::custom(format!(
486                "Unknown SQL type: {}",
487                value
488            ))),
489        }
490    }
491}
492
493impl Serialize for SqlType {
494    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
495    where
496        S: Serializer,
497    {
498        let type_str = match self {
499            SqlType::Boolean => "BOOLEAN",
500            SqlType::TinyInt => "TINYINT",
501            SqlType::SmallInt => "SMALLINT",
502            SqlType::Int => "INTEGER",
503            SqlType::BigInt => "BIGINT",
504            SqlType::Real => "REAL",
505            SqlType::Double => "DOUBLE",
506            SqlType::Decimal => "DECIMAL",
507            SqlType::Char => "CHAR",
508            SqlType::Varchar => "VARCHAR",
509            SqlType::Binary => "BINARY",
510            SqlType::Varbinary => "VARBINARY",
511            SqlType::Time => "TIME",
512            SqlType::Date => "DATE",
513            SqlType::Timestamp => "TIMESTAMP",
514            SqlType::Interval(interval_unit) => match interval_unit {
515                IntervalUnit::Day => "INTERVAL_DAY",
516                IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
517                IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
518                IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
519                IntervalUnit::Hour => "INTERVAL_HOUR",
520                IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
521                IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
522                IntervalUnit::Minute => "INTERVAL_MINUTE",
523                IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
524                IntervalUnit::Month => "INTERVAL_MONTH",
525                IntervalUnit::Second => "INTERVAL_SECOND",
526                IntervalUnit::Year => "INTERVAL_YEAR",
527                IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
528            },
529            SqlType::Array => "ARRAY",
530            SqlType::Struct => "STRUCT",
531            SqlType::Uuid => "UUID",
532            SqlType::Map => "MAP",
533            SqlType::Null => "NULL",
534            SqlType::Variant => "VARIANT",
535        };
536        serializer.serialize_str(type_str)
537    }
538}
539
540impl SqlType {
541    /// Is this a string type?
542    pub fn is_string(&self) -> bool {
543        matches!(self, Self::Char | Self::Varchar)
544    }
545
546    pub fn is_varchar(&self) -> bool {
547        matches!(self, Self::Varchar)
548    }
549
550    pub fn is_varbinary(&self) -> bool {
551        matches!(self, Self::Varbinary)
552    }
553}
554
555/// It so happens that when the type field is missing in the Calcite schema, it's a struct,
556/// so we use it as the default.
557const fn default_is_struct() -> SqlType {
558    SqlType::Struct
559}
560
561/// A SQL column type description.
562///
563/// Matches the Calcite JSON format.
564#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
565#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
566pub struct ColumnType {
567    /// Identifier for the type (e.g., `VARCHAR`, `BIGINT`, `ARRAY` etc.)
568    #[serde(rename = "type")]
569    #[serde(default = "default_is_struct")]
570    pub typ: SqlType,
571    /// Does the type accept NULL values?
572    pub nullable: bool,
573    /// Precision of the type.
574    ///
575    /// # Examples
576    /// - `VARCHAR` sets precision to `-1`.
577    /// - `VARCHAR(255)` sets precision to `255`.
578    /// - `BIGINT`, `DATE`, `FLOAT`, `DOUBLE`, `GEOMETRY`, etc. sets precision
579    ///   to None
580    /// - `TIME`, `TIMESTAMP` set precision to `0`.
581    pub precision: Option<i64>,
582    /// The scale of the type.
583    ///
584    /// # Example
585    /// - `DECIMAL(1,2)` sets scale to `2`.
586    pub scale: Option<i64>,
587    /// A component of the type (if available).
588    ///
589    /// This is in a `Box` because it makes it a recursive types.
590    ///
591    /// For example, this would specify the `VARCHAR(20)` in the `VARCHAR(20)
592    /// ARRAY` type.
593    #[cfg_attr(feature = "testing", proptest(value = "None"))]
594    pub component: Option<Box<ColumnType>>,
595    /// The fields of the type (if available).
596    ///
597    /// For example this would specify the fields of a `CREATE TYPE` construct.
598    ///
599    /// ```sql
600    /// CREATE TYPE person_typ AS (
601    ///   firstname       VARCHAR(30),
602    ///   lastname        VARCHAR(30),
603    ///   address         ADDRESS_TYP
604    /// );
605    /// ```
606    ///
607    /// Would lead to the following `fields` value:
608    ///
609    /// ```sql
610    /// [
611    ///  ColumnType { name: "firstname, ... },
612    ///  ColumnType { name: "lastname", ... },
613    ///  ColumnType { name: "address", fields: [ ... ] }
614    /// ]
615    /// ```
616    #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
617    pub fields: Option<Vec<Field>>,
618    /// Key type; must be set when `type == "MAP"`.
619    #[cfg_attr(feature = "testing", proptest(value = "None"))]
620    pub key: Option<Box<ColumnType>>,
621    /// Value type; must be set when `type == "MAP"`.
622    #[cfg_attr(feature = "testing", proptest(value = "None"))]
623    pub value: Option<Box<ColumnType>>,
624}
625
626impl ColumnType {
627    pub fn boolean(nullable: bool) -> Self {
628        ColumnType {
629            typ: SqlType::Boolean,
630            nullable,
631            precision: None,
632            scale: None,
633            component: None,
634            fields: None,
635            key: None,
636            value: None,
637        }
638    }
639
640    pub fn uuid(nullable: bool) -> Self {
641        ColumnType {
642            typ: SqlType::Uuid,
643            nullable,
644            precision: None,
645            scale: None,
646            component: None,
647            fields: None,
648            key: None,
649            value: None,
650        }
651    }
652
653    pub fn tinyint(nullable: bool) -> Self {
654        ColumnType {
655            typ: SqlType::TinyInt,
656            nullable,
657            precision: None,
658            scale: None,
659            component: None,
660            fields: None,
661            key: None,
662            value: None,
663        }
664    }
665
666    pub fn smallint(nullable: bool) -> Self {
667        ColumnType {
668            typ: SqlType::SmallInt,
669            nullable,
670            precision: None,
671            scale: None,
672            component: None,
673            fields: None,
674            key: None,
675            value: None,
676        }
677    }
678
679    pub fn int(nullable: bool) -> Self {
680        ColumnType {
681            typ: SqlType::Int,
682            nullable,
683            precision: None,
684            scale: None,
685            component: None,
686            fields: None,
687            key: None,
688            value: None,
689        }
690    }
691
692    pub fn bigint(nullable: bool) -> Self {
693        ColumnType {
694            typ: SqlType::BigInt,
695            nullable,
696            precision: None,
697            scale: None,
698            component: None,
699            fields: None,
700            key: None,
701            value: None,
702        }
703    }
704
705    pub fn double(nullable: bool) -> Self {
706        ColumnType {
707            typ: SqlType::Double,
708            nullable,
709            precision: None,
710            scale: None,
711            component: None,
712            fields: None,
713            key: None,
714            value: None,
715        }
716    }
717
718    pub fn real(nullable: bool) -> Self {
719        ColumnType {
720            typ: SqlType::Real,
721            nullable,
722            precision: None,
723            scale: None,
724            component: None,
725            fields: None,
726            key: None,
727            value: None,
728        }
729    }
730
731    pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
732        ColumnType {
733            typ: SqlType::Decimal,
734            nullable,
735            precision: Some(precision),
736            scale: Some(scale),
737            component: None,
738            fields: None,
739            key: None,
740            value: None,
741        }
742    }
743
744    pub fn varchar(nullable: bool) -> Self {
745        ColumnType {
746            typ: SqlType::Varchar,
747            nullable,
748            precision: None,
749            scale: None,
750            component: None,
751            fields: None,
752            key: None,
753            value: None,
754        }
755    }
756
757    pub fn varbinary(nullable: bool) -> Self {
758        ColumnType {
759            typ: SqlType::Varbinary,
760            nullable,
761            precision: None,
762            scale: None,
763            component: None,
764            fields: None,
765            key: None,
766            value: None,
767        }
768    }
769
770    pub fn fixed(width: i64, nullable: bool) -> Self {
771        ColumnType {
772            typ: SqlType::Binary,
773            nullable,
774            precision: Some(width),
775            scale: None,
776            component: None,
777            fields: None,
778            key: None,
779            value: None,
780        }
781    }
782
783    pub fn date(nullable: bool) -> Self {
784        ColumnType {
785            typ: SqlType::Date,
786            nullable,
787            precision: None,
788            scale: None,
789            component: None,
790            fields: None,
791            key: None,
792            value: None,
793        }
794    }
795
796    pub fn time(nullable: bool) -> Self {
797        ColumnType {
798            typ: SqlType::Time,
799            nullable,
800            precision: None,
801            scale: None,
802            component: None,
803            fields: None,
804            key: None,
805            value: None,
806        }
807    }
808
809    pub fn timestamp(nullable: bool) -> Self {
810        ColumnType {
811            typ: SqlType::Timestamp,
812            nullable,
813            precision: None,
814            scale: None,
815            component: None,
816            fields: None,
817            key: None,
818            value: None,
819        }
820    }
821
822    pub fn variant(nullable: bool) -> Self {
823        ColumnType {
824            typ: SqlType::Variant,
825            nullable,
826            precision: None,
827            scale: None,
828            component: None,
829            fields: None,
830            key: None,
831            value: None,
832        }
833    }
834
835    pub fn array(nullable: bool, element: ColumnType) -> Self {
836        ColumnType {
837            typ: SqlType::Array,
838            nullable,
839            precision: None,
840            scale: None,
841            component: Some(Box::new(element)),
842            fields: None,
843            key: None,
844            value: None,
845        }
846    }
847
848    pub fn structure(nullable: bool, fields: &[Field]) -> Self {
849        ColumnType {
850            typ: SqlType::Struct,
851            nullable,
852            precision: None,
853            scale: None,
854            component: None,
855            fields: Some(fields.to_vec()),
856            key: None,
857            value: None,
858        }
859    }
860
861    pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
862        ColumnType {
863            typ: SqlType::Map,
864            nullable,
865            precision: None,
866            scale: None,
867            component: None,
868            fields: None,
869            key: Some(Box::new(key)),
870            value: Some(Box::new(val)),
871        }
872    }
873
874    pub fn is_integral_type(&self) -> bool {
875        matches!(
876            &self.typ,
877            SqlType::TinyInt | SqlType::SmallInt | SqlType::Int | SqlType::BigInt
878        )
879    }
880
881    pub fn is_fp_type(&self) -> bool {
882        matches!(&self.typ, SqlType::Double | SqlType::Real)
883    }
884
885    pub fn is_decimal_type(&self) -> bool {
886        matches!(&self.typ, SqlType::Decimal)
887    }
888
889    pub fn is_numeric_type(&self) -> bool {
890        self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
891    }
892}
893
894#[cfg(test)]
895mod tests {
896    use super::{IntervalUnit, SqlIdentifier};
897    use crate::program_schema::{ColumnType, Field, SqlType};
898
899    #[test]
900    fn serde_sql_type() {
901        for (sql_str_base, expected_value) in [
902            ("Boolean", SqlType::Boolean),
903            ("Uuid", SqlType::Uuid),
904            ("TinyInt", SqlType::TinyInt),
905            ("SmallInt", SqlType::SmallInt),
906            ("Integer", SqlType::Int),
907            ("BigInt", SqlType::BigInt),
908            ("Real", SqlType::Real),
909            ("Double", SqlType::Double),
910            ("Decimal", SqlType::Decimal),
911            ("Char", SqlType::Char),
912            ("Varchar", SqlType::Varchar),
913            ("Binary", SqlType::Binary),
914            ("Varbinary", SqlType::Varbinary),
915            ("Time", SqlType::Time),
916            ("Date", SqlType::Date),
917            ("Timestamp", SqlType::Timestamp),
918            ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
919            (
920                "Interval_Day_Hour",
921                SqlType::Interval(IntervalUnit::DayToHour),
922            ),
923            (
924                "Interval_Day_Minute",
925                SqlType::Interval(IntervalUnit::DayToMinute),
926            ),
927            (
928                "Interval_Day_Second",
929                SqlType::Interval(IntervalUnit::DayToSecond),
930            ),
931            ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
932            (
933                "Interval_Hour_Minute",
934                SqlType::Interval(IntervalUnit::HourToMinute),
935            ),
936            (
937                "Interval_Hour_Second",
938                SqlType::Interval(IntervalUnit::HourToSecond),
939            ),
940            ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
941            (
942                "Interval_Minute_Second",
943                SqlType::Interval(IntervalUnit::MinuteToSecond),
944            ),
945            ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
946            ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
947            ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
948            (
949                "Interval_Year_Month",
950                SqlType::Interval(IntervalUnit::YearToMonth),
951            ),
952            ("Array", SqlType::Array),
953            ("Struct", SqlType::Struct),
954            ("Map", SqlType::Map),
955            ("Null", SqlType::Null),
956            ("Variant", SqlType::Variant),
957        ] {
958            for sql_str in [
959                sql_str_base,                 // Capitalized
960                &sql_str_base.to_lowercase(), // lowercase
961                &sql_str_base.to_uppercase(), // UPPERCASE
962            ] {
963                let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
964                    .unwrap_or_else(|_| {
965                        panic!("\"{sql_str}\" should deserialize into its SQL type")
966                    });
967                assert_eq!(value1, expected_value);
968                let serialized_str =
969                    serde_json::to_string(&value1).expect("Value should serialize into JSON");
970                let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
971                    panic!(
972                        "{} should deserialize back into its SQL type",
973                        serialized_str
974                    )
975                });
976                assert_eq!(value1, value2);
977            }
978        }
979    }
980
981    #[test]
982    fn deserialize_interval_types() {
983        use super::IntervalUnit::*;
984        use super::SqlType::*;
985
986        let schema = r#"
987{
988  "inputs" : [ {
989    "name" : "sales",
990    "case_sensitive" : false,
991    "fields" : [ {
992      "name" : "sales_id",
993      "case_sensitive" : false,
994      "columntype" : {
995        "type" : "INTEGER",
996        "nullable" : true
997      }
998    }, {
999      "name" : "customer_id",
1000      "case_sensitive" : false,
1001      "columntype" : {
1002        "type" : "INTEGER",
1003        "nullable" : true
1004      }
1005    }, {
1006      "name" : "amount",
1007      "case_sensitive" : false,
1008      "columntype" : {
1009        "type" : "DECIMAL",
1010        "nullable" : true,
1011        "precision" : 10,
1012        "scale" : 2
1013      }
1014    }, {
1015      "name" : "sale_date",
1016      "case_sensitive" : false,
1017      "columntype" : {
1018        "type" : "DATE",
1019        "nullable" : true
1020      }
1021    } ],
1022    "primary_key" : [ "sales_id" ]
1023  } ],
1024  "outputs" : [ {
1025    "name" : "salessummary",
1026    "case_sensitive" : false,
1027    "fields" : [ {
1028      "name" : "customer_id",
1029      "case_sensitive" : false,
1030      "columntype" : {
1031        "type" : "INTEGER",
1032        "nullable" : true
1033      }
1034    }, {
1035      "name" : "total_sales",
1036      "case_sensitive" : false,
1037      "columntype" : {
1038        "type" : "DECIMAL",
1039        "nullable" : true,
1040        "precision" : 38,
1041        "scale" : 2
1042      }
1043    }, {
1044      "name" : "interval_day",
1045      "case_sensitive" : false,
1046      "columntype" : {
1047        "type" : "INTERVAL_DAY",
1048        "nullable" : false,
1049        "precision" : 2,
1050        "scale" : 6
1051      }
1052    }, {
1053      "name" : "interval_day_to_hour",
1054      "case_sensitive" : false,
1055      "columntype" : {
1056        "type" : "INTERVAL_DAY_HOUR",
1057        "nullable" : false,
1058        "precision" : 2,
1059        "scale" : 6
1060      }
1061    }, {
1062      "name" : "interval_day_to_minute",
1063      "case_sensitive" : false,
1064      "columntype" : {
1065        "type" : "INTERVAL_DAY_MINUTE",
1066        "nullable" : false,
1067        "precision" : 2,
1068        "scale" : 6
1069      }
1070    }, {
1071      "name" : "interval_day_to_second",
1072      "case_sensitive" : false,
1073      "columntype" : {
1074        "type" : "INTERVAL_DAY_SECOND",
1075        "nullable" : false,
1076        "precision" : 2,
1077        "scale" : 6
1078      }
1079    }, {
1080      "name" : "interval_hour",
1081      "case_sensitive" : false,
1082      "columntype" : {
1083        "type" : "INTERVAL_HOUR",
1084        "nullable" : false,
1085        "precision" : 2,
1086        "scale" : 6
1087      }
1088    }, {
1089      "name" : "interval_hour_to_minute",
1090      "case_sensitive" : false,
1091      "columntype" : {
1092        "type" : "INTERVAL_HOUR_MINUTE",
1093        "nullable" : false,
1094        "precision" : 2,
1095        "scale" : 6
1096      }
1097    }, {
1098      "name" : "interval_hour_to_second",
1099      "case_sensitive" : false,
1100      "columntype" : {
1101        "type" : "INTERVAL_HOUR_SECOND",
1102        "nullable" : false,
1103        "precision" : 2,
1104        "scale" : 6
1105      }
1106    }, {
1107      "name" : "interval_minute",
1108      "case_sensitive" : false,
1109      "columntype" : {
1110        "type" : "INTERVAL_MINUTE",
1111        "nullable" : false,
1112        "precision" : 2,
1113        "scale" : 6
1114      }
1115    }, {
1116      "name" : "interval_minute_to_second",
1117      "case_sensitive" : false,
1118      "columntype" : {
1119        "type" : "INTERVAL_MINUTE_SECOND",
1120        "nullable" : false,
1121        "precision" : 2,
1122        "scale" : 6
1123      }
1124    }, {
1125      "name" : "interval_month",
1126      "case_sensitive" : false,
1127      "columntype" : {
1128        "type" : "INTERVAL_MONTH",
1129        "nullable" : false
1130      }
1131    }, {
1132      "name" : "interval_second",
1133      "case_sensitive" : false,
1134      "columntype" : {
1135        "type" : "INTERVAL_SECOND",
1136        "nullable" : false,
1137        "precision" : 2,
1138        "scale" : 6
1139      }
1140    }, {
1141      "name" : "interval_year",
1142      "case_sensitive" : false,
1143      "columntype" : {
1144        "type" : "INTERVAL_YEAR",
1145        "nullable" : false
1146      }
1147    }, {
1148      "name" : "interval_year_to_month",
1149      "case_sensitive" : false,
1150      "columntype" : {
1151        "type" : "INTERVAL_YEAR_MONTH",
1152        "nullable" : false
1153      }
1154    } ]
1155  } ]
1156}
1157"#;
1158
1159        let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1160        let types = schema
1161            .outputs
1162            .iter()
1163            .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1164        let expected_types = [
1165            Int,
1166            Decimal,
1167            Interval(Day),
1168            Interval(DayToHour),
1169            Interval(DayToMinute),
1170            Interval(DayToSecond),
1171            Interval(Hour),
1172            Interval(HourToMinute),
1173            Interval(HourToSecond),
1174            Interval(Minute),
1175            Interval(MinuteToSecond),
1176            Interval(Month),
1177            Interval(Second),
1178            Interval(Year),
1179            Interval(YearToMonth),
1180        ];
1181
1182        assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1183    }
1184
1185    #[test]
1186    fn serialize_struct_schemas() {
1187        let schema = r#"{
1188  "inputs" : [ {
1189    "name" : "PERS",
1190    "case_sensitive" : false,
1191    "fields" : [ {
1192      "name" : "P0",
1193      "case_sensitive" : false,
1194      "columntype" : {
1195        "fields" : [ {
1196          "type" : "VARCHAR",
1197          "nullable" : true,
1198          "precision" : 30,
1199          "name" : "FIRSTNAME"
1200        }, {
1201          "type" : "VARCHAR",
1202          "nullable" : true,
1203          "precision" : 30,
1204          "name" : "LASTNAME"
1205        }, {
1206          "fields" : {
1207            "fields" : [ {
1208              "type" : "VARCHAR",
1209              "nullable" : true,
1210              "precision" : 30,
1211              "name" : "STREET"
1212            }, {
1213              "type" : "VARCHAR",
1214              "nullable" : true,
1215              "precision" : 30,
1216              "name" : "CITY"
1217            }, {
1218              "type" : "CHAR",
1219              "nullable" : true,
1220              "precision" : 2,
1221              "name" : "STATE"
1222            }, {
1223              "type" : "VARCHAR",
1224              "nullable" : true,
1225              "precision" : 6,
1226              "name" : "POSTAL_CODE"
1227            } ],
1228            "nullable" : false
1229          },
1230          "nullable" : false,
1231          "name" : "ADDRESS"
1232        } ],
1233        "nullable" : false
1234      }
1235    }]
1236  } ],
1237  "outputs" : [ ]
1238}
1239"#;
1240        let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1241        eprintln!("{:#?}", schema);
1242        let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1243        let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1244        assert_eq!(p0.columntype.typ, SqlType::Struct);
1245        let p0_fields = p0.columntype.fields.as_ref().unwrap();
1246        assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1247        assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1248        assert_eq!(p0_fields[2].columntype.typ, SqlType::Struct);
1249        assert_eq!(p0_fields[2].name, "ADDRESS");
1250        let address = &p0_fields[2].columntype.fields.as_ref().unwrap();
1251        assert_eq!(address.len(), 4);
1252        assert_eq!(address[0].name, "STREET");
1253        assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1254        assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1255        assert_eq!(address[2].columntype.typ, SqlType::Char);
1256        assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1257    }
1258
1259    #[test]
1260    fn sql_identifier_cmp() {
1261        assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1262        assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1263        assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1264        assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1265        assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1266        assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1267        assert_eq!(
1268            SqlIdentifier::new("bAr", true),
1269            SqlIdentifier::from("\"bAr\"")
1270        );
1271
1272        assert_eq!(SqlIdentifier::from("bAr"), "bar");
1273        assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1274    }
1275
1276    #[test]
1277    fn sql_identifier_ord() {
1278        let mut btree = std::collections::BTreeSet::new();
1279        assert!(btree.insert(SqlIdentifier::from("foo")));
1280        assert!(btree.insert(SqlIdentifier::from("bar")));
1281        assert!(!btree.insert(SqlIdentifier::from("BAR")));
1282        assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1283        assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1284    }
1285
1286    #[test]
1287    fn sql_identifier_hash() {
1288        let mut hs = std::collections::HashSet::new();
1289        assert!(hs.insert(SqlIdentifier::from("foo")));
1290        assert!(hs.insert(SqlIdentifier::from("bar")));
1291        assert!(!hs.insert(SqlIdentifier::from("BAR")));
1292        assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1293        assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1294    }
1295
1296    #[test]
1297    fn sql_identifier_name() {
1298        assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1299        assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1300        assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1301        assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1302        assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1303        assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1304    }
1305
1306    #[test]
1307    fn issue3277() {
1308        let schema = r#"{
1309      "name" : "j",
1310      "case_sensitive" : false,
1311      "columntype" : {
1312        "fields" : [ {
1313          "key" : {
1314            "nullable" : false,
1315            "precision" : -1,
1316            "type" : "VARCHAR"
1317          },
1318          "name" : "s",
1319          "nullable" : true,
1320          "type" : "MAP",
1321          "value" : {
1322            "nullable" : true,
1323            "precision" : -1,
1324            "type" : "VARCHAR"
1325          }
1326        } ],
1327        "nullable" : true
1328      }
1329    }"#;
1330        let field: Field = serde_json::from_str(schema).unwrap();
1331        println!("field: {:#?}", field);
1332        assert_eq!(
1333            field,
1334            Field {
1335                name: SqlIdentifier {
1336                    name: "j".to_string(),
1337                    case_sensitive: false,
1338                },
1339                columntype: ColumnType {
1340                    typ: SqlType::Struct,
1341                    nullable: true,
1342                    precision: None,
1343                    scale: None,
1344                    component: None,
1345                    fields: Some(vec![Field {
1346                        name: SqlIdentifier {
1347                            name: "s".to_string(),
1348                            case_sensitive: false,
1349                        },
1350                        columntype: ColumnType {
1351                            typ: SqlType::Map,
1352                            nullable: true,
1353                            precision: None,
1354                            scale: None,
1355                            component: None,
1356                            fields: None,
1357                            key: Some(Box::new(ColumnType {
1358                                typ: SqlType::Varchar,
1359                                nullable: false,
1360                                precision: Some(-1),
1361                                scale: None,
1362                                component: None,
1363                                fields: None,
1364                                key: None,
1365                                value: None,
1366                            })),
1367                            value: Some(Box::new(ColumnType {
1368                                typ: SqlType::Varchar,
1369                                nullable: true,
1370                                precision: Some(-1),
1371                                scale: None,
1372                                component: None,
1373                                fields: None,
1374                                key: None,
1375                                value: None,
1376                            })),
1377                        },
1378                        lateness: None,
1379                        default: None,
1380                        watermark: None,
1381                    }]),
1382                    key: None,
1383                    value: None,
1384                },
1385                lateness: None,
1386                default: None,
1387                watermark: None,
1388            }
1389        );
1390    }
1391}