1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::cmp::Ordering;
3use std::collections::BTreeMap;
4use std::fmt::Display;
5use std::hash::{Hash, Hasher};
6use utoipa::ToSchema;
7
8#[cfg(feature = "testing")]
9use proptest::{collection::vec, prelude::any};
10
11pub fn canonical_identifier(id: &str) -> String {
19 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
20 id[1..id.len() - 1].to_string()
21 } else {
22 id.to_lowercase()
23 }
24}
25
26#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
31#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
32pub struct SqlIdentifier {
33 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
34 name: String,
35 pub case_sensitive: bool,
36}
37
38impl SqlIdentifier {
39 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
40 Self {
41 name: name.as_ref().to_string(),
42 case_sensitive,
43 }
44 }
45
46 pub fn name(&self) -> String {
56 if self.case_sensitive {
57 self.name.clone()
58 } else {
59 self.name.to_lowercase()
60 }
61 }
62
63 pub fn sql_name(&self) -> String {
74 if self.case_sensitive {
75 format!("\"{}\"", self.name)
76 } else {
77 self.name.clone()
78 }
79 }
80}
81
82impl Hash for SqlIdentifier {
83 fn hash<H: Hasher>(&self, state: &mut H) {
84 self.name().hash(state);
85 }
86}
87
88impl PartialEq for SqlIdentifier {
89 fn eq(&self, other: &Self) -> bool {
90 match (self.case_sensitive, other.case_sensitive) {
91 (true, true) => self.name == other.name,
92 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
93 (true, false) => self.name == other.name,
94 (false, true) => self.name == other.name,
95 }
96 }
97}
98
99impl Ord for SqlIdentifier {
100 fn cmp(&self, other: &Self) -> Ordering {
101 self.name().cmp(&other.name())
102 }
103}
104
105impl PartialOrd for SqlIdentifier {
106 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
107 Some(self.cmp(other))
108 }
109}
110
111impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
112 fn eq(&self, other: &S) -> bool {
113 self == &SqlIdentifier::from(other.as_ref())
114 }
115}
116
117impl Eq for SqlIdentifier {}
118
119impl<S: AsRef<str>> From<S> for SqlIdentifier {
120 fn from(name: S) -> Self {
121 if name.as_ref().starts_with('"')
122 && name.as_ref().ends_with('"')
123 && name.as_ref().len() >= 2
124 {
125 Self {
126 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
127 case_sensitive: true,
128 }
129 } else {
130 Self::new(name, false)
131 }
132 }
133}
134
135impl From<SqlIdentifier> for String {
136 fn from(id: SqlIdentifier) -> String {
137 id.name()
138 }
139}
140
141impl From<&SqlIdentifier> for String {
142 fn from(id: &SqlIdentifier) -> String {
143 id.name()
144 }
145}
146
147impl Display for SqlIdentifier {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 write!(f, "{}", self.name())
150 }
151}
152
153#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
157#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
158pub struct ProgramSchema {
159 #[cfg_attr(
160 feature = "testing",
161 proptest(strategy = "vec(any::<Relation>(), 0..2)")
162 )]
163 pub inputs: Vec<Relation>,
164 #[cfg_attr(
165 feature = "testing",
166 proptest(strategy = "vec(any::<Relation>(), 0..2)")
167 )]
168 pub outputs: Vec<Relation>,
169}
170
171#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
172#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
173pub struct SourcePosition {
174 pub start_line_number: usize,
175 pub start_column: usize,
176 pub end_line_number: usize,
177 pub end_column: usize,
178}
179
180#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
181#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
182pub struct PropertyValue {
183 pub value: String,
184 pub key_position: SourcePosition,
185 pub value_position: SourcePosition,
186}
187
188#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
192#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
193pub struct Relation {
194 #[serde(flatten)]
196 pub name: SqlIdentifier,
197 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
198 pub fields: Vec<Field>,
199 #[serde(default)]
200 pub materialized: bool,
201 #[serde(default)]
202 pub properties: BTreeMap<String, PropertyValue>,
203}
204
205impl Relation {
206 pub fn empty() -> Self {
207 Self {
208 name: SqlIdentifier::from("".to_string()),
209 fields: Vec::new(),
210 materialized: false,
211 properties: BTreeMap::new(),
212 }
213 }
214
215 pub fn new(
216 name: SqlIdentifier,
217 fields: Vec<Field>,
218 materialized: bool,
219 properties: BTreeMap<String, PropertyValue>,
220 ) -> Self {
221 Self {
222 name,
223 fields,
224 materialized,
225 properties,
226 }
227 }
228
229 pub fn field(&self, name: &str) -> Option<&Field> {
231 let name = canonical_identifier(name);
232 self.fields.iter().find(|f| f.name == name)
233 }
234}
235
236#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
240#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
241pub struct Field {
242 #[serde(flatten)]
243 pub name: SqlIdentifier,
244 pub columntype: ColumnType,
245 pub lateness: Option<String>,
246 pub default: Option<String>,
247 pub watermark: Option<String>,
248}
249
250impl Field {
251 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
252 Self {
253 name,
254 columntype,
255 lateness: None,
256 default: None,
257 watermark: None,
258 }
259 }
260
261 pub fn with_lateness(mut self, lateness: &str) -> Self {
262 self.lateness = Some(lateness.to_string());
263 self
264 }
265}
266
267impl<'de> Deserialize<'de> for Field {
272 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
273 where
274 D: Deserializer<'de>,
275 {
276 const fn default_is_struct() -> Option<SqlType> {
277 Some(SqlType::Struct)
278 }
279
280 #[derive(Debug, Clone, Deserialize)]
281 struct FieldHelper {
282 name: Option<String>,
283 #[serde(default)]
284 case_sensitive: bool,
285 columntype: Option<ColumnType>,
286 #[serde(rename = "type")]
287 #[serde(default = "default_is_struct")]
288 typ: Option<SqlType>,
289 nullable: Option<bool>,
290 precision: Option<i64>,
291 scale: Option<i64>,
292 component: Option<Box<ColumnType>>,
293 fields: Option<serde_json::Value>,
294 key: Option<Box<ColumnType>>,
295 value: Option<Box<ColumnType>>,
296 default: Option<String>,
297 lateness: Option<String>,
298 watermark: Option<String>,
299 }
300
301 fn helper_to_field(helper: FieldHelper) -> Field {
302 let columntype = if let Some(ctype) = helper.columntype {
303 ctype
304 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
305 let fields = fields
306 .into_iter()
307 .map(|field| {
308 let field: FieldHelper = serde_json::from_value(field).unwrap();
309 helper_to_field(field)
310 })
311 .collect::<Vec<Field>>();
312
313 ColumnType {
314 typ: helper.typ.unwrap_or(SqlType::Null),
315 nullable: helper.nullable.unwrap_or(false),
316 precision: helper.precision,
317 scale: helper.scale,
318 component: helper.component,
319 fields: Some(fields),
320 key: None,
321 value: None,
322 }
323 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
324 serde_json::from_value(serde_json::Value::Object(obj))
325 .expect("Failed to deserialize object")
326 } else {
327 ColumnType {
328 typ: helper.typ.unwrap_or(SqlType::Null),
329 nullable: helper.nullable.unwrap_or(false),
330 precision: helper.precision,
331 scale: helper.scale,
332 component: helper.component,
333 fields: None,
334 key: helper.key,
335 value: helper.value,
336 }
337 };
338
339 Field {
340 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
341 columntype,
342 default: helper.default,
343 lateness: helper.lateness,
344 watermark: helper.watermark,
345 }
346 }
347
348 let helper = FieldHelper::deserialize(deserializer)?;
349 Ok(helper_to_field(helper))
350 }
351}
352
353#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
358#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
359pub enum IntervalUnit {
360 Day,
362 DayToHour,
364 DayToMinute,
366 DayToSecond,
368 Hour,
370 HourToMinute,
372 HourToSecond,
374 Minute,
376 MinuteToSecond,
378 Month,
380 Second,
382 Year,
384 YearToMonth,
386}
387
388#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
390#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
391pub enum SqlType {
392 Boolean,
394 TinyInt,
396 SmallInt,
398 Int,
400 BigInt,
402 Real,
404 Double,
406 Decimal,
408 Char,
410 Varchar,
412 Binary,
414 Varbinary,
416 Time,
418 Date,
420 Timestamp,
422 Interval(IntervalUnit),
424 Array,
426 Struct,
428 Map,
430 Null,
432 Uuid,
434 Variant,
436}
437
438impl Display for SqlType {
439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440 f.write_str(&serde_json::to_string(self).unwrap())
441 }
442}
443
444impl<'de> Deserialize<'de> for SqlType {
445 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
446 where
447 D: Deserializer<'de>,
448 {
449 let value: String = Deserialize::deserialize(deserializer)?;
450 match value.to_lowercase().as_str() {
451 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
452 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
453 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
454 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
455 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
456 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
457 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
458 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
459 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
460 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
461 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
462 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
463 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
464 "boolean" => Ok(SqlType::Boolean),
465 "tinyint" => Ok(SqlType::TinyInt),
466 "smallint" => Ok(SqlType::SmallInt),
467 "integer" => Ok(SqlType::Int),
468 "bigint" => Ok(SqlType::BigInt),
469 "real" => Ok(SqlType::Real),
470 "double" => Ok(SqlType::Double),
471 "decimal" => Ok(SqlType::Decimal),
472 "char" => Ok(SqlType::Char),
473 "varchar" => Ok(SqlType::Varchar),
474 "binary" => Ok(SqlType::Binary),
475 "varbinary" => Ok(SqlType::Varbinary),
476 "variant" => Ok(SqlType::Variant),
477 "time" => Ok(SqlType::Time),
478 "date" => Ok(SqlType::Date),
479 "timestamp" => Ok(SqlType::Timestamp),
480 "array" => Ok(SqlType::Array),
481 "struct" => Ok(SqlType::Struct),
482 "map" => Ok(SqlType::Map),
483 "null" => Ok(SqlType::Null),
484 "uuid" => Ok(SqlType::Uuid),
485 _ => Err(serde::de::Error::custom(format!(
486 "Unknown SQL type: {}",
487 value
488 ))),
489 }
490 }
491}
492
493impl Serialize for SqlType {
494 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
495 where
496 S: Serializer,
497 {
498 let type_str = match self {
499 SqlType::Boolean => "BOOLEAN",
500 SqlType::TinyInt => "TINYINT",
501 SqlType::SmallInt => "SMALLINT",
502 SqlType::Int => "INTEGER",
503 SqlType::BigInt => "BIGINT",
504 SqlType::Real => "REAL",
505 SqlType::Double => "DOUBLE",
506 SqlType::Decimal => "DECIMAL",
507 SqlType::Char => "CHAR",
508 SqlType::Varchar => "VARCHAR",
509 SqlType::Binary => "BINARY",
510 SqlType::Varbinary => "VARBINARY",
511 SqlType::Time => "TIME",
512 SqlType::Date => "DATE",
513 SqlType::Timestamp => "TIMESTAMP",
514 SqlType::Interval(interval_unit) => match interval_unit {
515 IntervalUnit::Day => "INTERVAL_DAY",
516 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
517 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
518 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
519 IntervalUnit::Hour => "INTERVAL_HOUR",
520 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
521 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
522 IntervalUnit::Minute => "INTERVAL_MINUTE",
523 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
524 IntervalUnit::Month => "INTERVAL_MONTH",
525 IntervalUnit::Second => "INTERVAL_SECOND",
526 IntervalUnit::Year => "INTERVAL_YEAR",
527 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
528 },
529 SqlType::Array => "ARRAY",
530 SqlType::Struct => "STRUCT",
531 SqlType::Uuid => "UUID",
532 SqlType::Map => "MAP",
533 SqlType::Null => "NULL",
534 SqlType::Variant => "VARIANT",
535 };
536 serializer.serialize_str(type_str)
537 }
538}
539
540impl SqlType {
541 pub fn is_string(&self) -> bool {
543 matches!(self, Self::Char | Self::Varchar)
544 }
545
546 pub fn is_varchar(&self) -> bool {
547 matches!(self, Self::Varchar)
548 }
549
550 pub fn is_varbinary(&self) -> bool {
551 matches!(self, Self::Varbinary)
552 }
553}
554
555const fn default_is_struct() -> SqlType {
558 SqlType::Struct
559}
560
561#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
565#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
566pub struct ColumnType {
567 #[serde(rename = "type")]
569 #[serde(default = "default_is_struct")]
570 pub typ: SqlType,
571 pub nullable: bool,
573 pub precision: Option<i64>,
582 pub scale: Option<i64>,
587 #[cfg_attr(feature = "testing", proptest(value = "None"))]
594 pub component: Option<Box<ColumnType>>,
595 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
617 pub fields: Option<Vec<Field>>,
618 #[cfg_attr(feature = "testing", proptest(value = "None"))]
620 pub key: Option<Box<ColumnType>>,
621 #[cfg_attr(feature = "testing", proptest(value = "None"))]
623 pub value: Option<Box<ColumnType>>,
624}
625
626impl ColumnType {
627 pub fn boolean(nullable: bool) -> Self {
628 ColumnType {
629 typ: SqlType::Boolean,
630 nullable,
631 precision: None,
632 scale: None,
633 component: None,
634 fields: None,
635 key: None,
636 value: None,
637 }
638 }
639
640 pub fn uuid(nullable: bool) -> Self {
641 ColumnType {
642 typ: SqlType::Uuid,
643 nullable,
644 precision: None,
645 scale: None,
646 component: None,
647 fields: None,
648 key: None,
649 value: None,
650 }
651 }
652
653 pub fn tinyint(nullable: bool) -> Self {
654 ColumnType {
655 typ: SqlType::TinyInt,
656 nullable,
657 precision: None,
658 scale: None,
659 component: None,
660 fields: None,
661 key: None,
662 value: None,
663 }
664 }
665
666 pub fn smallint(nullable: bool) -> Self {
667 ColumnType {
668 typ: SqlType::SmallInt,
669 nullable,
670 precision: None,
671 scale: None,
672 component: None,
673 fields: None,
674 key: None,
675 value: None,
676 }
677 }
678
679 pub fn int(nullable: bool) -> Self {
680 ColumnType {
681 typ: SqlType::Int,
682 nullable,
683 precision: None,
684 scale: None,
685 component: None,
686 fields: None,
687 key: None,
688 value: None,
689 }
690 }
691
692 pub fn bigint(nullable: bool) -> Self {
693 ColumnType {
694 typ: SqlType::BigInt,
695 nullable,
696 precision: None,
697 scale: None,
698 component: None,
699 fields: None,
700 key: None,
701 value: None,
702 }
703 }
704
705 pub fn double(nullable: bool) -> Self {
706 ColumnType {
707 typ: SqlType::Double,
708 nullable,
709 precision: None,
710 scale: None,
711 component: None,
712 fields: None,
713 key: None,
714 value: None,
715 }
716 }
717
718 pub fn real(nullable: bool) -> Self {
719 ColumnType {
720 typ: SqlType::Real,
721 nullable,
722 precision: None,
723 scale: None,
724 component: None,
725 fields: None,
726 key: None,
727 value: None,
728 }
729 }
730
731 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
732 ColumnType {
733 typ: SqlType::Decimal,
734 nullable,
735 precision: Some(precision),
736 scale: Some(scale),
737 component: None,
738 fields: None,
739 key: None,
740 value: None,
741 }
742 }
743
744 pub fn varchar(nullable: bool) -> Self {
745 ColumnType {
746 typ: SqlType::Varchar,
747 nullable,
748 precision: None,
749 scale: None,
750 component: None,
751 fields: None,
752 key: None,
753 value: None,
754 }
755 }
756
757 pub fn varbinary(nullable: bool) -> Self {
758 ColumnType {
759 typ: SqlType::Varbinary,
760 nullable,
761 precision: None,
762 scale: None,
763 component: None,
764 fields: None,
765 key: None,
766 value: None,
767 }
768 }
769
770 pub fn fixed(width: i64, nullable: bool) -> Self {
771 ColumnType {
772 typ: SqlType::Binary,
773 nullable,
774 precision: Some(width),
775 scale: None,
776 component: None,
777 fields: None,
778 key: None,
779 value: None,
780 }
781 }
782
783 pub fn date(nullable: bool) -> Self {
784 ColumnType {
785 typ: SqlType::Date,
786 nullable,
787 precision: None,
788 scale: None,
789 component: None,
790 fields: None,
791 key: None,
792 value: None,
793 }
794 }
795
796 pub fn time(nullable: bool) -> Self {
797 ColumnType {
798 typ: SqlType::Time,
799 nullable,
800 precision: None,
801 scale: None,
802 component: None,
803 fields: None,
804 key: None,
805 value: None,
806 }
807 }
808
809 pub fn timestamp(nullable: bool) -> Self {
810 ColumnType {
811 typ: SqlType::Timestamp,
812 nullable,
813 precision: None,
814 scale: None,
815 component: None,
816 fields: None,
817 key: None,
818 value: None,
819 }
820 }
821
822 pub fn variant(nullable: bool) -> Self {
823 ColumnType {
824 typ: SqlType::Variant,
825 nullable,
826 precision: None,
827 scale: None,
828 component: None,
829 fields: None,
830 key: None,
831 value: None,
832 }
833 }
834
835 pub fn array(nullable: bool, element: ColumnType) -> Self {
836 ColumnType {
837 typ: SqlType::Array,
838 nullable,
839 precision: None,
840 scale: None,
841 component: Some(Box::new(element)),
842 fields: None,
843 key: None,
844 value: None,
845 }
846 }
847
848 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
849 ColumnType {
850 typ: SqlType::Struct,
851 nullable,
852 precision: None,
853 scale: None,
854 component: None,
855 fields: Some(fields.to_vec()),
856 key: None,
857 value: None,
858 }
859 }
860
861 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
862 ColumnType {
863 typ: SqlType::Map,
864 nullable,
865 precision: None,
866 scale: None,
867 component: None,
868 fields: None,
869 key: Some(Box::new(key)),
870 value: Some(Box::new(val)),
871 }
872 }
873
874 pub fn is_integral_type(&self) -> bool {
875 matches!(
876 &self.typ,
877 SqlType::TinyInt | SqlType::SmallInt | SqlType::Int | SqlType::BigInt
878 )
879 }
880
881 pub fn is_fp_type(&self) -> bool {
882 matches!(&self.typ, SqlType::Double | SqlType::Real)
883 }
884
885 pub fn is_decimal_type(&self) -> bool {
886 matches!(&self.typ, SqlType::Decimal)
887 }
888
889 pub fn is_numeric_type(&self) -> bool {
890 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
891 }
892}
893
894#[cfg(test)]
895mod tests {
896 use super::{IntervalUnit, SqlIdentifier};
897 use crate::program_schema::{ColumnType, Field, SqlType};
898
899 #[test]
900 fn serde_sql_type() {
901 for (sql_str_base, expected_value) in [
902 ("Boolean", SqlType::Boolean),
903 ("Uuid", SqlType::Uuid),
904 ("TinyInt", SqlType::TinyInt),
905 ("SmallInt", SqlType::SmallInt),
906 ("Integer", SqlType::Int),
907 ("BigInt", SqlType::BigInt),
908 ("Real", SqlType::Real),
909 ("Double", SqlType::Double),
910 ("Decimal", SqlType::Decimal),
911 ("Char", SqlType::Char),
912 ("Varchar", SqlType::Varchar),
913 ("Binary", SqlType::Binary),
914 ("Varbinary", SqlType::Varbinary),
915 ("Time", SqlType::Time),
916 ("Date", SqlType::Date),
917 ("Timestamp", SqlType::Timestamp),
918 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
919 (
920 "Interval_Day_Hour",
921 SqlType::Interval(IntervalUnit::DayToHour),
922 ),
923 (
924 "Interval_Day_Minute",
925 SqlType::Interval(IntervalUnit::DayToMinute),
926 ),
927 (
928 "Interval_Day_Second",
929 SqlType::Interval(IntervalUnit::DayToSecond),
930 ),
931 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
932 (
933 "Interval_Hour_Minute",
934 SqlType::Interval(IntervalUnit::HourToMinute),
935 ),
936 (
937 "Interval_Hour_Second",
938 SqlType::Interval(IntervalUnit::HourToSecond),
939 ),
940 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
941 (
942 "Interval_Minute_Second",
943 SqlType::Interval(IntervalUnit::MinuteToSecond),
944 ),
945 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
946 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
947 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
948 (
949 "Interval_Year_Month",
950 SqlType::Interval(IntervalUnit::YearToMonth),
951 ),
952 ("Array", SqlType::Array),
953 ("Struct", SqlType::Struct),
954 ("Map", SqlType::Map),
955 ("Null", SqlType::Null),
956 ("Variant", SqlType::Variant),
957 ] {
958 for sql_str in [
959 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
963 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
964 .unwrap_or_else(|_| {
965 panic!("\"{sql_str}\" should deserialize into its SQL type")
966 });
967 assert_eq!(value1, expected_value);
968 let serialized_str =
969 serde_json::to_string(&value1).expect("Value should serialize into JSON");
970 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
971 panic!(
972 "{} should deserialize back into its SQL type",
973 serialized_str
974 )
975 });
976 assert_eq!(value1, value2);
977 }
978 }
979 }
980
981 #[test]
982 fn deserialize_interval_types() {
983 use super::IntervalUnit::*;
984 use super::SqlType::*;
985
986 let schema = r#"
987{
988 "inputs" : [ {
989 "name" : "sales",
990 "case_sensitive" : false,
991 "fields" : [ {
992 "name" : "sales_id",
993 "case_sensitive" : false,
994 "columntype" : {
995 "type" : "INTEGER",
996 "nullable" : true
997 }
998 }, {
999 "name" : "customer_id",
1000 "case_sensitive" : false,
1001 "columntype" : {
1002 "type" : "INTEGER",
1003 "nullable" : true
1004 }
1005 }, {
1006 "name" : "amount",
1007 "case_sensitive" : false,
1008 "columntype" : {
1009 "type" : "DECIMAL",
1010 "nullable" : true,
1011 "precision" : 10,
1012 "scale" : 2
1013 }
1014 }, {
1015 "name" : "sale_date",
1016 "case_sensitive" : false,
1017 "columntype" : {
1018 "type" : "DATE",
1019 "nullable" : true
1020 }
1021 } ],
1022 "primary_key" : [ "sales_id" ]
1023 } ],
1024 "outputs" : [ {
1025 "name" : "salessummary",
1026 "case_sensitive" : false,
1027 "fields" : [ {
1028 "name" : "customer_id",
1029 "case_sensitive" : false,
1030 "columntype" : {
1031 "type" : "INTEGER",
1032 "nullable" : true
1033 }
1034 }, {
1035 "name" : "total_sales",
1036 "case_sensitive" : false,
1037 "columntype" : {
1038 "type" : "DECIMAL",
1039 "nullable" : true,
1040 "precision" : 38,
1041 "scale" : 2
1042 }
1043 }, {
1044 "name" : "interval_day",
1045 "case_sensitive" : false,
1046 "columntype" : {
1047 "type" : "INTERVAL_DAY",
1048 "nullable" : false,
1049 "precision" : 2,
1050 "scale" : 6
1051 }
1052 }, {
1053 "name" : "interval_day_to_hour",
1054 "case_sensitive" : false,
1055 "columntype" : {
1056 "type" : "INTERVAL_DAY_HOUR",
1057 "nullable" : false,
1058 "precision" : 2,
1059 "scale" : 6
1060 }
1061 }, {
1062 "name" : "interval_day_to_minute",
1063 "case_sensitive" : false,
1064 "columntype" : {
1065 "type" : "INTERVAL_DAY_MINUTE",
1066 "nullable" : false,
1067 "precision" : 2,
1068 "scale" : 6
1069 }
1070 }, {
1071 "name" : "interval_day_to_second",
1072 "case_sensitive" : false,
1073 "columntype" : {
1074 "type" : "INTERVAL_DAY_SECOND",
1075 "nullable" : false,
1076 "precision" : 2,
1077 "scale" : 6
1078 }
1079 }, {
1080 "name" : "interval_hour",
1081 "case_sensitive" : false,
1082 "columntype" : {
1083 "type" : "INTERVAL_HOUR",
1084 "nullable" : false,
1085 "precision" : 2,
1086 "scale" : 6
1087 }
1088 }, {
1089 "name" : "interval_hour_to_minute",
1090 "case_sensitive" : false,
1091 "columntype" : {
1092 "type" : "INTERVAL_HOUR_MINUTE",
1093 "nullable" : false,
1094 "precision" : 2,
1095 "scale" : 6
1096 }
1097 }, {
1098 "name" : "interval_hour_to_second",
1099 "case_sensitive" : false,
1100 "columntype" : {
1101 "type" : "INTERVAL_HOUR_SECOND",
1102 "nullable" : false,
1103 "precision" : 2,
1104 "scale" : 6
1105 }
1106 }, {
1107 "name" : "interval_minute",
1108 "case_sensitive" : false,
1109 "columntype" : {
1110 "type" : "INTERVAL_MINUTE",
1111 "nullable" : false,
1112 "precision" : 2,
1113 "scale" : 6
1114 }
1115 }, {
1116 "name" : "interval_minute_to_second",
1117 "case_sensitive" : false,
1118 "columntype" : {
1119 "type" : "INTERVAL_MINUTE_SECOND",
1120 "nullable" : false,
1121 "precision" : 2,
1122 "scale" : 6
1123 }
1124 }, {
1125 "name" : "interval_month",
1126 "case_sensitive" : false,
1127 "columntype" : {
1128 "type" : "INTERVAL_MONTH",
1129 "nullable" : false
1130 }
1131 }, {
1132 "name" : "interval_second",
1133 "case_sensitive" : false,
1134 "columntype" : {
1135 "type" : "INTERVAL_SECOND",
1136 "nullable" : false,
1137 "precision" : 2,
1138 "scale" : 6
1139 }
1140 }, {
1141 "name" : "interval_year",
1142 "case_sensitive" : false,
1143 "columntype" : {
1144 "type" : "INTERVAL_YEAR",
1145 "nullable" : false
1146 }
1147 }, {
1148 "name" : "interval_year_to_month",
1149 "case_sensitive" : false,
1150 "columntype" : {
1151 "type" : "INTERVAL_YEAR_MONTH",
1152 "nullable" : false
1153 }
1154 } ]
1155 } ]
1156}
1157"#;
1158
1159 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1160 let types = schema
1161 .outputs
1162 .iter()
1163 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1164 let expected_types = [
1165 Int,
1166 Decimal,
1167 Interval(Day),
1168 Interval(DayToHour),
1169 Interval(DayToMinute),
1170 Interval(DayToSecond),
1171 Interval(Hour),
1172 Interval(HourToMinute),
1173 Interval(HourToSecond),
1174 Interval(Minute),
1175 Interval(MinuteToSecond),
1176 Interval(Month),
1177 Interval(Second),
1178 Interval(Year),
1179 Interval(YearToMonth),
1180 ];
1181
1182 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1183 }
1184
1185 #[test]
1186 fn serialize_struct_schemas() {
1187 let schema = r#"{
1188 "inputs" : [ {
1189 "name" : "PERS",
1190 "case_sensitive" : false,
1191 "fields" : [ {
1192 "name" : "P0",
1193 "case_sensitive" : false,
1194 "columntype" : {
1195 "fields" : [ {
1196 "type" : "VARCHAR",
1197 "nullable" : true,
1198 "precision" : 30,
1199 "name" : "FIRSTNAME"
1200 }, {
1201 "type" : "VARCHAR",
1202 "nullable" : true,
1203 "precision" : 30,
1204 "name" : "LASTNAME"
1205 }, {
1206 "fields" : {
1207 "fields" : [ {
1208 "type" : "VARCHAR",
1209 "nullable" : true,
1210 "precision" : 30,
1211 "name" : "STREET"
1212 }, {
1213 "type" : "VARCHAR",
1214 "nullable" : true,
1215 "precision" : 30,
1216 "name" : "CITY"
1217 }, {
1218 "type" : "CHAR",
1219 "nullable" : true,
1220 "precision" : 2,
1221 "name" : "STATE"
1222 }, {
1223 "type" : "VARCHAR",
1224 "nullable" : true,
1225 "precision" : 6,
1226 "name" : "POSTAL_CODE"
1227 } ],
1228 "nullable" : false
1229 },
1230 "nullable" : false,
1231 "name" : "ADDRESS"
1232 } ],
1233 "nullable" : false
1234 }
1235 }]
1236 } ],
1237 "outputs" : [ ]
1238}
1239"#;
1240 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1241 eprintln!("{:#?}", schema);
1242 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1243 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1244 assert_eq!(p0.columntype.typ, SqlType::Struct);
1245 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1246 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1247 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1248 assert_eq!(p0_fields[2].columntype.typ, SqlType::Struct);
1249 assert_eq!(p0_fields[2].name, "ADDRESS");
1250 let address = &p0_fields[2].columntype.fields.as_ref().unwrap();
1251 assert_eq!(address.len(), 4);
1252 assert_eq!(address[0].name, "STREET");
1253 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1254 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1255 assert_eq!(address[2].columntype.typ, SqlType::Char);
1256 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1257 }
1258
1259 #[test]
1260 fn sql_identifier_cmp() {
1261 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1262 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1263 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1264 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1265 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1266 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1267 assert_eq!(
1268 SqlIdentifier::new("bAr", true),
1269 SqlIdentifier::from("\"bAr\"")
1270 );
1271
1272 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1273 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1274 }
1275
1276 #[test]
1277 fn sql_identifier_ord() {
1278 let mut btree = std::collections::BTreeSet::new();
1279 assert!(btree.insert(SqlIdentifier::from("foo")));
1280 assert!(btree.insert(SqlIdentifier::from("bar")));
1281 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1282 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1283 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1284 }
1285
1286 #[test]
1287 fn sql_identifier_hash() {
1288 let mut hs = std::collections::HashSet::new();
1289 assert!(hs.insert(SqlIdentifier::from("foo")));
1290 assert!(hs.insert(SqlIdentifier::from("bar")));
1291 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1292 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1293 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1294 }
1295
1296 #[test]
1297 fn sql_identifier_name() {
1298 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1299 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1300 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1301 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1302 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1303 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1304 }
1305
1306 #[test]
1307 fn issue3277() {
1308 let schema = r#"{
1309 "name" : "j",
1310 "case_sensitive" : false,
1311 "columntype" : {
1312 "fields" : [ {
1313 "key" : {
1314 "nullable" : false,
1315 "precision" : -1,
1316 "type" : "VARCHAR"
1317 },
1318 "name" : "s",
1319 "nullable" : true,
1320 "type" : "MAP",
1321 "value" : {
1322 "nullable" : true,
1323 "precision" : -1,
1324 "type" : "VARCHAR"
1325 }
1326 } ],
1327 "nullable" : true
1328 }
1329 }"#;
1330 let field: Field = serde_json::from_str(schema).unwrap();
1331 println!("field: {:#?}", field);
1332 assert_eq!(
1333 field,
1334 Field {
1335 name: SqlIdentifier {
1336 name: "j".to_string(),
1337 case_sensitive: false,
1338 },
1339 columntype: ColumnType {
1340 typ: SqlType::Struct,
1341 nullable: true,
1342 precision: None,
1343 scale: None,
1344 component: None,
1345 fields: Some(vec![Field {
1346 name: SqlIdentifier {
1347 name: "s".to_string(),
1348 case_sensitive: false,
1349 },
1350 columntype: ColumnType {
1351 typ: SqlType::Map,
1352 nullable: true,
1353 precision: None,
1354 scale: None,
1355 component: None,
1356 fields: None,
1357 key: Some(Box::new(ColumnType {
1358 typ: SqlType::Varchar,
1359 nullable: false,
1360 precision: Some(-1),
1361 scale: None,
1362 component: None,
1363 fields: None,
1364 key: None,
1365 value: None,
1366 })),
1367 value: Some(Box::new(ColumnType {
1368 typ: SqlType::Varchar,
1369 nullable: true,
1370 precision: Some(-1),
1371 scale: None,
1372 component: None,
1373 fields: None,
1374 key: None,
1375 value: None,
1376 })),
1377 },
1378 lateness: None,
1379 default: None,
1380 watermark: None,
1381 }]),
1382 key: None,
1383 value: None,
1384 },
1385 lateness: None,
1386 default: None,
1387 watermark: None,
1388 }
1389 );
1390 }
1391}