1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::cmp::Ordering;
3use std::collections::BTreeMap;
4use std::fmt::Display;
5use std::hash::{Hash, Hasher};
6use utoipa::ToSchema;
7
8#[cfg(feature = "testing")]
9use proptest::{collection::vec, prelude::any};
10
11pub fn canonical_identifier(id: &str) -> String {
19 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
20 id[1..id.len() - 1].to_string()
21 } else {
22 id.to_lowercase()
23 }
24}
25
26#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
31#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
32pub struct SqlIdentifier {
33 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
34 name: String,
35 pub case_sensitive: bool,
36}
37
38impl SqlIdentifier {
39 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
40 Self {
41 name: name.as_ref().to_string(),
42 case_sensitive,
43 }
44 }
45
46 pub fn name(&self) -> String {
56 if self.case_sensitive {
57 self.name.clone()
58 } else {
59 self.name.to_lowercase()
60 }
61 }
62
63 pub fn sql_name(&self) -> String {
74 if self.case_sensitive {
75 format!("\"{}\"", self.name)
76 } else {
77 self.name.clone()
78 }
79 }
80}
81
82impl Hash for SqlIdentifier {
83 fn hash<H: Hasher>(&self, state: &mut H) {
84 self.name().hash(state);
85 }
86}
87
88impl PartialEq for SqlIdentifier {
89 fn eq(&self, other: &Self) -> bool {
90 match (self.case_sensitive, other.case_sensitive) {
91 (true, true) => self.name == other.name,
92 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
93 (true, false) => self.name == other.name,
94 (false, true) => self.name == other.name,
95 }
96 }
97}
98
99impl Ord for SqlIdentifier {
100 fn cmp(&self, other: &Self) -> Ordering {
101 self.name().cmp(&other.name())
102 }
103}
104
105impl PartialOrd for SqlIdentifier {
106 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
107 Some(self.cmp(other))
108 }
109}
110
111impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
112 fn eq(&self, other: &S) -> bool {
113 self == &SqlIdentifier::from(other.as_ref())
114 }
115}
116
117impl Eq for SqlIdentifier {}
118
119impl<S: AsRef<str>> From<S> for SqlIdentifier {
120 fn from(name: S) -> Self {
121 if name.as_ref().starts_with('"')
122 && name.as_ref().ends_with('"')
123 && name.as_ref().len() >= 2
124 {
125 Self {
126 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
127 case_sensitive: true,
128 }
129 } else {
130 Self::new(name, false)
131 }
132 }
133}
134
135impl From<SqlIdentifier> for String {
136 fn from(id: SqlIdentifier) -> String {
137 id.name()
138 }
139}
140
141impl From<&SqlIdentifier> for String {
142 fn from(id: &SqlIdentifier) -> String {
143 id.name()
144 }
145}
146
147impl Display for SqlIdentifier {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 write!(f, "{}", self.name())
150 }
151}
152
153#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
157#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
158pub struct ProgramSchema {
159 #[cfg_attr(
160 feature = "testing",
161 proptest(strategy = "vec(any::<Relation>(), 0..2)")
162 )]
163 pub inputs: Vec<Relation>,
164 #[cfg_attr(
165 feature = "testing",
166 proptest(strategy = "vec(any::<Relation>(), 0..2)")
167 )]
168 pub outputs: Vec<Relation>,
169}
170
171#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
172#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
173pub struct SourcePosition {
174 pub start_line_number: usize,
175 pub start_column: usize,
176 pub end_line_number: usize,
177 pub end_column: usize,
178}
179
180#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
181#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
182pub struct PropertyValue {
183 pub value: String,
184 pub key_position: SourcePosition,
185 pub value_position: SourcePosition,
186}
187
188#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
192#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
193pub struct Relation {
194 #[serde(flatten)]
196 pub name: SqlIdentifier,
197 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
198 pub fields: Vec<Field>,
199 #[serde(default)]
200 pub materialized: bool,
201 #[serde(default)]
202 pub properties: BTreeMap<String, PropertyValue>,
203}
204
205impl Relation {
206 pub fn empty() -> Self {
207 Self {
208 name: SqlIdentifier::from("".to_string()),
209 fields: Vec::new(),
210 materialized: false,
211 properties: BTreeMap::new(),
212 }
213 }
214
215 pub fn new(
216 name: SqlIdentifier,
217 fields: Vec<Field>,
218 materialized: bool,
219 properties: BTreeMap<String, PropertyValue>,
220 ) -> Self {
221 Self {
222 name,
223 fields,
224 materialized,
225 properties,
226 }
227 }
228
229 pub fn field(&self, name: &str) -> Option<&Field> {
231 let name = canonical_identifier(name);
232 self.fields.iter().find(|f| f.name == name)
233 }
234}
235
236#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
240#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
241pub struct Field {
242 #[serde(flatten)]
243 pub name: SqlIdentifier,
244 pub columntype: ColumnType,
245 pub lateness: Option<String>,
246 pub default: Option<String>,
247 pub unused: bool,
248 pub watermark: Option<String>,
249}
250
251impl Field {
252 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
253 Self {
254 name,
255 columntype,
256 lateness: None,
257 default: None,
258 unused: false,
259 watermark: None,
260 }
261 }
262
263 pub fn with_lateness(mut self, lateness: &str) -> Self {
264 self.lateness = Some(lateness.to_string());
265 self
266 }
267
268 pub fn with_unused(mut self, unused: bool) -> Self {
269 self.unused = unused;
270 self
271 }
272}
273
274impl<'de> Deserialize<'de> for Field {
279 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
280 where
281 D: Deserializer<'de>,
282 {
283 const fn default_is_struct() -> Option<SqlType> {
284 Some(SqlType::Struct)
285 }
286
287 #[derive(Debug, Clone, Deserialize)]
288 struct FieldHelper {
289 name: Option<String>,
290 #[serde(default)]
291 case_sensitive: bool,
292 columntype: Option<ColumnType>,
293 #[serde(rename = "type")]
294 #[serde(default = "default_is_struct")]
295 typ: Option<SqlType>,
296 nullable: Option<bool>,
297 precision: Option<i64>,
298 scale: Option<i64>,
299 component: Option<Box<ColumnType>>,
300 fields: Option<serde_json::Value>,
301 key: Option<Box<ColumnType>>,
302 value: Option<Box<ColumnType>>,
303 default: Option<String>,
304 #[serde(default)]
305 unused: bool,
306 lateness: Option<String>,
307 watermark: Option<String>,
308 }
309
310 fn helper_to_field(helper: FieldHelper) -> Field {
311 let columntype = if let Some(ctype) = helper.columntype {
312 ctype
313 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
314 let fields = fields
315 .into_iter()
316 .map(|field| {
317 let field: FieldHelper = serde_json::from_value(field).unwrap();
318 helper_to_field(field)
319 })
320 .collect::<Vec<Field>>();
321
322 ColumnType {
323 typ: helper.typ.unwrap_or(SqlType::Null),
324 nullable: helper.nullable.unwrap_or(false),
325 precision: helper.precision,
326 scale: helper.scale,
327 component: helper.component,
328 fields: Some(fields),
329 key: None,
330 value: None,
331 }
332 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
333 serde_json::from_value(serde_json::Value::Object(obj))
334 .expect("Failed to deserialize object")
335 } else {
336 ColumnType {
337 typ: helper.typ.unwrap_or(SqlType::Null),
338 nullable: helper.nullable.unwrap_or(false),
339 precision: helper.precision,
340 scale: helper.scale,
341 component: helper.component,
342 fields: None,
343 key: helper.key,
344 value: helper.value,
345 }
346 };
347
348 Field {
349 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
350 columntype,
351 default: helper.default,
352 unused: helper.unused,
353 lateness: helper.lateness,
354 watermark: helper.watermark,
355 }
356 }
357
358 let helper = FieldHelper::deserialize(deserializer)?;
359 Ok(helper_to_field(helper))
360 }
361}
362
363#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
368#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
369pub enum IntervalUnit {
370 Day,
372 DayToHour,
374 DayToMinute,
376 DayToSecond,
378 Hour,
380 HourToMinute,
382 HourToSecond,
384 Minute,
386 MinuteToSecond,
388 Month,
390 Second,
392 Year,
394 YearToMonth,
396}
397
398#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
400#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
401pub enum SqlType {
402 Boolean,
404 TinyInt,
406 SmallInt,
408 Int,
410 BigInt,
412 UTinyInt,
414 USmallInt,
416 UInt,
418 UBigInt,
420 Real,
422 Double,
424 Decimal,
426 Char,
428 Varchar,
430 Binary,
432 Varbinary,
434 Time,
436 Date,
438 Timestamp,
440 Interval(IntervalUnit),
442 Array,
444 Struct,
446 Map,
448 Null,
450 Uuid,
452 Variant,
454}
455
456impl Display for SqlType {
457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458 f.write_str(&serde_json::to_string(self).unwrap())
459 }
460}
461
462impl<'de> Deserialize<'de> for SqlType {
463 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
464 where
465 D: Deserializer<'de>,
466 {
467 let value: String = Deserialize::deserialize(deserializer)?;
468 match value.to_lowercase().as_str() {
469 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
470 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
471 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
472 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
473 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
474 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
475 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
476 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
477 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
478 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
479 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
480 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
481 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
482 "boolean" => Ok(SqlType::Boolean),
483 "tinyint" => Ok(SqlType::TinyInt),
484 "smallint" => Ok(SqlType::SmallInt),
485 "integer" => Ok(SqlType::Int),
486 "bigint" => Ok(SqlType::BigInt),
487 "utinyint" => Ok(SqlType::UTinyInt),
488 "usmallint" => Ok(SqlType::USmallInt),
489 "uinteger" => Ok(SqlType::UInt),
490 "ubigint" => Ok(SqlType::UBigInt),
491 "real" => Ok(SqlType::Real),
492 "double" => Ok(SqlType::Double),
493 "decimal" => Ok(SqlType::Decimal),
494 "char" => Ok(SqlType::Char),
495 "varchar" => Ok(SqlType::Varchar),
496 "binary" => Ok(SqlType::Binary),
497 "varbinary" => Ok(SqlType::Varbinary),
498 "variant" => Ok(SqlType::Variant),
499 "time" => Ok(SqlType::Time),
500 "date" => Ok(SqlType::Date),
501 "timestamp" => Ok(SqlType::Timestamp),
502 "array" => Ok(SqlType::Array),
503 "struct" => Ok(SqlType::Struct),
504 "map" => Ok(SqlType::Map),
505 "null" => Ok(SqlType::Null),
506 "uuid" => Ok(SqlType::Uuid),
507 _ => Err(serde::de::Error::custom(format!(
508 "Unknown SQL type: {}",
509 value
510 ))),
511 }
512 }
513}
514
515impl Serialize for SqlType {
516 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
517 where
518 S: Serializer,
519 {
520 let type_str = match self {
521 SqlType::Boolean => "BOOLEAN",
522 SqlType::TinyInt => "TINYINT",
523 SqlType::SmallInt => "SMALLINT",
524 SqlType::Int => "INTEGER",
525 SqlType::BigInt => "BIGINT",
526 SqlType::UTinyInt => "UTINYINT",
527 SqlType::USmallInt => "USMALLINT",
528 SqlType::UInt => "UINTEGER",
529 SqlType::UBigInt => "UBIGINT",
530 SqlType::Real => "REAL",
531 SqlType::Double => "DOUBLE",
532 SqlType::Decimal => "DECIMAL",
533 SqlType::Char => "CHAR",
534 SqlType::Varchar => "VARCHAR",
535 SqlType::Binary => "BINARY",
536 SqlType::Varbinary => "VARBINARY",
537 SqlType::Time => "TIME",
538 SqlType::Date => "DATE",
539 SqlType::Timestamp => "TIMESTAMP",
540 SqlType::Interval(interval_unit) => match interval_unit {
541 IntervalUnit::Day => "INTERVAL_DAY",
542 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
543 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
544 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
545 IntervalUnit::Hour => "INTERVAL_HOUR",
546 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
547 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
548 IntervalUnit::Minute => "INTERVAL_MINUTE",
549 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
550 IntervalUnit::Month => "INTERVAL_MONTH",
551 IntervalUnit::Second => "INTERVAL_SECOND",
552 IntervalUnit::Year => "INTERVAL_YEAR",
553 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
554 },
555 SqlType::Array => "ARRAY",
556 SqlType::Struct => "STRUCT",
557 SqlType::Uuid => "UUID",
558 SqlType::Map => "MAP",
559 SqlType::Null => "NULL",
560 SqlType::Variant => "VARIANT",
561 };
562 serializer.serialize_str(type_str)
563 }
564}
565
566impl SqlType {
567 pub fn is_string(&self) -> bool {
569 matches!(self, Self::Char | Self::Varchar)
570 }
571
572 pub fn is_varchar(&self) -> bool {
573 matches!(self, Self::Varchar)
574 }
575
576 pub fn is_varbinary(&self) -> bool {
577 matches!(self, Self::Varbinary)
578 }
579}
580
581const fn default_is_struct() -> SqlType {
584 SqlType::Struct
585}
586
587#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
591#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
592pub struct ColumnType {
593 #[serde(rename = "type")]
595 #[serde(default = "default_is_struct")]
596 pub typ: SqlType,
597 pub nullable: bool,
599 pub precision: Option<i64>,
608 pub scale: Option<i64>,
613 #[cfg_attr(feature = "testing", proptest(value = "None"))]
620 pub component: Option<Box<ColumnType>>,
621 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
643 pub fields: Option<Vec<Field>>,
644 #[cfg_attr(feature = "testing", proptest(value = "None"))]
646 pub key: Option<Box<ColumnType>>,
647 #[cfg_attr(feature = "testing", proptest(value = "None"))]
649 pub value: Option<Box<ColumnType>>,
650}
651
652impl ColumnType {
653 pub fn boolean(nullable: bool) -> Self {
654 ColumnType {
655 typ: SqlType::Boolean,
656 nullable,
657 precision: None,
658 scale: None,
659 component: None,
660 fields: None,
661 key: None,
662 value: None,
663 }
664 }
665
666 pub fn uuid(nullable: bool) -> Self {
667 ColumnType {
668 typ: SqlType::Uuid,
669 nullable,
670 precision: None,
671 scale: None,
672 component: None,
673 fields: None,
674 key: None,
675 value: None,
676 }
677 }
678
679 pub fn tinyint(nullable: bool) -> Self {
680 ColumnType {
681 typ: SqlType::TinyInt,
682 nullable,
683 precision: None,
684 scale: None,
685 component: None,
686 fields: None,
687 key: None,
688 value: None,
689 }
690 }
691
692 pub fn smallint(nullable: bool) -> Self {
693 ColumnType {
694 typ: SqlType::SmallInt,
695 nullable,
696 precision: None,
697 scale: None,
698 component: None,
699 fields: None,
700 key: None,
701 value: None,
702 }
703 }
704
705 pub fn int(nullable: bool) -> Self {
706 ColumnType {
707 typ: SqlType::Int,
708 nullable,
709 precision: None,
710 scale: None,
711 component: None,
712 fields: None,
713 key: None,
714 value: None,
715 }
716 }
717
718 pub fn bigint(nullable: bool) -> Self {
719 ColumnType {
720 typ: SqlType::BigInt,
721 nullable,
722 precision: None,
723 scale: None,
724 component: None,
725 fields: None,
726 key: None,
727 value: None,
728 }
729 }
730
731 pub fn utinyint(nullable: bool) -> Self {
732 ColumnType {
733 typ: SqlType::UTinyInt,
734 nullable,
735 precision: None,
736 scale: None,
737 component: None,
738 fields: None,
739 key: None,
740 value: None,
741 }
742 }
743
744 pub fn usmallint(nullable: bool) -> Self {
745 ColumnType {
746 typ: SqlType::USmallInt,
747 nullable,
748 precision: None,
749 scale: None,
750 component: None,
751 fields: None,
752 key: None,
753 value: None,
754 }
755 }
756
757 pub fn uint(nullable: bool) -> Self {
758 ColumnType {
759 typ: SqlType::UInt,
760 nullable,
761 precision: None,
762 scale: None,
763 component: None,
764 fields: None,
765 key: None,
766 value: None,
767 }
768 }
769
770 pub fn ubigint(nullable: bool) -> Self {
771 ColumnType {
772 typ: SqlType::UBigInt,
773 nullable,
774 precision: None,
775 scale: None,
776 component: None,
777 fields: None,
778 key: None,
779 value: None,
780 }
781 }
782
783 pub fn double(nullable: bool) -> Self {
784 ColumnType {
785 typ: SqlType::Double,
786 nullable,
787 precision: None,
788 scale: None,
789 component: None,
790 fields: None,
791 key: None,
792 value: None,
793 }
794 }
795
796 pub fn real(nullable: bool) -> Self {
797 ColumnType {
798 typ: SqlType::Real,
799 nullable,
800 precision: None,
801 scale: None,
802 component: None,
803 fields: None,
804 key: None,
805 value: None,
806 }
807 }
808
809 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
810 ColumnType {
811 typ: SqlType::Decimal,
812 nullable,
813 precision: Some(precision),
814 scale: Some(scale),
815 component: None,
816 fields: None,
817 key: None,
818 value: None,
819 }
820 }
821
822 pub fn varchar(nullable: bool) -> Self {
823 ColumnType {
824 typ: SqlType::Varchar,
825 nullable,
826 precision: None,
827 scale: None,
828 component: None,
829 fields: None,
830 key: None,
831 value: None,
832 }
833 }
834
835 pub fn varbinary(nullable: bool) -> Self {
836 ColumnType {
837 typ: SqlType::Varbinary,
838 nullable,
839 precision: None,
840 scale: None,
841 component: None,
842 fields: None,
843 key: None,
844 value: None,
845 }
846 }
847
848 pub fn fixed(width: i64, nullable: bool) -> Self {
849 ColumnType {
850 typ: SqlType::Binary,
851 nullable,
852 precision: Some(width),
853 scale: None,
854 component: None,
855 fields: None,
856 key: None,
857 value: None,
858 }
859 }
860
861 pub fn date(nullable: bool) -> Self {
862 ColumnType {
863 typ: SqlType::Date,
864 nullable,
865 precision: None,
866 scale: None,
867 component: None,
868 fields: None,
869 key: None,
870 value: None,
871 }
872 }
873
874 pub fn time(nullable: bool) -> Self {
875 ColumnType {
876 typ: SqlType::Time,
877 nullable,
878 precision: None,
879 scale: None,
880 component: None,
881 fields: None,
882 key: None,
883 value: None,
884 }
885 }
886
887 pub fn timestamp(nullable: bool) -> Self {
888 ColumnType {
889 typ: SqlType::Timestamp,
890 nullable,
891 precision: None,
892 scale: None,
893 component: None,
894 fields: None,
895 key: None,
896 value: None,
897 }
898 }
899
900 pub fn variant(nullable: bool) -> Self {
901 ColumnType {
902 typ: SqlType::Variant,
903 nullable,
904 precision: None,
905 scale: None,
906 component: None,
907 fields: None,
908 key: None,
909 value: None,
910 }
911 }
912
913 pub fn array(nullable: bool, element: ColumnType) -> Self {
914 ColumnType {
915 typ: SqlType::Array,
916 nullable,
917 precision: None,
918 scale: None,
919 component: Some(Box::new(element)),
920 fields: None,
921 key: None,
922 value: None,
923 }
924 }
925
926 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
927 ColumnType {
928 typ: SqlType::Struct,
929 nullable,
930 precision: None,
931 scale: None,
932 component: None,
933 fields: Some(fields.to_vec()),
934 key: None,
935 value: None,
936 }
937 }
938
939 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
940 ColumnType {
941 typ: SqlType::Map,
942 nullable,
943 precision: None,
944 scale: None,
945 component: None,
946 fields: None,
947 key: Some(Box::new(key)),
948 value: Some(Box::new(val)),
949 }
950 }
951
952 pub fn is_integral_type(&self) -> bool {
953 matches!(
954 &self.typ,
955 SqlType::TinyInt
956 | SqlType::SmallInt
957 | SqlType::Int
958 | SqlType::BigInt
959 | SqlType::UTinyInt
960 | SqlType::USmallInt
961 | SqlType::UInt
962 | SqlType::UBigInt
963 )
964 }
965
966 pub fn is_fp_type(&self) -> bool {
967 matches!(&self.typ, SqlType::Double | SqlType::Real)
968 }
969
970 pub fn is_decimal_type(&self) -> bool {
971 matches!(&self.typ, SqlType::Decimal)
972 }
973
974 pub fn is_numeric_type(&self) -> bool {
975 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
976 }
977}
978
979#[cfg(test)]
980mod tests {
981 use super::{IntervalUnit, SqlIdentifier};
982 use crate::program_schema::{ColumnType, Field, SqlType};
983
984 #[test]
985 fn serde_sql_type() {
986 for (sql_str_base, expected_value) in [
987 ("Boolean", SqlType::Boolean),
988 ("Uuid", SqlType::Uuid),
989 ("TinyInt", SqlType::TinyInt),
990 ("SmallInt", SqlType::SmallInt),
991 ("Integer", SqlType::Int),
992 ("BigInt", SqlType::BigInt),
993 ("UTinyInt", SqlType::UTinyInt),
994 ("USmallInt", SqlType::USmallInt),
995 ("UInteger", SqlType::UInt),
996 ("UBigInt", SqlType::UBigInt),
997 ("Real", SqlType::Real),
998 ("Double", SqlType::Double),
999 ("Decimal", SqlType::Decimal),
1000 ("Char", SqlType::Char),
1001 ("Varchar", SqlType::Varchar),
1002 ("Binary", SqlType::Binary),
1003 ("Varbinary", SqlType::Varbinary),
1004 ("Time", SqlType::Time),
1005 ("Date", SqlType::Date),
1006 ("Timestamp", SqlType::Timestamp),
1007 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
1008 (
1009 "Interval_Day_Hour",
1010 SqlType::Interval(IntervalUnit::DayToHour),
1011 ),
1012 (
1013 "Interval_Day_Minute",
1014 SqlType::Interval(IntervalUnit::DayToMinute),
1015 ),
1016 (
1017 "Interval_Day_Second",
1018 SqlType::Interval(IntervalUnit::DayToSecond),
1019 ),
1020 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
1021 (
1022 "Interval_Hour_Minute",
1023 SqlType::Interval(IntervalUnit::HourToMinute),
1024 ),
1025 (
1026 "Interval_Hour_Second",
1027 SqlType::Interval(IntervalUnit::HourToSecond),
1028 ),
1029 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
1030 (
1031 "Interval_Minute_Second",
1032 SqlType::Interval(IntervalUnit::MinuteToSecond),
1033 ),
1034 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
1035 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
1036 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
1037 (
1038 "Interval_Year_Month",
1039 SqlType::Interval(IntervalUnit::YearToMonth),
1040 ),
1041 ("Array", SqlType::Array),
1042 ("Struct", SqlType::Struct),
1043 ("Map", SqlType::Map),
1044 ("Null", SqlType::Null),
1045 ("Variant", SqlType::Variant),
1046 ] {
1047 for sql_str in [
1048 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
1052 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
1053 .unwrap_or_else(|_| {
1054 panic!("\"{sql_str}\" should deserialize into its SQL type")
1055 });
1056 assert_eq!(value1, expected_value);
1057 let serialized_str =
1058 serde_json::to_string(&value1).expect("Value should serialize into JSON");
1059 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
1060 panic!(
1061 "{} should deserialize back into its SQL type",
1062 serialized_str
1063 )
1064 });
1065 assert_eq!(value1, value2);
1066 }
1067 }
1068 }
1069
1070 #[test]
1071 fn deserialize_interval_types() {
1072 use super::IntervalUnit::*;
1073 use super::SqlType::*;
1074
1075 let schema = r#"
1076{
1077 "inputs" : [ {
1078 "name" : "sales",
1079 "case_sensitive" : false,
1080 "fields" : [ {
1081 "name" : "sales_id",
1082 "case_sensitive" : false,
1083 "columntype" : {
1084 "type" : "INTEGER",
1085 "nullable" : true
1086 }
1087 }, {
1088 "name" : "customer_id",
1089 "case_sensitive" : false,
1090 "columntype" : {
1091 "type" : "INTEGER",
1092 "nullable" : true
1093 }
1094 }, {
1095 "name" : "age",
1096 "case_sensitive" : false,
1097 "columntype" : {
1098 "type" : "UINTEGER",
1099 "nullable" : true
1100 }
1101 }, {
1102 "name" : "amount",
1103 "case_sensitive" : false,
1104 "columntype" : {
1105 "type" : "DECIMAL",
1106 "nullable" : true,
1107 "precision" : 10,
1108 "scale" : 2
1109 }
1110 }, {
1111 "name" : "sale_date",
1112 "case_sensitive" : false,
1113 "columntype" : {
1114 "type" : "DATE",
1115 "nullable" : true
1116 }
1117 } ],
1118 "primary_key" : [ "sales_id" ]
1119 } ],
1120 "outputs" : [ {
1121 "name" : "salessummary",
1122 "case_sensitive" : false,
1123 "fields" : [ {
1124 "name" : "customer_id",
1125 "case_sensitive" : false,
1126 "columntype" : {
1127 "type" : "INTEGER",
1128 "nullable" : true
1129 }
1130 }, {
1131 "name" : "total_sales",
1132 "case_sensitive" : false,
1133 "columntype" : {
1134 "type" : "DECIMAL",
1135 "nullable" : true,
1136 "precision" : 38,
1137 "scale" : 2
1138 }
1139 }, {
1140 "name" : "interval_day",
1141 "case_sensitive" : false,
1142 "columntype" : {
1143 "type" : "INTERVAL_DAY",
1144 "nullable" : false,
1145 "precision" : 2,
1146 "scale" : 6
1147 }
1148 }, {
1149 "name" : "interval_day_to_hour",
1150 "case_sensitive" : false,
1151 "columntype" : {
1152 "type" : "INTERVAL_DAY_HOUR",
1153 "nullable" : false,
1154 "precision" : 2,
1155 "scale" : 6
1156 }
1157 }, {
1158 "name" : "interval_day_to_minute",
1159 "case_sensitive" : false,
1160 "columntype" : {
1161 "type" : "INTERVAL_DAY_MINUTE",
1162 "nullable" : false,
1163 "precision" : 2,
1164 "scale" : 6
1165 }
1166 }, {
1167 "name" : "interval_day_to_second",
1168 "case_sensitive" : false,
1169 "columntype" : {
1170 "type" : "INTERVAL_DAY_SECOND",
1171 "nullable" : false,
1172 "precision" : 2,
1173 "scale" : 6
1174 }
1175 }, {
1176 "name" : "interval_hour",
1177 "case_sensitive" : false,
1178 "columntype" : {
1179 "type" : "INTERVAL_HOUR",
1180 "nullable" : false,
1181 "precision" : 2,
1182 "scale" : 6
1183 }
1184 }, {
1185 "name" : "interval_hour_to_minute",
1186 "case_sensitive" : false,
1187 "columntype" : {
1188 "type" : "INTERVAL_HOUR_MINUTE",
1189 "nullable" : false,
1190 "precision" : 2,
1191 "scale" : 6
1192 }
1193 }, {
1194 "name" : "interval_hour_to_second",
1195 "case_sensitive" : false,
1196 "columntype" : {
1197 "type" : "INTERVAL_HOUR_SECOND",
1198 "nullable" : false,
1199 "precision" : 2,
1200 "scale" : 6
1201 }
1202 }, {
1203 "name" : "interval_minute",
1204 "case_sensitive" : false,
1205 "columntype" : {
1206 "type" : "INTERVAL_MINUTE",
1207 "nullable" : false,
1208 "precision" : 2,
1209 "scale" : 6
1210 }
1211 }, {
1212 "name" : "interval_minute_to_second",
1213 "case_sensitive" : false,
1214 "columntype" : {
1215 "type" : "INTERVAL_MINUTE_SECOND",
1216 "nullable" : false,
1217 "precision" : 2,
1218 "scale" : 6
1219 }
1220 }, {
1221 "name" : "interval_month",
1222 "case_sensitive" : false,
1223 "columntype" : {
1224 "type" : "INTERVAL_MONTH",
1225 "nullable" : false
1226 }
1227 }, {
1228 "name" : "interval_second",
1229 "case_sensitive" : false,
1230 "columntype" : {
1231 "type" : "INTERVAL_SECOND",
1232 "nullable" : false,
1233 "precision" : 2,
1234 "scale" : 6
1235 }
1236 }, {
1237 "name" : "interval_year",
1238 "case_sensitive" : false,
1239 "columntype" : {
1240 "type" : "INTERVAL_YEAR",
1241 "nullable" : false
1242 }
1243 }, {
1244 "name" : "interval_year_to_month",
1245 "case_sensitive" : false,
1246 "columntype" : {
1247 "type" : "INTERVAL_YEAR_MONTH",
1248 "nullable" : false
1249 }
1250 } ]
1251 } ]
1252}
1253"#;
1254
1255 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1256 let types = schema
1257 .outputs
1258 .iter()
1259 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1260 let expected_types = [
1261 Int,
1262 Decimal,
1263 Interval(Day),
1264 Interval(DayToHour),
1265 Interval(DayToMinute),
1266 Interval(DayToSecond),
1267 Interval(Hour),
1268 Interval(HourToMinute),
1269 Interval(HourToSecond),
1270 Interval(Minute),
1271 Interval(MinuteToSecond),
1272 Interval(Month),
1273 Interval(Second),
1274 Interval(Year),
1275 Interval(YearToMonth),
1276 ];
1277
1278 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1279 }
1280
1281 #[test]
1282 fn serialize_struct_schemas() {
1283 let schema = r#"{
1284 "inputs" : [ {
1285 "name" : "PERS",
1286 "case_sensitive" : false,
1287 "fields" : [ {
1288 "name" : "P0",
1289 "case_sensitive" : false,
1290 "columntype" : {
1291 "fields" : [ {
1292 "type" : "VARCHAR",
1293 "nullable" : true,
1294 "precision" : 30,
1295 "name" : "FIRSTNAME"
1296 }, {
1297 "type" : "VARCHAR",
1298 "nullable" : true,
1299 "precision" : 30,
1300 "name" : "LASTNAME"
1301 }, {
1302 "type" : "UINTEGER",
1303 "nullable" : true,
1304 "name" : "AGE"
1305 }, {
1306 "fields" : {
1307 "fields" : [ {
1308 "type" : "VARCHAR",
1309 "nullable" : true,
1310 "precision" : 30,
1311 "name" : "STREET"
1312 }, {
1313 "type" : "VARCHAR",
1314 "nullable" : true,
1315 "precision" : 30,
1316 "name" : "CITY"
1317 }, {
1318 "type" : "CHAR",
1319 "nullable" : true,
1320 "precision" : 2,
1321 "name" : "STATE"
1322 }, {
1323 "type" : "VARCHAR",
1324 "nullable" : true,
1325 "precision" : 6,
1326 "name" : "POSTAL_CODE"
1327 } ],
1328 "nullable" : false
1329 },
1330 "nullable" : false,
1331 "name" : "ADDRESS"
1332 } ],
1333 "nullable" : false
1334 }
1335 }]
1336 } ],
1337 "outputs" : [ ]
1338}
1339"#;
1340 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1341 eprintln!("{:#?}", schema);
1342 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1343 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1344 assert_eq!(p0.columntype.typ, SqlType::Struct);
1345 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1346 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1347 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1348 assert_eq!(p0_fields[2].columntype.typ, SqlType::UInt);
1349 assert_eq!(p0_fields[3].columntype.typ, SqlType::Struct);
1350 assert_eq!(p0_fields[3].name, "ADDRESS");
1351 let address = &p0_fields[3].columntype.fields.as_ref().unwrap();
1352 assert_eq!(address.len(), 4);
1353 assert_eq!(address[0].name, "STREET");
1354 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1355 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1356 assert_eq!(address[2].columntype.typ, SqlType::Char);
1357 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1358 }
1359
1360 #[test]
1361 fn sql_identifier_cmp() {
1362 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1363 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1364 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1365 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1366 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1367 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1368 assert_eq!(
1369 SqlIdentifier::new("bAr", true),
1370 SqlIdentifier::from("\"bAr\"")
1371 );
1372
1373 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1374 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1375 }
1376
1377 #[test]
1378 fn sql_identifier_ord() {
1379 let mut btree = std::collections::BTreeSet::new();
1380 assert!(btree.insert(SqlIdentifier::from("foo")));
1381 assert!(btree.insert(SqlIdentifier::from("bar")));
1382 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1383 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1384 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1385 }
1386
1387 #[test]
1388 fn sql_identifier_hash() {
1389 let mut hs = std::collections::HashSet::new();
1390 assert!(hs.insert(SqlIdentifier::from("foo")));
1391 assert!(hs.insert(SqlIdentifier::from("bar")));
1392 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1393 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1394 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1395 }
1396
1397 #[test]
1398 fn sql_identifier_name() {
1399 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1400 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1401 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1402 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1403 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1404 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1405 }
1406
1407 #[test]
1408 fn issue3277() {
1409 let schema = r#"{
1410 "name" : "j",
1411 "case_sensitive" : false,
1412 "columntype" : {
1413 "fields" : [ {
1414 "key" : {
1415 "nullable" : false,
1416 "precision" : -1,
1417 "type" : "VARCHAR"
1418 },
1419 "name" : "s",
1420 "nullable" : true,
1421 "type" : "MAP",
1422 "value" : {
1423 "nullable" : true,
1424 "precision" : -1,
1425 "type" : "VARCHAR"
1426 }
1427 } ],
1428 "nullable" : true
1429 }
1430 }"#;
1431 let field: Field = serde_json::from_str(schema).unwrap();
1432 println!("field: {:#?}", field);
1433 assert_eq!(
1434 field,
1435 Field {
1436 name: SqlIdentifier {
1437 name: "j".to_string(),
1438 case_sensitive: false,
1439 },
1440 columntype: ColumnType {
1441 typ: SqlType::Struct,
1442 nullable: true,
1443 precision: None,
1444 scale: None,
1445 component: None,
1446 fields: Some(vec![Field {
1447 name: SqlIdentifier {
1448 name: "s".to_string(),
1449 case_sensitive: false,
1450 },
1451 columntype: ColumnType {
1452 typ: SqlType::Map,
1453 nullable: true,
1454 precision: None,
1455 scale: None,
1456 component: None,
1457 fields: None,
1458 key: Some(Box::new(ColumnType {
1459 typ: SqlType::Varchar,
1460 nullable: false,
1461 precision: Some(-1),
1462 scale: None,
1463 component: None,
1464 fields: None,
1465 key: None,
1466 value: None,
1467 })),
1468 value: Some(Box::new(ColumnType {
1469 typ: SqlType::Varchar,
1470 nullable: true,
1471 precision: Some(-1),
1472 scale: None,
1473 component: None,
1474 fields: None,
1475 key: None,
1476 value: None,
1477 })),
1478 },
1479 lateness: None,
1480 default: None,
1481 unused: false,
1482 watermark: None,
1483 }]),
1484 key: None,
1485 value: None,
1486 },
1487 lateness: None,
1488 default: None,
1489 unused: false,
1490 watermark: None,
1491 }
1492 );
1493 }
1494}