1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::cmp::Ordering;
3use std::collections::BTreeMap;
4use std::fmt::Display;
5use std::hash::{Hash, Hasher};
6use utoipa::ToSchema;
7
8#[cfg(feature = "testing")]
9use proptest::{collection::vec, prelude::any};
10
11pub fn canonical_identifier(id: &str) -> String {
19 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
20 id[1..id.len() - 1].to_string()
21 } else {
22 id.to_lowercase()
23 }
24}
25
26#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
31#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
32pub struct SqlIdentifier {
33 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
34 name: String,
35 pub case_sensitive: bool,
36}
37
38impl SqlIdentifier {
39 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
40 Self {
41 name: name.as_ref().to_string(),
42 case_sensitive,
43 }
44 }
45
46 pub fn name(&self) -> String {
56 if self.case_sensitive {
57 self.name.clone()
58 } else {
59 self.name.to_lowercase()
60 }
61 }
62
63 pub fn sql_name(&self) -> String {
74 if self.case_sensitive {
75 format!("\"{}\"", self.name)
76 } else {
77 self.name.clone()
78 }
79 }
80}
81
82impl Hash for SqlIdentifier {
83 fn hash<H: Hasher>(&self, state: &mut H) {
84 self.name().hash(state);
85 }
86}
87
88impl PartialEq for SqlIdentifier {
89 fn eq(&self, other: &Self) -> bool {
90 match (self.case_sensitive, other.case_sensitive) {
91 (true, true) => self.name == other.name,
92 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
93 (true, false) => self.name == other.name,
94 (false, true) => self.name == other.name,
95 }
96 }
97}
98
99impl Ord for SqlIdentifier {
100 fn cmp(&self, other: &Self) -> Ordering {
101 self.name().cmp(&other.name())
102 }
103}
104
105impl PartialOrd for SqlIdentifier {
106 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
107 Some(self.cmp(other))
108 }
109}
110
111impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
112 fn eq(&self, other: &S) -> bool {
113 self == &SqlIdentifier::from(other.as_ref())
114 }
115}
116
117impl Eq for SqlIdentifier {}
118
119impl<S: AsRef<str>> From<S> for SqlIdentifier {
120 fn from(name: S) -> Self {
121 if name.as_ref().starts_with('"')
122 && name.as_ref().ends_with('"')
123 && name.as_ref().len() >= 2
124 {
125 Self {
126 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
127 case_sensitive: true,
128 }
129 } else {
130 Self::new(name, false)
131 }
132 }
133}
134
135impl From<SqlIdentifier> for String {
136 fn from(id: SqlIdentifier) -> String {
137 id.name()
138 }
139}
140
141impl From<&SqlIdentifier> for String {
142 fn from(id: &SqlIdentifier) -> String {
143 id.name()
144 }
145}
146
147impl Display for SqlIdentifier {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 write!(f, "{}", self.name())
150 }
151}
152
153#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
157#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
158pub struct ProgramSchema {
159 #[cfg_attr(
160 feature = "testing",
161 proptest(strategy = "vec(any::<Relation>(), 0..2)")
162 )]
163 pub inputs: Vec<Relation>,
164 #[cfg_attr(
165 feature = "testing",
166 proptest(strategy = "vec(any::<Relation>(), 0..2)")
167 )]
168 pub outputs: Vec<Relation>,
169}
170
171#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
172#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
173pub struct SourcePosition {
174 pub start_line_number: usize,
175 pub start_column: usize,
176 pub end_line_number: usize,
177 pub end_column: usize,
178}
179
180#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
181#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
182pub struct PropertyValue {
183 pub value: String,
184 pub key_position: SourcePosition,
185 pub value_position: SourcePosition,
186}
187
188#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
192#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
193pub struct Relation {
194 #[serde(flatten)]
196 pub name: SqlIdentifier,
197 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
198 pub fields: Vec<Field>,
199 #[serde(default)]
200 pub materialized: bool,
201 #[serde(default)]
202 pub properties: BTreeMap<String, PropertyValue>,
203}
204
205impl Relation {
206 pub fn empty() -> Self {
207 Self {
208 name: SqlIdentifier::from("".to_string()),
209 fields: Vec::new(),
210 materialized: false,
211 properties: BTreeMap::new(),
212 }
213 }
214
215 pub fn new(
216 name: SqlIdentifier,
217 fields: Vec<Field>,
218 materialized: bool,
219 properties: BTreeMap<String, PropertyValue>,
220 ) -> Self {
221 Self {
222 name,
223 fields,
224 materialized,
225 properties,
226 }
227 }
228
229 pub fn field(&self, name: &str) -> Option<&Field> {
231 let name = canonical_identifier(name);
232 self.fields.iter().find(|f| f.name == name)
233 }
234}
235
236#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
240#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
241pub struct Field {
242 #[serde(flatten)]
243 pub name: SqlIdentifier,
244 pub columntype: ColumnType,
245 pub lateness: Option<String>,
246 pub default: Option<String>,
247 pub unused: bool,
248 pub watermark: Option<String>,
249}
250
251impl Field {
252 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
253 Self {
254 name,
255 columntype,
256 lateness: None,
257 default: None,
258 unused: false,
259 watermark: None,
260 }
261 }
262
263 pub fn with_lateness(mut self, lateness: &str) -> Self {
264 self.lateness = Some(lateness.to_string());
265 self
266 }
267
268 pub fn with_unused(mut self, unused: bool) -> Self {
269 self.unused = unused;
270 self
271 }
272}
273
274impl<'de> Deserialize<'de> for Field {
279 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
280 where
281 D: Deserializer<'de>,
282 {
283 const fn default_is_struct() -> Option<SqlType> {
284 Some(SqlType::Struct)
285 }
286
287 #[derive(Debug, Clone, Deserialize)]
288 struct FieldHelper {
289 name: Option<String>,
290 #[serde(default)]
291 case_sensitive: bool,
292 columntype: Option<ColumnType>,
293 #[serde(rename = "type")]
294 #[serde(default = "default_is_struct")]
295 typ: Option<SqlType>,
296 nullable: Option<bool>,
297 precision: Option<i64>,
298 scale: Option<i64>,
299 component: Option<Box<ColumnType>>,
300 fields: Option<serde_json::Value>,
301 key: Option<Box<ColumnType>>,
302 value: Option<Box<ColumnType>>,
303 default: Option<String>,
304 #[serde(default)]
305 unused: bool,
306 lateness: Option<String>,
307 watermark: Option<String>,
308 }
309
310 fn helper_to_field(helper: FieldHelper) -> Field {
311 let columntype = if let Some(ctype) = helper.columntype {
312 ctype
313 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
314 let fields = fields
315 .into_iter()
316 .map(|field| {
317 let field: FieldHelper = serde_json::from_value(field).unwrap();
318 helper_to_field(field)
319 })
320 .collect::<Vec<Field>>();
321
322 ColumnType {
323 typ: helper.typ.unwrap_or(SqlType::Null),
324 nullable: helper.nullable.unwrap_or(false),
325 precision: helper.precision,
326 scale: helper.scale,
327 component: helper.component,
328 fields: Some(fields),
329 key: None,
330 value: None,
331 }
332 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
333 serde_json::from_value(serde_json::Value::Object(obj))
334 .expect("Failed to deserialize object")
335 } else {
336 ColumnType {
337 typ: helper.typ.unwrap_or(SqlType::Null),
338 nullable: helper.nullable.unwrap_or(false),
339 precision: helper.precision,
340 scale: helper.scale,
341 component: helper.component,
342 fields: None,
343 key: helper.key,
344 value: helper.value,
345 }
346 };
347
348 Field {
349 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
350 columntype,
351 default: helper.default,
352 unused: helper.unused,
353 lateness: helper.lateness,
354 watermark: helper.watermark,
355 }
356 }
357
358 let helper = FieldHelper::deserialize(deserializer)?;
359 Ok(helper_to_field(helper))
360 }
361}
362
363#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
368#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
369pub enum IntervalUnit {
370 Day,
372 DayToHour,
374 DayToMinute,
376 DayToSecond,
378 Hour,
380 HourToMinute,
382 HourToSecond,
384 Minute,
386 MinuteToSecond,
388 Month,
390 Second,
392 Year,
394 YearToMonth,
396}
397
398#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
400#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
401pub enum SqlType {
402 Boolean,
404 TinyInt,
406 SmallInt,
408 Int,
410 BigInt,
412 Real,
414 Double,
416 Decimal,
418 Char,
420 Varchar,
422 Binary,
424 Varbinary,
426 Time,
428 Date,
430 Timestamp,
432 Interval(IntervalUnit),
434 Array,
436 Struct,
438 Map,
440 Null,
442 Uuid,
444 Variant,
446}
447
448impl Display for SqlType {
449 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450 f.write_str(&serde_json::to_string(self).unwrap())
451 }
452}
453
454impl<'de> Deserialize<'de> for SqlType {
455 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
456 where
457 D: Deserializer<'de>,
458 {
459 let value: String = Deserialize::deserialize(deserializer)?;
460 match value.to_lowercase().as_str() {
461 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
462 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
463 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
464 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
465 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
466 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
467 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
468 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
469 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
470 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
471 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
472 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
473 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
474 "boolean" => Ok(SqlType::Boolean),
475 "tinyint" => Ok(SqlType::TinyInt),
476 "smallint" => Ok(SqlType::SmallInt),
477 "integer" => Ok(SqlType::Int),
478 "bigint" => Ok(SqlType::BigInt),
479 "real" => Ok(SqlType::Real),
480 "double" => Ok(SqlType::Double),
481 "decimal" => Ok(SqlType::Decimal),
482 "char" => Ok(SqlType::Char),
483 "varchar" => Ok(SqlType::Varchar),
484 "binary" => Ok(SqlType::Binary),
485 "varbinary" => Ok(SqlType::Varbinary),
486 "variant" => Ok(SqlType::Variant),
487 "time" => Ok(SqlType::Time),
488 "date" => Ok(SqlType::Date),
489 "timestamp" => Ok(SqlType::Timestamp),
490 "array" => Ok(SqlType::Array),
491 "struct" => Ok(SqlType::Struct),
492 "map" => Ok(SqlType::Map),
493 "null" => Ok(SqlType::Null),
494 "uuid" => Ok(SqlType::Uuid),
495 _ => Err(serde::de::Error::custom(format!(
496 "Unknown SQL type: {}",
497 value
498 ))),
499 }
500 }
501}
502
503impl Serialize for SqlType {
504 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
505 where
506 S: Serializer,
507 {
508 let type_str = match self {
509 SqlType::Boolean => "BOOLEAN",
510 SqlType::TinyInt => "TINYINT",
511 SqlType::SmallInt => "SMALLINT",
512 SqlType::Int => "INTEGER",
513 SqlType::BigInt => "BIGINT",
514 SqlType::Real => "REAL",
515 SqlType::Double => "DOUBLE",
516 SqlType::Decimal => "DECIMAL",
517 SqlType::Char => "CHAR",
518 SqlType::Varchar => "VARCHAR",
519 SqlType::Binary => "BINARY",
520 SqlType::Varbinary => "VARBINARY",
521 SqlType::Time => "TIME",
522 SqlType::Date => "DATE",
523 SqlType::Timestamp => "TIMESTAMP",
524 SqlType::Interval(interval_unit) => match interval_unit {
525 IntervalUnit::Day => "INTERVAL_DAY",
526 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
527 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
528 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
529 IntervalUnit::Hour => "INTERVAL_HOUR",
530 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
531 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
532 IntervalUnit::Minute => "INTERVAL_MINUTE",
533 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
534 IntervalUnit::Month => "INTERVAL_MONTH",
535 IntervalUnit::Second => "INTERVAL_SECOND",
536 IntervalUnit::Year => "INTERVAL_YEAR",
537 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
538 },
539 SqlType::Array => "ARRAY",
540 SqlType::Struct => "STRUCT",
541 SqlType::Uuid => "UUID",
542 SqlType::Map => "MAP",
543 SqlType::Null => "NULL",
544 SqlType::Variant => "VARIANT",
545 };
546 serializer.serialize_str(type_str)
547 }
548}
549
550impl SqlType {
551 pub fn is_string(&self) -> bool {
553 matches!(self, Self::Char | Self::Varchar)
554 }
555
556 pub fn is_varchar(&self) -> bool {
557 matches!(self, Self::Varchar)
558 }
559
560 pub fn is_varbinary(&self) -> bool {
561 matches!(self, Self::Varbinary)
562 }
563}
564
565const fn default_is_struct() -> SqlType {
568 SqlType::Struct
569}
570
571#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
575#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
576pub struct ColumnType {
577 #[serde(rename = "type")]
579 #[serde(default = "default_is_struct")]
580 pub typ: SqlType,
581 pub nullable: bool,
583 pub precision: Option<i64>,
592 pub scale: Option<i64>,
597 #[cfg_attr(feature = "testing", proptest(value = "None"))]
604 pub component: Option<Box<ColumnType>>,
605 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
627 pub fields: Option<Vec<Field>>,
628 #[cfg_attr(feature = "testing", proptest(value = "None"))]
630 pub key: Option<Box<ColumnType>>,
631 #[cfg_attr(feature = "testing", proptest(value = "None"))]
633 pub value: Option<Box<ColumnType>>,
634}
635
636impl ColumnType {
637 pub fn boolean(nullable: bool) -> Self {
638 ColumnType {
639 typ: SqlType::Boolean,
640 nullable,
641 precision: None,
642 scale: None,
643 component: None,
644 fields: None,
645 key: None,
646 value: None,
647 }
648 }
649
650 pub fn uuid(nullable: bool) -> Self {
651 ColumnType {
652 typ: SqlType::Uuid,
653 nullable,
654 precision: None,
655 scale: None,
656 component: None,
657 fields: None,
658 key: None,
659 value: None,
660 }
661 }
662
663 pub fn tinyint(nullable: bool) -> Self {
664 ColumnType {
665 typ: SqlType::TinyInt,
666 nullable,
667 precision: None,
668 scale: None,
669 component: None,
670 fields: None,
671 key: None,
672 value: None,
673 }
674 }
675
676 pub fn smallint(nullable: bool) -> Self {
677 ColumnType {
678 typ: SqlType::SmallInt,
679 nullable,
680 precision: None,
681 scale: None,
682 component: None,
683 fields: None,
684 key: None,
685 value: None,
686 }
687 }
688
689 pub fn int(nullable: bool) -> Self {
690 ColumnType {
691 typ: SqlType::Int,
692 nullable,
693 precision: None,
694 scale: None,
695 component: None,
696 fields: None,
697 key: None,
698 value: None,
699 }
700 }
701
702 pub fn bigint(nullable: bool) -> Self {
703 ColumnType {
704 typ: SqlType::BigInt,
705 nullable,
706 precision: None,
707 scale: None,
708 component: None,
709 fields: None,
710 key: None,
711 value: None,
712 }
713 }
714
715 pub fn double(nullable: bool) -> Self {
716 ColumnType {
717 typ: SqlType::Double,
718 nullable,
719 precision: None,
720 scale: None,
721 component: None,
722 fields: None,
723 key: None,
724 value: None,
725 }
726 }
727
728 pub fn real(nullable: bool) -> Self {
729 ColumnType {
730 typ: SqlType::Real,
731 nullable,
732 precision: None,
733 scale: None,
734 component: None,
735 fields: None,
736 key: None,
737 value: None,
738 }
739 }
740
741 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
742 ColumnType {
743 typ: SqlType::Decimal,
744 nullable,
745 precision: Some(precision),
746 scale: Some(scale),
747 component: None,
748 fields: None,
749 key: None,
750 value: None,
751 }
752 }
753
754 pub fn varchar(nullable: bool) -> Self {
755 ColumnType {
756 typ: SqlType::Varchar,
757 nullable,
758 precision: None,
759 scale: None,
760 component: None,
761 fields: None,
762 key: None,
763 value: None,
764 }
765 }
766
767 pub fn varbinary(nullable: bool) -> Self {
768 ColumnType {
769 typ: SqlType::Varbinary,
770 nullable,
771 precision: None,
772 scale: None,
773 component: None,
774 fields: None,
775 key: None,
776 value: None,
777 }
778 }
779
780 pub fn fixed(width: i64, nullable: bool) -> Self {
781 ColumnType {
782 typ: SqlType::Binary,
783 nullable,
784 precision: Some(width),
785 scale: None,
786 component: None,
787 fields: None,
788 key: None,
789 value: None,
790 }
791 }
792
793 pub fn date(nullable: bool) -> Self {
794 ColumnType {
795 typ: SqlType::Date,
796 nullable,
797 precision: None,
798 scale: None,
799 component: None,
800 fields: None,
801 key: None,
802 value: None,
803 }
804 }
805
806 pub fn time(nullable: bool) -> Self {
807 ColumnType {
808 typ: SqlType::Time,
809 nullable,
810 precision: None,
811 scale: None,
812 component: None,
813 fields: None,
814 key: None,
815 value: None,
816 }
817 }
818
819 pub fn timestamp(nullable: bool) -> Self {
820 ColumnType {
821 typ: SqlType::Timestamp,
822 nullable,
823 precision: None,
824 scale: None,
825 component: None,
826 fields: None,
827 key: None,
828 value: None,
829 }
830 }
831
832 pub fn variant(nullable: bool) -> Self {
833 ColumnType {
834 typ: SqlType::Variant,
835 nullable,
836 precision: None,
837 scale: None,
838 component: None,
839 fields: None,
840 key: None,
841 value: None,
842 }
843 }
844
845 pub fn array(nullable: bool, element: ColumnType) -> Self {
846 ColumnType {
847 typ: SqlType::Array,
848 nullable,
849 precision: None,
850 scale: None,
851 component: Some(Box::new(element)),
852 fields: None,
853 key: None,
854 value: None,
855 }
856 }
857
858 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
859 ColumnType {
860 typ: SqlType::Struct,
861 nullable,
862 precision: None,
863 scale: None,
864 component: None,
865 fields: Some(fields.to_vec()),
866 key: None,
867 value: None,
868 }
869 }
870
871 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
872 ColumnType {
873 typ: SqlType::Map,
874 nullable,
875 precision: None,
876 scale: None,
877 component: None,
878 fields: None,
879 key: Some(Box::new(key)),
880 value: Some(Box::new(val)),
881 }
882 }
883
884 pub fn is_integral_type(&self) -> bool {
885 matches!(
886 &self.typ,
887 SqlType::TinyInt | SqlType::SmallInt | SqlType::Int | SqlType::BigInt
888 )
889 }
890
891 pub fn is_fp_type(&self) -> bool {
892 matches!(&self.typ, SqlType::Double | SqlType::Real)
893 }
894
895 pub fn is_decimal_type(&self) -> bool {
896 matches!(&self.typ, SqlType::Decimal)
897 }
898
899 pub fn is_numeric_type(&self) -> bool {
900 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::{IntervalUnit, SqlIdentifier};
907 use crate::program_schema::{ColumnType, Field, SqlType};
908
909 #[test]
910 fn serde_sql_type() {
911 for (sql_str_base, expected_value) in [
912 ("Boolean", SqlType::Boolean),
913 ("Uuid", SqlType::Uuid),
914 ("TinyInt", SqlType::TinyInt),
915 ("SmallInt", SqlType::SmallInt),
916 ("Integer", SqlType::Int),
917 ("BigInt", SqlType::BigInt),
918 ("Real", SqlType::Real),
919 ("Double", SqlType::Double),
920 ("Decimal", SqlType::Decimal),
921 ("Char", SqlType::Char),
922 ("Varchar", SqlType::Varchar),
923 ("Binary", SqlType::Binary),
924 ("Varbinary", SqlType::Varbinary),
925 ("Time", SqlType::Time),
926 ("Date", SqlType::Date),
927 ("Timestamp", SqlType::Timestamp),
928 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
929 (
930 "Interval_Day_Hour",
931 SqlType::Interval(IntervalUnit::DayToHour),
932 ),
933 (
934 "Interval_Day_Minute",
935 SqlType::Interval(IntervalUnit::DayToMinute),
936 ),
937 (
938 "Interval_Day_Second",
939 SqlType::Interval(IntervalUnit::DayToSecond),
940 ),
941 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
942 (
943 "Interval_Hour_Minute",
944 SqlType::Interval(IntervalUnit::HourToMinute),
945 ),
946 (
947 "Interval_Hour_Second",
948 SqlType::Interval(IntervalUnit::HourToSecond),
949 ),
950 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
951 (
952 "Interval_Minute_Second",
953 SqlType::Interval(IntervalUnit::MinuteToSecond),
954 ),
955 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
956 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
957 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
958 (
959 "Interval_Year_Month",
960 SqlType::Interval(IntervalUnit::YearToMonth),
961 ),
962 ("Array", SqlType::Array),
963 ("Struct", SqlType::Struct),
964 ("Map", SqlType::Map),
965 ("Null", SqlType::Null),
966 ("Variant", SqlType::Variant),
967 ] {
968 for sql_str in [
969 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
973 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
974 .unwrap_or_else(|_| {
975 panic!("\"{sql_str}\" should deserialize into its SQL type")
976 });
977 assert_eq!(value1, expected_value);
978 let serialized_str =
979 serde_json::to_string(&value1).expect("Value should serialize into JSON");
980 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
981 panic!(
982 "{} should deserialize back into its SQL type",
983 serialized_str
984 )
985 });
986 assert_eq!(value1, value2);
987 }
988 }
989 }
990
991 #[test]
992 fn deserialize_interval_types() {
993 use super::IntervalUnit::*;
994 use super::SqlType::*;
995
996 let schema = r#"
997{
998 "inputs" : [ {
999 "name" : "sales",
1000 "case_sensitive" : false,
1001 "fields" : [ {
1002 "name" : "sales_id",
1003 "case_sensitive" : false,
1004 "columntype" : {
1005 "type" : "INTEGER",
1006 "nullable" : true
1007 }
1008 }, {
1009 "name" : "customer_id",
1010 "case_sensitive" : false,
1011 "columntype" : {
1012 "type" : "INTEGER",
1013 "nullable" : true
1014 }
1015 }, {
1016 "name" : "amount",
1017 "case_sensitive" : false,
1018 "columntype" : {
1019 "type" : "DECIMAL",
1020 "nullable" : true,
1021 "precision" : 10,
1022 "scale" : 2
1023 }
1024 }, {
1025 "name" : "sale_date",
1026 "case_sensitive" : false,
1027 "columntype" : {
1028 "type" : "DATE",
1029 "nullable" : true
1030 }
1031 } ],
1032 "primary_key" : [ "sales_id" ]
1033 } ],
1034 "outputs" : [ {
1035 "name" : "salessummary",
1036 "case_sensitive" : false,
1037 "fields" : [ {
1038 "name" : "customer_id",
1039 "case_sensitive" : false,
1040 "columntype" : {
1041 "type" : "INTEGER",
1042 "nullable" : true
1043 }
1044 }, {
1045 "name" : "total_sales",
1046 "case_sensitive" : false,
1047 "columntype" : {
1048 "type" : "DECIMAL",
1049 "nullable" : true,
1050 "precision" : 38,
1051 "scale" : 2
1052 }
1053 }, {
1054 "name" : "interval_day",
1055 "case_sensitive" : false,
1056 "columntype" : {
1057 "type" : "INTERVAL_DAY",
1058 "nullable" : false,
1059 "precision" : 2,
1060 "scale" : 6
1061 }
1062 }, {
1063 "name" : "interval_day_to_hour",
1064 "case_sensitive" : false,
1065 "columntype" : {
1066 "type" : "INTERVAL_DAY_HOUR",
1067 "nullable" : false,
1068 "precision" : 2,
1069 "scale" : 6
1070 }
1071 }, {
1072 "name" : "interval_day_to_minute",
1073 "case_sensitive" : false,
1074 "columntype" : {
1075 "type" : "INTERVAL_DAY_MINUTE",
1076 "nullable" : false,
1077 "precision" : 2,
1078 "scale" : 6
1079 }
1080 }, {
1081 "name" : "interval_day_to_second",
1082 "case_sensitive" : false,
1083 "columntype" : {
1084 "type" : "INTERVAL_DAY_SECOND",
1085 "nullable" : false,
1086 "precision" : 2,
1087 "scale" : 6
1088 }
1089 }, {
1090 "name" : "interval_hour",
1091 "case_sensitive" : false,
1092 "columntype" : {
1093 "type" : "INTERVAL_HOUR",
1094 "nullable" : false,
1095 "precision" : 2,
1096 "scale" : 6
1097 }
1098 }, {
1099 "name" : "interval_hour_to_minute",
1100 "case_sensitive" : false,
1101 "columntype" : {
1102 "type" : "INTERVAL_HOUR_MINUTE",
1103 "nullable" : false,
1104 "precision" : 2,
1105 "scale" : 6
1106 }
1107 }, {
1108 "name" : "interval_hour_to_second",
1109 "case_sensitive" : false,
1110 "columntype" : {
1111 "type" : "INTERVAL_HOUR_SECOND",
1112 "nullable" : false,
1113 "precision" : 2,
1114 "scale" : 6
1115 }
1116 }, {
1117 "name" : "interval_minute",
1118 "case_sensitive" : false,
1119 "columntype" : {
1120 "type" : "INTERVAL_MINUTE",
1121 "nullable" : false,
1122 "precision" : 2,
1123 "scale" : 6
1124 }
1125 }, {
1126 "name" : "interval_minute_to_second",
1127 "case_sensitive" : false,
1128 "columntype" : {
1129 "type" : "INTERVAL_MINUTE_SECOND",
1130 "nullable" : false,
1131 "precision" : 2,
1132 "scale" : 6
1133 }
1134 }, {
1135 "name" : "interval_month",
1136 "case_sensitive" : false,
1137 "columntype" : {
1138 "type" : "INTERVAL_MONTH",
1139 "nullable" : false
1140 }
1141 }, {
1142 "name" : "interval_second",
1143 "case_sensitive" : false,
1144 "columntype" : {
1145 "type" : "INTERVAL_SECOND",
1146 "nullable" : false,
1147 "precision" : 2,
1148 "scale" : 6
1149 }
1150 }, {
1151 "name" : "interval_year",
1152 "case_sensitive" : false,
1153 "columntype" : {
1154 "type" : "INTERVAL_YEAR",
1155 "nullable" : false
1156 }
1157 }, {
1158 "name" : "interval_year_to_month",
1159 "case_sensitive" : false,
1160 "columntype" : {
1161 "type" : "INTERVAL_YEAR_MONTH",
1162 "nullable" : false
1163 }
1164 } ]
1165 } ]
1166}
1167"#;
1168
1169 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1170 let types = schema
1171 .outputs
1172 .iter()
1173 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1174 let expected_types = [
1175 Int,
1176 Decimal,
1177 Interval(Day),
1178 Interval(DayToHour),
1179 Interval(DayToMinute),
1180 Interval(DayToSecond),
1181 Interval(Hour),
1182 Interval(HourToMinute),
1183 Interval(HourToSecond),
1184 Interval(Minute),
1185 Interval(MinuteToSecond),
1186 Interval(Month),
1187 Interval(Second),
1188 Interval(Year),
1189 Interval(YearToMonth),
1190 ];
1191
1192 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1193 }
1194
1195 #[test]
1196 fn serialize_struct_schemas() {
1197 let schema = r#"{
1198 "inputs" : [ {
1199 "name" : "PERS",
1200 "case_sensitive" : false,
1201 "fields" : [ {
1202 "name" : "P0",
1203 "case_sensitive" : false,
1204 "columntype" : {
1205 "fields" : [ {
1206 "type" : "VARCHAR",
1207 "nullable" : true,
1208 "precision" : 30,
1209 "name" : "FIRSTNAME"
1210 }, {
1211 "type" : "VARCHAR",
1212 "nullable" : true,
1213 "precision" : 30,
1214 "name" : "LASTNAME"
1215 }, {
1216 "fields" : {
1217 "fields" : [ {
1218 "type" : "VARCHAR",
1219 "nullable" : true,
1220 "precision" : 30,
1221 "name" : "STREET"
1222 }, {
1223 "type" : "VARCHAR",
1224 "nullable" : true,
1225 "precision" : 30,
1226 "name" : "CITY"
1227 }, {
1228 "type" : "CHAR",
1229 "nullable" : true,
1230 "precision" : 2,
1231 "name" : "STATE"
1232 }, {
1233 "type" : "VARCHAR",
1234 "nullable" : true,
1235 "precision" : 6,
1236 "name" : "POSTAL_CODE"
1237 } ],
1238 "nullable" : false
1239 },
1240 "nullable" : false,
1241 "name" : "ADDRESS"
1242 } ],
1243 "nullable" : false
1244 }
1245 }]
1246 } ],
1247 "outputs" : [ ]
1248}
1249"#;
1250 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1251 eprintln!("{:#?}", schema);
1252 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1253 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1254 assert_eq!(p0.columntype.typ, SqlType::Struct);
1255 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1256 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1257 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1258 assert_eq!(p0_fields[2].columntype.typ, SqlType::Struct);
1259 assert_eq!(p0_fields[2].name, "ADDRESS");
1260 let address = &p0_fields[2].columntype.fields.as_ref().unwrap();
1261 assert_eq!(address.len(), 4);
1262 assert_eq!(address[0].name, "STREET");
1263 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1264 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1265 assert_eq!(address[2].columntype.typ, SqlType::Char);
1266 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1267 }
1268
1269 #[test]
1270 fn sql_identifier_cmp() {
1271 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1272 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1273 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1274 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1275 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1276 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1277 assert_eq!(
1278 SqlIdentifier::new("bAr", true),
1279 SqlIdentifier::from("\"bAr\"")
1280 );
1281
1282 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1283 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1284 }
1285
1286 #[test]
1287 fn sql_identifier_ord() {
1288 let mut btree = std::collections::BTreeSet::new();
1289 assert!(btree.insert(SqlIdentifier::from("foo")));
1290 assert!(btree.insert(SqlIdentifier::from("bar")));
1291 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1292 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1293 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1294 }
1295
1296 #[test]
1297 fn sql_identifier_hash() {
1298 let mut hs = std::collections::HashSet::new();
1299 assert!(hs.insert(SqlIdentifier::from("foo")));
1300 assert!(hs.insert(SqlIdentifier::from("bar")));
1301 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1302 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1303 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1304 }
1305
1306 #[test]
1307 fn sql_identifier_name() {
1308 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1309 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1310 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1311 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1312 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1313 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1314 }
1315
1316 #[test]
1317 fn issue3277() {
1318 let schema = r#"{
1319 "name" : "j",
1320 "case_sensitive" : false,
1321 "columntype" : {
1322 "fields" : [ {
1323 "key" : {
1324 "nullable" : false,
1325 "precision" : -1,
1326 "type" : "VARCHAR"
1327 },
1328 "name" : "s",
1329 "nullable" : true,
1330 "type" : "MAP",
1331 "value" : {
1332 "nullable" : true,
1333 "precision" : -1,
1334 "type" : "VARCHAR"
1335 }
1336 } ],
1337 "nullable" : true
1338 }
1339 }"#;
1340 let field: Field = serde_json::from_str(schema).unwrap();
1341 println!("field: {:#?}", field);
1342 assert_eq!(
1343 field,
1344 Field {
1345 name: SqlIdentifier {
1346 name: "j".to_string(),
1347 case_sensitive: false,
1348 },
1349 columntype: ColumnType {
1350 typ: SqlType::Struct,
1351 nullable: true,
1352 precision: None,
1353 scale: None,
1354 component: None,
1355 fields: Some(vec![Field {
1356 name: SqlIdentifier {
1357 name: "s".to_string(),
1358 case_sensitive: false,
1359 },
1360 columntype: ColumnType {
1361 typ: SqlType::Map,
1362 nullable: true,
1363 precision: None,
1364 scale: None,
1365 component: None,
1366 fields: None,
1367 key: Some(Box::new(ColumnType {
1368 typ: SqlType::Varchar,
1369 nullable: false,
1370 precision: Some(-1),
1371 scale: None,
1372 component: None,
1373 fields: None,
1374 key: None,
1375 value: None,
1376 })),
1377 value: Some(Box::new(ColumnType {
1378 typ: SqlType::Varchar,
1379 nullable: true,
1380 precision: Some(-1),
1381 scale: None,
1382 component: None,
1383 fields: None,
1384 key: None,
1385 value: None,
1386 })),
1387 },
1388 lateness: None,
1389 default: None,
1390 unused: false,
1391 watermark: None,
1392 }]),
1393 key: None,
1394 value: None,
1395 },
1396 lateness: None,
1397 default: None,
1398 unused: false,
1399 watermark: None,
1400 }
1401 );
1402 }
1403}