1pub use feldera_ir::SourcePosition;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::cmp::Ordering;
4use std::collections::BTreeMap;
5use std::fmt::Display;
6use std::hash::{Hash, Hasher};
7use utoipa::ToSchema;
8
9#[cfg(feature = "testing")]
10use proptest::{collection::vec, prelude::any};
11
12pub fn canonical_identifier(id: &str) -> String {
20 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
21 id[1..id.len() - 1].to_string()
22 } else {
23 id.to_lowercase()
24 }
25}
26
27#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
32#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
33pub struct SqlIdentifier {
34 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
35 name: String,
36 pub case_sensitive: bool,
37}
38
39impl SqlIdentifier {
40 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
41 Self {
42 name: name.as_ref().to_string(),
43 case_sensitive,
44 }
45 }
46
47 pub fn name(&self) -> String {
57 if self.case_sensitive {
58 self.name.clone()
59 } else {
60 self.name.to_lowercase()
61 }
62 }
63
64 pub fn sql_name(&self) -> String {
75 if self.case_sensitive {
76 format!("\"{}\"", self.name)
77 } else {
78 self.name.clone()
79 }
80 }
81}
82
83impl Hash for SqlIdentifier {
84 fn hash<H: Hasher>(&self, state: &mut H) {
85 self.name().hash(state);
86 }
87}
88
89impl PartialEq for SqlIdentifier {
90 fn eq(&self, other: &Self) -> bool {
91 match (self.case_sensitive, other.case_sensitive) {
92 (true, true) => self.name == other.name,
93 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
94 (true, false) => self.name == other.name,
95 (false, true) => self.name == other.name,
96 }
97 }
98}
99
100impl Ord for SqlIdentifier {
101 fn cmp(&self, other: &Self) -> Ordering {
102 self.name().cmp(&other.name())
103 }
104}
105
106impl PartialOrd for SqlIdentifier {
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
113 fn eq(&self, other: &S) -> bool {
114 self == &SqlIdentifier::from(other.as_ref())
115 }
116}
117
118impl Eq for SqlIdentifier {}
119
120impl<S: AsRef<str>> From<S> for SqlIdentifier {
121 fn from(name: S) -> Self {
122 if name.as_ref().starts_with('"')
123 && name.as_ref().ends_with('"')
124 && name.as_ref().len() >= 2
125 {
126 Self {
127 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
128 case_sensitive: true,
129 }
130 } else {
131 Self::new(name, false)
132 }
133 }
134}
135
136impl From<SqlIdentifier> for String {
137 fn from(id: SqlIdentifier) -> String {
138 id.name()
139 }
140}
141
142impl From<&SqlIdentifier> for String {
143 fn from(id: &SqlIdentifier) -> String {
144 id.name()
145 }
146}
147
148impl Display for SqlIdentifier {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(f, "{}", self.name())
151 }
152}
153
154#[derive(Default, Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
158#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
159pub struct ProgramSchema {
160 #[cfg_attr(
161 feature = "testing",
162 proptest(strategy = "vec(any::<Relation>(), 0..2)")
163 )]
164 pub inputs: Vec<Relation>,
165 #[cfg_attr(
166 feature = "testing",
167 proptest(strategy = "vec(any::<Relation>(), 0..2)")
168 )]
169 pub outputs: Vec<Relation>,
170}
171
172impl ProgramSchema {
173 pub fn relations_with_lateness(&self) -> Vec<SqlIdentifier> {
174 self.inputs
175 .iter()
176 .chain(self.outputs.iter())
177 .filter(|rel| rel.has_lateness())
178 .map(|rel| rel.name.clone())
179 .collect()
180 }
181}
182
183#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
184#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
185pub struct PropertyValue {
186 pub value: String,
187 pub key_position: SourcePosition,
188 pub value_position: SourcePosition,
189}
190
191#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
195#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
196pub struct Relation {
197 #[serde(flatten)]
198 pub name: SqlIdentifier,
199 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
200 pub fields: Vec<Field>,
201 #[serde(default)]
202 pub materialized: bool,
203 #[serde(default)]
204 pub properties: BTreeMap<String, PropertyValue>,
205}
206
207impl Relation {
208 pub fn empty() -> Self {
209 Self {
210 name: SqlIdentifier::from("".to_string()),
211 fields: Vec::new(),
212 materialized: false,
213 properties: BTreeMap::new(),
214 }
215 }
216
217 pub fn new(
218 name: SqlIdentifier,
219 fields: Vec<Field>,
220 materialized: bool,
221 properties: BTreeMap<String, PropertyValue>,
222 ) -> Self {
223 Self {
224 name,
225 fields,
226 materialized,
227 properties,
228 }
229 }
230
231 pub fn field(&self, name: &str) -> Option<&Field> {
233 let name = canonical_identifier(name);
234 self.fields.iter().find(|f| f.name == name)
235 }
236
237 pub fn has_lateness(&self) -> bool {
238 self.fields.iter().any(|f| f.lateness.is_some())
239 }
240
241 pub fn get_property(&self, name: &str) -> Option<&str> {
242 self.properties.get(name).map(|p| p.value.as_str())
243 }
244}
245
246#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
250#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
251pub struct Field {
252 #[serde(flatten)]
253 pub name: SqlIdentifier,
254 pub columntype: ColumnType,
255 pub lateness: Option<String>,
256 pub default: Option<String>,
257 pub unused: bool,
258 pub watermark: Option<String>,
259}
260
261impl Field {
262 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
263 Self {
264 name,
265 columntype,
266 lateness: None,
267 default: None,
268 unused: false,
269 watermark: None,
270 }
271 }
272
273 pub fn with_lateness(mut self, lateness: &str) -> Self {
274 self.lateness = Some(lateness.to_string());
275 self
276 }
277
278 pub fn with_unused(mut self, unused: bool) -> Self {
279 self.unused = unused;
280 self
281 }
282}
283
284impl<'de> Deserialize<'de> for Field {
289 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
290 where
291 D: Deserializer<'de>,
292 {
293 const fn default_is_struct() -> Option<SqlType> {
294 Some(SqlType::Struct)
295 }
296
297 #[derive(Debug, Clone, Deserialize)]
298 struct FieldHelper {
299 name: Option<String>,
300 #[serde(default)]
301 case_sensitive: bool,
302 columntype: Option<ColumnType>,
303 #[serde(rename = "type")]
304 #[serde(default = "default_is_struct")]
305 typ: Option<SqlType>,
306 nullable: Option<bool>,
307 precision: Option<i64>,
308 scale: Option<i64>,
309 component: Option<Box<ColumnType>>,
310 fields: Option<serde_json::Value>,
311 key: Option<Box<ColumnType>>,
312 value: Option<Box<ColumnType>>,
313 default: Option<String>,
314 #[serde(default)]
315 unused: bool,
316 lateness: Option<String>,
317 watermark: Option<String>,
318 }
319
320 fn helper_to_field(helper: FieldHelper) -> Field {
321 let columntype = if let Some(ctype) = helper.columntype {
322 ctype
323 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
324 let fields = fields
325 .into_iter()
326 .map(|field| {
327 let field: FieldHelper = serde_json::from_value(field).unwrap();
328 helper_to_field(field)
329 })
330 .collect::<Vec<Field>>();
331
332 ColumnType {
333 typ: helper.typ.unwrap_or(SqlType::Null),
334 nullable: helper.nullable.unwrap_or(false),
335 precision: helper.precision,
336 scale: helper.scale,
337 component: helper.component,
338 fields: Some(fields),
339 key: None,
340 value: None,
341 }
342 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
343 serde_json::from_value(serde_json::Value::Object(obj))
344 .expect("Failed to deserialize object")
345 } else {
346 ColumnType {
347 typ: helper.typ.unwrap_or(SqlType::Null),
348 nullable: helper.nullable.unwrap_or(false),
349 precision: helper.precision,
350 scale: helper.scale,
351 component: helper.component,
352 fields: None,
353 key: helper.key,
354 value: helper.value,
355 }
356 };
357
358 Field {
359 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
360 columntype,
361 default: helper.default,
362 unused: helper.unused,
363 lateness: helper.lateness,
364 watermark: helper.watermark,
365 }
366 }
367
368 let helper = FieldHelper::deserialize(deserializer)?;
369 Ok(helper_to_field(helper))
370 }
371}
372
373#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
378#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
379pub enum IntervalUnit {
380 Day,
382 DayToHour,
384 DayToMinute,
386 DayToSecond,
388 Hour,
390 HourToMinute,
392 HourToSecond,
394 Minute,
396 MinuteToSecond,
398 Month,
400 Second,
402 Year,
404 YearToMonth,
406}
407
408#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
410#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
411pub enum SqlType {
412 Boolean,
414 TinyInt,
416 SmallInt,
418 Int,
420 BigInt,
422 UTinyInt,
424 USmallInt,
426 UInt,
428 UBigInt,
430 Real,
432 Double,
434 Decimal,
436 Char,
438 Varchar,
440 Binary,
442 Varbinary,
444 Time,
446 Date,
448 Timestamp,
450 Interval(IntervalUnit),
452 Array,
454 Struct,
456 Map,
458 Null,
460 Uuid,
462 Variant,
464}
465
466impl Display for SqlType {
467 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468 f.write_str(&serde_json::to_string(self).unwrap())
469 }
470}
471
472impl<'de> Deserialize<'de> for SqlType {
473 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
474 where
475 D: Deserializer<'de>,
476 {
477 let value: String = Deserialize::deserialize(deserializer)?;
478 match value.to_lowercase().as_str() {
479 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
480 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
481 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
482 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
483 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
484 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
485 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
486 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
487 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
488 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
489 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
490 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
491 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
492 "boolean" => Ok(SqlType::Boolean),
493 "tinyint" => Ok(SqlType::TinyInt),
494 "smallint" => Ok(SqlType::SmallInt),
495 "integer" => Ok(SqlType::Int),
496 "bigint" => Ok(SqlType::BigInt),
497 "utinyint" => Ok(SqlType::UTinyInt),
498 "usmallint" => Ok(SqlType::USmallInt),
499 "uinteger" => Ok(SqlType::UInt),
500 "ubigint" => Ok(SqlType::UBigInt),
501 "real" => Ok(SqlType::Real),
502 "double" => Ok(SqlType::Double),
503 "decimal" => Ok(SqlType::Decimal),
504 "char" => Ok(SqlType::Char),
505 "varchar" => Ok(SqlType::Varchar),
506 "binary" => Ok(SqlType::Binary),
507 "varbinary" => Ok(SqlType::Varbinary),
508 "variant" => Ok(SqlType::Variant),
509 "time" => Ok(SqlType::Time),
510 "date" => Ok(SqlType::Date),
511 "timestamp" => Ok(SqlType::Timestamp),
512 "array" => Ok(SqlType::Array),
513 "struct" => Ok(SqlType::Struct),
514 "map" => Ok(SqlType::Map),
515 "null" => Ok(SqlType::Null),
516 "uuid" => Ok(SqlType::Uuid),
517 _ => Err(serde::de::Error::custom(format!(
518 "Unknown SQL type: {}",
519 value
520 ))),
521 }
522 }
523}
524
525impl Serialize for SqlType {
526 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
527 where
528 S: Serializer,
529 {
530 let type_str = match self {
531 SqlType::Boolean => "BOOLEAN",
532 SqlType::TinyInt => "TINYINT",
533 SqlType::SmallInt => "SMALLINT",
534 SqlType::Int => "INTEGER",
535 SqlType::BigInt => "BIGINT",
536 SqlType::UTinyInt => "UTINYINT",
537 SqlType::USmallInt => "USMALLINT",
538 SqlType::UInt => "UINTEGER",
539 SqlType::UBigInt => "UBIGINT",
540 SqlType::Real => "REAL",
541 SqlType::Double => "DOUBLE",
542 SqlType::Decimal => "DECIMAL",
543 SqlType::Char => "CHAR",
544 SqlType::Varchar => "VARCHAR",
545 SqlType::Binary => "BINARY",
546 SqlType::Varbinary => "VARBINARY",
547 SqlType::Time => "TIME",
548 SqlType::Date => "DATE",
549 SqlType::Timestamp => "TIMESTAMP",
550 SqlType::Interval(interval_unit) => match interval_unit {
551 IntervalUnit::Day => "INTERVAL_DAY",
552 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
553 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
554 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
555 IntervalUnit::Hour => "INTERVAL_HOUR",
556 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
557 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
558 IntervalUnit::Minute => "INTERVAL_MINUTE",
559 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
560 IntervalUnit::Month => "INTERVAL_MONTH",
561 IntervalUnit::Second => "INTERVAL_SECOND",
562 IntervalUnit::Year => "INTERVAL_YEAR",
563 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
564 },
565 SqlType::Array => "ARRAY",
566 SqlType::Struct => "STRUCT",
567 SqlType::Uuid => "UUID",
568 SqlType::Map => "MAP",
569 SqlType::Null => "NULL",
570 SqlType::Variant => "VARIANT",
571 };
572 serializer.serialize_str(type_str)
573 }
574}
575
576impl SqlType {
577 pub fn is_string(&self) -> bool {
579 matches!(self, Self::Char | Self::Varchar)
580 }
581
582 pub fn is_varchar(&self) -> bool {
583 matches!(self, Self::Varchar)
584 }
585
586 pub fn is_varbinary(&self) -> bool {
587 matches!(self, Self::Varbinary)
588 }
589}
590
591const fn default_is_struct() -> SqlType {
594 SqlType::Struct
595}
596
597#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
601#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
602pub struct ColumnType {
603 #[serde(rename = "type")]
605 #[serde(default = "default_is_struct")]
606 pub typ: SqlType,
607 pub nullable: bool,
609 pub precision: Option<i64>,
618 pub scale: Option<i64>,
623 #[cfg_attr(feature = "testing", proptest(value = "None"))]
630 pub component: Option<Box<ColumnType>>,
631 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
653 pub fields: Option<Vec<Field>>,
654 #[cfg_attr(feature = "testing", proptest(value = "None"))]
656 pub key: Option<Box<ColumnType>>,
657 #[cfg_attr(feature = "testing", proptest(value = "None"))]
659 pub value: Option<Box<ColumnType>>,
660}
661
662impl ColumnType {
663 pub fn boolean(nullable: bool) -> Self {
664 ColumnType {
665 typ: SqlType::Boolean,
666 nullable,
667 precision: None,
668 scale: None,
669 component: None,
670 fields: None,
671 key: None,
672 value: None,
673 }
674 }
675
676 pub fn uuid(nullable: bool) -> Self {
677 ColumnType {
678 typ: SqlType::Uuid,
679 nullable,
680 precision: None,
681 scale: None,
682 component: None,
683 fields: None,
684 key: None,
685 value: None,
686 }
687 }
688
689 pub fn tinyint(nullable: bool) -> Self {
690 ColumnType {
691 typ: SqlType::TinyInt,
692 nullable,
693 precision: None,
694 scale: None,
695 component: None,
696 fields: None,
697 key: None,
698 value: None,
699 }
700 }
701
702 pub fn smallint(nullable: bool) -> Self {
703 ColumnType {
704 typ: SqlType::SmallInt,
705 nullable,
706 precision: None,
707 scale: None,
708 component: None,
709 fields: None,
710 key: None,
711 value: None,
712 }
713 }
714
715 pub fn int(nullable: bool) -> Self {
716 ColumnType {
717 typ: SqlType::Int,
718 nullable,
719 precision: None,
720 scale: None,
721 component: None,
722 fields: None,
723 key: None,
724 value: None,
725 }
726 }
727
728 pub fn bigint(nullable: bool) -> Self {
729 ColumnType {
730 typ: SqlType::BigInt,
731 nullable,
732 precision: None,
733 scale: None,
734 component: None,
735 fields: None,
736 key: None,
737 value: None,
738 }
739 }
740
741 pub fn utinyint(nullable: bool) -> Self {
742 ColumnType {
743 typ: SqlType::UTinyInt,
744 nullable,
745 precision: None,
746 scale: None,
747 component: None,
748 fields: None,
749 key: None,
750 value: None,
751 }
752 }
753
754 pub fn usmallint(nullable: bool) -> Self {
755 ColumnType {
756 typ: SqlType::USmallInt,
757 nullable,
758 precision: None,
759 scale: None,
760 component: None,
761 fields: None,
762 key: None,
763 value: None,
764 }
765 }
766
767 pub fn uint(nullable: bool) -> Self {
768 ColumnType {
769 typ: SqlType::UInt,
770 nullable,
771 precision: None,
772 scale: None,
773 component: None,
774 fields: None,
775 key: None,
776 value: None,
777 }
778 }
779
780 pub fn ubigint(nullable: bool) -> Self {
781 ColumnType {
782 typ: SqlType::UBigInt,
783 nullable,
784 precision: None,
785 scale: None,
786 component: None,
787 fields: None,
788 key: None,
789 value: None,
790 }
791 }
792
793 pub fn double(nullable: bool) -> Self {
794 ColumnType {
795 typ: SqlType::Double,
796 nullable,
797 precision: None,
798 scale: None,
799 component: None,
800 fields: None,
801 key: None,
802 value: None,
803 }
804 }
805
806 pub fn real(nullable: bool) -> Self {
807 ColumnType {
808 typ: SqlType::Real,
809 nullable,
810 precision: None,
811 scale: None,
812 component: None,
813 fields: None,
814 key: None,
815 value: None,
816 }
817 }
818
819 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
820 ColumnType {
821 typ: SqlType::Decimal,
822 nullable,
823 precision: Some(precision),
824 scale: Some(scale),
825 component: None,
826 fields: None,
827 key: None,
828 value: None,
829 }
830 }
831
832 pub fn varchar(nullable: bool) -> Self {
833 ColumnType {
834 typ: SqlType::Varchar,
835 nullable,
836 precision: None,
837 scale: None,
838 component: None,
839 fields: None,
840 key: None,
841 value: None,
842 }
843 }
844
845 pub fn varbinary(nullable: bool) -> Self {
846 ColumnType {
847 typ: SqlType::Varbinary,
848 nullable,
849 precision: None,
850 scale: None,
851 component: None,
852 fields: None,
853 key: None,
854 value: None,
855 }
856 }
857
858 pub fn fixed(width: i64, nullable: bool) -> Self {
859 ColumnType {
860 typ: SqlType::Binary,
861 nullable,
862 precision: Some(width),
863 scale: None,
864 component: None,
865 fields: None,
866 key: None,
867 value: None,
868 }
869 }
870
871 pub fn date(nullable: bool) -> Self {
872 ColumnType {
873 typ: SqlType::Date,
874 nullable,
875 precision: None,
876 scale: None,
877 component: None,
878 fields: None,
879 key: None,
880 value: None,
881 }
882 }
883
884 pub fn time(nullable: bool) -> Self {
885 ColumnType {
886 typ: SqlType::Time,
887 nullable,
888 precision: None,
889 scale: None,
890 component: None,
891 fields: None,
892 key: None,
893 value: None,
894 }
895 }
896
897 pub fn timestamp(nullable: bool) -> Self {
898 ColumnType {
899 typ: SqlType::Timestamp,
900 nullable,
901 precision: None,
902 scale: None,
903 component: None,
904 fields: None,
905 key: None,
906 value: None,
907 }
908 }
909
910 pub fn variant(nullable: bool) -> Self {
911 ColumnType {
912 typ: SqlType::Variant,
913 nullable,
914 precision: None,
915 scale: None,
916 component: None,
917 fields: None,
918 key: None,
919 value: None,
920 }
921 }
922
923 pub fn array(nullable: bool, element: ColumnType) -> Self {
924 ColumnType {
925 typ: SqlType::Array,
926 nullable,
927 precision: None,
928 scale: None,
929 component: Some(Box::new(element)),
930 fields: None,
931 key: None,
932 value: None,
933 }
934 }
935
936 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
937 ColumnType {
938 typ: SqlType::Struct,
939 nullable,
940 precision: None,
941 scale: None,
942 component: None,
943 fields: Some(fields.to_vec()),
944 key: None,
945 value: None,
946 }
947 }
948
949 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
950 ColumnType {
951 typ: SqlType::Map,
952 nullable,
953 precision: None,
954 scale: None,
955 component: None,
956 fields: None,
957 key: Some(Box::new(key)),
958 value: Some(Box::new(val)),
959 }
960 }
961
962 pub fn is_integral_type(&self) -> bool {
963 matches!(
964 &self.typ,
965 SqlType::TinyInt
966 | SqlType::SmallInt
967 | SqlType::Int
968 | SqlType::BigInt
969 | SqlType::UTinyInt
970 | SqlType::USmallInt
971 | SqlType::UInt
972 | SqlType::UBigInt
973 )
974 }
975
976 pub fn is_fp_type(&self) -> bool {
977 matches!(&self.typ, SqlType::Double | SqlType::Real)
978 }
979
980 pub fn is_decimal_type(&self) -> bool {
981 matches!(&self.typ, SqlType::Decimal)
982 }
983
984 pub fn is_numeric_type(&self) -> bool {
985 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
986 }
987}
988
989#[cfg(test)]
990mod tests {
991 use super::{IntervalUnit, SqlIdentifier};
992 use crate::program_schema::{ColumnType, Field, SqlType};
993
994 #[test]
995 fn serde_sql_type() {
996 for (sql_str_base, expected_value) in [
997 ("Boolean", SqlType::Boolean),
998 ("Uuid", SqlType::Uuid),
999 ("TinyInt", SqlType::TinyInt),
1000 ("SmallInt", SqlType::SmallInt),
1001 ("Integer", SqlType::Int),
1002 ("BigInt", SqlType::BigInt),
1003 ("UTinyInt", SqlType::UTinyInt),
1004 ("USmallInt", SqlType::USmallInt),
1005 ("UInteger", SqlType::UInt),
1006 ("UBigInt", SqlType::UBigInt),
1007 ("Real", SqlType::Real),
1008 ("Double", SqlType::Double),
1009 ("Decimal", SqlType::Decimal),
1010 ("Char", SqlType::Char),
1011 ("Varchar", SqlType::Varchar),
1012 ("Binary", SqlType::Binary),
1013 ("Varbinary", SqlType::Varbinary),
1014 ("Time", SqlType::Time),
1015 ("Date", SqlType::Date),
1016 ("Timestamp", SqlType::Timestamp),
1017 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
1018 (
1019 "Interval_Day_Hour",
1020 SqlType::Interval(IntervalUnit::DayToHour),
1021 ),
1022 (
1023 "Interval_Day_Minute",
1024 SqlType::Interval(IntervalUnit::DayToMinute),
1025 ),
1026 (
1027 "Interval_Day_Second",
1028 SqlType::Interval(IntervalUnit::DayToSecond),
1029 ),
1030 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
1031 (
1032 "Interval_Hour_Minute",
1033 SqlType::Interval(IntervalUnit::HourToMinute),
1034 ),
1035 (
1036 "Interval_Hour_Second",
1037 SqlType::Interval(IntervalUnit::HourToSecond),
1038 ),
1039 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
1040 (
1041 "Interval_Minute_Second",
1042 SqlType::Interval(IntervalUnit::MinuteToSecond),
1043 ),
1044 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
1045 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
1046 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
1047 (
1048 "Interval_Year_Month",
1049 SqlType::Interval(IntervalUnit::YearToMonth),
1050 ),
1051 ("Array", SqlType::Array),
1052 ("Struct", SqlType::Struct),
1053 ("Map", SqlType::Map),
1054 ("Null", SqlType::Null),
1055 ("Variant", SqlType::Variant),
1056 ] {
1057 for sql_str in [
1058 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
1062 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
1063 .unwrap_or_else(|_| {
1064 panic!("\"{sql_str}\" should deserialize into its SQL type")
1065 });
1066 assert_eq!(value1, expected_value);
1067 let serialized_str =
1068 serde_json::to_string(&value1).expect("Value should serialize into JSON");
1069 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
1070 panic!(
1071 "{} should deserialize back into its SQL type",
1072 serialized_str
1073 )
1074 });
1075 assert_eq!(value1, value2);
1076 }
1077 }
1078 }
1079
1080 #[test]
1081 fn deserialize_interval_types() {
1082 use super::IntervalUnit::*;
1083 use super::SqlType::*;
1084
1085 let schema = r#"
1086{
1087 "inputs" : [ {
1088 "name" : "sales",
1089 "case_sensitive" : false,
1090 "fields" : [ {
1091 "name" : "sales_id",
1092 "case_sensitive" : false,
1093 "columntype" : {
1094 "type" : "INTEGER",
1095 "nullable" : true
1096 }
1097 }, {
1098 "name" : "customer_id",
1099 "case_sensitive" : false,
1100 "columntype" : {
1101 "type" : "INTEGER",
1102 "nullable" : true
1103 }
1104 }, {
1105 "name" : "age",
1106 "case_sensitive" : false,
1107 "columntype" : {
1108 "type" : "UINTEGER",
1109 "nullable" : true
1110 }
1111 }, {
1112 "name" : "amount",
1113 "case_sensitive" : false,
1114 "columntype" : {
1115 "type" : "DECIMAL",
1116 "nullable" : true,
1117 "precision" : 10,
1118 "scale" : 2
1119 }
1120 }, {
1121 "name" : "sale_date",
1122 "case_sensitive" : false,
1123 "columntype" : {
1124 "type" : "DATE",
1125 "nullable" : true
1126 }
1127 } ],
1128 "primary_key" : [ "sales_id" ]
1129 } ],
1130 "outputs" : [ {
1131 "name" : "salessummary",
1132 "case_sensitive" : false,
1133 "fields" : [ {
1134 "name" : "customer_id",
1135 "case_sensitive" : false,
1136 "columntype" : {
1137 "type" : "INTEGER",
1138 "nullable" : true
1139 }
1140 }, {
1141 "name" : "total_sales",
1142 "case_sensitive" : false,
1143 "columntype" : {
1144 "type" : "DECIMAL",
1145 "nullable" : true,
1146 "precision" : 38,
1147 "scale" : 2
1148 }
1149 }, {
1150 "name" : "interval_day",
1151 "case_sensitive" : false,
1152 "columntype" : {
1153 "type" : "INTERVAL_DAY",
1154 "nullable" : false,
1155 "precision" : 2,
1156 "scale" : 6
1157 }
1158 }, {
1159 "name" : "interval_day_to_hour",
1160 "case_sensitive" : false,
1161 "columntype" : {
1162 "type" : "INTERVAL_DAY_HOUR",
1163 "nullable" : false,
1164 "precision" : 2,
1165 "scale" : 6
1166 }
1167 }, {
1168 "name" : "interval_day_to_minute",
1169 "case_sensitive" : false,
1170 "columntype" : {
1171 "type" : "INTERVAL_DAY_MINUTE",
1172 "nullable" : false,
1173 "precision" : 2,
1174 "scale" : 6
1175 }
1176 }, {
1177 "name" : "interval_day_to_second",
1178 "case_sensitive" : false,
1179 "columntype" : {
1180 "type" : "INTERVAL_DAY_SECOND",
1181 "nullable" : false,
1182 "precision" : 2,
1183 "scale" : 6
1184 }
1185 }, {
1186 "name" : "interval_hour",
1187 "case_sensitive" : false,
1188 "columntype" : {
1189 "type" : "INTERVAL_HOUR",
1190 "nullable" : false,
1191 "precision" : 2,
1192 "scale" : 6
1193 }
1194 }, {
1195 "name" : "interval_hour_to_minute",
1196 "case_sensitive" : false,
1197 "columntype" : {
1198 "type" : "INTERVAL_HOUR_MINUTE",
1199 "nullable" : false,
1200 "precision" : 2,
1201 "scale" : 6
1202 }
1203 }, {
1204 "name" : "interval_hour_to_second",
1205 "case_sensitive" : false,
1206 "columntype" : {
1207 "type" : "INTERVAL_HOUR_SECOND",
1208 "nullable" : false,
1209 "precision" : 2,
1210 "scale" : 6
1211 }
1212 }, {
1213 "name" : "interval_minute",
1214 "case_sensitive" : false,
1215 "columntype" : {
1216 "type" : "INTERVAL_MINUTE",
1217 "nullable" : false,
1218 "precision" : 2,
1219 "scale" : 6
1220 }
1221 }, {
1222 "name" : "interval_minute_to_second",
1223 "case_sensitive" : false,
1224 "columntype" : {
1225 "type" : "INTERVAL_MINUTE_SECOND",
1226 "nullable" : false,
1227 "precision" : 2,
1228 "scale" : 6
1229 }
1230 }, {
1231 "name" : "interval_month",
1232 "case_sensitive" : false,
1233 "columntype" : {
1234 "type" : "INTERVAL_MONTH",
1235 "nullable" : false
1236 }
1237 }, {
1238 "name" : "interval_second",
1239 "case_sensitive" : false,
1240 "columntype" : {
1241 "type" : "INTERVAL_SECOND",
1242 "nullable" : false,
1243 "precision" : 2,
1244 "scale" : 6
1245 }
1246 }, {
1247 "name" : "interval_year",
1248 "case_sensitive" : false,
1249 "columntype" : {
1250 "type" : "INTERVAL_YEAR",
1251 "nullable" : false
1252 }
1253 }, {
1254 "name" : "interval_year_to_month",
1255 "case_sensitive" : false,
1256 "columntype" : {
1257 "type" : "INTERVAL_YEAR_MONTH",
1258 "nullable" : false
1259 }
1260 } ]
1261 } ]
1262}
1263"#;
1264
1265 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1266 let types = schema
1267 .outputs
1268 .iter()
1269 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1270 let expected_types = [
1271 Int,
1272 Decimal,
1273 Interval(Day),
1274 Interval(DayToHour),
1275 Interval(DayToMinute),
1276 Interval(DayToSecond),
1277 Interval(Hour),
1278 Interval(HourToMinute),
1279 Interval(HourToSecond),
1280 Interval(Minute),
1281 Interval(MinuteToSecond),
1282 Interval(Month),
1283 Interval(Second),
1284 Interval(Year),
1285 Interval(YearToMonth),
1286 ];
1287
1288 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1289 }
1290
1291 #[test]
1292 fn serialize_struct_schemas() {
1293 let schema = r#"{
1294 "inputs" : [ {
1295 "name" : "PERS",
1296 "case_sensitive" : false,
1297 "fields" : [ {
1298 "name" : "P0",
1299 "case_sensitive" : false,
1300 "columntype" : {
1301 "fields" : [ {
1302 "type" : "VARCHAR",
1303 "nullable" : true,
1304 "precision" : 30,
1305 "name" : "FIRSTNAME"
1306 }, {
1307 "type" : "VARCHAR",
1308 "nullable" : true,
1309 "precision" : 30,
1310 "name" : "LASTNAME"
1311 }, {
1312 "type" : "UINTEGER",
1313 "nullable" : true,
1314 "name" : "AGE"
1315 }, {
1316 "fields" : {
1317 "fields" : [ {
1318 "type" : "VARCHAR",
1319 "nullable" : true,
1320 "precision" : 30,
1321 "name" : "STREET"
1322 }, {
1323 "type" : "VARCHAR",
1324 "nullable" : true,
1325 "precision" : 30,
1326 "name" : "CITY"
1327 }, {
1328 "type" : "CHAR",
1329 "nullable" : true,
1330 "precision" : 2,
1331 "name" : "STATE"
1332 }, {
1333 "type" : "VARCHAR",
1334 "nullable" : true,
1335 "precision" : 6,
1336 "name" : "POSTAL_CODE"
1337 } ],
1338 "nullable" : false
1339 },
1340 "nullable" : false,
1341 "name" : "ADDRESS"
1342 } ],
1343 "nullable" : false
1344 }
1345 }]
1346 } ],
1347 "outputs" : [ ]
1348}
1349"#;
1350 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1351 eprintln!("{:#?}", schema);
1352 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1353 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1354 assert_eq!(p0.columntype.typ, SqlType::Struct);
1355 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1356 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1357 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1358 assert_eq!(p0_fields[2].columntype.typ, SqlType::UInt);
1359 assert_eq!(p0_fields[3].columntype.typ, SqlType::Struct);
1360 assert_eq!(p0_fields[3].name, "ADDRESS");
1361 let address = &p0_fields[3].columntype.fields.as_ref().unwrap();
1362 assert_eq!(address.len(), 4);
1363 assert_eq!(address[0].name, "STREET");
1364 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1365 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1366 assert_eq!(address[2].columntype.typ, SqlType::Char);
1367 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1368 }
1369
1370 #[test]
1371 fn sql_identifier_cmp() {
1372 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1373 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1374 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1375 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1376 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1377 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1378 assert_eq!(
1379 SqlIdentifier::new("bAr", true),
1380 SqlIdentifier::from("\"bAr\"")
1381 );
1382
1383 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1384 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1385 }
1386
1387 #[test]
1388 fn sql_identifier_ord() {
1389 let mut btree = std::collections::BTreeSet::new();
1390 assert!(btree.insert(SqlIdentifier::from("foo")));
1391 assert!(btree.insert(SqlIdentifier::from("bar")));
1392 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1393 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1394 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1395 }
1396
1397 #[test]
1398 fn sql_identifier_hash() {
1399 let mut hs = std::collections::HashSet::new();
1400 assert!(hs.insert(SqlIdentifier::from("foo")));
1401 assert!(hs.insert(SqlIdentifier::from("bar")));
1402 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1403 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1404 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1405 }
1406
1407 #[test]
1408 fn sql_identifier_name() {
1409 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1410 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1411 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1412 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1413 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1414 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1415 }
1416
1417 #[test]
1418 fn issue3277() {
1419 let schema = r#"{
1420 "name" : "j",
1421 "case_sensitive" : false,
1422 "columntype" : {
1423 "fields" : [ {
1424 "key" : {
1425 "nullable" : false,
1426 "precision" : -1,
1427 "type" : "VARCHAR"
1428 },
1429 "name" : "s",
1430 "nullable" : true,
1431 "type" : "MAP",
1432 "value" : {
1433 "nullable" : true,
1434 "precision" : -1,
1435 "type" : "VARCHAR"
1436 }
1437 } ],
1438 "nullable" : true
1439 }
1440 }"#;
1441 let field: Field = serde_json::from_str(schema).unwrap();
1442 println!("field: {:#?}", field);
1443 assert_eq!(
1444 field,
1445 Field {
1446 name: SqlIdentifier {
1447 name: "j".to_string(),
1448 case_sensitive: false,
1449 },
1450 columntype: ColumnType {
1451 typ: SqlType::Struct,
1452 nullable: true,
1453 precision: None,
1454 scale: None,
1455 component: None,
1456 fields: Some(vec![Field {
1457 name: SqlIdentifier {
1458 name: "s".to_string(),
1459 case_sensitive: false,
1460 },
1461 columntype: ColumnType {
1462 typ: SqlType::Map,
1463 nullable: true,
1464 precision: None,
1465 scale: None,
1466 component: None,
1467 fields: None,
1468 key: Some(Box::new(ColumnType {
1469 typ: SqlType::Varchar,
1470 nullable: false,
1471 precision: Some(-1),
1472 scale: None,
1473 component: None,
1474 fields: None,
1475 key: None,
1476 value: None,
1477 })),
1478 value: Some(Box::new(ColumnType {
1479 typ: SqlType::Varchar,
1480 nullable: true,
1481 precision: Some(-1),
1482 scale: None,
1483 component: None,
1484 fields: None,
1485 key: None,
1486 value: None,
1487 })),
1488 },
1489 lateness: None,
1490 default: None,
1491 unused: false,
1492 watermark: None,
1493 }]),
1494 key: None,
1495 value: None,
1496 },
1497 lateness: None,
1498 default: None,
1499 unused: false,
1500 watermark: None,
1501 }
1502 );
1503 }
1504}