Skip to main content

nodedb_sql/resolver/
expr.rs

1//! Convert sqlparser AST expressions to our SqlExpr IR.
2
3use sqlparser::ast::{self, Expr, UnaryOperator, Value};
4
5use crate::error::{Result, SqlError};
6use crate::parser::normalize::normalize_ident;
7use crate::types::*;
8
9/// Maximum AST nesting depth accepted by `convert_expr`.
10/// Exceeding this limit returns `Err` instead of overflowing the stack.
11const MAX_CONVERT_DEPTH: usize = 128;
12
13/// Convert a sqlparser `Expr` to our `SqlExpr`.
14pub fn convert_expr(expr: &Expr) -> Result<SqlExpr> {
15    convert_expr_depth(expr, &mut 0)
16}
17
18/// Internal recursive helper that carries a depth counter to enforce
19/// `MAX_CONVERT_DEPTH` and prevent stack overflow on malformed ASTs.
20fn convert_expr_depth(expr: &Expr, depth: &mut usize) -> Result<SqlExpr> {
21    *depth += 1;
22    if *depth > MAX_CONVERT_DEPTH {
23        return Err(SqlError::Unsupported {
24            detail: format!("expression nesting depth exceeds maximum of {MAX_CONVERT_DEPTH}"),
25        });
26    }
27    let result = convert_expr_inner(expr, depth);
28    *depth -= 1;
29    result
30}
31
32fn convert_expr_inner(expr: &Expr, depth: &mut usize) -> Result<SqlExpr> {
33    match expr {
34        Expr::Identifier(ident) => Ok(SqlExpr::Column {
35            table: None,
36            name: normalize_ident(ident),
37        }),
38        Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(SqlExpr::Column {
39            table: Some(normalize_ident(&parts[0])),
40            name: normalize_ident(&parts[1]),
41        }),
42        Expr::Value(val) => Ok(SqlExpr::Literal(convert_value(&val.value)?)),
43        Expr::BinaryOp { left, op, right } => Ok(SqlExpr::BinaryOp {
44            left: Box::new(convert_expr_depth(left, depth)?),
45            op: convert_binary_op(op)?,
46            right: Box::new(convert_expr_depth(right, depth)?),
47        }),
48        Expr::UnaryOp { op, expr } => Ok(SqlExpr::UnaryOp {
49            op: convert_unary_op(op)?,
50            expr: Box::new(convert_expr_depth(expr, depth)?),
51        }),
52        Expr::Function(func) => convert_function_depth(func, depth),
53        Expr::Nested(inner) => convert_expr_depth(inner, depth),
54        Expr::IsNull(inner) => Ok(SqlExpr::IsNull {
55            expr: Box::new(convert_expr_depth(inner, depth)?),
56            negated: false,
57        }),
58        Expr::IsNotNull(inner) => Ok(SqlExpr::IsNull {
59            expr: Box::new(convert_expr_depth(inner, depth)?),
60            negated: true,
61        }),
62        Expr::InList {
63            expr,
64            list,
65            negated,
66        } => Ok(SqlExpr::InList {
67            expr: Box::new(convert_expr_depth(expr, depth)?),
68            list: list
69                .iter()
70                .map(|e| convert_expr_depth(e, depth))
71                .collect::<Result<_>>()?,
72            negated: *negated,
73        }),
74        Expr::Between {
75            expr,
76            low,
77            high,
78            negated,
79        } => Ok(SqlExpr::Between {
80            expr: Box::new(convert_expr_depth(expr, depth)?),
81            low: Box::new(convert_expr_depth(low, depth)?),
82            high: Box::new(convert_expr_depth(high, depth)?),
83            negated: *negated,
84        }),
85        Expr::Like {
86            expr,
87            pattern,
88            negated,
89            ..
90        } => Ok(SqlExpr::Like {
91            expr: Box::new(convert_expr_depth(expr, depth)?),
92            pattern: Box::new(convert_expr_depth(pattern, depth)?),
93            negated: *negated,
94        }),
95        Expr::ILike {
96            expr,
97            pattern,
98            negated,
99            ..
100        } => Ok(SqlExpr::Like {
101            expr: Box::new(convert_expr_depth(expr, depth)?),
102            pattern: Box::new(convert_expr_depth(pattern, depth)?),
103            negated: *negated,
104        }),
105        Expr::Case {
106            operand,
107            conditions,
108            else_result,
109            ..
110        } => {
111            let when_then = conditions
112                .iter()
113                .map(|cw| {
114                    Ok((
115                        convert_expr_depth(&cw.condition, depth)?,
116                        convert_expr_depth(&cw.result, depth)?,
117                    ))
118                })
119                .collect::<Result<Vec<_>>>()?;
120            Ok(SqlExpr::Case {
121                operand: operand
122                    .as_ref()
123                    .map(|e| convert_expr_depth(e, depth).map(Box::new))
124                    .transpose()?,
125                when_then,
126                else_expr: else_result
127                    .as_ref()
128                    .map(|e| convert_expr_depth(e, depth).map(Box::new))
129                    .transpose()?,
130            })
131        }
132        Expr::Cast {
133            expr, data_type, ..
134        } => Ok(SqlExpr::Cast {
135            expr: Box::new(convert_expr_depth(expr, depth)?),
136            to_type: format!("{data_type}"),
137        }),
138        Expr::Array(ast::Array { elem, .. }) => {
139            let elems = elem
140                .iter()
141                .map(|e| convert_expr_depth(e, depth))
142                .collect::<Result<_>>()?;
143            Ok(SqlExpr::ArrayLiteral(elems))
144        }
145        Expr::Wildcard(_) => Ok(SqlExpr::Wildcard),
146        // TRIM([BOTH|LEADING|TRAILING] [what FROM] expr)
147        Expr::Trim { expr, .. } => Ok(SqlExpr::Function {
148            name: "trim".into(),
149            args: vec![convert_expr_depth(expr, depth)?],
150            distinct: false,
151        }),
152        // CEIL(expr) / FLOOR(expr)
153        Expr::Ceil { expr, .. } => Ok(SqlExpr::Function {
154            name: "ceil".into(),
155            args: vec![convert_expr_depth(expr, depth)?],
156            distinct: false,
157        }),
158        Expr::Floor { expr, .. } => Ok(SqlExpr::Function {
159            name: "floor".into(),
160            args: vec![convert_expr_depth(expr, depth)?],
161            distinct: false,
162        }),
163        // SUBSTRING(expr FROM start FOR len)
164        Expr::Substring {
165            expr,
166            substring_from,
167            substring_for,
168            ..
169        } => {
170            let mut args = vec![convert_expr_depth(expr, depth)?];
171            if let Some(from) = substring_from {
172                args.push(convert_expr_depth(from, depth)?);
173            }
174            if let Some(len) = substring_for {
175                args.push(convert_expr_depth(len, depth)?);
176            }
177            Ok(SqlExpr::Function {
178                name: "substring".into(),
179                args,
180                distinct: false,
181            })
182        }
183        Expr::Interval(interval) => {
184            // INTERVAL '1 hour' → microseconds as i64 literal.
185            // The interval value is typically a string literal.
186            let interval_str = match interval.value.as_ref() {
187                Expr::Value(v) => match &v.value {
188                    Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => s.clone(),
189                    Value::Number(n, _) => {
190                        // INTERVAL 5 HOUR → combine number with leading_field.
191                        if let Some(ref field) = interval.leading_field {
192                            format!("{n} {field}")
193                        } else {
194                            n.clone()
195                        }
196                    }
197                    _ => {
198                        return Err(SqlError::Unsupported {
199                            detail: format!("INTERVAL value: {}", interval.value),
200                        });
201                    }
202                },
203                _ => {
204                    return Err(SqlError::Unsupported {
205                        detail: format!("INTERVAL expression: {}", interval.value),
206                    });
207                }
208            };
209
210            // If leading_field is specified, append it: INTERVAL '5' HOUR → "5 HOUR"
211            let full_str = if interval_str.chars().all(|c| c.is_ascii_digit())
212                && let Some(ref field) = interval.leading_field
213            {
214                format!("{interval_str} {field}")
215            } else {
216                interval_str
217            };
218
219            let micros = parse_interval_to_micros(&full_str).ok_or_else(|| SqlError::Parse {
220                detail: format!("cannot parse INTERVAL '{full_str}'"),
221            })?;
222
223            Ok(SqlExpr::Literal(SqlValue::Int(micros)))
224        }
225        _ => Err(SqlError::Unsupported {
226            detail: format!("expression: {expr}"),
227        }),
228    }
229}
230
231/// Parse an interval string to microseconds.
232///
233/// Delegates to `nodedb_types::kv_parsing::parse_interval_to_ms` (ms → μs)
234/// and `NdbDuration::parse` for compound shorthand forms.
235fn parse_interval_to_micros(s: &str) -> Option<i64> {
236    let s = s.trim();
237    if s.is_empty() {
238        return None;
239    }
240
241    // Try NdbDuration::parse first (handles compound "1h30m", "500ms", "2d").
242    if let Some(dur) = nodedb_types::NdbDuration::parse(s) {
243        return Some(dur.micros);
244    }
245
246    // Delegate to shared interval parser (handles all forms including compound).
247    if let Ok(ms) = nodedb_types::kv_parsing::parse_interval_to_ms(s) {
248        return Some(ms as i64 * 1000); // ms → μs
249    }
250
251    None
252}
253
254/// Convert a sqlparser `Value` to our `SqlValue`.
255pub fn convert_value(val: &Value) -> Result<SqlValue> {
256    match val {
257        Value::Number(n, _) => {
258            if let Ok(i) = n.parse::<i64>() {
259                Ok(SqlValue::Int(i))
260            } else if let Ok(f) = n.parse::<f64>() {
261                Ok(SqlValue::Float(f))
262            } else {
263                Ok(SqlValue::String(n.clone()))
264            }
265        }
266        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => {
267            Ok(SqlValue::String(s.clone()))
268        }
269        Value::Boolean(b) => Ok(SqlValue::Bool(*b)),
270        Value::Null => Ok(SqlValue::Null),
271        _ => Err(SqlError::Unsupported {
272            detail: format!("value literal: {val}"),
273        }),
274    }
275}
276
277fn convert_function_depth(func: &ast::Function, depth: &mut usize) -> Result<SqlExpr> {
278    let name = func
279        .name
280        .0
281        .iter()
282        .map(|p| match p {
283            ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
284            _ => String::new(),
285        })
286        .collect::<Vec<_>>()
287        .join(".");
288
289    let args = match &func.args {
290        ast::FunctionArguments::None => Vec::new(),
291        ast::FunctionArguments::Subquery(_) => {
292            return Err(SqlError::Unsupported {
293                detail: "subquery in function args".into(),
294            });
295        }
296        ast::FunctionArguments::List(arg_list) => arg_list
297            .args
298            .iter()
299            .filter_map(|a| match a {
300                ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => {
301                    Some(convert_expr_depth(e, depth))
302                }
303                ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => {
304                    Some(Ok(SqlExpr::Wildcard))
305                }
306                ast::FunctionArg::Named {
307                    arg: ast::FunctionArgExpr::Expr(e),
308                    ..
309                } => Some(convert_expr_depth(e, depth)),
310                _ => None,
311            })
312            .collect::<Result<Vec<_>>>()?,
313    };
314
315    let distinct = match &func.args {
316        ast::FunctionArguments::List(arg_list) => {
317            matches!(
318                arg_list.duplicate_treatment,
319                Some(ast::DuplicateTreatment::Distinct)
320            )
321        }
322        _ => false,
323    };
324
325    Ok(SqlExpr::Function {
326        name,
327        args,
328        distinct,
329    })
330}
331
332fn convert_binary_op(op: &ast::BinaryOperator) -> Result<BinaryOp> {
333    match op {
334        ast::BinaryOperator::Plus => Ok(BinaryOp::Add),
335        ast::BinaryOperator::Minus => Ok(BinaryOp::Sub),
336        ast::BinaryOperator::Multiply => Ok(BinaryOp::Mul),
337        ast::BinaryOperator::Divide => Ok(BinaryOp::Div),
338        ast::BinaryOperator::Modulo => Ok(BinaryOp::Mod),
339        ast::BinaryOperator::Eq => Ok(BinaryOp::Eq),
340        ast::BinaryOperator::NotEq => Ok(BinaryOp::Ne),
341        ast::BinaryOperator::Gt => Ok(BinaryOp::Gt),
342        ast::BinaryOperator::GtEq => Ok(BinaryOp::Ge),
343        ast::BinaryOperator::Lt => Ok(BinaryOp::Lt),
344        ast::BinaryOperator::LtEq => Ok(BinaryOp::Le),
345        ast::BinaryOperator::And => Ok(BinaryOp::And),
346        ast::BinaryOperator::Or => Ok(BinaryOp::Or),
347        ast::BinaryOperator::StringConcat => Ok(BinaryOp::Concat),
348        _ => Err(SqlError::Unsupported {
349            detail: format!("binary operator: {op}"),
350        }),
351    }
352}
353
354fn convert_unary_op(op: &UnaryOperator) -> Result<UnaryOp> {
355    match op {
356        UnaryOperator::Minus => Ok(UnaryOp::Neg),
357        UnaryOperator::Not => Ok(UnaryOp::Not),
358        _ => Err(SqlError::Unsupported {
359            detail: format!("unary operator: {op}"),
360        }),
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn parse_interval_sql_word_forms() {
370        assert_eq!(parse_interval_to_micros("1 hour"), Some(3_600_000_000));
371        assert_eq!(parse_interval_to_micros("5 days"), Some(5 * 86_400_000_000));
372        assert_eq!(
373            parse_interval_to_micros("30 minutes"),
374            Some(30 * 60_000_000)
375        );
376        assert_eq!(
377            parse_interval_to_micros("2 hours 30 minutes"),
378            Some(9_000_000_000)
379        );
380        assert_eq!(parse_interval_to_micros("1 week"), Some(604_800_000_000));
381        assert_eq!(parse_interval_to_micros("100 milliseconds"), Some(100_000));
382    }
383
384    #[test]
385    fn parse_interval_shorthand() {
386        assert_eq!(parse_interval_to_micros("1h"), Some(3_600_000_000));
387        assert_eq!(parse_interval_to_micros("30m"), Some(30 * 60_000_000));
388        assert_eq!(parse_interval_to_micros("1h30m"), Some(5_400_000_000));
389        assert_eq!(parse_interval_to_micros("500ms"), Some(500_000));
390    }
391
392    #[test]
393    fn parse_interval_invalid() {
394        assert_eq!(parse_interval_to_micros(""), None);
395        assert_eq!(parse_interval_to_micros("abc"), None);
396    }
397}