Skip to main content

nodedb_sql/resolver/expr/
convert.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Convert sqlparser AST expressions to our SqlExpr IR.
4
5use sqlparser::ast::{self, Expr, Value};
6
7use crate::error::{Result, SqlError};
8use crate::parser::normalize::{SCHEMA_QUALIFIED_MSG, normalize_ident};
9use crate::types::*;
10
11use super::binary_ops::{convert_binary_op, convert_unary_op};
12use super::functions::convert_function_depth;
13use super::value::{convert_value, parse_interval_to_micros};
14
15/// Maximum AST nesting depth accepted by `convert_expr`.
16/// Exceeding this limit returns `Err` instead of overflowing the stack.
17const MAX_CONVERT_DEPTH: usize = 128;
18
19/// SQL-standard niladic functions: written without parentheses. Parsers
20/// emit them as bare identifiers; we promote them to function calls so
21/// they fold to a value at plan time instead of resolving to a column.
22fn is_zero_arg_keyword_function(name: &str) -> bool {
23    matches!(
24        name,
25        "current_timestamp"
26            | "current_date"
27            | "current_time"
28            | "localtime"
29            | "localtimestamp"
30            | "current_user"
31            | "current_role"
32            | "current_schema"
33            | "session_user"
34            | "user"
35    )
36}
37
38/// Convert a sqlparser `Expr` to our `SqlExpr`.
39pub fn convert_expr(expr: &Expr) -> Result<SqlExpr> {
40    convert_expr_depth(expr, &mut 0)
41}
42
43/// Internal recursive helper that carries a depth counter to enforce
44/// `MAX_CONVERT_DEPTH` and prevent stack overflow on malformed ASTs.
45pub(super) fn convert_expr_depth(expr: &Expr, depth: &mut usize) -> Result<SqlExpr> {
46    *depth += 1;
47    if *depth > MAX_CONVERT_DEPTH {
48        return Err(SqlError::Unsupported {
49            detail: format!("expression nesting depth exceeds maximum of {MAX_CONVERT_DEPTH}"),
50        });
51    }
52    let result = convert_expr_inner(expr, depth);
53    *depth -= 1;
54    result
55}
56
57fn convert_expr_inner(expr: &Expr, depth: &mut usize) -> Result<SqlExpr> {
58    match expr {
59        Expr::Identifier(ident) => {
60            let name = normalize_ident(ident);
61            // SQL-standard zero-arg keyword functions parse as bare
62            // identifiers (no parentheses): `SELECT current_timestamp`,
63            // `SELECT current_user`, etc. Promote them to function calls
64            // so const folding evaluates them like the parenthesised form.
65            if is_zero_arg_keyword_function(&name) {
66                return Ok(SqlExpr::Function {
67                    name,
68                    args: vec![],
69                    distinct: false,
70                });
71            }
72            Ok(SqlExpr::Column { table: None, name })
73        }
74        Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
75            let qualified: String = parts
76                .iter()
77                .map(normalize_ident)
78                .collect::<Vec<_>>()
79                .join(".");
80            Err(SqlError::Unsupported {
81                detail: format!(
82                    "schema-qualified column reference '{qualified}': {SCHEMA_QUALIFIED_MSG}"
83                ),
84            })
85        }
86        Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(SqlExpr::Column {
87            table: Some(normalize_ident(&parts[0])),
88            name: normalize_ident(&parts[1]),
89        }),
90        Expr::Value(val) => Ok(SqlExpr::Literal(convert_value(&val.value)?)),
91        Expr::BinaryOp { left, op, right } => {
92            // JSON and FTS operators are lowered to function calls before the
93            // generic binary-op path so they are never passed to
94            // convert_binary_op.
95            use ast::BinaryOperator;
96            let json_fn: Option<&str> = match op {
97                BinaryOperator::Arrow => Some("pg_json_get"),
98                BinaryOperator::LongArrow => Some("pg_json_get_text"),
99                BinaryOperator::HashArrow => Some("pg_json_path_get"),
100                BinaryOperator::HashLongArrow => Some("pg_json_path_get_text"),
101                BinaryOperator::AtArrow => Some("pg_json_contains"),
102                BinaryOperator::ArrowAt => Some("pg_json_contained_by"),
103                BinaryOperator::Question => Some("pg_json_has_key"),
104                BinaryOperator::QuestionAnd => Some("pg_json_has_all_keys"),
105                BinaryOperator::QuestionPipe => Some("pg_json_has_any_key"),
106                _ => None,
107            };
108            if let Some(name) = json_fn {
109                return Ok(SqlExpr::Function {
110                    name: name.into(),
111                    args: vec![
112                        convert_expr_depth(left, depth)?,
113                        convert_expr_depth(right, depth)?,
114                    ],
115                    distinct: false,
116                });
117            }
118            // `col @@ query` → pg_fts_match(col, query)
119            if matches!(op, BinaryOperator::AtAt) {
120                let col_expr = convert_expr_depth(left, depth)?;
121                let query_expr = convert_expr_depth(right, depth)?;
122                return Ok(crate::functions::fts_ops::pg_fts_funcs::lower_pg_fts_match(
123                    col_expr, query_expr,
124                ));
125            }
126            Ok(SqlExpr::BinaryOp {
127                left: Box::new(convert_expr_depth(left, depth)?),
128                op: convert_binary_op(op)?,
129                right: Box::new(convert_expr_depth(right, depth)?),
130            })
131        }
132        Expr::UnaryOp { op, expr } => Ok(SqlExpr::UnaryOp {
133            op: convert_unary_op(op)?,
134            expr: Box::new(convert_expr_depth(expr, depth)?),
135        }),
136        Expr::Function(func) => convert_function_depth(func, depth),
137        Expr::Nested(inner) => convert_expr_depth(inner, depth),
138        Expr::IsNull(inner) => Ok(SqlExpr::IsNull {
139            expr: Box::new(convert_expr_depth(inner, depth)?),
140            negated: false,
141        }),
142        Expr::IsNotNull(inner) => Ok(SqlExpr::IsNull {
143            expr: Box::new(convert_expr_depth(inner, depth)?),
144            negated: true,
145        }),
146        Expr::InList {
147            expr,
148            list,
149            negated,
150        } => Ok(SqlExpr::InList {
151            expr: Box::new(convert_expr_depth(expr, depth)?),
152            list: list
153                .iter()
154                .map(|e| convert_expr_depth(e, depth))
155                .collect::<Result<_>>()?,
156            negated: *negated,
157        }),
158        Expr::Between {
159            expr,
160            low,
161            high,
162            negated,
163        } => Ok(SqlExpr::Between {
164            expr: Box::new(convert_expr_depth(expr, depth)?),
165            low: Box::new(convert_expr_depth(low, depth)?),
166            high: Box::new(convert_expr_depth(high, depth)?),
167            negated: *negated,
168        }),
169        Expr::Like {
170            expr,
171            pattern,
172            negated,
173            ..
174        } => Ok(SqlExpr::Like {
175            expr: Box::new(convert_expr_depth(expr, depth)?),
176            pattern: Box::new(convert_expr_depth(pattern, depth)?),
177            negated: *negated,
178            case_insensitive: false,
179        }),
180        Expr::ILike {
181            expr,
182            pattern,
183            negated,
184            ..
185        } => Ok(SqlExpr::Like {
186            expr: Box::new(convert_expr_depth(expr, depth)?),
187            pattern: Box::new(convert_expr_depth(pattern, depth)?),
188            negated: *negated,
189            case_insensitive: true,
190        }),
191        Expr::Case {
192            operand,
193            conditions,
194            else_result,
195            ..
196        } => {
197            let when_then = conditions
198                .iter()
199                .map(|cw| {
200                    Ok((
201                        convert_expr_depth(&cw.condition, depth)?,
202                        convert_expr_depth(&cw.result, depth)?,
203                    ))
204                })
205                .collect::<Result<Vec<_>>>()?;
206            Ok(SqlExpr::Case {
207                operand: operand
208                    .as_ref()
209                    .map(|e| convert_expr_depth(e, depth).map(Box::new))
210                    .transpose()?,
211                when_then,
212                else_expr: else_result
213                    .as_ref()
214                    .map(|e| convert_expr_depth(e, depth).map(Box::new))
215                    .transpose()?,
216            })
217        }
218        Expr::TypedString(ts) => {
219            // TIMESTAMP '...' and TIMESTAMPTZ '...' typed string literals.
220            let type_str = format!("{}", ts.data_type).to_ascii_uppercase();
221            let raw = match &ts.value.value {
222                Value::SingleQuotedString(s) => s.clone(),
223                other => {
224                    return Err(SqlError::Unsupported {
225                        detail: format!("typed string value: {other}"),
226                    });
227                }
228            };
229            match type_str.as_str() {
230                "TIMESTAMP" => {
231                    let dt =
232                        nodedb_types::NdbDateTime::parse(&raw).ok_or_else(|| SqlError::Parse {
233                            detail: format!("cannot parse TIMESTAMP literal: '{raw}'"),
234                        })?;
235                    return Ok(SqlExpr::Literal(SqlValue::Timestamp(dt)));
236                }
237                "TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => {
238                    let dt =
239                        nodedb_types::NdbDateTime::parse(&raw).ok_or_else(|| SqlError::Parse {
240                            detail: format!("cannot parse TIMESTAMPTZ literal: '{raw}'"),
241                        })?;
242                    return Ok(SqlExpr::Literal(SqlValue::Timestamptz(dt)));
243                }
244                _ => {}
245            }
246            // Fall through: return as a generic literal string.
247            Ok(SqlExpr::Literal(SqlValue::String(raw)))
248        }
249        Expr::Cast {
250            expr, data_type, ..
251        } => {
252            // `::tsvector` and `::tsquery` casts are PG surface notation; the
253            // inner expression is the actual text value.  Elide the cast and
254            // return the inner expression directly — no runtime type change is
255            // needed since we operate on plain strings internally.
256            let type_str = format!("{data_type}").to_ascii_lowercase();
257            if type_str == "tsvector" || type_str == "tsquery" {
258                return convert_expr_depth(expr, depth);
259            }
260            // `'...'::TIMESTAMP` and `'...'::TIMESTAMPTZ` — promote string literals
261            // to typed SqlValue when the inner expression is a string literal.
262            let upper = type_str.to_uppercase();
263            if (upper == "TIMESTAMP"
264                || upper == "TIMESTAMPTZ"
265                || upper == "TIMESTAMP WITH TIME ZONE")
266                && let Expr::Value(v) = expr.as_ref()
267                && let Value::SingleQuotedString(s) = &v.value
268            {
269                let dt = nodedb_types::NdbDateTime::parse(s).ok_or_else(|| SqlError::Parse {
270                    detail: format!("cannot parse timestamp cast: '{s}'"),
271                })?;
272                return Ok(SqlExpr::Literal(if upper == "TIMESTAMP" {
273                    SqlValue::Timestamp(dt)
274                } else {
275                    SqlValue::Timestamptz(dt)
276                }));
277            }
278            Ok(SqlExpr::Cast {
279                expr: Box::new(convert_expr_depth(expr, depth)?),
280                to_type: format!("{data_type}"),
281            })
282        }
283        Expr::Array(ast::Array { elem, .. }) => {
284            let elems = elem
285                .iter()
286                .map(|e| convert_expr_depth(e, depth))
287                .collect::<Result<_>>()?;
288            Ok(SqlExpr::ArrayLiteral(elems))
289        }
290        Expr::Wildcard(_) => Ok(SqlExpr::Wildcard),
291        // TRIM([BOTH|LEADING|TRAILING] [what FROM] expr)
292        Expr::Trim { expr, .. } => Ok(SqlExpr::Function {
293            name: "trim".into(),
294            args: vec![convert_expr_depth(expr, depth)?],
295            distinct: false,
296        }),
297        // CEIL(expr) / FLOOR(expr)
298        Expr::Ceil { expr, .. } => Ok(SqlExpr::Function {
299            name: "ceil".into(),
300            args: vec![convert_expr_depth(expr, depth)?],
301            distinct: false,
302        }),
303        Expr::Floor { expr, .. } => Ok(SqlExpr::Function {
304            name: "floor".into(),
305            args: vec![convert_expr_depth(expr, depth)?],
306            distinct: false,
307        }),
308        // SUBSTRING(expr FROM start FOR len)
309        Expr::Substring {
310            expr,
311            substring_from,
312            substring_for,
313            ..
314        } => {
315            let mut args = vec![convert_expr_depth(expr, depth)?];
316            if let Some(from) = substring_from {
317                args.push(convert_expr_depth(from, depth)?);
318            }
319            if let Some(len) = substring_for {
320                args.push(convert_expr_depth(len, depth)?);
321            }
322            Ok(SqlExpr::Function {
323                name: "substring".into(),
324                args,
325                distinct: false,
326            })
327        }
328        Expr::Interval(interval) => {
329            // INTERVAL '1 hour' → microseconds as i64 literal.
330            // The interval value is typically a string literal.
331            let interval_str = match interval.value.as_ref() {
332                Expr::Value(v) => match &v.value {
333                    Value::SingleQuotedString(s) => s.clone(),
334                    Value::Number(n, _) => {
335                        // INTERVAL 5 HOUR → combine number with leading_field.
336                        if let Some(ref field) = interval.leading_field {
337                            format!("{n} {field}")
338                        } else {
339                            n.clone()
340                        }
341                    }
342                    _ => {
343                        return Err(SqlError::Unsupported {
344                            detail: format!("INTERVAL value: {}", interval.value),
345                        });
346                    }
347                },
348                _ => {
349                    return Err(SqlError::Unsupported {
350                        detail: format!("INTERVAL expression: {}", interval.value),
351                    });
352                }
353            };
354
355            // If leading_field is specified, append it: INTERVAL '5' HOUR → "5 HOUR"
356            let full_str = if interval_str.chars().all(|c| c.is_ascii_digit())
357                && let Some(ref field) = interval.leading_field
358            {
359                format!("{interval_str} {field}")
360            } else {
361                interval_str
362            };
363
364            let micros = parse_interval_to_micros(&full_str).ok_or_else(|| SqlError::Parse {
365                detail: format!("cannot parse INTERVAL '{full_str}'"),
366            })?;
367
368            Ok(SqlExpr::Literal(SqlValue::Int(micros)))
369        }
370        _ => Err(SqlError::Unsupported {
371            detail: format!("expression: {expr}"),
372        }),
373    }
374}