1pub use feldera_ir::SourcePosition;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::cmp::Ordering;
4use std::collections::BTreeMap;
5use std::fmt::Display;
6use std::hash::{Hash, Hasher};
7use utoipa::ToSchema;
8
9#[cfg(feature = "testing")]
10use proptest::{collection::vec, prelude::any};
11
12pub fn canonical_identifier(id: &str) -> String {
20 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
21 id[1..id.len() - 1].to_string()
22 } else {
23 id.to_lowercase()
24 }
25}
26
27#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
32#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
33pub struct SqlIdentifier {
34 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
35 name: String,
36 pub case_sensitive: bool,
37}
38
39impl SqlIdentifier {
40 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
41 Self {
42 name: name.as_ref().to_string(),
43 case_sensitive,
44 }
45 }
46
47 pub fn name(&self) -> String {
57 if self.case_sensitive {
58 self.name.clone()
59 } else {
60 self.name.to_lowercase()
61 }
62 }
63
64 pub fn sql_name(&self) -> String {
75 if self.case_sensitive {
76 format!("\"{}\"", self.name)
77 } else {
78 self.name.clone()
79 }
80 }
81}
82
83impl Hash for SqlIdentifier {
84 fn hash<H: Hasher>(&self, state: &mut H) {
85 self.name().hash(state);
86 }
87}
88
89impl PartialEq for SqlIdentifier {
90 fn eq(&self, other: &Self) -> bool {
91 match (self.case_sensitive, other.case_sensitive) {
92 (true, true) => self.name == other.name,
93 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
94 (true, false) => self.name == other.name,
95 (false, true) => self.name == other.name,
96 }
97 }
98}
99
100impl Ord for SqlIdentifier {
101 fn cmp(&self, other: &Self) -> Ordering {
102 self.name().cmp(&other.name())
103 }
104}
105
106impl PartialOrd for SqlIdentifier {
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
113 fn eq(&self, other: &S) -> bool {
114 self == &SqlIdentifier::from(other.as_ref())
115 }
116}
117
118impl Eq for SqlIdentifier {}
119
120impl<S: AsRef<str>> From<S> for SqlIdentifier {
121 fn from(name: S) -> Self {
122 if name.as_ref().starts_with('"')
123 && name.as_ref().ends_with('"')
124 && name.as_ref().len() >= 2
125 {
126 Self {
127 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
128 case_sensitive: true,
129 }
130 } else {
131 Self::new(name, false)
132 }
133 }
134}
135
136impl From<SqlIdentifier> for String {
137 fn from(id: SqlIdentifier) -> String {
138 id.name()
139 }
140}
141
142impl From<&SqlIdentifier> for String {
143 fn from(id: &SqlIdentifier) -> String {
144 id.name()
145 }
146}
147
148impl Display for SqlIdentifier {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(f, "{}", self.name())
151 }
152}
153
154#[derive(Default, Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
158#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
159pub struct ProgramSchema {
160 #[cfg_attr(
161 feature = "testing",
162 proptest(strategy = "vec(any::<Relation>(), 0..2)")
163 )]
164 pub inputs: Vec<Relation>,
165 #[cfg_attr(
166 feature = "testing",
167 proptest(strategy = "vec(any::<Relation>(), 0..2)")
168 )]
169 pub outputs: Vec<Relation>,
170}
171
172impl ProgramSchema {
173 pub fn relations_with_lateness(&self) -> Vec<SqlIdentifier> {
174 self.inputs
175 .iter()
176 .chain(self.outputs.iter())
177 .filter(|rel| rel.has_lateness())
178 .map(|rel| rel.name.clone())
179 .collect()
180 }
181}
182
183#[derive(Debug, Deserialize)]
188pub struct ProgramSchemaPropertiesOnly {
189 #[serde(default)]
190 pub inputs: Vec<RelationPropertiesOnly>,
191 #[serde(default)]
192 pub outputs: Vec<RelationPropertiesOnly>,
193}
194
195#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
196#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
197pub struct PropertyValue {
198 pub value: String,
199 pub key_position: SourcePosition,
200 pub value_position: SourcePosition,
201}
202
203#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
207#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
208pub struct Relation {
209 #[serde(flatten)]
210 pub name: SqlIdentifier,
211 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
212 pub fields: Vec<Field>,
213 #[serde(default)]
214 pub materialized: bool,
215 #[serde(default)]
216 pub properties: BTreeMap<String, PropertyValue>,
217 pub primary_key: Option<Vec<String>>,
218}
219
220impl Relation {
221 pub fn empty() -> Self {
222 Self {
223 name: SqlIdentifier::from("".to_string()),
224 fields: Vec::new(),
225 materialized: false,
226 properties: BTreeMap::new(),
227 primary_key: None,
228 }
229 }
230
231 pub fn new(
232 name: SqlIdentifier,
233 fields: Vec<Field>,
234 materialized: bool,
235 properties: BTreeMap<String, PropertyValue>,
236 ) -> Self {
237 Self {
238 name,
239 fields,
240 materialized,
241 properties,
242 primary_key: None,
243 }
244 }
245
246 pub fn field(&self, name: &str) -> Option<&Field> {
248 let name = canonical_identifier(name);
249 self.fields.iter().find(|f| f.name == name)
250 }
251
252 pub fn has_lateness(&self) -> bool {
253 self.fields.iter().any(|f| f.lateness.is_some())
254 }
255
256 pub fn get_property(&self, name: &str) -> Option<&str> {
257 self.properties.get(name).map(|p| p.value.as_str())
258 }
259
260 pub fn with_primary_key<'a>(
261 mut self,
262 primary_key: impl IntoIterator<Item = &'a SqlIdentifier>,
263 ) -> Self {
264 self.primary_key = Some(primary_key.into_iter().map(|id| id.name()).collect());
265 self
266 }
267}
268
269#[derive(Debug, Deserialize)]
273pub struct RelationPropertiesOnly {
274 #[serde(flatten)]
275 pub name: SqlIdentifier,
276 #[serde(default)]
277 pub properties: BTreeMap<String, PropertyValue>,
278}
279
280#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
284#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
285pub struct Field {
286 #[serde(flatten)]
287 pub name: SqlIdentifier,
288 pub columntype: ColumnType,
289 pub lateness: Option<String>,
290 pub default: Option<String>,
291 pub unused: bool,
292 pub watermark: Option<String>,
293}
294
295impl Field {
296 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
297 Self {
298 name,
299 columntype,
300 lateness: None,
301 default: None,
302 unused: false,
303 watermark: None,
304 }
305 }
306
307 pub fn with_lateness(mut self, lateness: &str) -> Self {
308 self.lateness = Some(lateness.to_string());
309 self
310 }
311
312 pub fn with_unused(mut self, unused: bool) -> Self {
313 self.unused = unused;
314 self
315 }
316}
317
318impl<'de> Deserialize<'de> for Field {
323 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
324 where
325 D: Deserializer<'de>,
326 {
327 const fn default_is_struct() -> Option<SqlType> {
328 Some(SqlType::Struct)
329 }
330
331 #[derive(Debug, Clone, Deserialize)]
332 struct FieldHelper {
333 name: Option<String>,
334 #[serde(default)]
335 case_sensitive: bool,
336 columntype: Option<ColumnType>,
337 #[serde(rename = "type")]
338 #[serde(default = "default_is_struct")]
339 typ: Option<SqlType>,
340 nullable: Option<bool>,
341 precision: Option<i64>,
342 scale: Option<i64>,
343 component: Option<Box<ColumnType>>,
344 fields: Option<serde_json::Value>,
345 key: Option<Box<ColumnType>>,
346 value: Option<Box<ColumnType>>,
347 default: Option<String>,
348 #[serde(default)]
349 unused: bool,
350 lateness: Option<String>,
351 watermark: Option<String>,
352 }
353
354 fn helper_to_field(helper: FieldHelper) -> Field {
355 let columntype = if let Some(ctype) = helper.columntype {
356 ctype
357 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
358 let fields = fields
359 .into_iter()
360 .map(|field| {
361 let field: FieldHelper = serde_json::from_value(field).unwrap();
362 helper_to_field(field)
363 })
364 .collect::<Vec<Field>>();
365
366 ColumnType {
367 typ: helper.typ.unwrap_or(SqlType::Null),
368 nullable: helper.nullable.unwrap_or(false),
369 precision: helper.precision,
370 scale: helper.scale,
371 component: helper.component,
372 fields: Some(fields),
373 key: None,
374 value: None,
375 }
376 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
377 serde_json::from_value(serde_json::Value::Object(obj))
378 .expect("Failed to deserialize object")
379 } else {
380 ColumnType {
381 typ: helper.typ.unwrap_or(SqlType::Null),
382 nullable: helper.nullable.unwrap_or(false),
383 precision: helper.precision,
384 scale: helper.scale,
385 component: helper.component,
386 fields: None,
387 key: helper.key,
388 value: helper.value,
389 }
390 };
391
392 Field {
393 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
394 columntype,
395 default: helper.default,
396 unused: helper.unused,
397 lateness: helper.lateness,
398 watermark: helper.watermark,
399 }
400 }
401
402 let helper = FieldHelper::deserialize(deserializer)?;
403 Ok(helper_to_field(helper))
404 }
405}
406
407#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
412#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
413pub enum IntervalUnit {
414 Day,
416 DayToHour,
418 DayToMinute,
420 DayToSecond,
422 Hour,
424 HourToMinute,
426 HourToSecond,
428 Minute,
430 MinuteToSecond,
432 Month,
434 Second,
436 Year,
438 YearToMonth,
440}
441
442#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
444#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
445pub enum SqlType {
446 Boolean,
448 TinyInt,
450 SmallInt,
452 Int,
454 BigInt,
456 UTinyInt,
458 USmallInt,
460 UInt,
462 UBigInt,
464 Real,
466 Double,
468 Decimal,
470 Char,
472 Varchar,
474 Binary,
476 Varbinary,
478 Time,
480 Date,
482 Timestamp,
484 Interval(IntervalUnit),
486 Array,
488 Struct,
490 Map,
492 Null,
494 Uuid,
496 Variant,
498}
499
500impl Display for SqlType {
501 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
502 f.write_str(&serde_json::to_string(self).unwrap())
503 }
504}
505
506impl<'de> Deserialize<'de> for SqlType {
507 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
508 where
509 D: Deserializer<'de>,
510 {
511 let value: String = Deserialize::deserialize(deserializer)?;
512 match value.to_lowercase().as_str() {
513 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
514 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
515 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
516 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
517 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
518 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
519 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
520 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
521 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
522 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
523 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
524 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
525 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
526 "boolean" => Ok(SqlType::Boolean),
527 "tinyint" => Ok(SqlType::TinyInt),
528 "smallint" => Ok(SqlType::SmallInt),
529 "integer" => Ok(SqlType::Int),
530 "bigint" => Ok(SqlType::BigInt),
531 "utinyint" => Ok(SqlType::UTinyInt),
532 "usmallint" => Ok(SqlType::USmallInt),
533 "uinteger" => Ok(SqlType::UInt),
534 "ubigint" => Ok(SqlType::UBigInt),
535 "real" => Ok(SqlType::Real),
536 "double" => Ok(SqlType::Double),
537 "decimal" => Ok(SqlType::Decimal),
538 "char" => Ok(SqlType::Char),
539 "varchar" => Ok(SqlType::Varchar),
540 "binary" => Ok(SqlType::Binary),
541 "varbinary" => Ok(SqlType::Varbinary),
542 "variant" => Ok(SqlType::Variant),
543 "time" => Ok(SqlType::Time),
544 "date" => Ok(SqlType::Date),
545 "timestamp" => Ok(SqlType::Timestamp),
546 "array" => Ok(SqlType::Array),
547 "struct" => Ok(SqlType::Struct),
548 "map" => Ok(SqlType::Map),
549 "null" => Ok(SqlType::Null),
550 "uuid" => Ok(SqlType::Uuid),
551 _ => Err(serde::de::Error::custom(format!(
552 "Unknown SQL type: {}",
553 value
554 ))),
555 }
556 }
557}
558
559impl Serialize for SqlType {
560 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
561 where
562 S: Serializer,
563 {
564 let type_str = match self {
565 SqlType::Boolean => "BOOLEAN",
566 SqlType::TinyInt => "TINYINT",
567 SqlType::SmallInt => "SMALLINT",
568 SqlType::Int => "INTEGER",
569 SqlType::BigInt => "BIGINT",
570 SqlType::UTinyInt => "UTINYINT",
571 SqlType::USmallInt => "USMALLINT",
572 SqlType::UInt => "UINTEGER",
573 SqlType::UBigInt => "UBIGINT",
574 SqlType::Real => "REAL",
575 SqlType::Double => "DOUBLE",
576 SqlType::Decimal => "DECIMAL",
577 SqlType::Char => "CHAR",
578 SqlType::Varchar => "VARCHAR",
579 SqlType::Binary => "BINARY",
580 SqlType::Varbinary => "VARBINARY",
581 SqlType::Time => "TIME",
582 SqlType::Date => "DATE",
583 SqlType::Timestamp => "TIMESTAMP",
584 SqlType::Interval(interval_unit) => match interval_unit {
585 IntervalUnit::Day => "INTERVAL_DAY",
586 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
587 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
588 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
589 IntervalUnit::Hour => "INTERVAL_HOUR",
590 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
591 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
592 IntervalUnit::Minute => "INTERVAL_MINUTE",
593 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
594 IntervalUnit::Month => "INTERVAL_MONTH",
595 IntervalUnit::Second => "INTERVAL_SECOND",
596 IntervalUnit::Year => "INTERVAL_YEAR",
597 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
598 },
599 SqlType::Array => "ARRAY",
600 SqlType::Struct => "STRUCT",
601 SqlType::Uuid => "UUID",
602 SqlType::Map => "MAP",
603 SqlType::Null => "NULL",
604 SqlType::Variant => "VARIANT",
605 };
606 serializer.serialize_str(type_str)
607 }
608}
609
610impl SqlType {
611 pub fn is_string(&self) -> bool {
613 matches!(self, Self::Char | Self::Varchar)
614 }
615
616 pub fn is_varchar(&self) -> bool {
617 matches!(self, Self::Varchar)
618 }
619
620 pub fn is_varbinary(&self) -> bool {
621 matches!(self, Self::Varbinary)
622 }
623}
624
625const fn default_is_struct() -> SqlType {
628 SqlType::Struct
629}
630
631#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
635#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
636pub struct ColumnType {
637 #[serde(rename = "type")]
639 #[serde(default = "default_is_struct")]
640 pub typ: SqlType,
641 pub nullable: bool,
643 pub precision: Option<i64>,
652 pub scale: Option<i64>,
657 #[cfg_attr(feature = "testing", proptest(value = "None"))]
664 pub component: Option<Box<ColumnType>>,
665 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
687 pub fields: Option<Vec<Field>>,
688 #[cfg_attr(feature = "testing", proptest(value = "None"))]
690 pub key: Option<Box<ColumnType>>,
691 #[cfg_attr(feature = "testing", proptest(value = "None"))]
693 pub value: Option<Box<ColumnType>>,
694}
695
696impl ColumnType {
697 pub fn boolean(nullable: bool) -> Self {
698 ColumnType {
699 typ: SqlType::Boolean,
700 nullable,
701 precision: None,
702 scale: None,
703 component: None,
704 fields: None,
705 key: None,
706 value: None,
707 }
708 }
709
710 pub fn uuid(nullable: bool) -> Self {
711 ColumnType {
712 typ: SqlType::Uuid,
713 nullable,
714 precision: None,
715 scale: None,
716 component: None,
717 fields: None,
718 key: None,
719 value: None,
720 }
721 }
722
723 pub fn tinyint(nullable: bool) -> Self {
724 ColumnType {
725 typ: SqlType::TinyInt,
726 nullable,
727 precision: None,
728 scale: None,
729 component: None,
730 fields: None,
731 key: None,
732 value: None,
733 }
734 }
735
736 pub fn smallint(nullable: bool) -> Self {
737 ColumnType {
738 typ: SqlType::SmallInt,
739 nullable,
740 precision: None,
741 scale: None,
742 component: None,
743 fields: None,
744 key: None,
745 value: None,
746 }
747 }
748
749 pub fn int(nullable: bool) -> Self {
750 ColumnType {
751 typ: SqlType::Int,
752 nullable,
753 precision: None,
754 scale: None,
755 component: None,
756 fields: None,
757 key: None,
758 value: None,
759 }
760 }
761
762 pub fn bigint(nullable: bool) -> Self {
763 ColumnType {
764 typ: SqlType::BigInt,
765 nullable,
766 precision: None,
767 scale: None,
768 component: None,
769 fields: None,
770 key: None,
771 value: None,
772 }
773 }
774
775 pub fn utinyint(nullable: bool) -> Self {
776 ColumnType {
777 typ: SqlType::UTinyInt,
778 nullable,
779 precision: None,
780 scale: None,
781 component: None,
782 fields: None,
783 key: None,
784 value: None,
785 }
786 }
787
788 pub fn usmallint(nullable: bool) -> Self {
789 ColumnType {
790 typ: SqlType::USmallInt,
791 nullable,
792 precision: None,
793 scale: None,
794 component: None,
795 fields: None,
796 key: None,
797 value: None,
798 }
799 }
800
801 pub fn uint(nullable: bool) -> Self {
802 ColumnType {
803 typ: SqlType::UInt,
804 nullable,
805 precision: None,
806 scale: None,
807 component: None,
808 fields: None,
809 key: None,
810 value: None,
811 }
812 }
813
814 pub fn ubigint(nullable: bool) -> Self {
815 ColumnType {
816 typ: SqlType::UBigInt,
817 nullable,
818 precision: None,
819 scale: None,
820 component: None,
821 fields: None,
822 key: None,
823 value: None,
824 }
825 }
826
827 pub fn double(nullable: bool) -> Self {
828 ColumnType {
829 typ: SqlType::Double,
830 nullable,
831 precision: None,
832 scale: None,
833 component: None,
834 fields: None,
835 key: None,
836 value: None,
837 }
838 }
839
840 pub fn real(nullable: bool) -> Self {
841 ColumnType {
842 typ: SqlType::Real,
843 nullable,
844 precision: None,
845 scale: None,
846 component: None,
847 fields: None,
848 key: None,
849 value: None,
850 }
851 }
852
853 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
854 ColumnType {
855 typ: SqlType::Decimal,
856 nullable,
857 precision: Some(precision),
858 scale: Some(scale),
859 component: None,
860 fields: None,
861 key: None,
862 value: None,
863 }
864 }
865
866 pub fn varchar(nullable: bool) -> Self {
867 ColumnType {
868 typ: SqlType::Varchar,
869 nullable,
870 precision: None,
871 scale: None,
872 component: None,
873 fields: None,
874 key: None,
875 value: None,
876 }
877 }
878
879 pub fn varbinary(nullable: bool) -> Self {
880 ColumnType {
881 typ: SqlType::Varbinary,
882 nullable,
883 precision: None,
884 scale: None,
885 component: None,
886 fields: None,
887 key: None,
888 value: None,
889 }
890 }
891
892 pub fn fixed(width: i64, nullable: bool) -> Self {
893 ColumnType {
894 typ: SqlType::Binary,
895 nullable,
896 precision: Some(width),
897 scale: None,
898 component: None,
899 fields: None,
900 key: None,
901 value: None,
902 }
903 }
904
905 pub fn date(nullable: bool) -> Self {
906 ColumnType {
907 typ: SqlType::Date,
908 nullable,
909 precision: None,
910 scale: None,
911 component: None,
912 fields: None,
913 key: None,
914 value: None,
915 }
916 }
917
918 pub fn time(nullable: bool) -> Self {
919 ColumnType {
920 typ: SqlType::Time,
921 nullable,
922 precision: None,
923 scale: None,
924 component: None,
925 fields: None,
926 key: None,
927 value: None,
928 }
929 }
930
931 pub fn timestamp(nullable: bool) -> Self {
932 ColumnType {
933 typ: SqlType::Timestamp,
934 nullable,
935 precision: None,
936 scale: None,
937 component: None,
938 fields: None,
939 key: None,
940 value: None,
941 }
942 }
943
944 pub fn variant(nullable: bool) -> Self {
945 ColumnType {
946 typ: SqlType::Variant,
947 nullable,
948 precision: None,
949 scale: None,
950 component: None,
951 fields: None,
952 key: None,
953 value: None,
954 }
955 }
956
957 pub fn array(nullable: bool, element: ColumnType) -> Self {
958 ColumnType {
959 typ: SqlType::Array,
960 nullable,
961 precision: None,
962 scale: None,
963 component: Some(Box::new(element)),
964 fields: None,
965 key: None,
966 value: None,
967 }
968 }
969
970 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
971 ColumnType {
972 typ: SqlType::Struct,
973 nullable,
974 precision: None,
975 scale: None,
976 component: None,
977 fields: Some(fields.to_vec()),
978 key: None,
979 value: None,
980 }
981 }
982
983 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
984 ColumnType {
985 typ: SqlType::Map,
986 nullable,
987 precision: None,
988 scale: None,
989 component: None,
990 fields: None,
991 key: Some(Box::new(key)),
992 value: Some(Box::new(val)),
993 }
994 }
995
996 pub fn is_integral_type(&self) -> bool {
997 matches!(
998 &self.typ,
999 SqlType::TinyInt
1000 | SqlType::SmallInt
1001 | SqlType::Int
1002 | SqlType::BigInt
1003 | SqlType::UTinyInt
1004 | SqlType::USmallInt
1005 | SqlType::UInt
1006 | SqlType::UBigInt
1007 )
1008 }
1009
1010 pub fn is_fp_type(&self) -> bool {
1011 matches!(&self.typ, SqlType::Double | SqlType::Real)
1012 }
1013
1014 pub fn is_decimal_type(&self) -> bool {
1015 matches!(&self.typ, SqlType::Decimal)
1016 }
1017
1018 pub fn is_numeric_type(&self) -> bool {
1019 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
1020 }
1021}
1022
1023#[cfg(test)]
1024mod tests {
1025 use super::{IntervalUnit, SqlIdentifier};
1026 use crate::program_schema::{ColumnType, Field, SqlType};
1027
1028 #[test]
1029 fn serde_sql_type() {
1030 for (sql_str_base, expected_value) in [
1031 ("Boolean", SqlType::Boolean),
1032 ("Uuid", SqlType::Uuid),
1033 ("TinyInt", SqlType::TinyInt),
1034 ("SmallInt", SqlType::SmallInt),
1035 ("Integer", SqlType::Int),
1036 ("BigInt", SqlType::BigInt),
1037 ("UTinyInt", SqlType::UTinyInt),
1038 ("USmallInt", SqlType::USmallInt),
1039 ("UInteger", SqlType::UInt),
1040 ("UBigInt", SqlType::UBigInt),
1041 ("Real", SqlType::Real),
1042 ("Double", SqlType::Double),
1043 ("Decimal", SqlType::Decimal),
1044 ("Char", SqlType::Char),
1045 ("Varchar", SqlType::Varchar),
1046 ("Binary", SqlType::Binary),
1047 ("Varbinary", SqlType::Varbinary),
1048 ("Time", SqlType::Time),
1049 ("Date", SqlType::Date),
1050 ("Timestamp", SqlType::Timestamp),
1051 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
1052 (
1053 "Interval_Day_Hour",
1054 SqlType::Interval(IntervalUnit::DayToHour),
1055 ),
1056 (
1057 "Interval_Day_Minute",
1058 SqlType::Interval(IntervalUnit::DayToMinute),
1059 ),
1060 (
1061 "Interval_Day_Second",
1062 SqlType::Interval(IntervalUnit::DayToSecond),
1063 ),
1064 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
1065 (
1066 "Interval_Hour_Minute",
1067 SqlType::Interval(IntervalUnit::HourToMinute),
1068 ),
1069 (
1070 "Interval_Hour_Second",
1071 SqlType::Interval(IntervalUnit::HourToSecond),
1072 ),
1073 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
1074 (
1075 "Interval_Minute_Second",
1076 SqlType::Interval(IntervalUnit::MinuteToSecond),
1077 ),
1078 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
1079 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
1080 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
1081 (
1082 "Interval_Year_Month",
1083 SqlType::Interval(IntervalUnit::YearToMonth),
1084 ),
1085 ("Array", SqlType::Array),
1086 ("Struct", SqlType::Struct),
1087 ("Map", SqlType::Map),
1088 ("Null", SqlType::Null),
1089 ("Variant", SqlType::Variant),
1090 ] {
1091 for sql_str in [
1092 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
1096 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
1097 .unwrap_or_else(|_| {
1098 panic!("\"{sql_str}\" should deserialize into its SQL type")
1099 });
1100 assert_eq!(value1, expected_value);
1101 let serialized_str =
1102 serde_json::to_string(&value1).expect("Value should serialize into JSON");
1103 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
1104 panic!(
1105 "{} should deserialize back into its SQL type",
1106 serialized_str
1107 )
1108 });
1109 assert_eq!(value1, value2);
1110 }
1111 }
1112 }
1113
1114 #[test]
1115 fn deserialize_interval_types() {
1116 use super::IntervalUnit::*;
1117 use super::SqlType::*;
1118
1119 let schema = r#"
1120{
1121 "inputs" : [ {
1122 "name" : "sales",
1123 "case_sensitive" : false,
1124 "fields" : [ {
1125 "name" : "sales_id",
1126 "case_sensitive" : false,
1127 "columntype" : {
1128 "type" : "INTEGER",
1129 "nullable" : true
1130 }
1131 }, {
1132 "name" : "customer_id",
1133 "case_sensitive" : false,
1134 "columntype" : {
1135 "type" : "INTEGER",
1136 "nullable" : true
1137 }
1138 }, {
1139 "name" : "age",
1140 "case_sensitive" : false,
1141 "columntype" : {
1142 "type" : "UINTEGER",
1143 "nullable" : true
1144 }
1145 }, {
1146 "name" : "amount",
1147 "case_sensitive" : false,
1148 "columntype" : {
1149 "type" : "DECIMAL",
1150 "nullable" : true,
1151 "precision" : 10,
1152 "scale" : 2
1153 }
1154 }, {
1155 "name" : "sale_date",
1156 "case_sensitive" : false,
1157 "columntype" : {
1158 "type" : "DATE",
1159 "nullable" : true
1160 }
1161 } ],
1162 "primary_key" : [ "sales_id" ]
1163 } ],
1164 "outputs" : [ {
1165 "name" : "salessummary",
1166 "case_sensitive" : false,
1167 "fields" : [ {
1168 "name" : "customer_id",
1169 "case_sensitive" : false,
1170 "columntype" : {
1171 "type" : "INTEGER",
1172 "nullable" : true
1173 }
1174 }, {
1175 "name" : "total_sales",
1176 "case_sensitive" : false,
1177 "columntype" : {
1178 "type" : "DECIMAL",
1179 "nullable" : true,
1180 "precision" : 38,
1181 "scale" : 2
1182 }
1183 }, {
1184 "name" : "interval_day",
1185 "case_sensitive" : false,
1186 "columntype" : {
1187 "type" : "INTERVAL_DAY",
1188 "nullable" : false,
1189 "precision" : 2,
1190 "scale" : 6
1191 }
1192 }, {
1193 "name" : "interval_day_to_hour",
1194 "case_sensitive" : false,
1195 "columntype" : {
1196 "type" : "INTERVAL_DAY_HOUR",
1197 "nullable" : false,
1198 "precision" : 2,
1199 "scale" : 6
1200 }
1201 }, {
1202 "name" : "interval_day_to_minute",
1203 "case_sensitive" : false,
1204 "columntype" : {
1205 "type" : "INTERVAL_DAY_MINUTE",
1206 "nullable" : false,
1207 "precision" : 2,
1208 "scale" : 6
1209 }
1210 }, {
1211 "name" : "interval_day_to_second",
1212 "case_sensitive" : false,
1213 "columntype" : {
1214 "type" : "INTERVAL_DAY_SECOND",
1215 "nullable" : false,
1216 "precision" : 2,
1217 "scale" : 6
1218 }
1219 }, {
1220 "name" : "interval_hour",
1221 "case_sensitive" : false,
1222 "columntype" : {
1223 "type" : "INTERVAL_HOUR",
1224 "nullable" : false,
1225 "precision" : 2,
1226 "scale" : 6
1227 }
1228 }, {
1229 "name" : "interval_hour_to_minute",
1230 "case_sensitive" : false,
1231 "columntype" : {
1232 "type" : "INTERVAL_HOUR_MINUTE",
1233 "nullable" : false,
1234 "precision" : 2,
1235 "scale" : 6
1236 }
1237 }, {
1238 "name" : "interval_hour_to_second",
1239 "case_sensitive" : false,
1240 "columntype" : {
1241 "type" : "INTERVAL_HOUR_SECOND",
1242 "nullable" : false,
1243 "precision" : 2,
1244 "scale" : 6
1245 }
1246 }, {
1247 "name" : "interval_minute",
1248 "case_sensitive" : false,
1249 "columntype" : {
1250 "type" : "INTERVAL_MINUTE",
1251 "nullable" : false,
1252 "precision" : 2,
1253 "scale" : 6
1254 }
1255 }, {
1256 "name" : "interval_minute_to_second",
1257 "case_sensitive" : false,
1258 "columntype" : {
1259 "type" : "INTERVAL_MINUTE_SECOND",
1260 "nullable" : false,
1261 "precision" : 2,
1262 "scale" : 6
1263 }
1264 }, {
1265 "name" : "interval_month",
1266 "case_sensitive" : false,
1267 "columntype" : {
1268 "type" : "INTERVAL_MONTH",
1269 "nullable" : false
1270 }
1271 }, {
1272 "name" : "interval_second",
1273 "case_sensitive" : false,
1274 "columntype" : {
1275 "type" : "INTERVAL_SECOND",
1276 "nullable" : false,
1277 "precision" : 2,
1278 "scale" : 6
1279 }
1280 }, {
1281 "name" : "interval_year",
1282 "case_sensitive" : false,
1283 "columntype" : {
1284 "type" : "INTERVAL_YEAR",
1285 "nullable" : false
1286 }
1287 }, {
1288 "name" : "interval_year_to_month",
1289 "case_sensitive" : false,
1290 "columntype" : {
1291 "type" : "INTERVAL_YEAR_MONTH",
1292 "nullable" : false
1293 }
1294 } ]
1295 } ]
1296}
1297"#;
1298
1299 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1300 let types = schema
1301 .outputs
1302 .iter()
1303 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1304 let expected_types = [
1305 Int,
1306 Decimal,
1307 Interval(Day),
1308 Interval(DayToHour),
1309 Interval(DayToMinute),
1310 Interval(DayToSecond),
1311 Interval(Hour),
1312 Interval(HourToMinute),
1313 Interval(HourToSecond),
1314 Interval(Minute),
1315 Interval(MinuteToSecond),
1316 Interval(Month),
1317 Interval(Second),
1318 Interval(Year),
1319 Interval(YearToMonth),
1320 ];
1321
1322 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1323 }
1324
1325 #[test]
1326 fn serialize_struct_schemas() {
1327 let schema = r#"{
1328 "inputs" : [ {
1329 "name" : "PERS",
1330 "case_sensitive" : false,
1331 "fields" : [ {
1332 "name" : "P0",
1333 "case_sensitive" : false,
1334 "columntype" : {
1335 "fields" : [ {
1336 "type" : "VARCHAR",
1337 "nullable" : true,
1338 "precision" : 30,
1339 "name" : "FIRSTNAME"
1340 }, {
1341 "type" : "VARCHAR",
1342 "nullable" : true,
1343 "precision" : 30,
1344 "name" : "LASTNAME"
1345 }, {
1346 "type" : "UINTEGER",
1347 "nullable" : true,
1348 "name" : "AGE"
1349 }, {
1350 "fields" : {
1351 "fields" : [ {
1352 "type" : "VARCHAR",
1353 "nullable" : true,
1354 "precision" : 30,
1355 "name" : "STREET"
1356 }, {
1357 "type" : "VARCHAR",
1358 "nullable" : true,
1359 "precision" : 30,
1360 "name" : "CITY"
1361 }, {
1362 "type" : "CHAR",
1363 "nullable" : true,
1364 "precision" : 2,
1365 "name" : "STATE"
1366 }, {
1367 "type" : "VARCHAR",
1368 "nullable" : true,
1369 "precision" : 6,
1370 "name" : "POSTAL_CODE"
1371 } ],
1372 "nullable" : false
1373 },
1374 "nullable" : false,
1375 "name" : "ADDRESS"
1376 } ],
1377 "nullable" : false
1378 }
1379 }]
1380 } ],
1381 "outputs" : [ ]
1382}
1383"#;
1384 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1385 eprintln!("{:#?}", schema);
1386 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1387 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1388 assert_eq!(p0.columntype.typ, SqlType::Struct);
1389 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1390 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1391 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1392 assert_eq!(p0_fields[2].columntype.typ, SqlType::UInt);
1393 assert_eq!(p0_fields[3].columntype.typ, SqlType::Struct);
1394 assert_eq!(p0_fields[3].name, "ADDRESS");
1395 let address = &p0_fields[3].columntype.fields.as_ref().unwrap();
1396 assert_eq!(address.len(), 4);
1397 assert_eq!(address[0].name, "STREET");
1398 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1399 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1400 assert_eq!(address[2].columntype.typ, SqlType::Char);
1401 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1402 }
1403
1404 #[test]
1405 fn sql_identifier_cmp() {
1406 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1407 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1408 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1409 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1410 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1411 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1412 assert_eq!(
1413 SqlIdentifier::new("bAr", true),
1414 SqlIdentifier::from("\"bAr\"")
1415 );
1416
1417 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1418 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1419 }
1420
1421 #[test]
1422 fn sql_identifier_ord() {
1423 let mut btree = std::collections::BTreeSet::new();
1424 assert!(btree.insert(SqlIdentifier::from("foo")));
1425 assert!(btree.insert(SqlIdentifier::from("bar")));
1426 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1427 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1428 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1429 }
1430
1431 #[test]
1432 fn sql_identifier_hash() {
1433 let mut hs = std::collections::HashSet::new();
1434 assert!(hs.insert(SqlIdentifier::from("foo")));
1435 assert!(hs.insert(SqlIdentifier::from("bar")));
1436 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1437 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1438 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1439 }
1440
1441 #[test]
1442 fn sql_identifier_name() {
1443 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1444 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1445 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1446 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1447 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1448 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1449 }
1450
1451 #[test]
1452 fn issue3277() {
1453 let schema = r#"{
1454 "name" : "j",
1455 "case_sensitive" : false,
1456 "columntype" : {
1457 "fields" : [ {
1458 "key" : {
1459 "nullable" : false,
1460 "precision" : -1,
1461 "type" : "VARCHAR"
1462 },
1463 "name" : "s",
1464 "nullable" : true,
1465 "type" : "MAP",
1466 "value" : {
1467 "nullable" : true,
1468 "precision" : -1,
1469 "type" : "VARCHAR"
1470 }
1471 } ],
1472 "nullable" : true
1473 }
1474 }"#;
1475 let field: Field = serde_json::from_str(schema).unwrap();
1476 println!("field: {:#?}", field);
1477 assert_eq!(
1478 field,
1479 Field {
1480 name: SqlIdentifier {
1481 name: "j".to_string(),
1482 case_sensitive: false,
1483 },
1484 columntype: ColumnType {
1485 typ: SqlType::Struct,
1486 nullable: true,
1487 precision: None,
1488 scale: None,
1489 component: None,
1490 fields: Some(vec![Field {
1491 name: SqlIdentifier {
1492 name: "s".to_string(),
1493 case_sensitive: false,
1494 },
1495 columntype: ColumnType {
1496 typ: SqlType::Map,
1497 nullable: true,
1498 precision: None,
1499 scale: None,
1500 component: None,
1501 fields: None,
1502 key: Some(Box::new(ColumnType {
1503 typ: SqlType::Varchar,
1504 nullable: false,
1505 precision: Some(-1),
1506 scale: None,
1507 component: None,
1508 fields: None,
1509 key: None,
1510 value: None,
1511 })),
1512 value: Some(Box::new(ColumnType {
1513 typ: SqlType::Varchar,
1514 nullable: true,
1515 precision: Some(-1),
1516 scale: None,
1517 component: None,
1518 fields: None,
1519 key: None,
1520 value: None,
1521 })),
1522 },
1523 lateness: None,
1524 default: None,
1525 unused: false,
1526 watermark: None,
1527 }]),
1528 key: None,
1529 value: None,
1530 },
1531 lateness: None,
1532 default: None,
1533 unused: false,
1534 watermark: None,
1535 }
1536 );
1537 }
1538}