1use llkv_result::{Error, Result};
10use sqlparser::ast::{DataType, ExactNumberInfo};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SqlTypeFamily {
16 String,
18 Integer,
20 Decimal { scale: i8 },
22 Date32,
24 Binary,
26}
27
28pub fn classify_sql_data_type(data_type: &DataType) -> Result<SqlTypeFamily> {
35 use DataType::*;
36
37 let family = match data_type {
38 Character(_)
39 | Char(_)
40 | CharacterVarying(_)
41 | CharVarying(_)
42 | Varchar(_)
43 | Nvarchar(_)
44 | CharacterLargeObject(_)
45 | CharLargeObject(_)
46 | Clob(_)
47 | Text
48 | String(_)
49 | Uuid
50 | JSON
51 | JSONB => SqlTypeFamily::String,
52 Binary(_) | Varbinary(_) | Blob(_) | TinyBlob | MediumBlob | LongBlob | Bytes(_) => {
53 SqlTypeFamily::Binary
54 }
55 Date | Date32 => SqlTypeFamily::Date32,
56 Decimal(info)
57 | DecimalUnsigned(info)
58 | Numeric(info)
59 | Dec(info)
60 | DecUnsigned(info)
61 | BigDecimal(info)
62 | BigNumeric(info) => SqlTypeFamily::Decimal {
63 scale: decimal_scale(info)?,
64 },
65 TinyInt(_) | TinyIntUnsigned(_) | UTinyInt | Int2(_) | Int2Unsigned(_) | SmallInt(_)
66 | SmallIntUnsigned(_) | USmallInt | MediumInt(_) | MediumIntUnsigned(_) | Int(_)
67 | Int4(_) | Int8(_) | Int16 | Int32 | Int64 | Int128 | Int256 | Integer(_)
68 | IntUnsigned(_) | Int4Unsigned(_) | IntegerUnsigned(_) | HugeInt | UHugeInt | UInt8
69 | UInt16 | UInt32 | UInt64 | UInt128 | UInt256 | BigInt(_) | BigIntUnsigned(_)
70 | UBigInt | Int8Unsigned(_) | Signed | SignedInteger | Unsigned | UnsignedInteger => {
71 SqlTypeFamily::Integer
72 }
73 other => {
74 return Err(Error::InvalidArgumentError(format!(
75 "unsupported SQL data type '{other}' for classification"
76 )));
77 }
78 };
79
80 Ok(family)
81}
82
83fn decimal_scale(info: &ExactNumberInfo) -> Result<i8> {
84 let raw_scale = match info {
85 ExactNumberInfo::None | ExactNumberInfo::Precision(_) => 0,
86 ExactNumberInfo::PrecisionAndScale(_, scale) => *scale,
87 };
88
89 if raw_scale < i64::from(i8::MIN) || raw_scale > i64::from(i8::MAX) {
90 return Err(Error::InvalidArgumentError(format!(
91 "decimal scale {raw_scale} exceeds i8 range"
92 )));
93 }
94
95 Ok(raw_scale as i8)
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn decimal_scale_defaults_to_zero() {
104 let family = classify_sql_data_type(&DataType::Decimal(ExactNumberInfo::None)).unwrap();
105 assert_eq!(family, SqlTypeFamily::Decimal { scale: 0 });
106 }
107
108 #[test]
109 fn catches_unsupported_types() {
110 let err = classify_sql_data_type(&DataType::Boolean).unwrap_err();
111 assert!(matches!(err, Error::InvalidArgumentError(_)));
112 }
113}