1pub use feldera_ir::SourcePosition;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::cmp::Ordering;
4use std::collections::BTreeMap;
5use std::fmt::Display;
6use std::hash::{Hash, Hasher};
7use utoipa::ToSchema;
8
9#[cfg(feature = "testing")]
10use proptest::{collection::vec, prelude::any};
11
12pub fn canonical_identifier(id: &str) -> String {
20 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
21 id[1..id.len() - 1].to_string()
22 } else {
23 id.to_lowercase()
24 }
25}
26
27#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
32#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
33pub struct SqlIdentifier {
34 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
35 name: String,
36 pub case_sensitive: bool,
37}
38
39impl SqlIdentifier {
40 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
41 Self {
42 name: name.as_ref().to_string(),
43 case_sensitive,
44 }
45 }
46
47 pub fn name(&self) -> String {
57 if self.case_sensitive {
58 self.name.clone()
59 } else {
60 self.name.to_lowercase()
61 }
62 }
63
64 pub fn sql_name(&self) -> String {
75 if self.case_sensitive {
76 format!("\"{}\"", self.name)
77 } else {
78 self.name.clone()
79 }
80 }
81}
82
83impl Hash for SqlIdentifier {
84 fn hash<H: Hasher>(&self, state: &mut H) {
85 self.name().hash(state);
86 }
87}
88
89impl PartialEq for SqlIdentifier {
90 fn eq(&self, other: &Self) -> bool {
91 match (self.case_sensitive, other.case_sensitive) {
92 (true, true) => self.name == other.name,
93 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
94 (true, false) => self.name == other.name,
95 (false, true) => self.name == other.name,
96 }
97 }
98}
99
100impl Ord for SqlIdentifier {
101 fn cmp(&self, other: &Self) -> Ordering {
102 self.name().cmp(&other.name())
103 }
104}
105
106impl PartialOrd for SqlIdentifier {
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
113 fn eq(&self, other: &S) -> bool {
114 self == &SqlIdentifier::from(other.as_ref())
115 }
116}
117
118impl Eq for SqlIdentifier {}
119
120impl<S: AsRef<str>> From<S> for SqlIdentifier {
121 fn from(name: S) -> Self {
122 if name.as_ref().starts_with('"')
123 && name.as_ref().ends_with('"')
124 && name.as_ref().len() >= 2
125 {
126 Self {
127 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
128 case_sensitive: true,
129 }
130 } else {
131 Self::new(name, false)
132 }
133 }
134}
135
136impl From<SqlIdentifier> for String {
137 fn from(id: SqlIdentifier) -> String {
138 id.name()
139 }
140}
141
142impl From<&SqlIdentifier> for String {
143 fn from(id: &SqlIdentifier) -> String {
144 id.name()
145 }
146}
147
148impl Display for SqlIdentifier {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(f, "{}", self.name())
151 }
152}
153
154#[derive(Default, Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
158#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
159pub struct ProgramSchema {
160 #[cfg_attr(
161 feature = "testing",
162 proptest(strategy = "vec(any::<Relation>(), 0..2)")
163 )]
164 pub inputs: Vec<Relation>,
165 #[cfg_attr(
166 feature = "testing",
167 proptest(strategy = "vec(any::<Relation>(), 0..2)")
168 )]
169 pub outputs: Vec<Relation>,
170}
171
172impl ProgramSchema {
173 pub fn relations_with_lateness(&self) -> Vec<SqlIdentifier> {
174 self.inputs
175 .iter()
176 .chain(self.outputs.iter())
177 .filter(|rel| rel.has_lateness())
178 .map(|rel| rel.name.clone())
179 .collect()
180 }
181}
182
183#[derive(Debug, Deserialize)]
188pub struct ProgramSchemaPropertiesOnly {
189 #[serde(default)]
190 pub inputs: Vec<RelationPropertiesOnly>,
191 #[serde(default)]
192 pub outputs: Vec<RelationPropertiesOnly>,
193}
194
195#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
196#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
197pub struct PropertyValue {
198 pub value: String,
199 pub key_position: SourcePosition,
200 pub value_position: SourcePosition,
201}
202
203#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
207#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
208pub struct Relation {
209 #[serde(flatten)]
210 pub name: SqlIdentifier,
211 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
212 pub fields: Vec<Field>,
213 #[serde(default)]
214 pub materialized: bool,
215 #[serde(default)]
216 pub properties: BTreeMap<String, PropertyValue>,
217 pub primary_key: Option<Vec<String>>,
218}
219
220impl Relation {
221 pub fn empty() -> Self {
222 Self {
223 name: SqlIdentifier::from("".to_string()),
224 fields: Vec::new(),
225 materialized: false,
226 properties: BTreeMap::new(),
227 primary_key: None,
228 }
229 }
230
231 pub fn new(
232 name: SqlIdentifier,
233 fields: Vec<Field>,
234 materialized: bool,
235 properties: BTreeMap<String, PropertyValue>,
236 ) -> Self {
237 Self {
238 name,
239 fields,
240 materialized,
241 properties,
242 primary_key: None,
243 }
244 }
245
246 pub fn field(&self, name: &str) -> Option<&Field> {
248 let name = canonical_identifier(name);
249 self.fields.iter().find(|f| f.name == name)
250 }
251
252 pub fn has_lateness(&self) -> bool {
253 self.fields.iter().any(|f| f.lateness.is_some())
254 }
255
256 pub fn get_property(&self, name: &str) -> Option<&str> {
257 self.properties.get(name).map(|p| p.value.as_str())
258 }
259
260 pub fn with_primary_key<'a>(
261 mut self,
262 primary_key: impl IntoIterator<Item = &'a SqlIdentifier>,
263 ) -> Self {
264 self.primary_key = Some(primary_key.into_iter().map(|id| id.name()).collect());
265 self
266 }
267}
268
269#[derive(Debug, Deserialize)]
273pub struct RelationPropertiesOnly {
274 #[serde(flatten)]
275 pub name: SqlIdentifier,
276 #[serde(default)]
277 pub properties: BTreeMap<String, PropertyValue>,
278}
279
280#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
284#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
285pub struct Field {
286 #[serde(flatten)]
287 pub name: SqlIdentifier,
288 pub columntype: ColumnType,
289 pub lateness: Option<String>,
290 pub default: Option<String>,
291 pub unused: bool,
292 pub watermark: Option<String>,
293}
294
295impl Field {
296 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
297 Self {
298 name,
299 columntype,
300 lateness: None,
301 default: None,
302 unused: false,
303 watermark: None,
304 }
305 }
306
307 pub fn with_lateness(mut self, lateness: &str) -> Self {
308 self.lateness = Some(lateness.to_string());
309 self
310 }
311
312 pub fn with_unused(mut self, unused: bool) -> Self {
313 self.unused = unused;
314 self
315 }
316}
317
318impl<'de> Deserialize<'de> for Field {
323 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
324 where
325 D: Deserializer<'de>,
326 {
327 const fn default_is_struct() -> Option<SqlType> {
328 Some(SqlType::Struct)
329 }
330
331 #[derive(Debug, Clone, Deserialize)]
332 struct FieldHelper {
333 name: Option<String>,
334 #[serde(default)]
335 case_sensitive: bool,
336 columntype: Option<ColumnType>,
337 #[serde(rename = "type")]
338 #[serde(default = "default_is_struct")]
339 typ: Option<SqlType>,
340 nullable: Option<bool>,
341 precision: Option<i64>,
342 scale: Option<i64>,
343 component: Option<Box<ColumnType>>,
344 fields: Option<serde_json::Value>,
345 key: Option<Box<ColumnType>>,
346 value: Option<Box<ColumnType>>,
347 default: Option<String>,
348 #[serde(default)]
349 unused: bool,
350 lateness: Option<String>,
351 watermark: Option<String>,
352 }
353
354 fn helper_to_field(helper: FieldHelper) -> Field {
355 let columntype = if let Some(ctype) = helper.columntype {
356 ctype
357 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
358 let fields = fields
359 .into_iter()
360 .map(|field| {
361 let field: FieldHelper = serde_json::from_value(field).unwrap();
362 helper_to_field(field)
363 })
364 .collect::<Vec<Field>>();
365
366 ColumnType {
367 typ: helper.typ.unwrap_or(SqlType::Null),
368 nullable: helper.nullable.unwrap_or(false),
369 precision: helper.precision,
370 scale: helper.scale,
371 component: helper.component,
372 fields: Some(fields),
373 key: None,
374 value: None,
375 }
376 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
377 serde_json::from_value(serde_json::Value::Object(obj))
378 .expect("Failed to deserialize object")
379 } else {
380 ColumnType {
381 typ: helper.typ.unwrap_or(SqlType::Null),
382 nullable: helper.nullable.unwrap_or(false),
383 precision: helper.precision,
384 scale: helper.scale,
385 component: helper.component,
386 fields: None,
387 key: helper.key,
388 value: helper.value,
389 }
390 };
391
392 Field {
393 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
394 columntype,
395 default: helper.default,
396 unused: helper.unused,
397 lateness: helper.lateness,
398 watermark: helper.watermark,
399 }
400 }
401
402 let helper = FieldHelper::deserialize(deserializer)?;
403 Ok(helper_to_field(helper))
404 }
405}
406
407#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
412#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
413pub enum IntervalUnit {
414 Day,
416 DayToHour,
418 DayToMinute,
420 DayToSecond,
422 Hour,
424 HourToMinute,
426 HourToSecond,
428 Minute,
430 MinuteToSecond,
432 Month,
434 Second,
436 Year,
438 YearToMonth,
440}
441
442#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
444#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
445pub enum SqlType {
446 Boolean,
448 TinyInt,
450 SmallInt,
452 Int,
454 BigInt,
456 UTinyInt,
458 USmallInt,
460 UInt,
462 UBigInt,
464 Real,
466 Double,
468 Decimal,
470 Char,
472 Varchar,
474 Binary,
476 Varbinary,
478 Time,
480 Date,
482 Timestamp,
484 TimestampTz,
486 Interval(IntervalUnit),
488 Array,
490 Struct,
492 Map,
494 Null,
496 Uuid,
498 Variant,
500}
501
502impl Display for SqlType {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 f.write_str(&serde_json::to_string(self).unwrap())
505 }
506}
507
508impl<'de> Deserialize<'de> for SqlType {
509 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
510 where
511 D: Deserializer<'de>,
512 {
513 let value: String = Deserialize::deserialize(deserializer)?;
514 match value.to_lowercase().as_str() {
515 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
516 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
517 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
518 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
519 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
520 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
521 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
522 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
523 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
524 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
525 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
526 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
527 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
528 "boolean" => Ok(SqlType::Boolean),
529 "tinyint" => Ok(SqlType::TinyInt),
530 "smallint" => Ok(SqlType::SmallInt),
531 "integer" => Ok(SqlType::Int),
532 "bigint" => Ok(SqlType::BigInt),
533 "utinyint" => Ok(SqlType::UTinyInt),
534 "usmallint" => Ok(SqlType::USmallInt),
535 "uinteger" => Ok(SqlType::UInt),
536 "ubigint" => Ok(SqlType::UBigInt),
537 "real" => Ok(SqlType::Real),
538 "double" => Ok(SqlType::Double),
539 "decimal" => Ok(SqlType::Decimal),
540 "char" => Ok(SqlType::Char),
541 "varchar" => Ok(SqlType::Varchar),
542 "binary" => Ok(SqlType::Binary),
543 "varbinary" => Ok(SqlType::Varbinary),
544 "variant" => Ok(SqlType::Variant),
545 "time" => Ok(SqlType::Time),
546 "date" => Ok(SqlType::Date),
547 "timestamp" => Ok(SqlType::Timestamp),
548 "timestamp_tz" => Ok(SqlType::TimestampTz),
549 "array" => Ok(SqlType::Array),
550 "struct" => Ok(SqlType::Struct),
551 "map" => Ok(SqlType::Map),
552 "null" => Ok(SqlType::Null),
553 "uuid" => Ok(SqlType::Uuid),
554 _ => Err(serde::de::Error::custom(format!(
555 "Unknown SQL type: {}",
556 value
557 ))),
558 }
559 }
560}
561
562impl Serialize for SqlType {
563 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
564 where
565 S: Serializer,
566 {
567 let type_str = match self {
568 SqlType::Boolean => "BOOLEAN",
569 SqlType::TinyInt => "TINYINT",
570 SqlType::SmallInt => "SMALLINT",
571 SqlType::Int => "INTEGER",
572 SqlType::BigInt => "BIGINT",
573 SqlType::UTinyInt => "UTINYINT",
574 SqlType::USmallInt => "USMALLINT",
575 SqlType::UInt => "UINTEGER",
576 SqlType::UBigInt => "UBIGINT",
577 SqlType::Real => "REAL",
578 SqlType::Double => "DOUBLE",
579 SqlType::Decimal => "DECIMAL",
580 SqlType::Char => "CHAR",
581 SqlType::Varchar => "VARCHAR",
582 SqlType::Binary => "BINARY",
583 SqlType::Varbinary => "VARBINARY",
584 SqlType::Time => "TIME",
585 SqlType::Date => "DATE",
586 SqlType::Timestamp => "TIMESTAMP",
587 SqlType::TimestampTz => "TIMESTAMP_TZ",
588 SqlType::Interval(interval_unit) => match interval_unit {
589 IntervalUnit::Day => "INTERVAL_DAY",
590 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
591 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
592 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
593 IntervalUnit::Hour => "INTERVAL_HOUR",
594 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
595 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
596 IntervalUnit::Minute => "INTERVAL_MINUTE",
597 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
598 IntervalUnit::Month => "INTERVAL_MONTH",
599 IntervalUnit::Second => "INTERVAL_SECOND",
600 IntervalUnit::Year => "INTERVAL_YEAR",
601 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
602 },
603 SqlType::Array => "ARRAY",
604 SqlType::Struct => "STRUCT",
605 SqlType::Uuid => "UUID",
606 SqlType::Map => "MAP",
607 SqlType::Null => "NULL",
608 SqlType::Variant => "VARIANT",
609 };
610 serializer.serialize_str(type_str)
611 }
612}
613
614impl SqlType {
615 pub fn is_string(&self) -> bool {
617 matches!(self, Self::Char | Self::Varchar)
618 }
619
620 pub fn is_varchar(&self) -> bool {
621 matches!(self, Self::Varchar)
622 }
623
624 pub fn is_varbinary(&self) -> bool {
625 matches!(self, Self::Varbinary)
626 }
627}
628
629const fn default_is_struct() -> SqlType {
632 SqlType::Struct
633}
634
635#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
639#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
640pub struct ColumnType {
641 #[serde(rename = "type")]
643 #[serde(default = "default_is_struct")]
644 pub typ: SqlType,
645 pub nullable: bool,
647 pub precision: Option<i64>,
656 pub scale: Option<i64>,
661 #[cfg_attr(feature = "testing", proptest(value = "None"))]
668 pub component: Option<Box<ColumnType>>,
669 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
691 pub fields: Option<Vec<Field>>,
692 #[cfg_attr(feature = "testing", proptest(value = "None"))]
694 pub key: Option<Box<ColumnType>>,
695 #[cfg_attr(feature = "testing", proptest(value = "None"))]
697 pub value: Option<Box<ColumnType>>,
698}
699
700impl ColumnType {
701 pub fn boolean(nullable: bool) -> Self {
702 ColumnType {
703 typ: SqlType::Boolean,
704 nullable,
705 precision: None,
706 scale: None,
707 component: None,
708 fields: None,
709 key: None,
710 value: None,
711 }
712 }
713
714 pub fn uuid(nullable: bool) -> Self {
715 ColumnType {
716 typ: SqlType::Uuid,
717 nullable,
718 precision: None,
719 scale: None,
720 component: None,
721 fields: None,
722 key: None,
723 value: None,
724 }
725 }
726
727 pub fn tinyint(nullable: bool) -> Self {
728 ColumnType {
729 typ: SqlType::TinyInt,
730 nullable,
731 precision: None,
732 scale: None,
733 component: None,
734 fields: None,
735 key: None,
736 value: None,
737 }
738 }
739
740 pub fn smallint(nullable: bool) -> Self {
741 ColumnType {
742 typ: SqlType::SmallInt,
743 nullable,
744 precision: None,
745 scale: None,
746 component: None,
747 fields: None,
748 key: None,
749 value: None,
750 }
751 }
752
753 pub fn int(nullable: bool) -> Self {
754 ColumnType {
755 typ: SqlType::Int,
756 nullable,
757 precision: None,
758 scale: None,
759 component: None,
760 fields: None,
761 key: None,
762 value: None,
763 }
764 }
765
766 pub fn bigint(nullable: bool) -> Self {
767 ColumnType {
768 typ: SqlType::BigInt,
769 nullable,
770 precision: None,
771 scale: None,
772 component: None,
773 fields: None,
774 key: None,
775 value: None,
776 }
777 }
778
779 pub fn utinyint(nullable: bool) -> Self {
780 ColumnType {
781 typ: SqlType::UTinyInt,
782 nullable,
783 precision: None,
784 scale: None,
785 component: None,
786 fields: None,
787 key: None,
788 value: None,
789 }
790 }
791
792 pub fn usmallint(nullable: bool) -> Self {
793 ColumnType {
794 typ: SqlType::USmallInt,
795 nullable,
796 precision: None,
797 scale: None,
798 component: None,
799 fields: None,
800 key: None,
801 value: None,
802 }
803 }
804
805 pub fn uint(nullable: bool) -> Self {
806 ColumnType {
807 typ: SqlType::UInt,
808 nullable,
809 precision: None,
810 scale: None,
811 component: None,
812 fields: None,
813 key: None,
814 value: None,
815 }
816 }
817
818 pub fn ubigint(nullable: bool) -> Self {
819 ColumnType {
820 typ: SqlType::UBigInt,
821 nullable,
822 precision: None,
823 scale: None,
824 component: None,
825 fields: None,
826 key: None,
827 value: None,
828 }
829 }
830
831 pub fn double(nullable: bool) -> Self {
832 ColumnType {
833 typ: SqlType::Double,
834 nullable,
835 precision: None,
836 scale: None,
837 component: None,
838 fields: None,
839 key: None,
840 value: None,
841 }
842 }
843
844 pub fn real(nullable: bool) -> Self {
845 ColumnType {
846 typ: SqlType::Real,
847 nullable,
848 precision: None,
849 scale: None,
850 component: None,
851 fields: None,
852 key: None,
853 value: None,
854 }
855 }
856
857 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
858 ColumnType {
859 typ: SqlType::Decimal,
860 nullable,
861 precision: Some(precision),
862 scale: Some(scale),
863 component: None,
864 fields: None,
865 key: None,
866 value: None,
867 }
868 }
869
870 pub fn varchar(nullable: bool) -> Self {
871 ColumnType {
872 typ: SqlType::Varchar,
873 nullable,
874 precision: None,
875 scale: None,
876 component: None,
877 fields: None,
878 key: None,
879 value: None,
880 }
881 }
882
883 pub fn varbinary(nullable: bool) -> Self {
884 ColumnType {
885 typ: SqlType::Varbinary,
886 nullable,
887 precision: None,
888 scale: None,
889 component: None,
890 fields: None,
891 key: None,
892 value: None,
893 }
894 }
895
896 pub fn fixed(width: i64, nullable: bool) -> Self {
897 ColumnType {
898 typ: SqlType::Binary,
899 nullable,
900 precision: Some(width),
901 scale: None,
902 component: None,
903 fields: None,
904 key: None,
905 value: None,
906 }
907 }
908
909 pub fn date(nullable: bool) -> Self {
910 ColumnType {
911 typ: SqlType::Date,
912 nullable,
913 precision: None,
914 scale: None,
915 component: None,
916 fields: None,
917 key: None,
918 value: None,
919 }
920 }
921
922 pub fn time(nullable: bool) -> Self {
923 ColumnType {
924 typ: SqlType::Time,
925 nullable,
926 precision: None,
927 scale: None,
928 component: None,
929 fields: None,
930 key: None,
931 value: None,
932 }
933 }
934
935 pub fn timestamp(nullable: bool) -> Self {
936 ColumnType {
937 typ: SqlType::Timestamp,
938 nullable,
939 precision: None,
940 scale: None,
941 component: None,
942 fields: None,
943 key: None,
944 value: None,
945 }
946 }
947
948 pub fn timestamp_tz(nullable: bool) -> Self {
949 ColumnType {
950 typ: SqlType::TimestampTz,
951 nullable,
952 precision: None,
953 scale: None,
954 component: None,
955 fields: None,
956 key: None,
957 value: None,
958 }
959 }
960
961 pub fn variant(nullable: bool) -> Self {
962 ColumnType {
963 typ: SqlType::Variant,
964 nullable,
965 precision: None,
966 scale: None,
967 component: None,
968 fields: None,
969 key: None,
970 value: None,
971 }
972 }
973
974 pub fn array(nullable: bool, element: ColumnType) -> Self {
975 ColumnType {
976 typ: SqlType::Array,
977 nullable,
978 precision: None,
979 scale: None,
980 component: Some(Box::new(element)),
981 fields: None,
982 key: None,
983 value: None,
984 }
985 }
986
987 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
988 ColumnType {
989 typ: SqlType::Struct,
990 nullable,
991 precision: None,
992 scale: None,
993 component: None,
994 fields: Some(fields.to_vec()),
995 key: None,
996 value: None,
997 }
998 }
999
1000 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
1001 ColumnType {
1002 typ: SqlType::Map,
1003 nullable,
1004 precision: None,
1005 scale: None,
1006 component: None,
1007 fields: None,
1008 key: Some(Box::new(key)),
1009 value: Some(Box::new(val)),
1010 }
1011 }
1012
1013 pub fn is_integral_type(&self) -> bool {
1014 matches!(
1015 &self.typ,
1016 SqlType::TinyInt
1017 | SqlType::SmallInt
1018 | SqlType::Int
1019 | SqlType::BigInt
1020 | SqlType::UTinyInt
1021 | SqlType::USmallInt
1022 | SqlType::UInt
1023 | SqlType::UBigInt
1024 )
1025 }
1026
1027 pub fn is_fp_type(&self) -> bool {
1028 matches!(&self.typ, SqlType::Double | SqlType::Real)
1029 }
1030
1031 pub fn is_decimal_type(&self) -> bool {
1032 matches!(&self.typ, SqlType::Decimal)
1033 }
1034
1035 pub fn is_numeric_type(&self) -> bool {
1036 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
1037 }
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042 use super::{IntervalUnit, SqlIdentifier};
1043 use crate::program_schema::{ColumnType, Field, SqlType};
1044
1045 #[test]
1046 fn serde_sql_type() {
1047 for (sql_str_base, expected_value) in [
1048 ("Boolean", SqlType::Boolean),
1049 ("Uuid", SqlType::Uuid),
1050 ("TinyInt", SqlType::TinyInt),
1051 ("SmallInt", SqlType::SmallInt),
1052 ("Integer", SqlType::Int),
1053 ("BigInt", SqlType::BigInt),
1054 ("UTinyInt", SqlType::UTinyInt),
1055 ("USmallInt", SqlType::USmallInt),
1056 ("UInteger", SqlType::UInt),
1057 ("UBigInt", SqlType::UBigInt),
1058 ("Real", SqlType::Real),
1059 ("Double", SqlType::Double),
1060 ("Decimal", SqlType::Decimal),
1061 ("Char", SqlType::Char),
1062 ("Varchar", SqlType::Varchar),
1063 ("Binary", SqlType::Binary),
1064 ("Varbinary", SqlType::Varbinary),
1065 ("Time", SqlType::Time),
1066 ("Date", SqlType::Date),
1067 ("Timestamp", SqlType::Timestamp),
1068 ("Timestamp_Tz", SqlType::TimestampTz),
1069 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
1070 (
1071 "Interval_Day_Hour",
1072 SqlType::Interval(IntervalUnit::DayToHour),
1073 ),
1074 (
1075 "Interval_Day_Minute",
1076 SqlType::Interval(IntervalUnit::DayToMinute),
1077 ),
1078 (
1079 "Interval_Day_Second",
1080 SqlType::Interval(IntervalUnit::DayToSecond),
1081 ),
1082 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
1083 (
1084 "Interval_Hour_Minute",
1085 SqlType::Interval(IntervalUnit::HourToMinute),
1086 ),
1087 (
1088 "Interval_Hour_Second",
1089 SqlType::Interval(IntervalUnit::HourToSecond),
1090 ),
1091 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
1092 (
1093 "Interval_Minute_Second",
1094 SqlType::Interval(IntervalUnit::MinuteToSecond),
1095 ),
1096 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
1097 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
1098 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
1099 (
1100 "Interval_Year_Month",
1101 SqlType::Interval(IntervalUnit::YearToMonth),
1102 ),
1103 ("Array", SqlType::Array),
1104 ("Struct", SqlType::Struct),
1105 ("Map", SqlType::Map),
1106 ("Null", SqlType::Null),
1107 ("Variant", SqlType::Variant),
1108 ] {
1109 for sql_str in [
1110 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
1114 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
1115 .unwrap_or_else(|e| {
1116 panic!(
1117 "\"{sql_str}\" should deserialize into its SQL type: {}",
1118 e.to_string()
1119 )
1120 });
1121 assert_eq!(value1, expected_value);
1122 let serialized_str =
1123 serde_json::to_string(&value1).expect("Value should serialize into JSON");
1124 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
1125 panic!(
1126 "{} should deserialize back into its SQL type",
1127 serialized_str
1128 )
1129 });
1130 assert_eq!(value1, value2);
1131 }
1132 }
1133 }
1134
1135 #[test]
1136 fn deserialize_interval_types() {
1137 use super::IntervalUnit::*;
1138 use super::SqlType::*;
1139
1140 let schema = r#"
1141{
1142 "inputs" : [ {
1143 "name" : "sales",
1144 "case_sensitive" : false,
1145 "fields" : [ {
1146 "name" : "sales_id",
1147 "case_sensitive" : false,
1148 "columntype" : {
1149 "type" : "INTEGER",
1150 "nullable" : true
1151 }
1152 }, {
1153 "name" : "customer_id",
1154 "case_sensitive" : false,
1155 "columntype" : {
1156 "type" : "INTEGER",
1157 "nullable" : true
1158 }
1159 }, {
1160 "name" : "age",
1161 "case_sensitive" : false,
1162 "columntype" : {
1163 "type" : "UINTEGER",
1164 "nullable" : true
1165 }
1166 }, {
1167 "name" : "amount",
1168 "case_sensitive" : false,
1169 "columntype" : {
1170 "type" : "DECIMAL",
1171 "nullable" : true,
1172 "precision" : 10,
1173 "scale" : 2
1174 }
1175 }, {
1176 "name" : "sale_date",
1177 "case_sensitive" : false,
1178 "columntype" : {
1179 "type" : "DATE",
1180 "nullable" : true
1181 }
1182 } ],
1183 "primary_key" : [ "sales_id" ]
1184 } ],
1185 "outputs" : [ {
1186 "name" : "salessummary",
1187 "case_sensitive" : false,
1188 "fields" : [ {
1189 "name" : "customer_id",
1190 "case_sensitive" : false,
1191 "columntype" : {
1192 "type" : "INTEGER",
1193 "nullable" : true
1194 }
1195 }, {
1196 "name" : "total_sales",
1197 "case_sensitive" : false,
1198 "columntype" : {
1199 "type" : "DECIMAL",
1200 "nullable" : true,
1201 "precision" : 38,
1202 "scale" : 2
1203 }
1204 }, {
1205 "name" : "interval_day",
1206 "case_sensitive" : false,
1207 "columntype" : {
1208 "type" : "INTERVAL_DAY",
1209 "nullable" : false,
1210 "precision" : 2,
1211 "scale" : 6
1212 }
1213 }, {
1214 "name" : "interval_day_to_hour",
1215 "case_sensitive" : false,
1216 "columntype" : {
1217 "type" : "INTERVAL_DAY_HOUR",
1218 "nullable" : false,
1219 "precision" : 2,
1220 "scale" : 6
1221 }
1222 }, {
1223 "name" : "interval_day_to_minute",
1224 "case_sensitive" : false,
1225 "columntype" : {
1226 "type" : "INTERVAL_DAY_MINUTE",
1227 "nullable" : false,
1228 "precision" : 2,
1229 "scale" : 6
1230 }
1231 }, {
1232 "name" : "interval_day_to_second",
1233 "case_sensitive" : false,
1234 "columntype" : {
1235 "type" : "INTERVAL_DAY_SECOND",
1236 "nullable" : false,
1237 "precision" : 2,
1238 "scale" : 6
1239 }
1240 }, {
1241 "name" : "interval_hour",
1242 "case_sensitive" : false,
1243 "columntype" : {
1244 "type" : "INTERVAL_HOUR",
1245 "nullable" : false,
1246 "precision" : 2,
1247 "scale" : 6
1248 }
1249 }, {
1250 "name" : "interval_hour_to_minute",
1251 "case_sensitive" : false,
1252 "columntype" : {
1253 "type" : "INTERVAL_HOUR_MINUTE",
1254 "nullable" : false,
1255 "precision" : 2,
1256 "scale" : 6
1257 }
1258 }, {
1259 "name" : "interval_hour_to_second",
1260 "case_sensitive" : false,
1261 "columntype" : {
1262 "type" : "INTERVAL_HOUR_SECOND",
1263 "nullable" : false,
1264 "precision" : 2,
1265 "scale" : 6
1266 }
1267 }, {
1268 "name" : "interval_minute",
1269 "case_sensitive" : false,
1270 "columntype" : {
1271 "type" : "INTERVAL_MINUTE",
1272 "nullable" : false,
1273 "precision" : 2,
1274 "scale" : 6
1275 }
1276 }, {
1277 "name" : "interval_minute_to_second",
1278 "case_sensitive" : false,
1279 "columntype" : {
1280 "type" : "INTERVAL_MINUTE_SECOND",
1281 "nullable" : false,
1282 "precision" : 2,
1283 "scale" : 6
1284 }
1285 }, {
1286 "name" : "interval_month",
1287 "case_sensitive" : false,
1288 "columntype" : {
1289 "type" : "INTERVAL_MONTH",
1290 "nullable" : false
1291 }
1292 }, {
1293 "name" : "interval_second",
1294 "case_sensitive" : false,
1295 "columntype" : {
1296 "type" : "INTERVAL_SECOND",
1297 "nullable" : false,
1298 "precision" : 2,
1299 "scale" : 6
1300 }
1301 }, {
1302 "name" : "interval_year",
1303 "case_sensitive" : false,
1304 "columntype" : {
1305 "type" : "INTERVAL_YEAR",
1306 "nullable" : false
1307 }
1308 }, {
1309 "name" : "interval_year_to_month",
1310 "case_sensitive" : false,
1311 "columntype" : {
1312 "type" : "INTERVAL_YEAR_MONTH",
1313 "nullable" : false
1314 }
1315 } ]
1316 } ]
1317}
1318"#;
1319
1320 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1321 let types = schema
1322 .outputs
1323 .iter()
1324 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1325 let expected_types = [
1326 Int,
1327 Decimal,
1328 Interval(Day),
1329 Interval(DayToHour),
1330 Interval(DayToMinute),
1331 Interval(DayToSecond),
1332 Interval(Hour),
1333 Interval(HourToMinute),
1334 Interval(HourToSecond),
1335 Interval(Minute),
1336 Interval(MinuteToSecond),
1337 Interval(Month),
1338 Interval(Second),
1339 Interval(Year),
1340 Interval(YearToMonth),
1341 ];
1342
1343 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1344 }
1345
1346 #[test]
1347 fn serialize_struct_schemas() {
1348 let schema = r#"{
1349 "inputs" : [ {
1350 "name" : "PERS",
1351 "case_sensitive" : false,
1352 "fields" : [ {
1353 "name" : "P0",
1354 "case_sensitive" : false,
1355 "columntype" : {
1356 "fields" : [ {
1357 "type" : "VARCHAR",
1358 "nullable" : true,
1359 "precision" : 30,
1360 "name" : "FIRSTNAME"
1361 }, {
1362 "type" : "VARCHAR",
1363 "nullable" : true,
1364 "precision" : 30,
1365 "name" : "LASTNAME"
1366 }, {
1367 "type" : "UINTEGER",
1368 "nullable" : true,
1369 "name" : "AGE"
1370 }, {
1371 "fields" : {
1372 "fields" : [ {
1373 "type" : "VARCHAR",
1374 "nullable" : true,
1375 "precision" : 30,
1376 "name" : "STREET"
1377 }, {
1378 "type" : "VARCHAR",
1379 "nullable" : true,
1380 "precision" : 30,
1381 "name" : "CITY"
1382 }, {
1383 "type" : "CHAR",
1384 "nullable" : true,
1385 "precision" : 2,
1386 "name" : "STATE"
1387 }, {
1388 "type" : "VARCHAR",
1389 "nullable" : true,
1390 "precision" : 6,
1391 "name" : "POSTAL_CODE"
1392 } ],
1393 "nullable" : false
1394 },
1395 "nullable" : false,
1396 "name" : "ADDRESS"
1397 } ],
1398 "nullable" : false
1399 }
1400 }]
1401 } ],
1402 "outputs" : [ ]
1403}
1404"#;
1405 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1406 eprintln!("{:#?}", schema);
1407 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1408 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1409 assert_eq!(p0.columntype.typ, SqlType::Struct);
1410 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1411 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1412 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1413 assert_eq!(p0_fields[2].columntype.typ, SqlType::UInt);
1414 assert_eq!(p0_fields[3].columntype.typ, SqlType::Struct);
1415 assert_eq!(p0_fields[3].name, "ADDRESS");
1416 let address = &p0_fields[3].columntype.fields.as_ref().unwrap();
1417 assert_eq!(address.len(), 4);
1418 assert_eq!(address[0].name, "STREET");
1419 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1420 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1421 assert_eq!(address[2].columntype.typ, SqlType::Char);
1422 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1423 }
1424
1425 #[test]
1426 fn sql_identifier_cmp() {
1427 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1428 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1429 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1430 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1431 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1432 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1433 assert_eq!(
1434 SqlIdentifier::new("bAr", true),
1435 SqlIdentifier::from("\"bAr\"")
1436 );
1437
1438 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1439 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1440 }
1441
1442 #[test]
1443 fn sql_identifier_ord() {
1444 let mut btree = std::collections::BTreeSet::new();
1445 assert!(btree.insert(SqlIdentifier::from("foo")));
1446 assert!(btree.insert(SqlIdentifier::from("bar")));
1447 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1448 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1449 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1450 }
1451
1452 #[test]
1453 fn sql_identifier_hash() {
1454 let mut hs = std::collections::HashSet::new();
1455 assert!(hs.insert(SqlIdentifier::from("foo")));
1456 assert!(hs.insert(SqlIdentifier::from("bar")));
1457 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1458 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1459 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1460 }
1461
1462 #[test]
1463 fn sql_identifier_name() {
1464 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1465 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1466 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1467 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1468 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1469 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1470 }
1471
1472 #[test]
1473 fn issue3277() {
1474 let schema = r#"{
1475 "name" : "j",
1476 "case_sensitive" : false,
1477 "columntype" : {
1478 "fields" : [ {
1479 "key" : {
1480 "nullable" : false,
1481 "precision" : -1,
1482 "type" : "VARCHAR"
1483 },
1484 "name" : "s",
1485 "nullable" : true,
1486 "type" : "MAP",
1487 "value" : {
1488 "nullable" : true,
1489 "precision" : -1,
1490 "type" : "VARCHAR"
1491 }
1492 } ],
1493 "nullable" : true
1494 }
1495 }"#;
1496 let field: Field = serde_json::from_str(schema).unwrap();
1497 println!("field: {:#?}", field);
1498 assert_eq!(
1499 field,
1500 Field {
1501 name: SqlIdentifier {
1502 name: "j".to_string(),
1503 case_sensitive: false,
1504 },
1505 columntype: ColumnType {
1506 typ: SqlType::Struct,
1507 nullable: true,
1508 precision: None,
1509 scale: None,
1510 component: None,
1511 fields: Some(vec![Field {
1512 name: SqlIdentifier {
1513 name: "s".to_string(),
1514 case_sensitive: false,
1515 },
1516 columntype: ColumnType {
1517 typ: SqlType::Map,
1518 nullable: true,
1519 precision: None,
1520 scale: None,
1521 component: None,
1522 fields: None,
1523 key: Some(Box::new(ColumnType {
1524 typ: SqlType::Varchar,
1525 nullable: false,
1526 precision: Some(-1),
1527 scale: None,
1528 component: None,
1529 fields: None,
1530 key: None,
1531 value: None,
1532 })),
1533 value: Some(Box::new(ColumnType {
1534 typ: SqlType::Varchar,
1535 nullable: true,
1536 precision: Some(-1),
1537 scale: None,
1538 component: None,
1539 fields: None,
1540 key: None,
1541 value: None,
1542 })),
1543 },
1544 lateness: None,
1545 default: None,
1546 unused: false,
1547 watermark: None,
1548 }]),
1549 key: None,
1550 value: None,
1551 },
1552 lateness: None,
1553 default: None,
1554 unused: false,
1555 watermark: None,
1556 }
1557 );
1558 }
1559}