1pub use feldera_ir::SourcePosition;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::cmp::Ordering;
4use std::collections::BTreeMap;
5use std::fmt::Display;
6use std::hash::{Hash, Hasher};
7use utoipa::ToSchema;
8
9#[cfg(feature = "testing")]
10use proptest::{collection::vec, prelude::any};
11
12pub fn canonical_identifier(id: &str) -> String {
20 if id.starts_with('"') && id.ends_with('"') && id.len() >= 2 {
21 id[1..id.len() - 1].to_string()
22 } else {
23 id.to_lowercase()
24 }
25}
26
27#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
32#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
33pub struct SqlIdentifier {
34 #[cfg_attr(feature = "testing", proptest(regex = "relation1|relation2|relation3"))]
35 name: String,
36 pub case_sensitive: bool,
37}
38
39impl SqlIdentifier {
40 pub fn new<S: AsRef<str>>(name: S, case_sensitive: bool) -> Self {
41 Self {
42 name: name.as_ref().to_string(),
43 case_sensitive,
44 }
45 }
46
47 pub fn name(&self) -> String {
57 if self.case_sensitive {
58 self.name.clone()
59 } else {
60 self.name.to_lowercase()
61 }
62 }
63
64 pub fn sql_name(&self) -> String {
75 if self.case_sensitive {
76 format!("\"{}\"", self.name)
77 } else {
78 self.name.clone()
79 }
80 }
81}
82
83impl Hash for SqlIdentifier {
84 fn hash<H: Hasher>(&self, state: &mut H) {
85 self.name().hash(state);
86 }
87}
88
89impl PartialEq for SqlIdentifier {
90 fn eq(&self, other: &Self) -> bool {
91 match (self.case_sensitive, other.case_sensitive) {
92 (true, true) => self.name == other.name,
93 (false, false) => self.name.to_lowercase() == other.name.to_lowercase(),
94 (true, false) => self.name == other.name,
95 (false, true) => self.name == other.name,
96 }
97 }
98}
99
100impl Ord for SqlIdentifier {
101 fn cmp(&self, other: &Self) -> Ordering {
102 self.name().cmp(&other.name())
103 }
104}
105
106impl PartialOrd for SqlIdentifier {
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl<S: AsRef<str>> PartialEq<S> for SqlIdentifier {
113 fn eq(&self, other: &S) -> bool {
114 self == &SqlIdentifier::from(other.as_ref())
115 }
116}
117
118impl Eq for SqlIdentifier {}
119
120impl<S: AsRef<str>> From<S> for SqlIdentifier {
121 fn from(name: S) -> Self {
122 if name.as_ref().starts_with('"')
123 && name.as_ref().ends_with('"')
124 && name.as_ref().len() >= 2
125 {
126 Self {
127 name: name.as_ref()[1..name.as_ref().len() - 1].to_string(),
128 case_sensitive: true,
129 }
130 } else {
131 Self::new(name, false)
132 }
133 }
134}
135
136impl From<SqlIdentifier> for String {
137 fn from(id: SqlIdentifier) -> String {
138 id.name()
139 }
140}
141
142impl From<&SqlIdentifier> for String {
143 fn from(id: &SqlIdentifier) -> String {
144 id.name()
145 }
146}
147
148impl Display for SqlIdentifier {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(f, "{}", self.name())
151 }
152}
153
154#[derive(Default, Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
158#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
159pub struct ProgramSchema {
160 #[cfg_attr(
161 feature = "testing",
162 proptest(strategy = "vec(any::<Relation>(), 0..2)")
163 )]
164 pub inputs: Vec<Relation>,
165 #[cfg_attr(
166 feature = "testing",
167 proptest(strategy = "vec(any::<Relation>(), 0..2)")
168 )]
169 pub outputs: Vec<Relation>,
170}
171
172impl ProgramSchema {
173 pub fn relations_with_lateness(&self) -> Vec<SqlIdentifier> {
174 self.inputs
175 .iter()
176 .chain(self.outputs.iter())
177 .filter(|rel| rel.has_lateness())
178 .map(|rel| rel.name.clone())
179 .collect()
180 }
181}
182
183#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
184#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
185pub struct PropertyValue {
186 pub value: String,
187 pub key_position: SourcePosition,
188 pub value_position: SourcePosition,
189}
190
191#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
195#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
196pub struct Relation {
197 #[serde(flatten)]
199 pub name: SqlIdentifier,
200 #[cfg_attr(feature = "testing", proptest(value = "Vec::new()"))]
201 pub fields: Vec<Field>,
202 #[serde(default)]
203 pub materialized: bool,
204 #[serde(default)]
205 pub properties: BTreeMap<String, PropertyValue>,
206}
207
208impl Relation {
209 pub fn empty() -> Self {
210 Self {
211 name: SqlIdentifier::from("".to_string()),
212 fields: Vec::new(),
213 materialized: false,
214 properties: BTreeMap::new(),
215 }
216 }
217
218 pub fn new(
219 name: SqlIdentifier,
220 fields: Vec<Field>,
221 materialized: bool,
222 properties: BTreeMap<String, PropertyValue>,
223 ) -> Self {
224 Self {
225 name,
226 fields,
227 materialized,
228 properties,
229 }
230 }
231
232 pub fn field(&self, name: &str) -> Option<&Field> {
234 let name = canonical_identifier(name);
235 self.fields.iter().find(|f| f.name == name)
236 }
237
238 pub fn has_lateness(&self) -> bool {
239 self.fields.iter().any(|f| f.lateness.is_some())
240 }
241}
242
243#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone)]
247#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
248pub struct Field {
249 #[serde(flatten)]
250 pub name: SqlIdentifier,
251 pub columntype: ColumnType,
252 pub lateness: Option<String>,
253 pub default: Option<String>,
254 pub unused: bool,
255 pub watermark: Option<String>,
256}
257
258impl Field {
259 pub fn new(name: SqlIdentifier, columntype: ColumnType) -> Self {
260 Self {
261 name,
262 columntype,
263 lateness: None,
264 default: None,
265 unused: false,
266 watermark: None,
267 }
268 }
269
270 pub fn with_lateness(mut self, lateness: &str) -> Self {
271 self.lateness = Some(lateness.to_string());
272 self
273 }
274
275 pub fn with_unused(mut self, unused: bool) -> Self {
276 self.unused = unused;
277 self
278 }
279}
280
281impl<'de> Deserialize<'de> for Field {
286 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
287 where
288 D: Deserializer<'de>,
289 {
290 const fn default_is_struct() -> Option<SqlType> {
291 Some(SqlType::Struct)
292 }
293
294 #[derive(Debug, Clone, Deserialize)]
295 struct FieldHelper {
296 name: Option<String>,
297 #[serde(default)]
298 case_sensitive: bool,
299 columntype: Option<ColumnType>,
300 #[serde(rename = "type")]
301 #[serde(default = "default_is_struct")]
302 typ: Option<SqlType>,
303 nullable: Option<bool>,
304 precision: Option<i64>,
305 scale: Option<i64>,
306 component: Option<Box<ColumnType>>,
307 fields: Option<serde_json::Value>,
308 key: Option<Box<ColumnType>>,
309 value: Option<Box<ColumnType>>,
310 default: Option<String>,
311 #[serde(default)]
312 unused: bool,
313 lateness: Option<String>,
314 watermark: Option<String>,
315 }
316
317 fn helper_to_field(helper: FieldHelper) -> Field {
318 let columntype = if let Some(ctype) = helper.columntype {
319 ctype
320 } else if let Some(serde_json::Value::Array(fields)) = helper.fields {
321 let fields = fields
322 .into_iter()
323 .map(|field| {
324 let field: FieldHelper = serde_json::from_value(field).unwrap();
325 helper_to_field(field)
326 })
327 .collect::<Vec<Field>>();
328
329 ColumnType {
330 typ: helper.typ.unwrap_or(SqlType::Null),
331 nullable: helper.nullable.unwrap_or(false),
332 precision: helper.precision,
333 scale: helper.scale,
334 component: helper.component,
335 fields: Some(fields),
336 key: None,
337 value: None,
338 }
339 } else if let Some(serde_json::Value::Object(obj)) = helper.fields {
340 serde_json::from_value(serde_json::Value::Object(obj))
341 .expect("Failed to deserialize object")
342 } else {
343 ColumnType {
344 typ: helper.typ.unwrap_or(SqlType::Null),
345 nullable: helper.nullable.unwrap_or(false),
346 precision: helper.precision,
347 scale: helper.scale,
348 component: helper.component,
349 fields: None,
350 key: helper.key,
351 value: helper.value,
352 }
353 };
354
355 Field {
356 name: SqlIdentifier::new(helper.name.unwrap(), helper.case_sensitive),
357 columntype,
358 default: helper.default,
359 unused: helper.unused,
360 lateness: helper.lateness,
361 watermark: helper.watermark,
362 }
363 }
364
365 let helper = FieldHelper::deserialize(deserializer)?;
366 Ok(helper_to_field(helper))
367 }
368}
369
370#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
375#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
376pub enum IntervalUnit {
377 Day,
379 DayToHour,
381 DayToMinute,
383 DayToSecond,
385 Hour,
387 HourToMinute,
389 HourToSecond,
391 Minute,
393 MinuteToSecond,
395 Month,
397 Second,
399 Year,
401 YearToMonth,
403}
404
405#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
407#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
408pub enum SqlType {
409 Boolean,
411 TinyInt,
413 SmallInt,
415 Int,
417 BigInt,
419 UTinyInt,
421 USmallInt,
423 UInt,
425 UBigInt,
427 Real,
429 Double,
431 Decimal,
433 Char,
435 Varchar,
437 Binary,
439 Varbinary,
441 Time,
443 Date,
445 Timestamp,
447 Interval(IntervalUnit),
449 Array,
451 Struct,
453 Map,
455 Null,
457 Uuid,
459 Variant,
461}
462
463impl Display for SqlType {
464 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465 f.write_str(&serde_json::to_string(self).unwrap())
466 }
467}
468
469impl<'de> Deserialize<'de> for SqlType {
470 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
471 where
472 D: Deserializer<'de>,
473 {
474 let value: String = Deserialize::deserialize(deserializer)?;
475 match value.to_lowercase().as_str() {
476 "interval_day" => Ok(SqlType::Interval(IntervalUnit::Day)),
477 "interval_day_hour" => Ok(SqlType::Interval(IntervalUnit::DayToHour)),
478 "interval_day_minute" => Ok(SqlType::Interval(IntervalUnit::DayToMinute)),
479 "interval_day_second" => Ok(SqlType::Interval(IntervalUnit::DayToSecond)),
480 "interval_hour" => Ok(SqlType::Interval(IntervalUnit::Hour)),
481 "interval_hour_minute" => Ok(SqlType::Interval(IntervalUnit::HourToMinute)),
482 "interval_hour_second" => Ok(SqlType::Interval(IntervalUnit::HourToSecond)),
483 "interval_minute" => Ok(SqlType::Interval(IntervalUnit::Minute)),
484 "interval_minute_second" => Ok(SqlType::Interval(IntervalUnit::MinuteToSecond)),
485 "interval_month" => Ok(SqlType::Interval(IntervalUnit::Month)),
486 "interval_second" => Ok(SqlType::Interval(IntervalUnit::Second)),
487 "interval_year" => Ok(SqlType::Interval(IntervalUnit::Year)),
488 "interval_year_month" => Ok(SqlType::Interval(IntervalUnit::YearToMonth)),
489 "boolean" => Ok(SqlType::Boolean),
490 "tinyint" => Ok(SqlType::TinyInt),
491 "smallint" => Ok(SqlType::SmallInt),
492 "integer" => Ok(SqlType::Int),
493 "bigint" => Ok(SqlType::BigInt),
494 "utinyint" => Ok(SqlType::UTinyInt),
495 "usmallint" => Ok(SqlType::USmallInt),
496 "uinteger" => Ok(SqlType::UInt),
497 "ubigint" => Ok(SqlType::UBigInt),
498 "real" => Ok(SqlType::Real),
499 "double" => Ok(SqlType::Double),
500 "decimal" => Ok(SqlType::Decimal),
501 "char" => Ok(SqlType::Char),
502 "varchar" => Ok(SqlType::Varchar),
503 "binary" => Ok(SqlType::Binary),
504 "varbinary" => Ok(SqlType::Varbinary),
505 "variant" => Ok(SqlType::Variant),
506 "time" => Ok(SqlType::Time),
507 "date" => Ok(SqlType::Date),
508 "timestamp" => Ok(SqlType::Timestamp),
509 "array" => Ok(SqlType::Array),
510 "struct" => Ok(SqlType::Struct),
511 "map" => Ok(SqlType::Map),
512 "null" => Ok(SqlType::Null),
513 "uuid" => Ok(SqlType::Uuid),
514 _ => Err(serde::de::Error::custom(format!(
515 "Unknown SQL type: {}",
516 value
517 ))),
518 }
519 }
520}
521
522impl Serialize for SqlType {
523 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
524 where
525 S: Serializer,
526 {
527 let type_str = match self {
528 SqlType::Boolean => "BOOLEAN",
529 SqlType::TinyInt => "TINYINT",
530 SqlType::SmallInt => "SMALLINT",
531 SqlType::Int => "INTEGER",
532 SqlType::BigInt => "BIGINT",
533 SqlType::UTinyInt => "UTINYINT",
534 SqlType::USmallInt => "USMALLINT",
535 SqlType::UInt => "UINTEGER",
536 SqlType::UBigInt => "UBIGINT",
537 SqlType::Real => "REAL",
538 SqlType::Double => "DOUBLE",
539 SqlType::Decimal => "DECIMAL",
540 SqlType::Char => "CHAR",
541 SqlType::Varchar => "VARCHAR",
542 SqlType::Binary => "BINARY",
543 SqlType::Varbinary => "VARBINARY",
544 SqlType::Time => "TIME",
545 SqlType::Date => "DATE",
546 SqlType::Timestamp => "TIMESTAMP",
547 SqlType::Interval(interval_unit) => match interval_unit {
548 IntervalUnit::Day => "INTERVAL_DAY",
549 IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
550 IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
551 IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
552 IntervalUnit::Hour => "INTERVAL_HOUR",
553 IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
554 IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
555 IntervalUnit::Minute => "INTERVAL_MINUTE",
556 IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
557 IntervalUnit::Month => "INTERVAL_MONTH",
558 IntervalUnit::Second => "INTERVAL_SECOND",
559 IntervalUnit::Year => "INTERVAL_YEAR",
560 IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
561 },
562 SqlType::Array => "ARRAY",
563 SqlType::Struct => "STRUCT",
564 SqlType::Uuid => "UUID",
565 SqlType::Map => "MAP",
566 SqlType::Null => "NULL",
567 SqlType::Variant => "VARIANT",
568 };
569 serializer.serialize_str(type_str)
570 }
571}
572
573impl SqlType {
574 pub fn is_string(&self) -> bool {
576 matches!(self, Self::Char | Self::Varchar)
577 }
578
579 pub fn is_varchar(&self) -> bool {
580 matches!(self, Self::Varchar)
581 }
582
583 pub fn is_varbinary(&self) -> bool {
584 matches!(self, Self::Varbinary)
585 }
586}
587
588const fn default_is_struct() -> SqlType {
591 SqlType::Struct
592}
593
594#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone)]
598#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
599pub struct ColumnType {
600 #[serde(rename = "type")]
602 #[serde(default = "default_is_struct")]
603 pub typ: SqlType,
604 pub nullable: bool,
606 pub precision: Option<i64>,
615 pub scale: Option<i64>,
620 #[cfg_attr(feature = "testing", proptest(value = "None"))]
627 pub component: Option<Box<ColumnType>>,
628 #[cfg_attr(feature = "testing", proptest(value = "Some(Vec::new())"))]
650 pub fields: Option<Vec<Field>>,
651 #[cfg_attr(feature = "testing", proptest(value = "None"))]
653 pub key: Option<Box<ColumnType>>,
654 #[cfg_attr(feature = "testing", proptest(value = "None"))]
656 pub value: Option<Box<ColumnType>>,
657}
658
659impl ColumnType {
660 pub fn boolean(nullable: bool) -> Self {
661 ColumnType {
662 typ: SqlType::Boolean,
663 nullable,
664 precision: None,
665 scale: None,
666 component: None,
667 fields: None,
668 key: None,
669 value: None,
670 }
671 }
672
673 pub fn uuid(nullable: bool) -> Self {
674 ColumnType {
675 typ: SqlType::Uuid,
676 nullable,
677 precision: None,
678 scale: None,
679 component: None,
680 fields: None,
681 key: None,
682 value: None,
683 }
684 }
685
686 pub fn tinyint(nullable: bool) -> Self {
687 ColumnType {
688 typ: SqlType::TinyInt,
689 nullable,
690 precision: None,
691 scale: None,
692 component: None,
693 fields: None,
694 key: None,
695 value: None,
696 }
697 }
698
699 pub fn smallint(nullable: bool) -> Self {
700 ColumnType {
701 typ: SqlType::SmallInt,
702 nullable,
703 precision: None,
704 scale: None,
705 component: None,
706 fields: None,
707 key: None,
708 value: None,
709 }
710 }
711
712 pub fn int(nullable: bool) -> Self {
713 ColumnType {
714 typ: SqlType::Int,
715 nullable,
716 precision: None,
717 scale: None,
718 component: None,
719 fields: None,
720 key: None,
721 value: None,
722 }
723 }
724
725 pub fn bigint(nullable: bool) -> Self {
726 ColumnType {
727 typ: SqlType::BigInt,
728 nullable,
729 precision: None,
730 scale: None,
731 component: None,
732 fields: None,
733 key: None,
734 value: None,
735 }
736 }
737
738 pub fn utinyint(nullable: bool) -> Self {
739 ColumnType {
740 typ: SqlType::UTinyInt,
741 nullable,
742 precision: None,
743 scale: None,
744 component: None,
745 fields: None,
746 key: None,
747 value: None,
748 }
749 }
750
751 pub fn usmallint(nullable: bool) -> Self {
752 ColumnType {
753 typ: SqlType::USmallInt,
754 nullable,
755 precision: None,
756 scale: None,
757 component: None,
758 fields: None,
759 key: None,
760 value: None,
761 }
762 }
763
764 pub fn uint(nullable: bool) -> Self {
765 ColumnType {
766 typ: SqlType::UInt,
767 nullable,
768 precision: None,
769 scale: None,
770 component: None,
771 fields: None,
772 key: None,
773 value: None,
774 }
775 }
776
777 pub fn ubigint(nullable: bool) -> Self {
778 ColumnType {
779 typ: SqlType::UBigInt,
780 nullable,
781 precision: None,
782 scale: None,
783 component: None,
784 fields: None,
785 key: None,
786 value: None,
787 }
788 }
789
790 pub fn double(nullable: bool) -> Self {
791 ColumnType {
792 typ: SqlType::Double,
793 nullable,
794 precision: None,
795 scale: None,
796 component: None,
797 fields: None,
798 key: None,
799 value: None,
800 }
801 }
802
803 pub fn real(nullable: bool) -> Self {
804 ColumnType {
805 typ: SqlType::Real,
806 nullable,
807 precision: None,
808 scale: None,
809 component: None,
810 fields: None,
811 key: None,
812 value: None,
813 }
814 }
815
816 pub fn decimal(precision: i64, scale: i64, nullable: bool) -> Self {
817 ColumnType {
818 typ: SqlType::Decimal,
819 nullable,
820 precision: Some(precision),
821 scale: Some(scale),
822 component: None,
823 fields: None,
824 key: None,
825 value: None,
826 }
827 }
828
829 pub fn varchar(nullable: bool) -> Self {
830 ColumnType {
831 typ: SqlType::Varchar,
832 nullable,
833 precision: None,
834 scale: None,
835 component: None,
836 fields: None,
837 key: None,
838 value: None,
839 }
840 }
841
842 pub fn varbinary(nullable: bool) -> Self {
843 ColumnType {
844 typ: SqlType::Varbinary,
845 nullable,
846 precision: None,
847 scale: None,
848 component: None,
849 fields: None,
850 key: None,
851 value: None,
852 }
853 }
854
855 pub fn fixed(width: i64, nullable: bool) -> Self {
856 ColumnType {
857 typ: SqlType::Binary,
858 nullable,
859 precision: Some(width),
860 scale: None,
861 component: None,
862 fields: None,
863 key: None,
864 value: None,
865 }
866 }
867
868 pub fn date(nullable: bool) -> Self {
869 ColumnType {
870 typ: SqlType::Date,
871 nullable,
872 precision: None,
873 scale: None,
874 component: None,
875 fields: None,
876 key: None,
877 value: None,
878 }
879 }
880
881 pub fn time(nullable: bool) -> Self {
882 ColumnType {
883 typ: SqlType::Time,
884 nullable,
885 precision: None,
886 scale: None,
887 component: None,
888 fields: None,
889 key: None,
890 value: None,
891 }
892 }
893
894 pub fn timestamp(nullable: bool) -> Self {
895 ColumnType {
896 typ: SqlType::Timestamp,
897 nullable,
898 precision: None,
899 scale: None,
900 component: None,
901 fields: None,
902 key: None,
903 value: None,
904 }
905 }
906
907 pub fn variant(nullable: bool) -> Self {
908 ColumnType {
909 typ: SqlType::Variant,
910 nullable,
911 precision: None,
912 scale: None,
913 component: None,
914 fields: None,
915 key: None,
916 value: None,
917 }
918 }
919
920 pub fn array(nullable: bool, element: ColumnType) -> Self {
921 ColumnType {
922 typ: SqlType::Array,
923 nullable,
924 precision: None,
925 scale: None,
926 component: Some(Box::new(element)),
927 fields: None,
928 key: None,
929 value: None,
930 }
931 }
932
933 pub fn structure(nullable: bool, fields: &[Field]) -> Self {
934 ColumnType {
935 typ: SqlType::Struct,
936 nullable,
937 precision: None,
938 scale: None,
939 component: None,
940 fields: Some(fields.to_vec()),
941 key: None,
942 value: None,
943 }
944 }
945
946 pub fn map(nullable: bool, key: ColumnType, val: ColumnType) -> Self {
947 ColumnType {
948 typ: SqlType::Map,
949 nullable,
950 precision: None,
951 scale: None,
952 component: None,
953 fields: None,
954 key: Some(Box::new(key)),
955 value: Some(Box::new(val)),
956 }
957 }
958
959 pub fn is_integral_type(&self) -> bool {
960 matches!(
961 &self.typ,
962 SqlType::TinyInt
963 | SqlType::SmallInt
964 | SqlType::Int
965 | SqlType::BigInt
966 | SqlType::UTinyInt
967 | SqlType::USmallInt
968 | SqlType::UInt
969 | SqlType::UBigInt
970 )
971 }
972
973 pub fn is_fp_type(&self) -> bool {
974 matches!(&self.typ, SqlType::Double | SqlType::Real)
975 }
976
977 pub fn is_decimal_type(&self) -> bool {
978 matches!(&self.typ, SqlType::Decimal)
979 }
980
981 pub fn is_numeric_type(&self) -> bool {
982 self.is_integral_type() || self.is_fp_type() || self.is_decimal_type()
983 }
984}
985
986#[cfg(test)]
987mod tests {
988 use super::{IntervalUnit, SqlIdentifier};
989 use crate::program_schema::{ColumnType, Field, SqlType};
990
991 #[test]
992 fn serde_sql_type() {
993 for (sql_str_base, expected_value) in [
994 ("Boolean", SqlType::Boolean),
995 ("Uuid", SqlType::Uuid),
996 ("TinyInt", SqlType::TinyInt),
997 ("SmallInt", SqlType::SmallInt),
998 ("Integer", SqlType::Int),
999 ("BigInt", SqlType::BigInt),
1000 ("UTinyInt", SqlType::UTinyInt),
1001 ("USmallInt", SqlType::USmallInt),
1002 ("UInteger", SqlType::UInt),
1003 ("UBigInt", SqlType::UBigInt),
1004 ("Real", SqlType::Real),
1005 ("Double", SqlType::Double),
1006 ("Decimal", SqlType::Decimal),
1007 ("Char", SqlType::Char),
1008 ("Varchar", SqlType::Varchar),
1009 ("Binary", SqlType::Binary),
1010 ("Varbinary", SqlType::Varbinary),
1011 ("Time", SqlType::Time),
1012 ("Date", SqlType::Date),
1013 ("Timestamp", SqlType::Timestamp),
1014 ("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
1015 (
1016 "Interval_Day_Hour",
1017 SqlType::Interval(IntervalUnit::DayToHour),
1018 ),
1019 (
1020 "Interval_Day_Minute",
1021 SqlType::Interval(IntervalUnit::DayToMinute),
1022 ),
1023 (
1024 "Interval_Day_Second",
1025 SqlType::Interval(IntervalUnit::DayToSecond),
1026 ),
1027 ("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
1028 (
1029 "Interval_Hour_Minute",
1030 SqlType::Interval(IntervalUnit::HourToMinute),
1031 ),
1032 (
1033 "Interval_Hour_Second",
1034 SqlType::Interval(IntervalUnit::HourToSecond),
1035 ),
1036 ("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
1037 (
1038 "Interval_Minute_Second",
1039 SqlType::Interval(IntervalUnit::MinuteToSecond),
1040 ),
1041 ("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
1042 ("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
1043 ("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
1044 (
1045 "Interval_Year_Month",
1046 SqlType::Interval(IntervalUnit::YearToMonth),
1047 ),
1048 ("Array", SqlType::Array),
1049 ("Struct", SqlType::Struct),
1050 ("Map", SqlType::Map),
1051 ("Null", SqlType::Null),
1052 ("Variant", SqlType::Variant),
1053 ] {
1054 for sql_str in [
1055 sql_str_base, &sql_str_base.to_lowercase(), &sql_str_base.to_uppercase(), ] {
1059 let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str))
1060 .unwrap_or_else(|_| {
1061 panic!("\"{sql_str}\" should deserialize into its SQL type")
1062 });
1063 assert_eq!(value1, expected_value);
1064 let serialized_str =
1065 serde_json::to_string(&value1).expect("Value should serialize into JSON");
1066 let value2: SqlType = serde_json::from_str(&serialized_str).unwrap_or_else(|_| {
1067 panic!(
1068 "{} should deserialize back into its SQL type",
1069 serialized_str
1070 )
1071 });
1072 assert_eq!(value1, value2);
1073 }
1074 }
1075 }
1076
1077 #[test]
1078 fn deserialize_interval_types() {
1079 use super::IntervalUnit::*;
1080 use super::SqlType::*;
1081
1082 let schema = r#"
1083{
1084 "inputs" : [ {
1085 "name" : "sales",
1086 "case_sensitive" : false,
1087 "fields" : [ {
1088 "name" : "sales_id",
1089 "case_sensitive" : false,
1090 "columntype" : {
1091 "type" : "INTEGER",
1092 "nullable" : true
1093 }
1094 }, {
1095 "name" : "customer_id",
1096 "case_sensitive" : false,
1097 "columntype" : {
1098 "type" : "INTEGER",
1099 "nullable" : true
1100 }
1101 }, {
1102 "name" : "age",
1103 "case_sensitive" : false,
1104 "columntype" : {
1105 "type" : "UINTEGER",
1106 "nullable" : true
1107 }
1108 }, {
1109 "name" : "amount",
1110 "case_sensitive" : false,
1111 "columntype" : {
1112 "type" : "DECIMAL",
1113 "nullable" : true,
1114 "precision" : 10,
1115 "scale" : 2
1116 }
1117 }, {
1118 "name" : "sale_date",
1119 "case_sensitive" : false,
1120 "columntype" : {
1121 "type" : "DATE",
1122 "nullable" : true
1123 }
1124 } ],
1125 "primary_key" : [ "sales_id" ]
1126 } ],
1127 "outputs" : [ {
1128 "name" : "salessummary",
1129 "case_sensitive" : false,
1130 "fields" : [ {
1131 "name" : "customer_id",
1132 "case_sensitive" : false,
1133 "columntype" : {
1134 "type" : "INTEGER",
1135 "nullable" : true
1136 }
1137 }, {
1138 "name" : "total_sales",
1139 "case_sensitive" : false,
1140 "columntype" : {
1141 "type" : "DECIMAL",
1142 "nullable" : true,
1143 "precision" : 38,
1144 "scale" : 2
1145 }
1146 }, {
1147 "name" : "interval_day",
1148 "case_sensitive" : false,
1149 "columntype" : {
1150 "type" : "INTERVAL_DAY",
1151 "nullable" : false,
1152 "precision" : 2,
1153 "scale" : 6
1154 }
1155 }, {
1156 "name" : "interval_day_to_hour",
1157 "case_sensitive" : false,
1158 "columntype" : {
1159 "type" : "INTERVAL_DAY_HOUR",
1160 "nullable" : false,
1161 "precision" : 2,
1162 "scale" : 6
1163 }
1164 }, {
1165 "name" : "interval_day_to_minute",
1166 "case_sensitive" : false,
1167 "columntype" : {
1168 "type" : "INTERVAL_DAY_MINUTE",
1169 "nullable" : false,
1170 "precision" : 2,
1171 "scale" : 6
1172 }
1173 }, {
1174 "name" : "interval_day_to_second",
1175 "case_sensitive" : false,
1176 "columntype" : {
1177 "type" : "INTERVAL_DAY_SECOND",
1178 "nullable" : false,
1179 "precision" : 2,
1180 "scale" : 6
1181 }
1182 }, {
1183 "name" : "interval_hour",
1184 "case_sensitive" : false,
1185 "columntype" : {
1186 "type" : "INTERVAL_HOUR",
1187 "nullable" : false,
1188 "precision" : 2,
1189 "scale" : 6
1190 }
1191 }, {
1192 "name" : "interval_hour_to_minute",
1193 "case_sensitive" : false,
1194 "columntype" : {
1195 "type" : "INTERVAL_HOUR_MINUTE",
1196 "nullable" : false,
1197 "precision" : 2,
1198 "scale" : 6
1199 }
1200 }, {
1201 "name" : "interval_hour_to_second",
1202 "case_sensitive" : false,
1203 "columntype" : {
1204 "type" : "INTERVAL_HOUR_SECOND",
1205 "nullable" : false,
1206 "precision" : 2,
1207 "scale" : 6
1208 }
1209 }, {
1210 "name" : "interval_minute",
1211 "case_sensitive" : false,
1212 "columntype" : {
1213 "type" : "INTERVAL_MINUTE",
1214 "nullable" : false,
1215 "precision" : 2,
1216 "scale" : 6
1217 }
1218 }, {
1219 "name" : "interval_minute_to_second",
1220 "case_sensitive" : false,
1221 "columntype" : {
1222 "type" : "INTERVAL_MINUTE_SECOND",
1223 "nullable" : false,
1224 "precision" : 2,
1225 "scale" : 6
1226 }
1227 }, {
1228 "name" : "interval_month",
1229 "case_sensitive" : false,
1230 "columntype" : {
1231 "type" : "INTERVAL_MONTH",
1232 "nullable" : false
1233 }
1234 }, {
1235 "name" : "interval_second",
1236 "case_sensitive" : false,
1237 "columntype" : {
1238 "type" : "INTERVAL_SECOND",
1239 "nullable" : false,
1240 "precision" : 2,
1241 "scale" : 6
1242 }
1243 }, {
1244 "name" : "interval_year",
1245 "case_sensitive" : false,
1246 "columntype" : {
1247 "type" : "INTERVAL_YEAR",
1248 "nullable" : false
1249 }
1250 }, {
1251 "name" : "interval_year_to_month",
1252 "case_sensitive" : false,
1253 "columntype" : {
1254 "type" : "INTERVAL_YEAR_MONTH",
1255 "nullable" : false
1256 }
1257 } ]
1258 } ]
1259}
1260"#;
1261
1262 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1263 let types = schema
1264 .outputs
1265 .iter()
1266 .flat_map(|r| r.fields.iter().map(|f| f.columntype.typ));
1267 let expected_types = [
1268 Int,
1269 Decimal,
1270 Interval(Day),
1271 Interval(DayToHour),
1272 Interval(DayToMinute),
1273 Interval(DayToSecond),
1274 Interval(Hour),
1275 Interval(HourToMinute),
1276 Interval(HourToSecond),
1277 Interval(Minute),
1278 Interval(MinuteToSecond),
1279 Interval(Month),
1280 Interval(Second),
1281 Interval(Year),
1282 Interval(YearToMonth),
1283 ];
1284
1285 assert_eq!(types.collect::<Vec<_>>(), &expected_types);
1286 }
1287
1288 #[test]
1289 fn serialize_struct_schemas() {
1290 let schema = r#"{
1291 "inputs" : [ {
1292 "name" : "PERS",
1293 "case_sensitive" : false,
1294 "fields" : [ {
1295 "name" : "P0",
1296 "case_sensitive" : false,
1297 "columntype" : {
1298 "fields" : [ {
1299 "type" : "VARCHAR",
1300 "nullable" : true,
1301 "precision" : 30,
1302 "name" : "FIRSTNAME"
1303 }, {
1304 "type" : "VARCHAR",
1305 "nullable" : true,
1306 "precision" : 30,
1307 "name" : "LASTNAME"
1308 }, {
1309 "type" : "UINTEGER",
1310 "nullable" : true,
1311 "name" : "AGE"
1312 }, {
1313 "fields" : {
1314 "fields" : [ {
1315 "type" : "VARCHAR",
1316 "nullable" : true,
1317 "precision" : 30,
1318 "name" : "STREET"
1319 }, {
1320 "type" : "VARCHAR",
1321 "nullable" : true,
1322 "precision" : 30,
1323 "name" : "CITY"
1324 }, {
1325 "type" : "CHAR",
1326 "nullable" : true,
1327 "precision" : 2,
1328 "name" : "STATE"
1329 }, {
1330 "type" : "VARCHAR",
1331 "nullable" : true,
1332 "precision" : 6,
1333 "name" : "POSTAL_CODE"
1334 } ],
1335 "nullable" : false
1336 },
1337 "nullable" : false,
1338 "name" : "ADDRESS"
1339 } ],
1340 "nullable" : false
1341 }
1342 }]
1343 } ],
1344 "outputs" : [ ]
1345}
1346"#;
1347 let schema: super::ProgramSchema = serde_json::from_str(schema).unwrap();
1348 eprintln!("{:#?}", schema);
1349 let pers = schema.inputs.iter().find(|r| r.name == "PERS").unwrap();
1350 let p0 = pers.fields.iter().find(|f| f.name == "P0").unwrap();
1351 assert_eq!(p0.columntype.typ, SqlType::Struct);
1352 let p0_fields = p0.columntype.fields.as_ref().unwrap();
1353 assert_eq!(p0_fields[0].columntype.typ, SqlType::Varchar);
1354 assert_eq!(p0_fields[1].columntype.typ, SqlType::Varchar);
1355 assert_eq!(p0_fields[2].columntype.typ, SqlType::UInt);
1356 assert_eq!(p0_fields[3].columntype.typ, SqlType::Struct);
1357 assert_eq!(p0_fields[3].name, "ADDRESS");
1358 let address = &p0_fields[3].columntype.fields.as_ref().unwrap();
1359 assert_eq!(address.len(), 4);
1360 assert_eq!(address[0].name, "STREET");
1361 assert_eq!(address[0].columntype.typ, SqlType::Varchar);
1362 assert_eq!(address[1].columntype.typ, SqlType::Varchar);
1363 assert_eq!(address[2].columntype.typ, SqlType::Char);
1364 assert_eq!(address[3].columntype.typ, SqlType::Varchar);
1365 }
1366
1367 #[test]
1368 fn sql_identifier_cmp() {
1369 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("foo"));
1370 assert_ne!(SqlIdentifier::from("foo"), SqlIdentifier::from("bar"));
1371 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("BAR"));
1372 assert_eq!(SqlIdentifier::from("foo"), SqlIdentifier::from("\"foo\""));
1373 assert_eq!(SqlIdentifier::from("bar"), SqlIdentifier::from("\"bar\""));
1374 assert_eq!(SqlIdentifier::from("bAr"), SqlIdentifier::from("\"bAr\""));
1375 assert_eq!(
1376 SqlIdentifier::new("bAr", true),
1377 SqlIdentifier::from("\"bAr\"")
1378 );
1379
1380 assert_eq!(SqlIdentifier::from("bAr"), "bar");
1381 assert_eq!(SqlIdentifier::from("bAr"), "bAr");
1382 }
1383
1384 #[test]
1385 fn sql_identifier_ord() {
1386 let mut btree = std::collections::BTreeSet::new();
1387 assert!(btree.insert(SqlIdentifier::from("foo")));
1388 assert!(btree.insert(SqlIdentifier::from("bar")));
1389 assert!(!btree.insert(SqlIdentifier::from("BAR")));
1390 assert!(!btree.insert(SqlIdentifier::from("\"foo\"")));
1391 assert!(!btree.insert(SqlIdentifier::from("\"bar\"")));
1392 }
1393
1394 #[test]
1395 fn sql_identifier_hash() {
1396 let mut hs = std::collections::HashSet::new();
1397 assert!(hs.insert(SqlIdentifier::from("foo")));
1398 assert!(hs.insert(SqlIdentifier::from("bar")));
1399 assert!(!hs.insert(SqlIdentifier::from("BAR")));
1400 assert!(!hs.insert(SqlIdentifier::from("\"foo\"")));
1401 assert!(!hs.insert(SqlIdentifier::from("\"bar\"")));
1402 }
1403
1404 #[test]
1405 fn sql_identifier_name() {
1406 assert_eq!(SqlIdentifier::from("foo").name(), "foo");
1407 assert_eq!(SqlIdentifier::from("bAr").name(), "bar");
1408 assert_eq!(SqlIdentifier::from("\"bAr\"").name(), "bAr");
1409 assert_eq!(SqlIdentifier::from("foo").sql_name(), "foo");
1410 assert_eq!(SqlIdentifier::from("bAr").sql_name(), "bAr");
1411 assert_eq!(SqlIdentifier::from("\"bAr\"").sql_name(), "\"bAr\"");
1412 }
1413
1414 #[test]
1415 fn issue3277() {
1416 let schema = r#"{
1417 "name" : "j",
1418 "case_sensitive" : false,
1419 "columntype" : {
1420 "fields" : [ {
1421 "key" : {
1422 "nullable" : false,
1423 "precision" : -1,
1424 "type" : "VARCHAR"
1425 },
1426 "name" : "s",
1427 "nullable" : true,
1428 "type" : "MAP",
1429 "value" : {
1430 "nullable" : true,
1431 "precision" : -1,
1432 "type" : "VARCHAR"
1433 }
1434 } ],
1435 "nullable" : true
1436 }
1437 }"#;
1438 let field: Field = serde_json::from_str(schema).unwrap();
1439 println!("field: {:#?}", field);
1440 assert_eq!(
1441 field,
1442 Field {
1443 name: SqlIdentifier {
1444 name: "j".to_string(),
1445 case_sensitive: false,
1446 },
1447 columntype: ColumnType {
1448 typ: SqlType::Struct,
1449 nullable: true,
1450 precision: None,
1451 scale: None,
1452 component: None,
1453 fields: Some(vec![Field {
1454 name: SqlIdentifier {
1455 name: "s".to_string(),
1456 case_sensitive: false,
1457 },
1458 columntype: ColumnType {
1459 typ: SqlType::Map,
1460 nullable: true,
1461 precision: None,
1462 scale: None,
1463 component: None,
1464 fields: None,
1465 key: Some(Box::new(ColumnType {
1466 typ: SqlType::Varchar,
1467 nullable: false,
1468 precision: Some(-1),
1469 scale: None,
1470 component: None,
1471 fields: None,
1472 key: None,
1473 value: None,
1474 })),
1475 value: Some(Box::new(ColumnType {
1476 typ: SqlType::Varchar,
1477 nullable: true,
1478 precision: Some(-1),
1479 scale: None,
1480 component: None,
1481 fields: None,
1482 key: None,
1483 value: None,
1484 })),
1485 },
1486 lateness: None,
1487 default: None,
1488 unused: false,
1489 watermark: None,
1490 }]),
1491 key: None,
1492 value: None,
1493 },
1494 lateness: None,
1495 default: None,
1496 unused: false,
1497 watermark: None,
1498 }
1499 );
1500 }
1501}