1pub use feldera_ir::SourcePosition;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::cmp::Ordering;
4use std::collections::BTreeMap;
5use std::fmt::Display;
6use std::hash::{Hash, Hasher};
7use utoipa::ToSchema;
8
9#[cfg(feature = "testing")]
10use proptest::{collection::vec, prelude::any};
11
12pub fn canonical_identifier(id: &str) -> String {
20 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
21 id[1..id.len() - 1].to_string()
22 } else {
23 id.to_lowercase()
24 }
25}
26
27#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
32#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
33pub struct SqlIdentifier {
34 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
35 name: String,
36 pub case_sensitive: bool,
37}
38
39impl SqlIdentifier {
40 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
41 Self {
42 name: name.as_ref().to_string(),
43 case_sensitive,
44 }
45 }
46
47 pub fn name(&self) -> String {
57 if self.case_sensitive {
58 self.name.clone()
59 } else {
60 self.name.to_lowercase()
61 }
62 }
63
64 pub fn sql_name(&self) -> String {
75 if self.case_sensitive {
76 format!("\"{}\"", self.name)
77 } else {
78 self.name.clone()
79 }
80 }
81}
82
83impl Hash for SqlIdentifier {
84 fn hash<H: Hasher>(&self, state: &mut H) {
85 self.name().hash(state);
86 }
87}
88
89impl PartialEq for SqlIdentifier {
90 fn eq(&self, other: &Self) -> bool {
91 match (self.case_sensitive, other.case_sensitive) {
92 (true, true) => self.name == other.name,
93 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
94 (true, false) => self.name == other.name,
95 (false, true) => self.name == other.name,
96 }
97 }
98}
99
100impl Ord for SqlIdentifier {
101 fn cmp(&self, other: &Self) -> Ordering {
102 self.name().cmp(&other.name())
103 }
104}
105
106impl PartialOrd for SqlIdentifier {
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
113 fn eq(&self, other: &S) -> bool {
114 self == &SqlIdentifier::from(other.as_ref())
115 }
116}
117
118impl Eq for SqlIdentifier {}
119
120impl<S: AsRef<str>> From<S> for SqlIdentifier {
121 fn from(name: S) -> Self {
122 if name.as_ref().starts_with('"')
123 && name.as_ref().ends_with('"')
124 && name.as_ref().len() >= 2
125 {
126 Self {
127 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
128 case_sensitive: true,
129 }
130 } else {
131 Self::new(name, false)
132 }
133 }
134}
135
136impl From<SqlIdentifier> for String {
137 fn from(id: SqlIdentifier) -> String {
138 id.name()
139 }
140}
141
142impl From<&SqlIdentifier> for String {
143 fn from(id: &SqlIdentifier) -> String {
144 id.name()
145 }
146}
147
148impl Display for SqlIdentifier {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(f, "{}", self.name())
151 }
152}
153
154#[derive(Default, Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
158#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
159pub struct ProgramSchema {
160 #[cfg_attr(
161 feature = "testing",
162 proptest(strategy = "vec(any::<Relation>(), 0..2)")
163 )]
164 pub inputs: Vec<Relation>,
165 #[cfg_attr(
166 feature = "testing",
167 proptest(strategy = "vec(any::<Relation>(), 0..2)")
168 )]
169 pub outputs: Vec<Relation>,
170}
171
172impl ProgramSchema {
173 pub fn relations_with_lateness(&self) -> Vec<SqlIdentifier> {
174 self.inputs
175 .iter()
176 .chain(self.outputs.iter())
177 .filter(|rel| rel.has_lateness())
178 .map(|rel| rel.name.clone())
179 .collect()
180 }
181}
182
183#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
184#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
185pub struct PropertyValue {
186 pub value: String,
187 pub key_position: SourcePosition,
188 pub value_position: SourcePosition,
189}
190
191#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
195#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
196pub struct Relation {
197 #[serde(flatten)]
199 pub name: SqlIdentifier,
200 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
201 pub fields: Vec<Field>,
202 #[serde(default)]
203 pub materialized: bool,
204 #[serde(default)]
205 pub properties: BTreeMap<String, PropertyValue>,
206}
207
208impl Relation {
209 pub fn empty() -> Self {
210 Self {
211 name: SqlIdentifier::from("".to_string()),
212 fields: Vec::new(),
213 materialized: false,
214 properties: BTreeMap::new(),
215 }
216 }
217
218 pub fn new(
219 name: SqlIdentifier,
220 fields: Vec<Field>,
221 materialized: bool,
222 properties: BTreeMap<String, PropertyValue>,
223 ) -> Self {
224 Self {
225 name,
226 fields,
227 materialized,
228 properties,
229 }
230 }
231
232 pub fn field(&self, name: &str) -> Option<&Field> {
234 let name = canonical_identifier(name);
235 self.fields.iter().find(|f| f.name == name)
236 }
237
238 pub fn has_lateness(&self) -> bool {
239 self.fields.iter().any(|f| f.lateness.is_some())
240 }
241
242 pub fn get_property(&self, name: &str) -> Option<&str> {
243 self.properties.get(name).map(|p| p.value.as_str())
244 }
245}
246
247#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
251#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
252pub struct Field {
253 #[serde(flatten)]
254 pub name: SqlIdentifier,
255 pub columntype: ColumnType,
256 pub lateness: Option<String>,
257 pub default: Option<String>,
258 pub unused: bool,
259 pub watermark: Option<String>,
260}
261
262impl Field {
263 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
264 Self {
265 name,
266 columntype,
267 lateness: None,
268 default: None,
269 unused: false,
270 watermark: None,
271 }
272 }
273
274 pub fn with_lateness(mut self, lateness: &str) -> Self {
275 self.lateness = Some(lateness.to_string());
276 self
277 }
278
279 pub fn with_unused(mut self, unused: bool) -> Self {
280 self.unused = unused;
281 self
282 }
283}
284
285impl<'de> Deserialize<'de> for Field {
290 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
291 where
292 D: Deserializer<'de>,
293 {
294 const fn default_is_struct() -> Option<SqlType> {
295 Some(SqlType::Struct)
296 }
297
298 #[derive(Debug, Clone, Deserialize)]
299 struct FieldHelper {
300 name: Option<String>,
301 #[serde(default)]
302 case_sensitive: bool,
303 columntype: Option<ColumnType>,
304 #[serde(rename = "type")]
305 #[serde(default = "default_is_struct")]
306 typ: Option<SqlType>,
307 nullable: Option<bool>,
308 precision: Option<i64>,
309 scale: Option<i64>,
310 component: Option<Box<ColumnType>>,
311 fields: Option<serde_json::Value>,
312 key: Option<Box<ColumnType>>,
313 value: Option<Box<ColumnType>>,
314 default: Option<String>,
315 #[serde(default)]
316 unused: bool,
317 lateness: Option<String>,
318 watermark: Option<String>,
319 }
320
321 fn helper_to_field(helper: FieldHelper) -> Field {
322 let columntype = if let Some(ctype) = helper.columntype {
323 ctype
324 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
325 let fields = fields
326 .into_iter()
327 .map(|field| {
328 let field: FieldHelper = serde_json::from_value(field).unwrap();
329 helper_to_field(field)
330 })
331 .collect::<Vec<Field>>();
332
333 ColumnType {
334 typ: helper.typ.unwrap_or(SqlType::Null),
335 nullable: helper.nullable.unwrap_or(false),
336 precision: helper.precision,
337 scale: helper.scale,
338 component: helper.component,
339 fields: Some(fields),
340 key: None,
341 value: None,
342 }
343 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
344 serde_json::from_value(serde_json::Value::Object(obj))
345 .expect("Failed to deserialize object")
346 } else {
347 ColumnType {
348 typ: helper.typ.unwrap_or(SqlType::Null),
349 nullable: helper.nullable.unwrap_or(false),
350 precision: helper.precision,
351 scale: helper.scale,
352 component: helper.component,
353 fields: None,
354 key: helper.key,
355 value: helper.value,
356 }
357 };
358
359 Field {
360 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
361 columntype,
362 default: helper.default,
363 unused: helper.unused,
364 lateness: helper.lateness,
365 watermark: helper.watermark,
366 }
367 }
368
369 let helper = FieldHelper::deserialize(deserializer)?;
370 Ok(helper_to_field(helper))
371 }
372}
373
374#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
379#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
380pub enum IntervalUnit {
381 Day,
383 DayToHour,
385 DayToMinute,
387 DayToSecond,
389 Hour,
391 HourToMinute,
393 HourToSecond,
395 Minute,
397 MinuteToSecond,
399 Month,
401 Second,
403 Year,
405 YearToMonth,
407}
408
409#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
411#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
412pub enum SqlType {
413 Boolean,
415 TinyInt,
417 SmallInt,
419 Int,
421 BigInt,
423 UTinyInt,
425 USmallInt,
427 UInt,
429 UBigInt,
431 Real,
433 Double,
435 Decimal,
437 Char,
439 Varchar,
441 Binary,
443 Varbinary,
445 Time,
447 Date,
449 Timestamp,
451 Interval(IntervalUnit),
453 Array,
455 Struct,
457 Map,
459 Null,
461 Uuid,
463 Variant,
465}
466
467impl Display for SqlType {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 f.write_str(&serde_json::to_string(self).unwrap())
470 }
471}
472
473impl<'de> Deserialize<'de> for SqlType {
474 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
475 where
476 D: Deserializer<'de>,
477 {
478 let value: String = Deserialize::deserialize(deserializer)?;
479 match value.to_lowercase().as_str() {
480 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
481 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
482 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
483 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
484 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
485 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
486 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
487 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
488 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
489 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
490 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
491 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
492 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
493 "boolean" => Ok(SqlType::Boolean),
494 "tinyint" => Ok(SqlType::TinyInt),
495 "smallint" => Ok(SqlType::SmallInt),
496 "integer" => Ok(SqlType::Int),
497 "bigint" => Ok(SqlType::BigInt),
498 "utinyint" => Ok(SqlType::UTinyInt),
499 "usmallint" => Ok(SqlType::USmallInt),
500 "uinteger" => Ok(SqlType::UInt),
501 "ubigint" => Ok(SqlType::UBigInt),
502 "real" => Ok(SqlType::Real),
503 "double" => Ok(SqlType::Double),
504 "decimal" => Ok(SqlType::Decimal),
505 "char" => Ok(SqlType::Char),
506 "varchar" => Ok(SqlType::Varchar),
507 "binary" => Ok(SqlType::Binary),
508 "varbinary" => Ok(SqlType::Varbinary),
509 "variant" => Ok(SqlType::Variant),
510 "time" => Ok(SqlType::Time),
511 "date" => Ok(SqlType::Date),
512 "timestamp" => Ok(SqlType::Timestamp),
513 "array" => Ok(SqlType::Array),
514 "struct" => Ok(SqlType::Struct),
515 "map" => Ok(SqlType::Map),
516 "null" => Ok(SqlType::Null),
517 "uuid" => Ok(SqlType::Uuid),
518 _ => Err(serde::de::Error::custom(format!(
519 "Unknown SQL type: {}",
520 value
521 ))),
522 }
523 }
524}
525
526impl Serialize for SqlType {
527 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
528 where
529 S: Serializer,
530 {
531 let type_str = match self {
532 SqlType::Boolean => "BOOLEAN",
533 SqlType::TinyInt => "TINYINT",
534 SqlType::SmallInt => "SMALLINT",
535 SqlType::Int => "INTEGER",
536 SqlType::BigInt => "BIGINT",
537 SqlType::UTinyInt => "UTINYINT",
538 SqlType::USmallInt => "USMALLINT",
539 SqlType::UInt => "UINTEGER",
540 SqlType::UBigInt => "UBIGINT",
541 SqlType::Real => "REAL",
542 SqlType::Double => "DOUBLE",
543 SqlType::Decimal => "DECIMAL",
544 SqlType::Char => "CHAR",
545 SqlType::Varchar => "VARCHAR",
546 SqlType::Binary => "BINARY",
547 SqlType::Varbinary => "VARBINARY",
548 SqlType::Time => "TIME",
549 SqlType::Date => "DATE",
550 SqlType::Timestamp => "TIMESTAMP",
551 SqlType::Interval(interval_unit) => match interval_unit {
552 IntervalUnit::Day => "INTERVAL_DAY",
553 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
554 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
555 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
556 IntervalUnit::Hour => "INTERVAL_HOUR",
557 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
558 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
559 IntervalUnit::Minute => "INTERVAL_MINUTE",
560 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
561 IntervalUnit::Month => "INTERVAL_MONTH",
562 IntervalUnit::Second => "INTERVAL_SECOND",
563 IntervalUnit::Year => "INTERVAL_YEAR",
564 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
565 },
566 SqlType::Array => "ARRAY",
567 SqlType::Struct => "STRUCT",
568 SqlType::Uuid => "UUID",
569 SqlType::Map => "MAP",
570 SqlType::Null => "NULL",
571 SqlType::Variant => "VARIANT",
572 };
573 serializer.serialize_str(type_str)
574 }
575}
576
577impl SqlType {
578 pub fn is_string(&self) -> bool {
580 matches!(self, Self::Char | Self::Varchar)
581 }
582
583 pub fn is_varchar(&self) -> bool {
584 matches!(self, Self::Varchar)
585 }
586
587 pub fn is_varbinary(&self) -> bool {
588 matches!(self, Self::Varbinary)
589 }
590}
591
592const fn default_is_struct() -> SqlType {
595 SqlType::Struct
596}
597
598#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
602#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
603pub struct ColumnType {
604 #[serde(rename = "type")]
606 #[serde(default = "default_is_struct")]
607 pub typ: SqlType,
608 pub nullable: bool,
610 pub precision: Option<i64>,
619 pub scale: Option<i64>,
624 #[cfg_attr(feature = "testing", proptest(value = "None"))]
631 pub component: Option<Box<ColumnType>>,
632 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
654 pub fields: Option<Vec<Field>>,
655 #[cfg_attr(feature = "testing", proptest(value = "None"))]
657 pub key: Option<Box<ColumnType>>,
658 #[cfg_attr(feature = "testing", proptest(value = "None"))]
660 pub value: Option<Box<ColumnType>>,
661}
662
663impl ColumnType {
664 pub fn boolean(nullable: bool) -> Self {
665 ColumnType {
666 typ: SqlType::Boolean,
667 nullable,
668 precision: None,
669 scale: None,
670 component: None,
671 fields: None,
672 key: None,
673 value: None,
674 }
675 }
676
677 pub fn uuid(nullable: bool) -> Self {
678 ColumnType {
679 typ: SqlType::Uuid,
680 nullable,
681 precision: None,
682 scale: None,
683 component: None,
684 fields: None,
685 key: None,
686 value: None,
687 }
688 }
689
690 pub fn tinyint(nullable: bool) -> Self {
691 ColumnType {
692 typ: SqlType::TinyInt,
693 nullable,
694 precision: None,
695 scale: None,
696 component: None,
697 fields: None,
698 key: None,
699 value: None,
700 }
701 }
702
703 pub fn smallint(nullable: bool) -> Self {
704 ColumnType {
705 typ: SqlType::SmallInt,
706 nullable,
707 precision: None,
708 scale: None,
709 component: None,
710 fields: None,
711 key: None,
712 value: None,
713 }
714 }
715
716 pub fn int(nullable: bool) -> Self {
717 ColumnType {
718 typ: SqlType::Int,
719 nullable,
720 precision: None,
721 scale: None,
722 component: None,
723 fields: None,
724 key: None,
725 value: None,
726 }
727 }
728
729 pub fn bigint(nullable: bool) -> Self {
730 ColumnType {
731 typ: SqlType::BigInt,
732 nullable,
733 precision: None,
734 scale: None,
735 component: None,
736 fields: None,
737 key: None,
738 value: None,
739 }
740 }
741
742 pub fn utinyint(nullable: bool) -> Self {
743 ColumnType {
744 typ: SqlType::UTinyInt,
745 nullable,
746 precision: None,
747 scale: None,
748 component: None,
749 fields: None,
750 key: None,
751 value: None,
752 }
753 }
754
755 pub fn usmallint(nullable: bool) -> Self {
756 ColumnType {
757 typ: SqlType::USmallInt,
758 nullable,
759 precision: None,
760 scale: None,
761 component: None,
762 fields: None,
763 key: None,
764 value: None,
765 }
766 }
767
768 pub fn uint(nullable: bool) -> Self {
769 ColumnType {
770 typ: SqlType::UInt,
771 nullable,
772 precision: None,
773 scale: None,
774 component: None,
775 fields: None,
776 key: None,
777 value: None,
778 }
779 }
780
781 pub fn ubigint(nullable: bool) -> Self {
782 ColumnType {
783 typ: SqlType::UBigInt,
784 nullable,
785 precision: None,
786 scale: None,
787 component: None,
788 fields: None,
789 key: None,
790 value: None,
791 }
792 }
793
794 pub fn double(nullable: bool) -> Self {
795 ColumnType {
796 typ: SqlType::Double,
797 nullable,
798 precision: None,
799 scale: None,
800 component: None,
801 fields: None,
802 key: None,
803 value: None,
804 }
805 }
806
807 pub fn real(nullable: bool) -> Self {
808 ColumnType {
809 typ: SqlType::Real,
810 nullable,
811 precision: None,
812 scale: None,
813 component: None,
814 fields: None,
815 key: None,
816 value: None,
817 }
818 }
819
820 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
821 ColumnType {
822 typ: SqlType::Decimal,
823 nullable,
824 precision: Some(precision),
825 scale: Some(scale),
826 component: None,
827 fields: None,
828 key: None,
829 value: None,
830 }
831 }
832
833 pub fn varchar(nullable: bool) -> Self {
834 ColumnType {
835 typ: SqlType::Varchar,
836 nullable,
837 precision: None,
838 scale: None,
839 component: None,
840 fields: None,
841 key: None,
842 value: None,
843 }
844 }
845
846 pub fn varbinary(nullable: bool) -> Self {
847 ColumnType {
848 typ: SqlType::Varbinary,
849 nullable,
850 precision: None,
851 scale: None,
852 component: None,
853 fields: None,
854 key: None,
855 value: None,
856 }
857 }
858
859 pub fn fixed(width: i64, nullable: bool) -> Self {
860 ColumnType {
861 typ: SqlType::Binary,
862 nullable,
863 precision: Some(width),
864 scale: None,
865 component: None,
866 fields: None,
867 key: None,
868 value: None,
869 }
870 }
871
872 pub fn date(nullable: bool) -> Self {
873 ColumnType {
874 typ: SqlType::Date,
875 nullable,
876 precision: None,
877 scale: None,
878 component: None,
879 fields: None,
880 key: None,
881 value: None,
882 }
883 }
884
885 pub fn time(nullable: bool) -> Self {
886 ColumnType {
887 typ: SqlType::Time,
888 nullable,
889 precision: None,
890 scale: None,
891 component: None,
892 fields: None,
893 key: None,
894 value: None,
895 }
896 }
897
898 pub fn timestamp(nullable: bool) -> Self {
899 ColumnType {
900 typ: SqlType::Timestamp,
901 nullable,
902 precision: None,
903 scale: None,
904 component: None,
905 fields: None,
906 key: None,
907 value: None,
908 }
909 }
910
911 pub fn variant(nullable: bool) -> Self {
912 ColumnType {
913 typ: SqlType::Variant,
914 nullable,
915 precision: None,
916 scale: None,
917 component: None,
918 fields: None,
919 key: None,
920 value: None,
921 }
922 }
923
924 pub fn array(nullable: bool, element: ColumnType) -> Self {
925 ColumnType {
926 typ: SqlType::Array,
927 nullable,
928 precision: None,
929 scale: None,
930 component: Some(Box::new(element)),
931 fields: None,
932 key: None,
933 value: None,
934 }
935 }
936
937 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
938 ColumnType {
939 typ: SqlType::Struct,
940 nullable,
941 precision: None,
942 scale: None,
943 component: None,
944 fields: Some(fields.to_vec()),
945 key: None,
946 value: None,
947 }
948 }
949
950 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
951 ColumnType {
952 typ: SqlType::Map,
953 nullable,
954 precision: None,
955 scale: None,
956 component: None,
957 fields: None,
958 key: Some(Box::new(key)),
959 value: Some(Box::new(val)),
960 }
961 }
962
963 pub fn is_integral_type(&self) -> bool {
964 matches!(
965 &self.typ,
966 SqlType::TinyInt
967 | SqlType::SmallInt
968 | SqlType::Int
969 | SqlType::BigInt
970 | SqlType::UTinyInt
971 | SqlType::USmallInt
972 | SqlType::UInt
973 | SqlType::UBigInt
974 )
975 }
976
977 pub fn is_fp_type(&self) -> bool {
978 matches!(&self.typ, SqlType::Double | SqlType::Real)
979 }
980
981 pub fn is_decimal_type(&self) -> bool {
982 matches!(&self.typ, SqlType::Decimal)
983 }
984
985 pub fn is_numeric_type(&self) -> bool {
986 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
987 }
988}
989
990#[cfg(test)]
991mod tests {
992 use super::{IntervalUnit, SqlIdentifier};
993 use crate::program_schema::{ColumnType, Field, SqlType};
994
995 #[test]
996 fn serde_sql_type() {
997 for (sql_str_base, expected_value) in [
998 ("Boolean", SqlType::Boolean),
999 ("Uuid", SqlType::Uuid),
1000 ("TinyInt", SqlType::TinyInt),
1001 ("SmallInt", SqlType::SmallInt),
1002 ("Integer", SqlType::Int),
1003 ("BigInt", SqlType::BigInt),
1004 ("UTinyInt", SqlType::UTinyInt),
1005 ("USmallInt", SqlType::USmallInt),
1006 ("UInteger", SqlType::UInt),
1007 ("UBigInt", SqlType::UBigInt),
1008 ("Real", SqlType::Real),
1009 ("Double", SqlType::Double),
1010 ("Decimal", SqlType::Decimal),
1011 ("Char", SqlType::Char),
1012 ("Varchar", SqlType::Varchar),
1013 ("Binary", SqlType::Binary),
1014 ("Varbinary", SqlType::Varbinary),
1015 ("Time", SqlType::Time),
1016 ("Date", SqlType::Date),
1017 ("Timestamp", SqlType::Timestamp),
1018 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
1019 (
1020 "Interval_Day_Hour",
1021 SqlType::Interval(IntervalUnit::DayToHour),
1022 ),
1023 (
1024 "Interval_Day_Minute",
1025 SqlType::Interval(IntervalUnit::DayToMinute),
1026 ),
1027 (
1028 "Interval_Day_Second",
1029 SqlType::Interval(IntervalUnit::DayToSecond),
1030 ),
1031 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
1032 (
1033 "Interval_Hour_Minute",
1034 SqlType::Interval(IntervalUnit::HourToMinute),
1035 ),
1036 (
1037 "Interval_Hour_Second",
1038 SqlType::Interval(IntervalUnit::HourToSecond),
1039 ),
1040 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
1041 (
1042 "Interval_Minute_Second",
1043 SqlType::Interval(IntervalUnit::MinuteToSecond),
1044 ),
1045 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
1046 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
1047 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
1048 (
1049 "Interval_Year_Month",
1050 SqlType::Interval(IntervalUnit::YearToMonth),
1051 ),
1052 ("Array", SqlType::Array),
1053 ("Struct", SqlType::Struct),
1054 ("Map", SqlType::Map),
1055 ("Null", SqlType::Null),
1056 ("Variant", SqlType::Variant),
1057 ] {
1058 for sql_str in [
1059 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
1063 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
1064 .unwrap_or_else(|_| {
1065 panic!("\"{sql_str}\" should deserialize into its SQL type")
1066 });
1067 assert_eq!(value1, expected_value);
1068 let serialized_str =
1069 serde_json::to_string(&value1).expect("Value should serialize into JSON");
1070 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
1071 panic!(
1072 "{} should deserialize back into its SQL type",
1073 serialized_str
1074 )
1075 });
1076 assert_eq!(value1, value2);
1077 }
1078 }
1079 }
1080
1081 #[test]
1082 fn deserialize_interval_types() {
1083 use super::IntervalUnit::*;
1084 use super::SqlType::*;
1085
1086 let schema = r#"
1087{
1088 "inputs" : [ {
1089 "name" : "sales",
1090 "case_sensitive" : false,
1091 "fields" : [ {
1092 "name" : "sales_id",
1093 "case_sensitive" : false,
1094 "columntype" : {
1095 "type" : "INTEGER",
1096 "nullable" : true
1097 }
1098 }, {
1099 "name" : "customer_id",
1100 "case_sensitive" : false,
1101 "columntype" : {
1102 "type" : "INTEGER",
1103 "nullable" : true
1104 }
1105 }, {
1106 "name" : "age",
1107 "case_sensitive" : false,
1108 "columntype" : {
1109 "type" : "UINTEGER",
1110 "nullable" : true
1111 }
1112 }, {
1113 "name" : "amount",
1114 "case_sensitive" : false,
1115 "columntype" : {
1116 "type" : "DECIMAL",
1117 "nullable" : true,
1118 "precision" : 10,
1119 "scale" : 2
1120 }
1121 }, {
1122 "name" : "sale_date",
1123 "case_sensitive" : false,
1124 "columntype" : {
1125 "type" : "DATE",
1126 "nullable" : true
1127 }
1128 } ],
1129 "primary_key" : [ "sales_id" ]
1130 } ],
1131 "outputs" : [ {
1132 "name" : "salessummary",
1133 "case_sensitive" : false,
1134 "fields" : [ {
1135 "name" : "customer_id",
1136 "case_sensitive" : false,
1137 "columntype" : {
1138 "type" : "INTEGER",
1139 "nullable" : true
1140 }
1141 }, {
1142 "name" : "total_sales",
1143 "case_sensitive" : false,
1144 "columntype" : {
1145 "type" : "DECIMAL",
1146 "nullable" : true,
1147 "precision" : 38,
1148 "scale" : 2
1149 }
1150 }, {
1151 "name" : "interval_day",
1152 "case_sensitive" : false,
1153 "columntype" : {
1154 "type" : "INTERVAL_DAY",
1155 "nullable" : false,
1156 "precision" : 2,
1157 "scale" : 6
1158 }
1159 }, {
1160 "name" : "interval_day_to_hour",
1161 "case_sensitive" : false,
1162 "columntype" : {
1163 "type" : "INTERVAL_DAY_HOUR",
1164 "nullable" : false,
1165 "precision" : 2,
1166 "scale" : 6
1167 }
1168 }, {
1169 "name" : "interval_day_to_minute",
1170 "case_sensitive" : false,
1171 "columntype" : {
1172 "type" : "INTERVAL_DAY_MINUTE",
1173 "nullable" : false,
1174 "precision" : 2,
1175 "scale" : 6
1176 }
1177 }, {
1178 "name" : "interval_day_to_second",
1179 "case_sensitive" : false,
1180 "columntype" : {
1181 "type" : "INTERVAL_DAY_SECOND",
1182 "nullable" : false,
1183 "precision" : 2,
1184 "scale" : 6
1185 }
1186 }, {
1187 "name" : "interval_hour",
1188 "case_sensitive" : false,
1189 "columntype" : {
1190 "type" : "INTERVAL_HOUR",
1191 "nullable" : false,
1192 "precision" : 2,
1193 "scale" : 6
1194 }
1195 }, {
1196 "name" : "interval_hour_to_minute",
1197 "case_sensitive" : false,
1198 "columntype" : {
1199 "type" : "INTERVAL_HOUR_MINUTE",
1200 "nullable" : false,
1201 "precision" : 2,
1202 "scale" : 6
1203 }
1204 }, {
1205 "name" : "interval_hour_to_second",
1206 "case_sensitive" : false,
1207 "columntype" : {
1208 "type" : "INTERVAL_HOUR_SECOND",
1209 "nullable" : false,
1210 "precision" : 2,
1211 "scale" : 6
1212 }
1213 }, {
1214 "name" : "interval_minute",
1215 "case_sensitive" : false,
1216 "columntype" : {
1217 "type" : "INTERVAL_MINUTE",
1218 "nullable" : false,
1219 "precision" : 2,
1220 "scale" : 6
1221 }
1222 }, {
1223 "name" : "interval_minute_to_second",
1224 "case_sensitive" : false,
1225 "columntype" : {
1226 "type" : "INTERVAL_MINUTE_SECOND",
1227 "nullable" : false,
1228 "precision" : 2,
1229 "scale" : 6
1230 }
1231 }, {
1232 "name" : "interval_month",
1233 "case_sensitive" : false,
1234 "columntype" : {
1235 "type" : "INTERVAL_MONTH",
1236 "nullable" : false
1237 }
1238 }, {
1239 "name" : "interval_second",
1240 "case_sensitive" : false,
1241 "columntype" : {
1242 "type" : "INTERVAL_SECOND",
1243 "nullable" : false,
1244 "precision" : 2,
1245 "scale" : 6
1246 }
1247 }, {
1248 "name" : "interval_year",
1249 "case_sensitive" : false,
1250 "columntype" : {
1251 "type" : "INTERVAL_YEAR",
1252 "nullable" : false
1253 }
1254 }, {
1255 "name" : "interval_year_to_month",
1256 "case_sensitive" : false,
1257 "columntype" : {
1258 "type" : "INTERVAL_YEAR_MONTH",
1259 "nullable" : false
1260 }
1261 } ]
1262 } ]
1263}
1264"#;
1265
1266 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1267 let types = schema
1268 .outputs
1269 .iter()
1270 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1271 let expected_types = [
1272 Int,
1273 Decimal,
1274 Interval(Day),
1275 Interval(DayToHour),
1276 Interval(DayToMinute),
1277 Interval(DayToSecond),
1278 Interval(Hour),
1279 Interval(HourToMinute),
1280 Interval(HourToSecond),
1281 Interval(Minute),
1282 Interval(MinuteToSecond),
1283 Interval(Month),
1284 Interval(Second),
1285 Interval(Year),
1286 Interval(YearToMonth),
1287 ];
1288
1289 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1290 }
1291
1292 #[test]
1293 fn serialize_struct_schemas() {
1294 let schema = r#"{
1295 "inputs" : [ {
1296 "name" : "PERS",
1297 "case_sensitive" : false,
1298 "fields" : [ {
1299 "name" : "P0",
1300 "case_sensitive" : false,
1301 "columntype" : {
1302 "fields" : [ {
1303 "type" : "VARCHAR",
1304 "nullable" : true,
1305 "precision" : 30,
1306 "name" : "FIRSTNAME"
1307 }, {
1308 "type" : "VARCHAR",
1309 "nullable" : true,
1310 "precision" : 30,
1311 "name" : "LASTNAME"
1312 }, {
1313 "type" : "UINTEGER",
1314 "nullable" : true,
1315 "name" : "AGE"
1316 }, {
1317 "fields" : {
1318 "fields" : [ {
1319 "type" : "VARCHAR",
1320 "nullable" : true,
1321 "precision" : 30,
1322 "name" : "STREET"
1323 }, {
1324 "type" : "VARCHAR",
1325 "nullable" : true,
1326 "precision" : 30,
1327 "name" : "CITY"
1328 }, {
1329 "type" : "CHAR",
1330 "nullable" : true,
1331 "precision" : 2,
1332 "name" : "STATE"
1333 }, {
1334 "type" : "VARCHAR",
1335 "nullable" : true,
1336 "precision" : 6,
1337 "name" : "POSTAL_CODE"
1338 } ],
1339 "nullable" : false
1340 },
1341 "nullable" : false,
1342 "name" : "ADDRESS"
1343 } ],
1344 "nullable" : false
1345 }
1346 }]
1347 } ],
1348 "outputs" : [ ]
1349}
1350"#;
1351 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1352 eprintln!("{:#?}", schema);
1353 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1354 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1355 assert_eq!(p0.columntype.typ, SqlType::Struct);
1356 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1357 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1358 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1359 assert_eq!(p0_fields[2].columntype.typ, SqlType::UInt);
1360 assert_eq!(p0_fields[3].columntype.typ, SqlType::Struct);
1361 assert_eq!(p0_fields[3].name, "ADDRESS");
1362 let address = &p0_fields[3].columntype.fields.as_ref().unwrap();
1363 assert_eq!(address.len(), 4);
1364 assert_eq!(address[0].name, "STREET");
1365 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1366 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1367 assert_eq!(address[2].columntype.typ, SqlType::Char);
1368 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1369 }
1370
1371 #[test]
1372 fn sql_identifier_cmp() {
1373 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1374 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1375 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1376 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1377 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1378 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1379 assert_eq!(
1380 SqlIdentifier::new("bAr", true),
1381 SqlIdentifier::from("\"bAr\"")
1382 );
1383
1384 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1385 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1386 }
1387
1388 #[test]
1389 fn sql_identifier_ord() {
1390 let mut btree = std::collections::BTreeSet::new();
1391 assert!(btree.insert(SqlIdentifier::from("foo")));
1392 assert!(btree.insert(SqlIdentifier::from("bar")));
1393 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1394 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1395 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1396 }
1397
1398 #[test]
1399 fn sql_identifier_hash() {
1400 let mut hs = std::collections::HashSet::new();
1401 assert!(hs.insert(SqlIdentifier::from("foo")));
1402 assert!(hs.insert(SqlIdentifier::from("bar")));
1403 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1404 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1405 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1406 }
1407
1408 #[test]
1409 fn sql_identifier_name() {
1410 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1411 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1412 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1413 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1414 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1415 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1416 }
1417
1418 #[test]
1419 fn issue3277() {
1420 let schema = r#"{
1421 "name" : "j",
1422 "case_sensitive" : false,
1423 "columntype" : {
1424 "fields" : [ {
1425 "key" : {
1426 "nullable" : false,
1427 "precision" : -1,
1428 "type" : "VARCHAR"
1429 },
1430 "name" : "s",
1431 "nullable" : true,
1432 "type" : "MAP",
1433 "value" : {
1434 "nullable" : true,
1435 "precision" : -1,
1436 "type" : "VARCHAR"
1437 }
1438 } ],
1439 "nullable" : true
1440 }
1441 }"#;
1442 let field: Field = serde_json::from_str(schema).unwrap();
1443 println!("field: {:#?}", field);
1444 assert_eq!(
1445 field,
1446 Field {
1447 name: SqlIdentifier {
1448 name: "j".to_string(),
1449 case_sensitive: false,
1450 },
1451 columntype: ColumnType {
1452 typ: SqlType::Struct,
1453 nullable: true,
1454 precision: None,
1455 scale: None,
1456 component: None,
1457 fields: Some(vec![Field {
1458 name: SqlIdentifier {
1459 name: "s".to_string(),
1460 case_sensitive: false,
1461 },
1462 columntype: ColumnType {
1463 typ: SqlType::Map,
1464 nullable: true,
1465 precision: None,
1466 scale: None,
1467 component: None,
1468 fields: None,
1469 key: Some(Box::new(ColumnType {
1470 typ: SqlType::Varchar,
1471 nullable: false,
1472 precision: Some(-1),
1473 scale: None,
1474 component: None,
1475 fields: None,
1476 key: None,
1477 value: None,
1478 })),
1479 value: Some(Box::new(ColumnType {
1480 typ: SqlType::Varchar,
1481 nullable: true,
1482 precision: Some(-1),
1483 scale: None,
1484 component: None,
1485 fields: None,
1486 key: None,
1487 value: None,
1488 })),
1489 },
1490 lateness: None,
1491 default: None,
1492 unused: false,
1493 watermark: None,
1494 }]),
1495 key: None,
1496 value: None,
1497 },
1498 lateness: None,
1499 default: None,
1500 unused: false,
1501 watermark: None,
1502 }
1503 );
1504 }
1505}