Skip to main content

hematite/parser/
types.rs

1//! Parser-owned SQL literal and type names.
2
3use std::cmp::Ordering;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum SqlTypeName {
7    Int8,
8    Int16,
9    Int,
10    Int64,
11    Int128,
12    UInt8,
13    UInt16,
14    UInt,
15    UInt64,
16    UInt128,
17    Text,
18    Char(u32),
19    VarChar(u32),
20    Binary(u32),
21    VarBinary(u32),
22    Enum(Vec<String>),
23    Boolean,
24    Float32,
25    Float,
26    Decimal {
27        precision: Option<u32>,
28        scale: Option<u32>,
29    },
30    Blob,
31    Date,
32    Time,
33    DateTime,
34    TimeWithTimeZone,
35    IntervalYearMonth,
36    IntervalDaySecond,
37}
38
39impl SqlTypeName {
40    pub fn to_sql(&self) -> String {
41        match self {
42            SqlTypeName::Int8 => "INT8".to_string(),
43            SqlTypeName::Int16 => "INT16".to_string(),
44            SqlTypeName::Int => "INT".to_string(),
45            SqlTypeName::Int64 => "INT64".to_string(),
46            SqlTypeName::Int128 => "INT128".to_string(),
47            SqlTypeName::UInt8 => "UINT8".to_string(),
48            SqlTypeName::UInt16 => "UINT16".to_string(),
49            SqlTypeName::UInt => "UINT".to_string(),
50            SqlTypeName::UInt64 => "UINT64".to_string(),
51            SqlTypeName::UInt128 => "UINT128".to_string(),
52            SqlTypeName::Text => "TEXT".to_string(),
53            SqlTypeName::Char(length) => format!("CHAR({length})"),
54            SqlTypeName::VarChar(length) => format!("VARCHAR({length})"),
55            SqlTypeName::Binary(length) => format!("BINARY({length})"),
56            SqlTypeName::VarBinary(length) => format!("VARBINARY({length})"),
57            SqlTypeName::Enum(values) => format!(
58                "ENUM({})",
59                values
60                    .iter()
61                    .map(|value| format!("'{}'", value.replace('\'', "''")))
62                    .collect::<Vec<_>>()
63                    .join(", ")
64            ),
65            SqlTypeName::Boolean => "BOOLEAN".to_string(),
66            SqlTypeName::Float32 => "FLOAT32".to_string(),
67            SqlTypeName::Float => "FLOAT".to_string(),
68            SqlTypeName::Decimal { precision, scale } => {
69                format_numeric_type("DECIMAL", *precision, *scale)
70            }
71            SqlTypeName::Blob => "BLOB".to_string(),
72            SqlTypeName::Date => "DATE".to_string(),
73            SqlTypeName::Time => "TIME".to_string(),
74            SqlTypeName::DateTime => "DATETIME".to_string(),
75            SqlTypeName::TimeWithTimeZone => "TIME WITH TIME ZONE".to_string(),
76            SqlTypeName::IntervalYearMonth => "INTERVAL YEAR TO MONTH".to_string(),
77            SqlTypeName::IntervalDaySecond => "INTERVAL DAY TO SECOND".to_string(),
78        }
79    }
80
81    pub fn supports_text_metadata(&self) -> bool {
82        matches!(
83            self,
84            SqlTypeName::Text | SqlTypeName::Char(_) | SqlTypeName::VarChar(_)
85        )
86    }
87}
88
89fn format_numeric_type(name: &str, precision: Option<u32>, scale: Option<u32>) -> String {
90    match (precision, scale) {
91        (Some(precision), Some(scale)) => format!("{name}({precision}, {scale})"),
92        (Some(precision), None) => format!("{name}({precision})"),
93        (None, _) => name.to_string(),
94    }
95}
96
97#[derive(Debug, Clone)]
98pub enum LiteralValue {
99    Integer(i128),
100    Text(String),
101    Blob(Vec<u8>),
102    Boolean(bool),
103    Float(String),
104    Null,
105}
106
107impl LiteralValue {
108    pub fn data_type(&self) -> SqlTypeName {
109        match self {
110            LiteralValue::Integer(_) => SqlTypeName::Int,
111            LiteralValue::Text(_) => SqlTypeName::Text,
112            LiteralValue::Blob(_) => SqlTypeName::Blob,
113            LiteralValue::Boolean(_) => SqlTypeName::Boolean,
114            LiteralValue::Float(_) => SqlTypeName::Float,
115            LiteralValue::Null => SqlTypeName::Text,
116        }
117    }
118
119    pub fn is_compatible_with(&self, data_type: SqlTypeName) -> bool {
120        match (self, data_type) {
121            (LiteralValue::Integer(_), SqlTypeName::Int8) => true,
122            (LiteralValue::Integer(_), SqlTypeName::Int16) => true,
123            (LiteralValue::Integer(_), SqlTypeName::Int) => true,
124            (LiteralValue::Integer(_), SqlTypeName::Int64) => true,
125            (LiteralValue::Integer(_), SqlTypeName::Int128) => true,
126            (LiteralValue::Integer(value), SqlTypeName::UInt8) => *value >= 0,
127            (LiteralValue::Integer(value), SqlTypeName::UInt16) => *value >= 0,
128            (LiteralValue::Integer(value), SqlTypeName::UInt) => *value >= 0,
129            (LiteralValue::Integer(value), SqlTypeName::UInt64) => *value >= 0,
130            (LiteralValue::Integer(value), SqlTypeName::UInt128) => *value >= 0,
131            (LiteralValue::Integer(_), SqlTypeName::Decimal { .. }) => true,
132            (LiteralValue::Float(_), SqlTypeName::Float32) => true,
133            (LiteralValue::Float(_), SqlTypeName::Float) => true,
134            (LiteralValue::Float(_), SqlTypeName::Decimal { .. }) => true,
135            (LiteralValue::Text(_), SqlTypeName::Text) => true,
136            (LiteralValue::Text(_), SqlTypeName::Char(_)) => true,
137            (LiteralValue::Text(_), SqlTypeName::VarChar(_)) => true,
138            (LiteralValue::Text(_), SqlTypeName::Binary(_)) => true,
139            (LiteralValue::Text(_), SqlTypeName::VarBinary(_)) => true,
140            (LiteralValue::Text(_), SqlTypeName::Enum(_)) => true,
141            (LiteralValue::Text(_), SqlTypeName::Blob) => true,
142            (LiteralValue::Text(_), SqlTypeName::Date) => true,
143            (LiteralValue::Text(_), SqlTypeName::Time) => true,
144            (LiteralValue::Text(_), SqlTypeName::DateTime) => true,
145            (LiteralValue::Text(_), SqlTypeName::TimeWithTimeZone) => true,
146            (LiteralValue::Text(_), SqlTypeName::Decimal { .. }) => true,
147            (LiteralValue::Blob(_), SqlTypeName::Binary(_)) => true,
148            (LiteralValue::Blob(_), SqlTypeName::VarBinary(_)) => true,
149            (LiteralValue::Blob(_), SqlTypeName::Blob) => true,
150            (LiteralValue::Boolean(_), SqlTypeName::Boolean) => true,
151            (LiteralValue::Null, _) => true,
152            _ => false,
153        }
154    }
155
156    pub fn is_null(&self) -> bool {
157        matches!(self, LiteralValue::Null)
158    }
159}
160
161impl PartialEq for LiteralValue {
162    fn eq(&self, other: &Self) -> bool {
163        match (self, other) {
164            (LiteralValue::Integer(a), LiteralValue::Integer(b)) => a == b,
165            (LiteralValue::Text(a), LiteralValue::Text(b)) => a == b,
166            (LiteralValue::Blob(a), LiteralValue::Blob(b)) => a == b,
167            (LiteralValue::Boolean(a), LiteralValue::Boolean(b)) => a == b,
168            (LiteralValue::Float(a), LiteralValue::Float(b)) => a == b,
169            (LiteralValue::Null, LiteralValue::Null) => true,
170            _ => false,
171        }
172    }
173}
174
175impl Eq for LiteralValue {}
176
177impl PartialOrd for LiteralValue {
178    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
179        match (self, other) {
180            (LiteralValue::Integer(a), LiteralValue::Integer(b)) => a.partial_cmp(b),
181            (LiteralValue::Text(a), LiteralValue::Text(b)) => a.partial_cmp(b),
182            (LiteralValue::Blob(a), LiteralValue::Blob(b)) => a.partial_cmp(b),
183            (LiteralValue::Boolean(a), LiteralValue::Boolean(b)) => a.partial_cmp(b),
184            (LiteralValue::Float(a), LiteralValue::Float(b)) => {
185                compare_normalized_float_literals(a, b)
186            }
187            (LiteralValue::Null, _) => Some(Ordering::Less),
188            (_, LiteralValue::Null) => Some(Ordering::Greater),
189            _ => None,
190        }
191    }
192}
193
194pub fn normalize_float_literal(input: &str) -> String {
195    let trimmed = input.trim();
196    let (negative, digits) = match trimmed.as_bytes().first() {
197        Some(b'+') => (false, &trimmed[1..]),
198        Some(b'-') => (true, &trimmed[1..]),
199        _ => (false, trimmed),
200    };
201
202    let mut parts = digits.split('.');
203    let integer = parts.next().unwrap_or_default().trim_start_matches('0');
204    let integer = if integer.is_empty() { "0" } else { integer };
205    let mut fraction = parts.next().unwrap_or_default().to_string();
206    while fraction.ends_with('0') {
207        fraction.pop();
208    }
209
210    let combined = if fraction.is_empty() {
211        integer.to_string()
212    } else {
213        format!("{integer}.{fraction}")
214    };
215
216    if combined == "0" {
217        "0".to_string()
218    } else if negative {
219        format!("-{combined}")
220    } else {
221        combined
222    }
223}
224
225fn compare_normalized_float_literals(left: &str, right: &str) -> Option<Ordering> {
226    let (left_negative, left_integer, left_fraction) = split_float_literal(left)?;
227    let (right_negative, right_integer, right_fraction) = split_float_literal(right)?;
228
229    if left_negative != right_negative {
230        return Some(if left_negative {
231            Ordering::Less
232        } else {
233            Ordering::Greater
234        });
235    }
236
237    let ordering = left_integer
238        .len()
239        .cmp(&right_integer.len())
240        .then_with(|| left_integer.cmp(right_integer))
241        .then_with(|| {
242            let len = left_fraction.len().max(right_fraction.len());
243            let left_padded = format!("{left_fraction:0<len$}");
244            let right_padded = format!("{right_fraction:0<len$}");
245            left_padded.cmp(&right_padded)
246        });
247
248    Some(if left_negative {
249        ordering.reverse()
250    } else {
251        ordering
252    })
253}
254
255fn split_float_literal(input: &str) -> Option<(bool, &str, &str)> {
256    let (negative, digits) = match input.as_bytes().first() {
257        Some(b'-') => (true, &input[1..]),
258        _ => (false, input),
259    };
260    let mut parts = digits.split('.');
261    let integer = parts.next()?;
262    let fraction = parts.next().unwrap_or_default();
263    Some((negative, integer, fraction))
264}