Skip to main content

nodedb_sql/resolver/
expr.rs

1//! Convert sqlparser AST expressions to our SqlExpr IR.
2
3use sqlparser::ast::{self, Expr, UnaryOperator, Value};
4
5use crate::error::{Result, SqlError};
6use crate::parser::normalize::normalize_ident;
7use crate::types::*;
8
9/// Convert a sqlparser `Expr` to our `SqlExpr`.
10pub fn convert_expr(expr: &Expr) -> Result<SqlExpr> {
11    match expr {
12        Expr::Identifier(ident) => Ok(SqlExpr::Column {
13            table: None,
14            name: normalize_ident(ident),
15        }),
16        Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(SqlExpr::Column {
17            table: Some(normalize_ident(&parts[0])),
18            name: normalize_ident(&parts[1]),
19        }),
20        Expr::Value(val) => Ok(SqlExpr::Literal(convert_value(&val.value)?)),
21        Expr::BinaryOp { left, op, right } => Ok(SqlExpr::BinaryOp {
22            left: Box::new(convert_expr(left)?),
23            op: convert_binary_op(op)?,
24            right: Box::new(convert_expr(right)?),
25        }),
26        Expr::UnaryOp { op, expr } => Ok(SqlExpr::UnaryOp {
27            op: convert_unary_op(op)?,
28            expr: Box::new(convert_expr(expr)?),
29        }),
30        Expr::Function(func) => convert_function(func),
31        Expr::Nested(inner) => convert_expr(inner),
32        Expr::IsNull(inner) => Ok(SqlExpr::IsNull {
33            expr: Box::new(convert_expr(inner)?),
34            negated: false,
35        }),
36        Expr::IsNotNull(inner) => Ok(SqlExpr::IsNull {
37            expr: Box::new(convert_expr(inner)?),
38            negated: true,
39        }),
40        Expr::InList {
41            expr,
42            list,
43            negated,
44        } => Ok(SqlExpr::InList {
45            expr: Box::new(convert_expr(expr)?),
46            list: list.iter().map(convert_expr).collect::<Result<_>>()?,
47            negated: *negated,
48        }),
49        Expr::Between {
50            expr,
51            low,
52            high,
53            negated,
54        } => Ok(SqlExpr::Between {
55            expr: Box::new(convert_expr(expr)?),
56            low: Box::new(convert_expr(low)?),
57            high: Box::new(convert_expr(high)?),
58            negated: *negated,
59        }),
60        Expr::Like {
61            expr,
62            pattern,
63            negated,
64            ..
65        } => Ok(SqlExpr::Like {
66            expr: Box::new(convert_expr(expr)?),
67            pattern: Box::new(convert_expr(pattern)?),
68            negated: *negated,
69        }),
70        Expr::ILike {
71            expr,
72            pattern,
73            negated,
74            ..
75        } => Ok(SqlExpr::Like {
76            expr: Box::new(convert_expr(expr)?),
77            pattern: Box::new(convert_expr(pattern)?),
78            negated: *negated,
79        }),
80        Expr::Case {
81            operand,
82            conditions,
83            else_result,
84            ..
85        } => {
86            let when_then = conditions
87                .iter()
88                .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
89                .collect::<Result<Vec<_>>>()?;
90            Ok(SqlExpr::Case {
91                operand: operand
92                    .as_ref()
93                    .map(|e| convert_expr(e).map(Box::new))
94                    .transpose()?,
95                when_then,
96                else_expr: else_result
97                    .as_ref()
98                    .map(|e| convert_expr(e).map(Box::new))
99                    .transpose()?,
100            })
101        }
102        Expr::Cast {
103            expr, data_type, ..
104        } => Ok(SqlExpr::Cast {
105            expr: Box::new(convert_expr(expr)?),
106            to_type: format!("{data_type}"),
107        }),
108        Expr::Array(ast::Array { elem, .. }) => {
109            let elems = elem.iter().map(convert_expr).collect::<Result<_>>()?;
110            Ok(SqlExpr::ArrayLiteral(elems))
111        }
112        Expr::Wildcard(_) => Ok(SqlExpr::Wildcard),
113        // TRIM([BOTH|LEADING|TRAILING] [what FROM] expr)
114        Expr::Trim { expr, .. } => Ok(SqlExpr::Function {
115            name: "trim".into(),
116            args: vec![convert_expr(expr)?],
117            distinct: false,
118        }),
119        // CEIL(expr) / FLOOR(expr)
120        Expr::Ceil { expr, .. } => Ok(SqlExpr::Function {
121            name: "ceil".into(),
122            args: vec![convert_expr(expr)?],
123            distinct: false,
124        }),
125        Expr::Floor { expr, .. } => Ok(SqlExpr::Function {
126            name: "floor".into(),
127            args: vec![convert_expr(expr)?],
128            distinct: false,
129        }),
130        // SUBSTRING(expr FROM start FOR len)
131        Expr::Substring {
132            expr,
133            substring_from,
134            substring_for,
135            ..
136        } => {
137            let mut args = vec![convert_expr(expr)?];
138            if let Some(from) = substring_from {
139                args.push(convert_expr(from)?);
140            }
141            if let Some(len) = substring_for {
142                args.push(convert_expr(len)?);
143            }
144            Ok(SqlExpr::Function {
145                name: "substring".into(),
146                args,
147                distinct: false,
148            })
149        }
150        Expr::Interval(interval) => {
151            // INTERVAL '1 hour' → microseconds as i64 literal.
152            // The interval value is typically a string literal.
153            let interval_str = match interval.value.as_ref() {
154                Expr::Value(v) => match &v.value {
155                    Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => s.clone(),
156                    Value::Number(n, _) => {
157                        // INTERVAL 5 HOUR → combine number with leading_field.
158                        if let Some(ref field) = interval.leading_field {
159                            format!("{n} {field}")
160                        } else {
161                            n.clone()
162                        }
163                    }
164                    _ => {
165                        return Err(SqlError::Unsupported {
166                            detail: format!("INTERVAL value: {}", interval.value),
167                        });
168                    }
169                },
170                _ => {
171                    return Err(SqlError::Unsupported {
172                        detail: format!("INTERVAL expression: {}", interval.value),
173                    });
174                }
175            };
176
177            // If leading_field is specified, append it: INTERVAL '5' HOUR → "5 HOUR"
178            let full_str = if interval_str.chars().all(|c| c.is_ascii_digit())
179                && let Some(ref field) = interval.leading_field
180            {
181                format!("{interval_str} {field}")
182            } else {
183                interval_str
184            };
185
186            let micros = parse_interval_to_micros(&full_str).ok_or_else(|| SqlError::Parse {
187                detail: format!("cannot parse INTERVAL '{full_str}'"),
188            })?;
189
190            Ok(SqlExpr::Literal(SqlValue::Int(micros)))
191        }
192        _ => Err(SqlError::Unsupported {
193            detail: format!("expression: {expr}"),
194        }),
195    }
196}
197
198/// Parse an interval string to microseconds.
199///
200/// Delegates to `nodedb_types::kv_parsing::parse_interval_to_ms` (ms → μs)
201/// and `NdbDuration::parse` for compound shorthand forms.
202fn parse_interval_to_micros(s: &str) -> Option<i64> {
203    let s = s.trim();
204    if s.is_empty() {
205        return None;
206    }
207
208    // Try NdbDuration::parse first (handles compound "1h30m", "500ms", "2d").
209    if let Some(dur) = nodedb_types::NdbDuration::parse(s) {
210        return Some(dur.micros);
211    }
212
213    // Delegate to shared interval parser (handles all forms including compound).
214    if let Ok(ms) = nodedb_types::kv_parsing::parse_interval_to_ms(s) {
215        return Some(ms as i64 * 1000); // ms → μs
216    }
217
218    None
219}
220
221/// Convert a sqlparser `Value` to our `SqlValue`.
222pub fn convert_value(val: &Value) -> Result<SqlValue> {
223    match val {
224        Value::Number(n, _) => {
225            if let Ok(i) = n.parse::<i64>() {
226                Ok(SqlValue::Int(i))
227            } else if let Ok(f) = n.parse::<f64>() {
228                Ok(SqlValue::Float(f))
229            } else {
230                Ok(SqlValue::String(n.clone()))
231            }
232        }
233        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => {
234            Ok(SqlValue::String(s.clone()))
235        }
236        Value::Boolean(b) => Ok(SqlValue::Bool(*b)),
237        Value::Null => Ok(SqlValue::Null),
238        _ => Err(SqlError::Unsupported {
239            detail: format!("value literal: {val}"),
240        }),
241    }
242}
243
244fn convert_function(func: &ast::Function) -> Result<SqlExpr> {
245    let name = func
246        .name
247        .0
248        .iter()
249        .map(|p| match p {
250            ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
251            _ => String::new(),
252        })
253        .collect::<Vec<_>>()
254        .join(".");
255
256    let args = match &func.args {
257        ast::FunctionArguments::None => Vec::new(),
258        ast::FunctionArguments::Subquery(_) => {
259            return Err(SqlError::Unsupported {
260                detail: "subquery in function args".into(),
261            });
262        }
263        ast::FunctionArguments::List(arg_list) => arg_list
264            .args
265            .iter()
266            .filter_map(|a| match a {
267                ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(convert_expr(e)),
268                ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => {
269                    Some(Ok(SqlExpr::Wildcard))
270                }
271                ast::FunctionArg::Named {
272                    arg: ast::FunctionArgExpr::Expr(e),
273                    ..
274                } => Some(convert_expr(e)),
275                _ => None,
276            })
277            .collect::<Result<Vec<_>>>()?,
278    };
279
280    let distinct = match &func.args {
281        ast::FunctionArguments::List(arg_list) => {
282            matches!(
283                arg_list.duplicate_treatment,
284                Some(ast::DuplicateTreatment::Distinct)
285            )
286        }
287        _ => false,
288    };
289
290    Ok(SqlExpr::Function {
291        name,
292        args,
293        distinct,
294    })
295}
296
297fn convert_binary_op(op: &ast::BinaryOperator) -> Result<BinaryOp> {
298    match op {
299        ast::BinaryOperator::Plus => Ok(BinaryOp::Add),
300        ast::BinaryOperator::Minus => Ok(BinaryOp::Sub),
301        ast::BinaryOperator::Multiply => Ok(BinaryOp::Mul),
302        ast::BinaryOperator::Divide => Ok(BinaryOp::Div),
303        ast::BinaryOperator::Modulo => Ok(BinaryOp::Mod),
304        ast::BinaryOperator::Eq => Ok(BinaryOp::Eq),
305        ast::BinaryOperator::NotEq => Ok(BinaryOp::Ne),
306        ast::BinaryOperator::Gt => Ok(BinaryOp::Gt),
307        ast::BinaryOperator::GtEq => Ok(BinaryOp::Ge),
308        ast::BinaryOperator::Lt => Ok(BinaryOp::Lt),
309        ast::BinaryOperator::LtEq => Ok(BinaryOp::Le),
310        ast::BinaryOperator::And => Ok(BinaryOp::And),
311        ast::BinaryOperator::Or => Ok(BinaryOp::Or),
312        ast::BinaryOperator::StringConcat => Ok(BinaryOp::Concat),
313        _ => Err(SqlError::Unsupported {
314            detail: format!("binary operator: {op}"),
315        }),
316    }
317}
318
319fn convert_unary_op(op: &UnaryOperator) -> Result<UnaryOp> {
320    match op {
321        UnaryOperator::Minus => Ok(UnaryOp::Neg),
322        UnaryOperator::Not => Ok(UnaryOp::Not),
323        _ => Err(SqlError::Unsupported {
324            detail: format!("unary operator: {op}"),
325        }),
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn parse_interval_sql_word_forms() {
335        assert_eq!(parse_interval_to_micros("1 hour"), Some(3_600_000_000));
336        assert_eq!(parse_interval_to_micros("5 days"), Some(5 * 86_400_000_000));
337        assert_eq!(
338            parse_interval_to_micros("30 minutes"),
339            Some(30 * 60_000_000)
340        );
341        assert_eq!(
342            parse_interval_to_micros("2 hours 30 minutes"),
343            Some(9_000_000_000)
344        );
345        assert_eq!(parse_interval_to_micros("1 week"), Some(604_800_000_000));
346        assert_eq!(parse_interval_to_micros("100 milliseconds"), Some(100_000));
347    }
348
349    #[test]
350    fn parse_interval_shorthand() {
351        assert_eq!(parse_interval_to_micros("1h"), Some(3_600_000_000));
352        assert_eq!(parse_interval_to_micros("30m"), Some(30 * 60_000_000));
353        assert_eq!(parse_interval_to_micros("1h30m"), Some(5_400_000_000));
354        assert_eq!(parse_interval_to_micros("500ms"), Some(500_000));
355    }
356
357    #[test]
358    fn parse_interval_invalid() {
359        assert_eq!(parse_interval_to_micros(""), None);
360        assert_eq!(parse_interval_to_micros("abc"), None);
361    }
362}