feldera_types/
program_schema.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::cmp::Ordering;
3use std::collections::BTreeMap;
4use std::fmt::Display;
5use std::hash::{Hash, Hasher};
6use utoipa::ToSchema;
7
8#[cfg(feature = "testing")]
9use proptest::{collection::vec, prelude::any};
10
11/// Returns canonical form of a SQL identifier:
12///
13/// - If id is _not_ quoted, then it is interpreted as a case-insensitive
14///   identifier and is converted to the lowercase representation
15/// - If id _is_ quoted, then it is a case-sensitive identifier and is returned
16///   as is, without quotes. No other processing is done on the inner string,
17///   e.g., un-escaping quotes.
18pub fn canonical_identifier(id: &str) -> String {
19    if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
20        id[1..id.len() - 1].to_string()
21    } else {
22        id.to_lowercase()
23    }
24}
25
26/// An SQL identifier.
27///
28/// This struct is used to represent SQL identifiers in a canonical form.
29/// We store table names or field names as identifiers in the schema.
30#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
31#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
32pub struct SqlIdentifier {
33    #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
34    name: String,
35    pub case_sensitive: bool,
36}
37
38impl SqlIdentifier {
39    pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
40        Self {
41            name: name.as_ref().to_string(),
42            case_sensitive,
43        }
44    }
45
46    /// Return the name of the identifier in canonical form.
47    /// The result is the true case-sensitive identifying name of the table,
48    /// and can be used for example to detect duplicate table names.
49    ///
50    /// Example return values for this function:
51    /// - `CREATE TABLE t1` -> `t1`
52    /// - `CREATE TABLE T1` -> `t1`
53    /// - `CREATE TABLE "t1"` -> `t1`
54    /// - `CREATE TABLE "T1"` -> `T1`
55    pub fn name(&self) -> String {
56        if self.case_sensitive {
57            self.name.clone()
58        } else {
59            self.name.to_lowercase()
60        }
61    }
62
63    /// Return the name of the identifier as it appeared originally in SQL.
64    /// This method should only be used for log or error messages as it is what
65    /// the user originally wrote, however it should not be used for identification
66    /// or disambiguation (use `name()` for that instead).
67    ///
68    /// Example return values for this function:
69    /// - `CREATE TABLE t1` -> `t1`
70    /// - `CREATE TABLE T1` -> `T1`
71    /// - `CREATE TABLE "t1"` -> `"t1"`
72    /// - `CREATE TABLE "T1"` -> `"T1"`
73    pub fn sql_name(&self) -> String {
74        if self.case_sensitive {
75            format!("\"{}\"", self.name)
76        } else {
77            self.name.clone()
78        }
79    }
80}
81
82impl Hash for SqlIdentifier {
83    fn hash<H: Hasher>(&self, state: &mut H) {
84        self.name().hash(state);
85    }
86}
87
88impl PartialEq for SqlIdentifier {
89    fn eq(&self, other: &Self) -> bool {
90        match (self.case_sensitive, other.case_sensitive) {
91            (true, true) => self.name == other.name,
92            (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
93            (true, false) => self.name == other.name,
94            (false, true) => self.name == other.name,
95        }
96    }
97}
98
99impl Ord for SqlIdentifier {
100    fn cmp(&self, other: &Self) -> Ordering {
101        self.name().cmp(&other.name())
102    }
103}
104
105impl PartialOrd for SqlIdentifier {
106    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
107        Some(self.cmp(other))
108    }
109}
110
111impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
112    fn eq(&self, other: &S) -> bool {
113        self == &SqlIdentifier::from(other.as_ref())
114    }
115}
116
117impl Eq for SqlIdentifier {}
118
119impl<S: AsRef<str>> From<S> for SqlIdentifier {
120    fn from(name: S) -> Self {
121        if name.as_ref().starts_with('"')
122            && name.as_ref().ends_with('"')
123            && name.as_ref().len() >= 2
124        {
125            Self {
126                name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
127                case_sensitive: true,
128            }
129        } else {
130            Self::new(name, false)
131        }
132    }
133}
134
135impl From<SqlIdentifier> for String {
136    fn from(id: SqlIdentifier) -> String {
137        id.name()
138    }
139}
140
141impl From<&SqlIdentifier> for String {
142    fn from(id: &SqlIdentifier) -> String {
143        id.name()
144    }
145}
146
147impl Display for SqlIdentifier {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        write!(f, "{}", self.name())
150    }
151}
152
153/// A struct containing the tables (inputs) and views for a program.
154///
155/// Parse from the JSON data-type of the DDL generated by the SQL compiler.
156#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
157#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
158pub struct ProgramSchema {
159    #[cfg_attr(
160        feature = "testing",
161        proptest(strategy = "vec(any::<Relation>(), 0..2)")
162    )]
163    pub inputs: Vec<Relation>,
164    #[cfg_attr(
165        feature = "testing",
166        proptest(strategy = "vec(any::<Relation>(), 0..2)")
167    )]
168    pub outputs: Vec<Relation>,
169}
170
171#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
172#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
173pub struct SourcePosition {
174    pub start_line_number: usize,
175    pub start_column: usize,
176    pub end_line_number: usize,
177    pub end_column: usize,
178}
179
180#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
181#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
182pub struct PropertyValue {
183    pub value: String,
184    pub key_position: SourcePosition,
185    pub value_position: SourcePosition,
186}
187
188/// A SQL table or view. It has a name and a list of fields.
189///
190/// Matches the Calcite JSON format.
191#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
192#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
193pub struct Relation {
194    // This field should only be accessed via the `name()` method.
195    #[serde(flatten)]
196    pub name: SqlIdentifier,
197    #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
198    pub fields: Vec<Field>,
199    #[serde(default)]
200    pub materialized: bool,
201    #[serde(default)]
202    pub properties: BTreeMap<String, PropertyValue>,
203}
204
205impl Relation {
206    pub fn empty() -> Self {
207        Self {
208            name: SqlIdentifier::from("".to_string()),
209            fields: Vec::new(),
210            materialized: false,
211            properties: BTreeMap::new(),
212        }
213    }
214
215    pub fn new(
216        name: SqlIdentifier,
217        fields: Vec<Field>,
218        materialized: bool,
219        properties: BTreeMap<String, PropertyValue>,
220    ) -> Self {
221        Self {
222            name,
223            fields,
224            materialized,
225            properties,
226        }
227    }
228
229    /// Lookup field by name.
230    pub fn field(&self, name: &str) -> Option<&Field> {
231        let name = canonical_identifier(name);
232        self.fields.iter().find(|f| f.name == name)
233    }
234}
235
236/// A SQL field.
237///
238/// Matches the SQL compiler JSON format.
239#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
240#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
241pub struct Field {
242    #[serde(flatten)]
243    pub name: SqlIdentifier,
244    pub columntype: ColumnType,
245    pub lateness: Option<String>,
246    pub default: Option<String>,
247    pub unused: bool,
248    pub watermark: Option<String>,
249}
250
251impl Field {
252    pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
253        Self {
254            name,
255            columntype,
256            lateness: None,
257            default: None,
258            unused: false,
259            watermark: None,
260        }
261    }
262
263    pub fn with_lateness(mut self, lateness: &str) -> Self {
264        self.lateness = Some(lateness.to_string());
265        self
266    }
267
268    pub fn with_unused(mut self, unused: bool) -> Self {
269        self.unused = unused;
270        self
271    }
272}
273
274/// Thanks to the brain-dead Calcite schema, if we are deserializing a field, the type options
275/// end up inside the Field struct.
276///
277/// This helper struct is used to deserialize the Field struct.
278impl<'de> Deserialize<'de> for Field {
279    fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
280    where
281        D: Deserializer<'de>,
282    {
283        const fn default_is_struct() -> Option<SqlType> {
284            Some(SqlType::Struct)
285        }
286
287        #[derive(Debug, Clone, Deserialize)]
288        struct FieldHelper {
289            name: Option<String>,
290            #[serde(default)]
291            case_sensitive: bool,
292            columntype: Option<ColumnType>,
293            #[serde(rename = "type")]
294            #[serde(default = "default_is_struct")]
295            typ: Option<SqlType>,
296            nullable: Option<bool>,
297            precision: Option<i64>,
298            scale: Option<i64>,
299            component: Option<Box<ColumnType>>,
300            fields: Option<serde_json::Value>,
301            key: Option<Box<ColumnType>>,
302            value: Option<Box<ColumnType>>,
303            default: Option<String>,
304            #[serde(default)]
305            unused: bool,
306            lateness: Option<String>,
307            watermark: Option<String>,
308        }
309
310        fn helper_to_field(helper: FieldHelper) -> Field {
311            let columntype = if let Some(ctype) = helper.columntype {
312                ctype
313            } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
314                let fields = fields
315                    .into_iter()
316                    .map(|field| {
317                        let field: FieldHelper = serde_json::from_value(field).unwrap();
318                        helper_to_field(field)
319                    })
320                    .collect::<Vec<Field>>();
321
322                ColumnType {
323                    typ: helper.typ.unwrap_or(SqlType::Null),
324                    nullable: helper.nullable.unwrap_or(false),
325                    precision: helper.precision,
326                    scale: helper.scale,
327                    component: helper.component,
328                    fields: Some(fields),
329                    key: None,
330                    value: None,
331                }
332            } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
333                serde_json::from_value(serde_json::Value::Object(obj))
334                    .expect("Failed to deserialize object")
335            } else {
336                ColumnType {
337                    typ: helper.typ.unwrap_or(SqlType::Null),
338                    nullable: helper.nullable.unwrap_or(false),
339                    precision: helper.precision,
340                    scale: helper.scale,
341                    component: helper.component,
342                    fields: None,
343                    key: helper.key,
344                    value: helper.value,
345                }
346            };
347
348            Field {
349                name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
350                columntype,
351                default: helper.default,
352                unused: helper.unused,
353                lateness: helper.lateness,
354                watermark: helper.watermark,
355            }
356        }
357
358        let helper = FieldHelper::deserialize(deserializer)?;
359        Ok(helper_to_field(helper))
360    }
361}
362
363/// The specified units for SQL Interval types.
364///
365/// `INTERVAL 1 DAY`, `INTERVAL 1 DAY TO HOUR`, `INTERVAL 1 DAY TO MINUTE`,
366/// would yield `Day`, `DayToHour`, `DayToMinute`, as the `IntervalUnit` respectively.
367#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
368#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
369pub enum IntervalUnit {
370    /// Unit for `INTERVAL ... DAY`.
371    Day,
372    /// Unit for `INTERVAL ... DAY TO HOUR`.
373    DayToHour,
374    /// Unit for `INTERVAL ... DAY TO MINUTE`.
375    DayToMinute,
376    /// Unit for `INTERVAL ... DAY TO SECOND`.
377    DayToSecond,
378    /// Unit for `INTERVAL ... HOUR`.
379    Hour,
380    /// Unit for `INTERVAL ... HOUR TO MINUTE`.
381    HourToMinute,
382    /// Unit for `INTERVAL ... HOUR TO SECOND`.
383    HourToSecond,
384    /// Unit for `INTERVAL ... MINUTE`.
385    Minute,
386    /// Unit for `INTERVAL ... MINUTE TO SECOND`.
387    MinuteToSecond,
388    /// Unit for `INTERVAL ... MONTH`.
389    Month,
390    /// Unit for `INTERVAL ... SECOND`.
391    Second,
392    /// Unit for `INTERVAL ... YEAR`.
393    Year,
394    /// Unit for `INTERVAL ... YEAR TO MONTH`.
395    YearToMonth,
396}
397
398/// The available SQL types as specified in `CREATE` statements.
399#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
400#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
401pub enum SqlType {
402    /// SQL `BOOLEAN` type.
403    Boolean,
404    /// SQL `TINYINT` type.
405    TinyInt,
406    /// SQL `SMALLINT` or `INT2` type.
407    SmallInt,
408    /// SQL `INTEGER`, `INT`, `SIGNED`, `INT4` type.
409    Int,
410    /// SQL `BIGINT` or `INT64` type.
411    BigInt,
412    /// SQL `REAL` or `FLOAT4` or `FLOAT32` type.
413    Real,
414    /// SQL `DOUBLE` or `FLOAT8` or `FLOAT64` type.
415    Double,
416    /// SQL `DECIMAL` or `DEC` or `NUMERIC` type.
417    Decimal,
418    /// SQL `CHAR(n)` or `CHARACTER(n)` type.
419    Char,
420    /// SQL `VARCHAR`, `CHARACTER VARYING`, `TEXT`, or `STRING` type.
421    Varchar,
422    /// SQL `BINARY(n)` type.
423    Binary,
424    /// SQL `VARBINARY` or `BYTEA` type.
425    Varbinary,
426    /// SQL `TIME` type.
427    Time,
428    /// SQL `DATE` type.
429    Date,
430    /// SQL `TIMESTAMP` type.
431    Timestamp,
432    /// SQL `INTERVAL ... X` type where `X` is a unit.
433    Interval(IntervalUnit),
434    /// SQL `ARRAY` type.
435    Array,
436    /// A complex SQL struct type (`CREATE TYPE x ...`).
437    Struct,
438    /// SQL `MAP` type.
439    Map,
440    /// SQL `NULL` type.
441    Null,
442    /// SQL `UUID` type.
443    Uuid,
444    /// SQL `VARIANT` type.
445    Variant,
446}
447
448impl Display for SqlType {
449    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450        f.write_str(&serde_json::to_string(self).unwrap())
451    }
452}
453
454impl<'de> Deserialize<'de> for SqlType {
455    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
456    where
457        D: Deserializer<'de>,
458    {
459        let value: String = Deserialize::deserialize(deserializer)?;
460        match value.to_lowercase().as_str() {
461            "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
462            "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
463            "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
464            "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
465            "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
466            "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
467            "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
468            "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
469            "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
470            "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
471            "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
472            "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
473            "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
474            "boolean" => Ok(SqlType::Boolean),
475            "tinyint" => Ok(SqlType::TinyInt),
476            "smallint" => Ok(SqlType::SmallInt),
477            "integer" => Ok(SqlType::Int),
478            "bigint" => Ok(SqlType::BigInt),
479            "real" => Ok(SqlType::Real),
480            "double" => Ok(SqlType::Double),
481            "decimal" => Ok(SqlType::Decimal),
482            "char" => Ok(SqlType::Char),
483            "varchar" => Ok(SqlType::Varchar),
484            "binary" => Ok(SqlType::Binary),
485            "varbinary" => Ok(SqlType::Varbinary),
486            "variant" => Ok(SqlType::Variant),
487            "time" => Ok(SqlType::Time),
488            "date" => Ok(SqlType::Date),
489            "timestamp" => Ok(SqlType::Timestamp),
490            "array" => Ok(SqlType::Array),
491            "struct" => Ok(SqlType::Struct),
492            "map" => Ok(SqlType::Map),
493            "null" => Ok(SqlType::Null),
494            "uuid" => Ok(SqlType::Uuid),
495            _ => Err(serde::de::Error::custom(format!(
496                "Unknown SQL type: {}",
497                value
498            ))),
499        }
500    }
501}
502
503impl Serialize for SqlType {
504    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
505    where
506        S: Serializer,
507    {
508        let type_str = match self {
509            SqlType::Boolean => "BOOLEAN",
510            SqlType::TinyInt => "TINYINT",
511            SqlType::SmallInt => "SMALLINT",
512            SqlType::Int => "INTEGER",
513            SqlType::BigInt => "BIGINT",
514            SqlType::Real => "REAL",
515            SqlType::Double => "DOUBLE",
516            SqlType::Decimal => "DECIMAL",
517            SqlType::Char => "CHAR",
518            SqlType::Varchar => "VARCHAR",
519            SqlType::Binary => "BINARY",
520            SqlType::Varbinary => "VARBINARY",
521            SqlType::Time => "TIME",
522            SqlType::Date => "DATE",
523            SqlType::Timestamp => "TIMESTAMP",
524            SqlType::Interval(interval_unit) => match interval_unit {
525                IntervalUnit::Day => "INTERVAL_DAY",
526                IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
527                IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
528                IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
529                IntervalUnit::Hour => "INTERVAL_HOUR",
530                IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
531                IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
532                IntervalUnit::Minute => "INTERVAL_MINUTE",
533                IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
534                IntervalUnit::Month => "INTERVAL_MONTH",
535                IntervalUnit::Second => "INTERVAL_SECOND",
536                IntervalUnit::Year => "INTERVAL_YEAR",
537                IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
538            },
539            SqlType::Array => "ARRAY",
540            SqlType::Struct => "STRUCT",
541            SqlType::Uuid => "UUID",
542            SqlType::Map => "MAP",
543            SqlType::Null => "NULL",
544            SqlType::Variant => "VARIANT",
545        };
546        serializer.serialize_str(type_str)
547    }
548}
549
550impl SqlType {
551    /// Is this a string type?
552    pub fn is_string(&self) -> bool {
553        matches!(self, Self::Char | Self::Varchar)
554    }
555
556    pub fn is_varchar(&self) -> bool {
557        matches!(self, Self::Varchar)
558    }
559
560    pub fn is_varbinary(&self) -> bool {
561        matches!(self, Self::Varbinary)
562    }
563}
564
565/// It so happens that when the type field is missing in the Calcite schema, it's a struct,
566/// so we use it as the default.
567const fn default_is_struct() -> SqlType {
568    SqlType::Struct
569}
570
571/// A SQL column type description.
572///
573/// Matches the Calcite JSON format.
574#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
575#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
576pub struct ColumnType {
577    /// Identifier for the type (e.g., `VARCHAR`, `BIGINT`, `ARRAY` etc.)
578    #[serde(rename = "type")]
579    #[serde(default = "default_is_struct")]
580    pub typ: SqlType,
581    /// Does the type accept NULL values?
582    pub nullable: bool,
583    /// Precision of the type.
584    ///
585    /// # Examples
586    /// - `VARCHAR` sets precision to `-1`.
587    /// - `VARCHAR(255)` sets precision to `255`.
588    /// - `BIGINT`, `DATE`, `FLOAT`, `DOUBLE`, `GEOMETRY`, etc. sets precision
589    ///   to None
590    /// - `TIME`, `TIMESTAMP` set precision to `0`.
591    pub precision: Option<i64>,
592    /// The scale of the type.
593    ///
594    /// # Example
595    /// - `DECIMAL(1,2)` sets scale to `2`.
596    pub scale: Option<i64>,
597    /// A component of the type (if available).
598    ///
599    /// This is in a `Box` because it makes it a recursive types.
600    ///
601    /// For example, this would specify the `VARCHAR(20)` in the `VARCHAR(20)
602    /// ARRAY` type.
603    #[cfg_attr(feature = "testing", proptest(value = "None"))]
604    pub component: Option<Box<ColumnType>>,
605    /// The fields of the type (if available).
606    ///
607    /// For example this would specify the fields of a `CREATE TYPE` construct.
608    ///
609    /// ```sql
610    /// CREATE TYPE person_typ AS (
611    ///   firstname       VARCHAR(30),
612    ///   lastname        VARCHAR(30),
613    ///   address         ADDRESS_TYP
614    /// );
615    /// ```
616    ///
617    /// Would lead to the following `fields` value:
618    ///
619    /// ```sql
620    /// [
621    ///  ColumnType { name: "firstname, ... },
622    ///  ColumnType { name: "lastname", ... },
623    ///  ColumnType { name: "address", fields: [ ... ] }
624    /// ]
625    /// ```
626    #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
627    pub fields: Option<Vec<Field>>,
628    /// Key type; must be set when `type == "MAP"`.
629    #[cfg_attr(feature = "testing", proptest(value = "None"))]
630    pub key: Option<Box<ColumnType>>,
631    /// Value type; must be set when `type == "MAP"`.
632    #[cfg_attr(feature = "testing", proptest(value = "None"))]
633    pub value: Option<Box<ColumnType>>,
634}
635
636impl ColumnType {
637    pub fn boolean(nullable: bool) -> Self {
638        ColumnType {
639            typ: SqlType::Boolean,
640            nullable,
641            precision: None,
642            scale: None,
643            component: None,
644            fields: None,
645            key: None,
646            value: None,
647        }
648    }
649
650    pub fn uuid(nullable: bool) -> Self {
651        ColumnType {
652            typ: SqlType::Uuid,
653            nullable,
654            precision: None,
655            scale: None,
656            component: None,
657            fields: None,
658            key: None,
659            value: None,
660        }
661    }
662
663    pub fn tinyint(nullable: bool) -> Self {
664        ColumnType {
665            typ: SqlType::TinyInt,
666            nullable,
667            precision: None,
668            scale: None,
669            component: None,
670            fields: None,
671            key: None,
672            value: None,
673        }
674    }
675
676    pub fn smallint(nullable: bool) -> Self {
677        ColumnType {
678            typ: SqlType::SmallInt,
679            nullable,
680            precision: None,
681            scale: None,
682            component: None,
683            fields: None,
684            key: None,
685            value: None,
686        }
687    }
688
689    pub fn int(nullable: bool) -> Self {
690        ColumnType {
691            typ: SqlType::Int,
692            nullable,
693            precision: None,
694            scale: None,
695            component: None,
696            fields: None,
697            key: None,
698            value: None,
699        }
700    }
701
702    pub fn bigint(nullable: bool) -> Self {
703        ColumnType {
704            typ: SqlType::BigInt,
705            nullable,
706            precision: None,
707            scale: None,
708            component: None,
709            fields: None,
710            key: None,
711            value: None,
712        }
713    }
714
715    pub fn double(nullable: bool) -> Self {
716        ColumnType {
717            typ: SqlType::Double,
718            nullable,
719            precision: None,
720            scale: None,
721            component: None,
722            fields: None,
723            key: None,
724            value: None,
725        }
726    }
727
728    pub fn real(nullable: bool) -> Self {
729        ColumnType {
730            typ: SqlType::Real,
731            nullable,
732            precision: None,
733            scale: None,
734            component: None,
735            fields: None,
736            key: None,
737            value: None,
738        }
739    }
740
741    pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
742        ColumnType {
743            typ: SqlType::Decimal,
744            nullable,
745            precision: Some(precision),
746            scale: Some(scale),
747            component: None,
748            fields: None,
749            key: None,
750            value: None,
751        }
752    }
753
754    pub fn varchar(nullable: bool) -> Self {
755        ColumnType {
756            typ: SqlType::Varchar,
757            nullable,
758            precision: None,
759            scale: None,
760            component: None,
761            fields: None,
762            key: None,
763            value: None,
764        }
765    }
766
767    pub fn varbinary(nullable: bool) -> Self {
768        ColumnType {
769            typ: SqlType::Varbinary,
770            nullable,
771            precision: None,
772            scale: None,
773            component: None,
774            fields: None,
775            key: None,
776            value: None,
777        }
778    }
779
780    pub fn fixed(width: i64, nullable: bool) -> Self {
781        ColumnType {
782            typ: SqlType::Binary,
783            nullable,
784            precision: Some(width),
785            scale: None,
786            component: None,
787            fields: None,
788            key: None,
789            value: None,
790        }
791    }
792
793    pub fn date(nullable: bool) -> Self {
794        ColumnType {
795            typ: SqlType::Date,
796            nullable,
797            precision: None,
798            scale: None,
799            component: None,
800            fields: None,
801            key: None,
802            value: None,
803        }
804    }
805
806    pub fn time(nullable: bool) -> Self {
807        ColumnType {
808            typ: SqlType::Time,
809            nullable,
810            precision: None,
811            scale: None,
812            component: None,
813            fields: None,
814            key: None,
815            value: None,
816        }
817    }
818
819    pub fn timestamp(nullable: bool) -> Self {
820        ColumnType {
821            typ: SqlType::Timestamp,
822            nullable,
823            precision: None,
824            scale: None,
825            component: None,
826            fields: None,
827            key: None,
828            value: None,
829        }
830    }
831
832    pub fn variant(nullable: bool) -> Self {
833        ColumnType {
834            typ: SqlType::Variant,
835            nullable,
836            precision: None,
837            scale: None,
838            component: None,
839            fields: None,
840            key: None,
841            value: None,
842        }
843    }
844
845    pub fn array(nullable: bool, element: ColumnType) -> Self {
846        ColumnType {
847            typ: SqlType::Array,
848            nullable,
849            precision: None,
850            scale: None,
851            component: Some(Box::new(element)),
852            fields: None,
853            key: None,
854            value: None,
855        }
856    }
857
858    pub fn structure(nullable: bool, fields: &[Field]) -> Self {
859        ColumnType {
860            typ: SqlType::Struct,
861            nullable,
862            precision: None,
863            scale: None,
864            component: None,
865            fields: Some(fields.to_vec()),
866            key: None,
867            value: None,
868        }
869    }
870
871    pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
872        ColumnType {
873            typ: SqlType::Map,
874            nullable,
875            precision: None,
876            scale: None,
877            component: None,
878            fields: None,
879            key: Some(Box::new(key)),
880            value: Some(Box::new(val)),
881        }
882    }
883
884    pub fn is_integral_type(&self) -> bool {
885        matches!(
886            &self.typ,
887            SqlType::TinyInt | SqlType::SmallInt | SqlType::Int | SqlType::BigInt
888        )
889    }
890
891    pub fn is_fp_type(&self) -> bool {
892        matches!(&self.typ, SqlType::Double | SqlType::Real)
893    }
894
895    pub fn is_decimal_type(&self) -> bool {
896        matches!(&self.typ, SqlType::Decimal)
897    }
898
899    pub fn is_numeric_type(&self) -> bool {
900        self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
901    }
902}
903
904#[cfg(test)]
905mod tests {
906    use super::{IntervalUnit, SqlIdentifier};
907    use crate::program_schema::{ColumnType, Field, SqlType};
908
909    #[test]
910    fn serde_sql_type() {
911        for (sql_str_base, expected_value) in [
912            ("Boolean", SqlType::Boolean),
913            ("Uuid", SqlType::Uuid),
914            ("TinyInt", SqlType::TinyInt),
915            ("SmallInt", SqlType::SmallInt),
916            ("Integer", SqlType::Int),
917            ("BigInt", SqlType::BigInt),
918            ("Real", SqlType::Real),
919            ("Double", SqlType::Double),
920            ("Decimal", SqlType::Decimal),
921            ("Char", SqlType::Char),
922            ("Varchar", SqlType::Varchar),
923            ("Binary", SqlType::Binary),
924            ("Varbinary", SqlType::Varbinary),
925            ("Time", SqlType::Time),
926            ("Date", SqlType::Date),
927            ("Timestamp", SqlType::Timestamp),
928            ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
929            (
930                "Interval_Day_Hour",
931                SqlType::Interval(IntervalUnit::DayToHour),
932            ),
933            (
934                "Interval_Day_Minute",
935                SqlType::Interval(IntervalUnit::DayToMinute),
936            ),
937            (
938                "Interval_Day_Second",
939                SqlType::Interval(IntervalUnit::DayToSecond),
940            ),
941            ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
942            (
943                "Interval_Hour_Minute",
944                SqlType::Interval(IntervalUnit::HourToMinute),
945            ),
946            (
947                "Interval_Hour_Second",
948                SqlType::Interval(IntervalUnit::HourToSecond),
949            ),
950            ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
951            (
952                "Interval_Minute_Second",
953                SqlType::Interval(IntervalUnit::MinuteToSecond),
954            ),
955            ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
956            ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
957            ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
958            (
959                "Interval_Year_Month",
960                SqlType::Interval(IntervalUnit::YearToMonth),
961            ),
962            ("Array", SqlType::Array),
963            ("Struct", SqlType::Struct),
964            ("Map", SqlType::Map),
965            ("Null", SqlType::Null),
966            ("Variant", SqlType::Variant),
967        ] {
968            for sql_str in [
969                sql_str_base,                 // Capitalized
970                &sql_str_base.to_lowercase(), // lowercase
971                &sql_str_base.to_uppercase(), // UPPERCASE
972            ] {
973                let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
974                    .unwrap_or_else(|_| {
975                        panic!("\"{sql_str}\" should deserialize into its SQL type")
976                    });
977                assert_eq!(value1, expected_value);
978                let serialized_str =
979                    serde_json::to_string(&value1).expect("Value should serialize into JSON");
980                let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
981                    panic!(
982                        "{} should deserialize back into its SQL type",
983                        serialized_str
984                    )
985                });
986                assert_eq!(value1, value2);
987            }
988        }
989    }
990
991    #[test]
992    fn deserialize_interval_types() {
993        use super::IntervalUnit::*;
994        use super::SqlType::*;
995
996        let schema = r#"
997{
998  "inputs" : [ {
999    "name" : "sales",
1000    "case_sensitive" : false,
1001    "fields" : [ {
1002      "name" : "sales_id",
1003      "case_sensitive" : false,
1004      "columntype" : {
1005        "type" : "INTEGER",
1006        "nullable" : true
1007      }
1008    }, {
1009      "name" : "customer_id",
1010      "case_sensitive" : false,
1011      "columntype" : {
1012        "type" : "INTEGER",
1013        "nullable" : true
1014      }
1015    }, {
1016      "name" : "amount",
1017      "case_sensitive" : false,
1018      "columntype" : {
1019        "type" : "DECIMAL",
1020        "nullable" : true,
1021        "precision" : 10,
1022        "scale" : 2
1023      }
1024    }, {
1025      "name" : "sale_date",
1026      "case_sensitive" : false,
1027      "columntype" : {
1028        "type" : "DATE",
1029        "nullable" : true
1030      }
1031    } ],
1032    "primary_key" : [ "sales_id" ]
1033  } ],
1034  "outputs" : [ {
1035    "name" : "salessummary",
1036    "case_sensitive" : false,
1037    "fields" : [ {
1038      "name" : "customer_id",
1039      "case_sensitive" : false,
1040      "columntype" : {
1041        "type" : "INTEGER",
1042        "nullable" : true
1043      }
1044    }, {
1045      "name" : "total_sales",
1046      "case_sensitive" : false,
1047      "columntype" : {
1048        "type" : "DECIMAL",
1049        "nullable" : true,
1050        "precision" : 38,
1051        "scale" : 2
1052      }
1053    }, {
1054      "name" : "interval_day",
1055      "case_sensitive" : false,
1056      "columntype" : {
1057        "type" : "INTERVAL_DAY",
1058        "nullable" : false,
1059        "precision" : 2,
1060        "scale" : 6
1061      }
1062    }, {
1063      "name" : "interval_day_to_hour",
1064      "case_sensitive" : false,
1065      "columntype" : {
1066        "type" : "INTERVAL_DAY_HOUR",
1067        "nullable" : false,
1068        "precision" : 2,
1069        "scale" : 6
1070      }
1071    }, {
1072      "name" : "interval_day_to_minute",
1073      "case_sensitive" : false,
1074      "columntype" : {
1075        "type" : "INTERVAL_DAY_MINUTE",
1076        "nullable" : false,
1077        "precision" : 2,
1078        "scale" : 6
1079      }
1080    }, {
1081      "name" : "interval_day_to_second",
1082      "case_sensitive" : false,
1083      "columntype" : {
1084        "type" : "INTERVAL_DAY_SECOND",
1085        "nullable" : false,
1086        "precision" : 2,
1087        "scale" : 6
1088      }
1089    }, {
1090      "name" : "interval_hour",
1091      "case_sensitive" : false,
1092      "columntype" : {
1093        "type" : "INTERVAL_HOUR",
1094        "nullable" : false,
1095        "precision" : 2,
1096        "scale" : 6
1097      }
1098    }, {
1099      "name" : "interval_hour_to_minute",
1100      "case_sensitive" : false,
1101      "columntype" : {
1102        "type" : "INTERVAL_HOUR_MINUTE",
1103        "nullable" : false,
1104        "precision" : 2,
1105        "scale" : 6
1106      }
1107    }, {
1108      "name" : "interval_hour_to_second",
1109      "case_sensitive" : false,
1110      "columntype" : {
1111        "type" : "INTERVAL_HOUR_SECOND",
1112        "nullable" : false,
1113        "precision" : 2,
1114        "scale" : 6
1115      }
1116    }, {
1117      "name" : "interval_minute",
1118      "case_sensitive" : false,
1119      "columntype" : {
1120        "type" : "INTERVAL_MINUTE",
1121        "nullable" : false,
1122        "precision" : 2,
1123        "scale" : 6
1124      }
1125    }, {
1126      "name" : "interval_minute_to_second",
1127      "case_sensitive" : false,
1128      "columntype" : {
1129        "type" : "INTERVAL_MINUTE_SECOND",
1130        "nullable" : false,
1131        "precision" : 2,
1132        "scale" : 6
1133      }
1134    }, {
1135      "name" : "interval_month",
1136      "case_sensitive" : false,
1137      "columntype" : {
1138        "type" : "INTERVAL_MONTH",
1139        "nullable" : false
1140      }
1141    }, {
1142      "name" : "interval_second",
1143      "case_sensitive" : false,
1144      "columntype" : {
1145        "type" : "INTERVAL_SECOND",
1146        "nullable" : false,
1147        "precision" : 2,
1148        "scale" : 6
1149      }
1150    }, {
1151      "name" : "interval_year",
1152      "case_sensitive" : false,
1153      "columntype" : {
1154        "type" : "INTERVAL_YEAR",
1155        "nullable" : false
1156      }
1157    }, {
1158      "name" : "interval_year_to_month",
1159      "case_sensitive" : false,
1160      "columntype" : {
1161        "type" : "INTERVAL_YEAR_MONTH",
1162        "nullable" : false
1163      }
1164    } ]
1165  } ]
1166}
1167"#;
1168
1169        let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1170        let types = schema
1171            .outputs
1172            .iter()
1173            .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1174        let expected_types = [
1175            Int,
1176            Decimal,
1177            Interval(Day),
1178            Interval(DayToHour),
1179            Interval(DayToMinute),
1180            Interval(DayToSecond),
1181            Interval(Hour),
1182            Interval(HourToMinute),
1183            Interval(HourToSecond),
1184            Interval(Minute),
1185            Interval(MinuteToSecond),
1186            Interval(Month),
1187            Interval(Second),
1188            Interval(Year),
1189            Interval(YearToMonth),
1190        ];
1191
1192        assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1193    }
1194
1195    #[test]
1196    fn serialize_struct_schemas() {
1197        let schema = r#"{
1198  "inputs" : [ {
1199    "name" : "PERS",
1200    "case_sensitive" : false,
1201    "fields" : [ {
1202      "name" : "P0",
1203      "case_sensitive" : false,
1204      "columntype" : {
1205        "fields" : [ {
1206          "type" : "VARCHAR",
1207          "nullable" : true,
1208          "precision" : 30,
1209          "name" : "FIRSTNAME"
1210        }, {
1211          "type" : "VARCHAR",
1212          "nullable" : true,
1213          "precision" : 30,
1214          "name" : "LASTNAME"
1215        }, {
1216          "fields" : {
1217            "fields" : [ {
1218              "type" : "VARCHAR",
1219              "nullable" : true,
1220              "precision" : 30,
1221              "name" : "STREET"
1222            }, {
1223              "type" : "VARCHAR",
1224              "nullable" : true,
1225              "precision" : 30,
1226              "name" : "CITY"
1227            }, {
1228              "type" : "CHAR",
1229              "nullable" : true,
1230              "precision" : 2,
1231              "name" : "STATE"
1232            }, {
1233              "type" : "VARCHAR",
1234              "nullable" : true,
1235              "precision" : 6,
1236              "name" : "POSTAL_CODE"
1237            } ],
1238            "nullable" : false
1239          },
1240          "nullable" : false,
1241          "name" : "ADDRESS"
1242        } ],
1243        "nullable" : false
1244      }
1245    }]
1246  } ],
1247  "outputs" : [ ]
1248}
1249"#;
1250        let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1251        eprintln!("{:#?}", schema);
1252        let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1253        let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1254        assert_eq!(p0.columntype.typ, SqlType::Struct);
1255        let p0_fields = p0.columntype.fields.as_ref().unwrap();
1256        assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1257        assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1258        assert_eq!(p0_fields[2].columntype.typ, SqlType::Struct);
1259        assert_eq!(p0_fields[2].name, "ADDRESS");
1260        let address = &p0_fields[2].columntype.fields.as_ref().unwrap();
1261        assert_eq!(address.len(), 4);
1262        assert_eq!(address[0].name, "STREET");
1263        assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1264        assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1265        assert_eq!(address[2].columntype.typ, SqlType::Char);
1266        assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1267    }
1268
1269    #[test]
1270    fn sql_identifier_cmp() {
1271        assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1272        assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1273        assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1274        assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1275        assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1276        assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1277        assert_eq!(
1278            SqlIdentifier::new("bAr", true),
1279            SqlIdentifier::from("\"bAr\"")
1280        );
1281
1282        assert_eq!(SqlIdentifier::from("bAr"), "bar");
1283        assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1284    }
1285
1286    #[test]
1287    fn sql_identifier_ord() {
1288        let mut btree = std::collections::BTreeSet::new();
1289        assert!(btree.insert(SqlIdentifier::from("foo")));
1290        assert!(btree.insert(SqlIdentifier::from("bar")));
1291        assert!(!btree.insert(SqlIdentifier::from("BAR")));
1292        assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1293        assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1294    }
1295
1296    #[test]
1297    fn sql_identifier_hash() {
1298        let mut hs = std::collections::HashSet::new();
1299        assert!(hs.insert(SqlIdentifier::from("foo")));
1300        assert!(hs.insert(SqlIdentifier::from("bar")));
1301        assert!(!hs.insert(SqlIdentifier::from("BAR")));
1302        assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1303        assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1304    }
1305
1306    #[test]
1307    fn sql_identifier_name() {
1308        assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1309        assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1310        assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1311        assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1312        assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1313        assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1314    }
1315
1316    #[test]
1317    fn issue3277() {
1318        let schema = r#"{
1319      "name" : "j",
1320      "case_sensitive" : false,
1321      "columntype" : {
1322        "fields" : [ {
1323          "key" : {
1324            "nullable" : false,
1325            "precision" : -1,
1326            "type" : "VARCHAR"
1327          },
1328          "name" : "s",
1329          "nullable" : true,
1330          "type" : "MAP",
1331          "value" : {
1332            "nullable" : true,
1333            "precision" : -1,
1334            "type" : "VARCHAR"
1335          }
1336        } ],
1337        "nullable" : true
1338      }
1339    }"#;
1340        let field: Field = serde_json::from_str(schema).unwrap();
1341        println!("field: {:#?}", field);
1342        assert_eq!(
1343            field,
1344            Field {
1345                name: SqlIdentifier {
1346                    name: "j".to_string(),
1347                    case_sensitive: false,
1348                },
1349                columntype: ColumnType {
1350                    typ: SqlType::Struct,
1351                    nullable: true,
1352                    precision: None,
1353                    scale: None,
1354                    component: None,
1355                    fields: Some(vec![Field {
1356                        name: SqlIdentifier {
1357                            name: "s".to_string(),
1358                            case_sensitive: false,
1359                        },
1360                        columntype: ColumnType {
1361                            typ: SqlType::Map,
1362                            nullable: true,
1363                            precision: None,
1364                            scale: None,
1365                            component: None,
1366                            fields: None,
1367                            key: Some(Box::new(ColumnType {
1368                                typ: SqlType::Varchar,
1369                                nullable: false,
1370                                precision: Some(-1),
1371                                scale: None,
1372                                component: None,
1373                                fields: None,
1374                                key: None,
1375                                value: None,
1376                            })),
1377                            value: Some(Box::new(ColumnType {
1378                                typ: SqlType::Varchar,
1379                                nullable: true,
1380                                precision: Some(-1),
1381                                scale: None,
1382                                component: None,
1383                                fields: None,
1384                                key: None,
1385                                value: None,
1386                            })),
1387                        },
1388                        lateness: None,
1389                        default: None,
1390                        unused: false,
1391                        watermark: None,
1392                    }]),
1393                    key: None,
1394                    value: None,
1395                },
1396                lateness: None,
1397                default: None,
1398                unused: false,
1399                watermark: None,
1400            }
1401        );
1402    }
1403}