1pub use feldera_ir::SourcePosition;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::cmp::Ordering;
4use std::collections::BTreeMap;
5use std::fmt::Display;
6use std::hash::{Hash, Hasher};
7use utoipa::ToSchema;
8
9#[cfg(feature = "testing")]
10use proptest::{collection::vec, prelude::any};
11
12pub fn canonical_identifier(id: &str) -> String {
20 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
21 id[1..id.len() - 1].to_string()
22 } else {
23 id.to_lowercase()
24 }
25}
26
27#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
32#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
33pub struct SqlIdentifier {
34 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
35 name: String,
36 pub case_sensitive: bool,
37}
38
39impl SqlIdentifier {
40 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
41 Self {
42 name: name.as_ref().to_string(),
43 case_sensitive,
44 }
45 }
46
47 pub fn name(&self) -> String {
57 if self.case_sensitive {
58 self.name.clone()
59 } else {
60 self.name.to_lowercase()
61 }
62 }
63
64 pub fn sql_name(&self) -> String {
75 if self.case_sensitive {
76 format!("\"{}\"", self.name)
77 } else {
78 self.name.clone()
79 }
80 }
81}
82
83impl Hash for SqlIdentifier {
84 fn hash<H: Hasher>(&self, state: &mut H) {
85 self.name().hash(state);
86 }
87}
88
89impl PartialEq for SqlIdentifier {
90 fn eq(&self, other: &Self) -> bool {
91 match (self.case_sensitive, other.case_sensitive) {
92 (true, true) => self.name == other.name,
93 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
94 (true, false) => self.name == other.name,
95 (false, true) => self.name == other.name,
96 }
97 }
98}
99
100impl Ord for SqlIdentifier {
101 fn cmp(&self, other: &Self) -> Ordering {
102 self.name().cmp(&other.name())
103 }
104}
105
106impl PartialOrd for SqlIdentifier {
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
113 fn eq(&self, other: &S) -> bool {
114 self == &SqlIdentifier::from(other.as_ref())
115 }
116}
117
118impl Eq for SqlIdentifier {}
119
120impl<S: AsRef<str>> From<S> for SqlIdentifier {
121 fn from(name: S) -> Self {
122 if name.as_ref().starts_with('"')
123 && name.as_ref().ends_with('"')
124 && name.as_ref().len() >= 2
125 {
126 Self {
127 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
128 case_sensitive: true,
129 }
130 } else {
131 Self::new(name, false)
132 }
133 }
134}
135
136impl From<SqlIdentifier> for String {
137 fn from(id: SqlIdentifier) -> String {
138 id.name()
139 }
140}
141
142impl From<&SqlIdentifier> for String {
143 fn from(id: &SqlIdentifier) -> String {
144 id.name()
145 }
146}
147
148impl Display for SqlIdentifier {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(f, "{}", self.name())
151 }
152}
153
154#[derive(Default, Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
158#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
159pub struct ProgramSchema {
160 #[cfg_attr(
161 feature = "testing",
162 proptest(strategy = "vec(any::<Relation>(), 0..2)")
163 )]
164 pub inputs: Vec<Relation>,
165 #[cfg_attr(
166 feature = "testing",
167 proptest(strategy = "vec(any::<Relation>(), 0..2)")
168 )]
169 pub outputs: Vec<Relation>,
170}
171
172impl ProgramSchema {
173 pub fn relations_with_lateness(&self) -> Vec<SqlIdentifier> {
174 self.inputs
175 .iter()
176 .chain(self.outputs.iter())
177 .filter(|rel| rel.has_lateness())
178 .map(|rel| rel.name.clone())
179 .collect()
180 }
181}
182
183#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
184#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
185pub struct PropertyValue {
186 pub value: String,
187 pub key_position: SourcePosition,
188 pub value_position: SourcePosition,
189}
190
191#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
195#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
196pub struct Relation {
197 #[serde(flatten)]
198 pub name: SqlIdentifier,
199 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
200 pub fields: Vec<Field>,
201 #[serde(default)]
202 pub materialized: bool,
203 #[serde(default)]
204 pub properties: BTreeMap<String, PropertyValue>,
205 pub primary_key: Option<Vec<String>>,
206}
207
208impl Relation {
209 pub fn empty() -> Self {
210 Self {
211 name: SqlIdentifier::from("".to_string()),
212 fields: Vec::new(),
213 materialized: false,
214 properties: BTreeMap::new(),
215 primary_key: None,
216 }
217 }
218
219 pub fn new(
220 name: SqlIdentifier,
221 fields: Vec<Field>,
222 materialized: bool,
223 properties: BTreeMap<String, PropertyValue>,
224 ) -> Self {
225 Self {
226 name,
227 fields,
228 materialized,
229 properties,
230 primary_key: None,
231 }
232 }
233
234 pub fn field(&self, name: &str) -> Option<&Field> {
236 let name = canonical_identifier(name);
237 self.fields.iter().find(|f| f.name == name)
238 }
239
240 pub fn has_lateness(&self) -> bool {
241 self.fields.iter().any(|f| f.lateness.is_some())
242 }
243
244 pub fn get_property(&self, name: &str) -> Option<&str> {
245 self.properties.get(name).map(|p| p.value.as_str())
246 }
247
248 pub fn with_primary_key<'a>(
249 mut self,
250 primary_key: impl IntoIterator<Item = &'a SqlIdentifier>,
251 ) -> Self {
252 self.primary_key = Some(primary_key.into_iter().map(|id| id.name()).collect());
253 self
254 }
255}
256
257#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
261#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
262pub struct Field {
263 #[serde(flatten)]
264 pub name: SqlIdentifier,
265 pub columntype: ColumnType,
266 pub lateness: Option<String>,
267 pub default: Option<String>,
268 pub unused: bool,
269 pub watermark: Option<String>,
270}
271
272impl Field {
273 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
274 Self {
275 name,
276 columntype,
277 lateness: None,
278 default: None,
279 unused: false,
280 watermark: None,
281 }
282 }
283
284 pub fn with_lateness(mut self, lateness: &str) -> Self {
285 self.lateness = Some(lateness.to_string());
286 self
287 }
288
289 pub fn with_unused(mut self, unused: bool) -> Self {
290 self.unused = unused;
291 self
292 }
293}
294
295impl<'de> Deserialize<'de> for Field {
300 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
301 where
302 D: Deserializer<'de>,
303 {
304 const fn default_is_struct() -> Option<SqlType> {
305 Some(SqlType::Struct)
306 }
307
308 #[derive(Debug, Clone, Deserialize)]
309 struct FieldHelper {
310 name: Option<String>,
311 #[serde(default)]
312 case_sensitive: bool,
313 columntype: Option<ColumnType>,
314 #[serde(rename = "type")]
315 #[serde(default = "default_is_struct")]
316 typ: Option<SqlType>,
317 nullable: Option<bool>,
318 precision: Option<i64>,
319 scale: Option<i64>,
320 component: Option<Box<ColumnType>>,
321 fields: Option<serde_json::Value>,
322 key: Option<Box<ColumnType>>,
323 value: Option<Box<ColumnType>>,
324 default: Option<String>,
325 #[serde(default)]
326 unused: bool,
327 lateness: Option<String>,
328 watermark: Option<String>,
329 }
330
331 fn helper_to_field(helper: FieldHelper) -> Field {
332 let columntype = if let Some(ctype) = helper.columntype {
333 ctype
334 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
335 let fields = fields
336 .into_iter()
337 .map(|field| {
338 let field: FieldHelper = serde_json::from_value(field).unwrap();
339 helper_to_field(field)
340 })
341 .collect::<Vec<Field>>();
342
343 ColumnType {
344 typ: helper.typ.unwrap_or(SqlType::Null),
345 nullable: helper.nullable.unwrap_or(false),
346 precision: helper.precision,
347 scale: helper.scale,
348 component: helper.component,
349 fields: Some(fields),
350 key: None,
351 value: None,
352 }
353 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
354 serde_json::from_value(serde_json::Value::Object(obj))
355 .expect("Failed to deserialize object")
356 } else {
357 ColumnType {
358 typ: helper.typ.unwrap_or(SqlType::Null),
359 nullable: helper.nullable.unwrap_or(false),
360 precision: helper.precision,
361 scale: helper.scale,
362 component: helper.component,
363 fields: None,
364 key: helper.key,
365 value: helper.value,
366 }
367 };
368
369 Field {
370 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
371 columntype,
372 default: helper.default,
373 unused: helper.unused,
374 lateness: helper.lateness,
375 watermark: helper.watermark,
376 }
377 }
378
379 let helper = FieldHelper::deserialize(deserializer)?;
380 Ok(helper_to_field(helper))
381 }
382}
383
384#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
389#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
390pub enum IntervalUnit {
391 Day,
393 DayToHour,
395 DayToMinute,
397 DayToSecond,
399 Hour,
401 HourToMinute,
403 HourToSecond,
405 Minute,
407 MinuteToSecond,
409 Month,
411 Second,
413 Year,
415 YearToMonth,
417}
418
419#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
421#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
422pub enum SqlType {
423 Boolean,
425 TinyInt,
427 SmallInt,
429 Int,
431 BigInt,
433 UTinyInt,
435 USmallInt,
437 UInt,
439 UBigInt,
441 Real,
443 Double,
445 Decimal,
447 Char,
449 Varchar,
451 Binary,
453 Varbinary,
455 Time,
457 Date,
459 Timestamp,
461 Interval(IntervalUnit),
463 Array,
465 Struct,
467 Map,
469 Null,
471 Uuid,
473 Variant,
475}
476
477impl Display for SqlType {
478 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479 f.write_str(&serde_json::to_string(self).unwrap())
480 }
481}
482
483impl<'de> Deserialize<'de> for SqlType {
484 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
485 where
486 D: Deserializer<'de>,
487 {
488 let value: String = Deserialize::deserialize(deserializer)?;
489 match value.to_lowercase().as_str() {
490 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
491 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
492 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
493 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
494 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
495 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
496 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
497 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
498 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
499 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
500 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
501 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
502 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
503 "boolean" => Ok(SqlType::Boolean),
504 "tinyint" => Ok(SqlType::TinyInt),
505 "smallint" => Ok(SqlType::SmallInt),
506 "integer" => Ok(SqlType::Int),
507 "bigint" => Ok(SqlType::BigInt),
508 "utinyint" => Ok(SqlType::UTinyInt),
509 "usmallint" => Ok(SqlType::USmallInt),
510 "uinteger" => Ok(SqlType::UInt),
511 "ubigint" => Ok(SqlType::UBigInt),
512 "real" => Ok(SqlType::Real),
513 "double" => Ok(SqlType::Double),
514 "decimal" => Ok(SqlType::Decimal),
515 "char" => Ok(SqlType::Char),
516 "varchar" => Ok(SqlType::Varchar),
517 "binary" => Ok(SqlType::Binary),
518 "varbinary" => Ok(SqlType::Varbinary),
519 "variant" => Ok(SqlType::Variant),
520 "time" => Ok(SqlType::Time),
521 "date" => Ok(SqlType::Date),
522 "timestamp" => Ok(SqlType::Timestamp),
523 "array" => Ok(SqlType::Array),
524 "struct" => Ok(SqlType::Struct),
525 "map" => Ok(SqlType::Map),
526 "null" => Ok(SqlType::Null),
527 "uuid" => Ok(SqlType::Uuid),
528 _ => Err(serde::de::Error::custom(format!(
529 "Unknown SQL type: {}",
530 value
531 ))),
532 }
533 }
534}
535
536impl Serialize for SqlType {
537 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
538 where
539 S: Serializer,
540 {
541 let type_str = match self {
542 SqlType::Boolean => "BOOLEAN",
543 SqlType::TinyInt => "TINYINT",
544 SqlType::SmallInt => "SMALLINT",
545 SqlType::Int => "INTEGER",
546 SqlType::BigInt => "BIGINT",
547 SqlType::UTinyInt => "UTINYINT",
548 SqlType::USmallInt => "USMALLINT",
549 SqlType::UInt => "UINTEGER",
550 SqlType::UBigInt => "UBIGINT",
551 SqlType::Real => "REAL",
552 SqlType::Double => "DOUBLE",
553 SqlType::Decimal => "DECIMAL",
554 SqlType::Char => "CHAR",
555 SqlType::Varchar => "VARCHAR",
556 SqlType::Binary => "BINARY",
557 SqlType::Varbinary => "VARBINARY",
558 SqlType::Time => "TIME",
559 SqlType::Date => "DATE",
560 SqlType::Timestamp => "TIMESTAMP",
561 SqlType::Interval(interval_unit) => match interval_unit {
562 IntervalUnit::Day => "INTERVAL_DAY",
563 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
564 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
565 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
566 IntervalUnit::Hour => "INTERVAL_HOUR",
567 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
568 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
569 IntervalUnit::Minute => "INTERVAL_MINUTE",
570 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
571 IntervalUnit::Month => "INTERVAL_MONTH",
572 IntervalUnit::Second => "INTERVAL_SECOND",
573 IntervalUnit::Year => "INTERVAL_YEAR",
574 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
575 },
576 SqlType::Array => "ARRAY",
577 SqlType::Struct => "STRUCT",
578 SqlType::Uuid => "UUID",
579 SqlType::Map => "MAP",
580 SqlType::Null => "NULL",
581 SqlType::Variant => "VARIANT",
582 };
583 serializer.serialize_str(type_str)
584 }
585}
586
587impl SqlType {
588 pub fn is_string(&self) -> bool {
590 matches!(self, Self::Char | Self::Varchar)
591 }
592
593 pub fn is_varchar(&self) -> bool {
594 matches!(self, Self::Varchar)
595 }
596
597 pub fn is_varbinary(&self) -> bool {
598 matches!(self, Self::Varbinary)
599 }
600}
601
602const fn default_is_struct() -> SqlType {
605 SqlType::Struct
606}
607
608#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
612#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
613pub struct ColumnType {
614 #[serde(rename = "type")]
616 #[serde(default = "default_is_struct")]
617 pub typ: SqlType,
618 pub nullable: bool,
620 pub precision: Option<i64>,
629 pub scale: Option<i64>,
634 #[cfg_attr(feature = "testing", proptest(value = "None"))]
641 pub component: Option<Box<ColumnType>>,
642 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
664 pub fields: Option<Vec<Field>>,
665 #[cfg_attr(feature = "testing", proptest(value = "None"))]
667 pub key: Option<Box<ColumnType>>,
668 #[cfg_attr(feature = "testing", proptest(value = "None"))]
670 pub value: Option<Box<ColumnType>>,
671}
672
673impl ColumnType {
674 pub fn boolean(nullable: bool) -> Self {
675 ColumnType {
676 typ: SqlType::Boolean,
677 nullable,
678 precision: None,
679 scale: None,
680 component: None,
681 fields: None,
682 key: None,
683 value: None,
684 }
685 }
686
687 pub fn uuid(nullable: bool) -> Self {
688 ColumnType {
689 typ: SqlType::Uuid,
690 nullable,
691 precision: None,
692 scale: None,
693 component: None,
694 fields: None,
695 key: None,
696 value: None,
697 }
698 }
699
700 pub fn tinyint(nullable: bool) -> Self {
701 ColumnType {
702 typ: SqlType::TinyInt,
703 nullable,
704 precision: None,
705 scale: None,
706 component: None,
707 fields: None,
708 key: None,
709 value: None,
710 }
711 }
712
713 pub fn smallint(nullable: bool) -> Self {
714 ColumnType {
715 typ: SqlType::SmallInt,
716 nullable,
717 precision: None,
718 scale: None,
719 component: None,
720 fields: None,
721 key: None,
722 value: None,
723 }
724 }
725
726 pub fn int(nullable: bool) -> Self {
727 ColumnType {
728 typ: SqlType::Int,
729 nullable,
730 precision: None,
731 scale: None,
732 component: None,
733 fields: None,
734 key: None,
735 value: None,
736 }
737 }
738
739 pub fn bigint(nullable: bool) -> Self {
740 ColumnType {
741 typ: SqlType::BigInt,
742 nullable,
743 precision: None,
744 scale: None,
745 component: None,
746 fields: None,
747 key: None,
748 value: None,
749 }
750 }
751
752 pub fn utinyint(nullable: bool) -> Self {
753 ColumnType {
754 typ: SqlType::UTinyInt,
755 nullable,
756 precision: None,
757 scale: None,
758 component: None,
759 fields: None,
760 key: None,
761 value: None,
762 }
763 }
764
765 pub fn usmallint(nullable: bool) -> Self {
766 ColumnType {
767 typ: SqlType::USmallInt,
768 nullable,
769 precision: None,
770 scale: None,
771 component: None,
772 fields: None,
773 key: None,
774 value: None,
775 }
776 }
777
778 pub fn uint(nullable: bool) -> Self {
779 ColumnType {
780 typ: SqlType::UInt,
781 nullable,
782 precision: None,
783 scale: None,
784 component: None,
785 fields: None,
786 key: None,
787 value: None,
788 }
789 }
790
791 pub fn ubigint(nullable: bool) -> Self {
792 ColumnType {
793 typ: SqlType::UBigInt,
794 nullable,
795 precision: None,
796 scale: None,
797 component: None,
798 fields: None,
799 key: None,
800 value: None,
801 }
802 }
803
804 pub fn double(nullable: bool) -> Self {
805 ColumnType {
806 typ: SqlType::Double,
807 nullable,
808 precision: None,
809 scale: None,
810 component: None,
811 fields: None,
812 key: None,
813 value: None,
814 }
815 }
816
817 pub fn real(nullable: bool) -> Self {
818 ColumnType {
819 typ: SqlType::Real,
820 nullable,
821 precision: None,
822 scale: None,
823 component: None,
824 fields: None,
825 key: None,
826 value: None,
827 }
828 }
829
830 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
831 ColumnType {
832 typ: SqlType::Decimal,
833 nullable,
834 precision: Some(precision),
835 scale: Some(scale),
836 component: None,
837 fields: None,
838 key: None,
839 value: None,
840 }
841 }
842
843 pub fn varchar(nullable: bool) -> Self {
844 ColumnType {
845 typ: SqlType::Varchar,
846 nullable,
847 precision: None,
848 scale: None,
849 component: None,
850 fields: None,
851 key: None,
852 value: None,
853 }
854 }
855
856 pub fn varbinary(nullable: bool) -> Self {
857 ColumnType {
858 typ: SqlType::Varbinary,
859 nullable,
860 precision: None,
861 scale: None,
862 component: None,
863 fields: None,
864 key: None,
865 value: None,
866 }
867 }
868
869 pub fn fixed(width: i64, nullable: bool) -> Self {
870 ColumnType {
871 typ: SqlType::Binary,
872 nullable,
873 precision: Some(width),
874 scale: None,
875 component: None,
876 fields: None,
877 key: None,
878 value: None,
879 }
880 }
881
882 pub fn date(nullable: bool) -> Self {
883 ColumnType {
884 typ: SqlType::Date,
885 nullable,
886 precision: None,
887 scale: None,
888 component: None,
889 fields: None,
890 key: None,
891 value: None,
892 }
893 }
894
895 pub fn time(nullable: bool) -> Self {
896 ColumnType {
897 typ: SqlType::Time,
898 nullable,
899 precision: None,
900 scale: None,
901 component: None,
902 fields: None,
903 key: None,
904 value: None,
905 }
906 }
907
908 pub fn timestamp(nullable: bool) -> Self {
909 ColumnType {
910 typ: SqlType::Timestamp,
911 nullable,
912 precision: None,
913 scale: None,
914 component: None,
915 fields: None,
916 key: None,
917 value: None,
918 }
919 }
920
921 pub fn variant(nullable: bool) -> Self {
922 ColumnType {
923 typ: SqlType::Variant,
924 nullable,
925 precision: None,
926 scale: None,
927 component: None,
928 fields: None,
929 key: None,
930 value: None,
931 }
932 }
933
934 pub fn array(nullable: bool, element: ColumnType) -> Self {
935 ColumnType {
936 typ: SqlType::Array,
937 nullable,
938 precision: None,
939 scale: None,
940 component: Some(Box::new(element)),
941 fields: None,
942 key: None,
943 value: None,
944 }
945 }
946
947 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
948 ColumnType {
949 typ: SqlType::Struct,
950 nullable,
951 precision: None,
952 scale: None,
953 component: None,
954 fields: Some(fields.to_vec()),
955 key: None,
956 value: None,
957 }
958 }
959
960 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
961 ColumnType {
962 typ: SqlType::Map,
963 nullable,
964 precision: None,
965 scale: None,
966 component: None,
967 fields: None,
968 key: Some(Box::new(key)),
969 value: Some(Box::new(val)),
970 }
971 }
972
973 pub fn is_integral_type(&self) -> bool {
974 matches!(
975 &self.typ,
976 SqlType::TinyInt
977 | SqlType::SmallInt
978 | SqlType::Int
979 | SqlType::BigInt
980 | SqlType::UTinyInt
981 | SqlType::USmallInt
982 | SqlType::UInt
983 | SqlType::UBigInt
984 )
985 }
986
987 pub fn is_fp_type(&self) -> bool {
988 matches!(&self.typ, SqlType::Double | SqlType::Real)
989 }
990
991 pub fn is_decimal_type(&self) -> bool {
992 matches!(&self.typ, SqlType::Decimal)
993 }
994
995 pub fn is_numeric_type(&self) -> bool {
996 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
997 }
998}
999
1000#[cfg(test)]
1001mod tests {
1002 use super::{IntervalUnit, SqlIdentifier};
1003 use crate::program_schema::{ColumnType, Field, SqlType};
1004
1005 #[test]
1006 fn serde_sql_type() {
1007 for (sql_str_base, expected_value) in [
1008 ("Boolean", SqlType::Boolean),
1009 ("Uuid", SqlType::Uuid),
1010 ("TinyInt", SqlType::TinyInt),
1011 ("SmallInt", SqlType::SmallInt),
1012 ("Integer", SqlType::Int),
1013 ("BigInt", SqlType::BigInt),
1014 ("UTinyInt", SqlType::UTinyInt),
1015 ("USmallInt", SqlType::USmallInt),
1016 ("UInteger", SqlType::UInt),
1017 ("UBigInt", SqlType::UBigInt),
1018 ("Real", SqlType::Real),
1019 ("Double", SqlType::Double),
1020 ("Decimal", SqlType::Decimal),
1021 ("Char", SqlType::Char),
1022 ("Varchar", SqlType::Varchar),
1023 ("Binary", SqlType::Binary),
1024 ("Varbinary", SqlType::Varbinary),
1025 ("Time", SqlType::Time),
1026 ("Date", SqlType::Date),
1027 ("Timestamp", SqlType::Timestamp),
1028 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
1029 (
1030 "Interval_Day_Hour",
1031 SqlType::Interval(IntervalUnit::DayToHour),
1032 ),
1033 (
1034 "Interval_Day_Minute",
1035 SqlType::Interval(IntervalUnit::DayToMinute),
1036 ),
1037 (
1038 "Interval_Day_Second",
1039 SqlType::Interval(IntervalUnit::DayToSecond),
1040 ),
1041 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
1042 (
1043 "Interval_Hour_Minute",
1044 SqlType::Interval(IntervalUnit::HourToMinute),
1045 ),
1046 (
1047 "Interval_Hour_Second",
1048 SqlType::Interval(IntervalUnit::HourToSecond),
1049 ),
1050 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
1051 (
1052 "Interval_Minute_Second",
1053 SqlType::Interval(IntervalUnit::MinuteToSecond),
1054 ),
1055 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
1056 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
1057 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
1058 (
1059 "Interval_Year_Month",
1060 SqlType::Interval(IntervalUnit::YearToMonth),
1061 ),
1062 ("Array", SqlType::Array),
1063 ("Struct", SqlType::Struct),
1064 ("Map", SqlType::Map),
1065 ("Null", SqlType::Null),
1066 ("Variant", SqlType::Variant),
1067 ] {
1068 for sql_str in [
1069 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
1073 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
1074 .unwrap_or_else(|_| {
1075 panic!("\"{sql_str}\" should deserialize into its SQL type")
1076 });
1077 assert_eq!(value1, expected_value);
1078 let serialized_str =
1079 serde_json::to_string(&value1).expect("Value should serialize into JSON");
1080 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
1081 panic!(
1082 "{} should deserialize back into its SQL type",
1083 serialized_str
1084 )
1085 });
1086 assert_eq!(value1, value2);
1087 }
1088 }
1089 }
1090
1091 #[test]
1092 fn deserialize_interval_types() {
1093 use super::IntervalUnit::*;
1094 use super::SqlType::*;
1095
1096 let schema = r#"
1097{
1098 "inputs" : [ {
1099 "name" : "sales",
1100 "case_sensitive" : false,
1101 "fields" : [ {
1102 "name" : "sales_id",
1103 "case_sensitive" : false,
1104 "columntype" : {
1105 "type" : "INTEGER",
1106 "nullable" : true
1107 }
1108 }, {
1109 "name" : "customer_id",
1110 "case_sensitive" : false,
1111 "columntype" : {
1112 "type" : "INTEGER",
1113 "nullable" : true
1114 }
1115 }, {
1116 "name" : "age",
1117 "case_sensitive" : false,
1118 "columntype" : {
1119 "type" : "UINTEGER",
1120 "nullable" : true
1121 }
1122 }, {
1123 "name" : "amount",
1124 "case_sensitive" : false,
1125 "columntype" : {
1126 "type" : "DECIMAL",
1127 "nullable" : true,
1128 "precision" : 10,
1129 "scale" : 2
1130 }
1131 }, {
1132 "name" : "sale_date",
1133 "case_sensitive" : false,
1134 "columntype" : {
1135 "type" : "DATE",
1136 "nullable" : true
1137 }
1138 } ],
1139 "primary_key" : [ "sales_id" ]
1140 } ],
1141 "outputs" : [ {
1142 "name" : "salessummary",
1143 "case_sensitive" : false,
1144 "fields" : [ {
1145 "name" : "customer_id",
1146 "case_sensitive" : false,
1147 "columntype" : {
1148 "type" : "INTEGER",
1149 "nullable" : true
1150 }
1151 }, {
1152 "name" : "total_sales",
1153 "case_sensitive" : false,
1154 "columntype" : {
1155 "type" : "DECIMAL",
1156 "nullable" : true,
1157 "precision" : 38,
1158 "scale" : 2
1159 }
1160 }, {
1161 "name" : "interval_day",
1162 "case_sensitive" : false,
1163 "columntype" : {
1164 "type" : "INTERVAL_DAY",
1165 "nullable" : false,
1166 "precision" : 2,
1167 "scale" : 6
1168 }
1169 }, {
1170 "name" : "interval_day_to_hour",
1171 "case_sensitive" : false,
1172 "columntype" : {
1173 "type" : "INTERVAL_DAY_HOUR",
1174 "nullable" : false,
1175 "precision" : 2,
1176 "scale" : 6
1177 }
1178 }, {
1179 "name" : "interval_day_to_minute",
1180 "case_sensitive" : false,
1181 "columntype" : {
1182 "type" : "INTERVAL_DAY_MINUTE",
1183 "nullable" : false,
1184 "precision" : 2,
1185 "scale" : 6
1186 }
1187 }, {
1188 "name" : "interval_day_to_second",
1189 "case_sensitive" : false,
1190 "columntype" : {
1191 "type" : "INTERVAL_DAY_SECOND",
1192 "nullable" : false,
1193 "precision" : 2,
1194 "scale" : 6
1195 }
1196 }, {
1197 "name" : "interval_hour",
1198 "case_sensitive" : false,
1199 "columntype" : {
1200 "type" : "INTERVAL_HOUR",
1201 "nullable" : false,
1202 "precision" : 2,
1203 "scale" : 6
1204 }
1205 }, {
1206 "name" : "interval_hour_to_minute",
1207 "case_sensitive" : false,
1208 "columntype" : {
1209 "type" : "INTERVAL_HOUR_MINUTE",
1210 "nullable" : false,
1211 "precision" : 2,
1212 "scale" : 6
1213 }
1214 }, {
1215 "name" : "interval_hour_to_second",
1216 "case_sensitive" : false,
1217 "columntype" : {
1218 "type" : "INTERVAL_HOUR_SECOND",
1219 "nullable" : false,
1220 "precision" : 2,
1221 "scale" : 6
1222 }
1223 }, {
1224 "name" : "interval_minute",
1225 "case_sensitive" : false,
1226 "columntype" : {
1227 "type" : "INTERVAL_MINUTE",
1228 "nullable" : false,
1229 "precision" : 2,
1230 "scale" : 6
1231 }
1232 }, {
1233 "name" : "interval_minute_to_second",
1234 "case_sensitive" : false,
1235 "columntype" : {
1236 "type" : "INTERVAL_MINUTE_SECOND",
1237 "nullable" : false,
1238 "precision" : 2,
1239 "scale" : 6
1240 }
1241 }, {
1242 "name" : "interval_month",
1243 "case_sensitive" : false,
1244 "columntype" : {
1245 "type" : "INTERVAL_MONTH",
1246 "nullable" : false
1247 }
1248 }, {
1249 "name" : "interval_second",
1250 "case_sensitive" : false,
1251 "columntype" : {
1252 "type" : "INTERVAL_SECOND",
1253 "nullable" : false,
1254 "precision" : 2,
1255 "scale" : 6
1256 }
1257 }, {
1258 "name" : "interval_year",
1259 "case_sensitive" : false,
1260 "columntype" : {
1261 "type" : "INTERVAL_YEAR",
1262 "nullable" : false
1263 }
1264 }, {
1265 "name" : "interval_year_to_month",
1266 "case_sensitive" : false,
1267 "columntype" : {
1268 "type" : "INTERVAL_YEAR_MONTH",
1269 "nullable" : false
1270 }
1271 } ]
1272 } ]
1273}
1274"#;
1275
1276 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1277 let types = schema
1278 .outputs
1279 .iter()
1280 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1281 let expected_types = [
1282 Int,
1283 Decimal,
1284 Interval(Day),
1285 Interval(DayToHour),
1286 Interval(DayToMinute),
1287 Interval(DayToSecond),
1288 Interval(Hour),
1289 Interval(HourToMinute),
1290 Interval(HourToSecond),
1291 Interval(Minute),
1292 Interval(MinuteToSecond),
1293 Interval(Month),
1294 Interval(Second),
1295 Interval(Year),
1296 Interval(YearToMonth),
1297 ];
1298
1299 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1300 }
1301
1302 #[test]
1303 fn serialize_struct_schemas() {
1304 let schema = r#"{
1305 "inputs" : [ {
1306 "name" : "PERS",
1307 "case_sensitive" : false,
1308 "fields" : [ {
1309 "name" : "P0",
1310 "case_sensitive" : false,
1311 "columntype" : {
1312 "fields" : [ {
1313 "type" : "VARCHAR",
1314 "nullable" : true,
1315 "precision" : 30,
1316 "name" : "FIRSTNAME"
1317 }, {
1318 "type" : "VARCHAR",
1319 "nullable" : true,
1320 "precision" : 30,
1321 "name" : "LASTNAME"
1322 }, {
1323 "type" : "UINTEGER",
1324 "nullable" : true,
1325 "name" : "AGE"
1326 }, {
1327 "fields" : {
1328 "fields" : [ {
1329 "type" : "VARCHAR",
1330 "nullable" : true,
1331 "precision" : 30,
1332 "name" : "STREET"
1333 }, {
1334 "type" : "VARCHAR",
1335 "nullable" : true,
1336 "precision" : 30,
1337 "name" : "CITY"
1338 }, {
1339 "type" : "CHAR",
1340 "nullable" : true,
1341 "precision" : 2,
1342 "name" : "STATE"
1343 }, {
1344 "type" : "VARCHAR",
1345 "nullable" : true,
1346 "precision" : 6,
1347 "name" : "POSTAL_CODE"
1348 } ],
1349 "nullable" : false
1350 },
1351 "nullable" : false,
1352 "name" : "ADDRESS"
1353 } ],
1354 "nullable" : false
1355 }
1356 }]
1357 } ],
1358 "outputs" : [ ]
1359}
1360"#;
1361 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1362 eprintln!("{:#?}", schema);
1363 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1364 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1365 assert_eq!(p0.columntype.typ, SqlType::Struct);
1366 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1367 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1368 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1369 assert_eq!(p0_fields[2].columntype.typ, SqlType::UInt);
1370 assert_eq!(p0_fields[3].columntype.typ, SqlType::Struct);
1371 assert_eq!(p0_fields[3].name, "ADDRESS");
1372 let address = &p0_fields[3].columntype.fields.as_ref().unwrap();
1373 assert_eq!(address.len(), 4);
1374 assert_eq!(address[0].name, "STREET");
1375 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1376 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1377 assert_eq!(address[2].columntype.typ, SqlType::Char);
1378 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1379 }
1380
1381 #[test]
1382 fn sql_identifier_cmp() {
1383 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1384 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1385 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1386 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1387 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1388 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1389 assert_eq!(
1390 SqlIdentifier::new("bAr", true),
1391 SqlIdentifier::from("\"bAr\"")
1392 );
1393
1394 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1395 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1396 }
1397
1398 #[test]
1399 fn sql_identifier_ord() {
1400 let mut btree = std::collections::BTreeSet::new();
1401 assert!(btree.insert(SqlIdentifier::from("foo")));
1402 assert!(btree.insert(SqlIdentifier::from("bar")));
1403 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1404 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1405 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1406 }
1407
1408 #[test]
1409 fn sql_identifier_hash() {
1410 let mut hs = std::collections::HashSet::new();
1411 assert!(hs.insert(SqlIdentifier::from("foo")));
1412 assert!(hs.insert(SqlIdentifier::from("bar")));
1413 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1414 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1415 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1416 }
1417
1418 #[test]
1419 fn sql_identifier_name() {
1420 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1421 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1422 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1423 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1424 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1425 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1426 }
1427
1428 #[test]
1429 fn issue3277() {
1430 let schema = r#"{
1431 "name" : "j",
1432 "case_sensitive" : false,
1433 "columntype" : {
1434 "fields" : [ {
1435 "key" : {
1436 "nullable" : false,
1437 "precision" : -1,
1438 "type" : "VARCHAR"
1439 },
1440 "name" : "s",
1441 "nullable" : true,
1442 "type" : "MAP",
1443 "value" : {
1444 "nullable" : true,
1445 "precision" : -1,
1446 "type" : "VARCHAR"
1447 }
1448 } ],
1449 "nullable" : true
1450 }
1451 }"#;
1452 let field: Field = serde_json::from_str(schema).unwrap();
1453 println!("field: {:#?}", field);
1454 assert_eq!(
1455 field,
1456 Field {
1457 name: SqlIdentifier {
1458 name: "j".to_string(),
1459 case_sensitive: false,
1460 },
1461 columntype: ColumnType {
1462 typ: SqlType::Struct,
1463 nullable: true,
1464 precision: None,
1465 scale: None,
1466 component: None,
1467 fields: Some(vec![Field {
1468 name: SqlIdentifier {
1469 name: "s".to_string(),
1470 case_sensitive: false,
1471 },
1472 columntype: ColumnType {
1473 typ: SqlType::Map,
1474 nullable: true,
1475 precision: None,
1476 scale: None,
1477 component: None,
1478 fields: None,
1479 key: Some(Box::new(ColumnType {
1480 typ: SqlType::Varchar,
1481 nullable: false,
1482 precision: Some(-1),
1483 scale: None,
1484 component: None,
1485 fields: None,
1486 key: None,
1487 value: None,
1488 })),
1489 value: Some(Box::new(ColumnType {
1490 typ: SqlType::Varchar,
1491 nullable: true,
1492 precision: Some(-1),
1493 scale: None,
1494 component: None,
1495 fields: None,
1496 key: None,
1497 value: None,
1498 })),
1499 },
1500 lateness: None,
1501 default: None,
1502 unused: false,
1503 watermark: None,
1504 }]),
1505 key: None,
1506 value: None,
1507 },
1508 lateness: None,
1509 default: None,
1510 unused: false,
1511 watermark: None,
1512 }
1513 );
1514 }
1515}