1use sqlparser::ast::{self, Expr, Value};
6
7use crate::error::{Result, SqlError};
8use crate::parser::normalize::{SCHEMA_QUALIFIED_MSG, normalize_ident};
9use crate::types::*;
10
11use super::binary_ops::{convert_binary_op, convert_unary_op};
12use super::functions::convert_function_depth;
13use super::value::{convert_value, parse_interval_to_micros};
14
15const MAX_CONVERT_DEPTH: usize = 128;
18
19fn is_zero_arg_keyword_function(name: &str) -> bool {
23 matches!(
24 name,
25 "current_timestamp"
26 | "current_date"
27 | "current_time"
28 | "localtime"
29 | "localtimestamp"
30 | "current_user"
31 | "current_role"
32 | "current_schema"
33 | "session_user"
34 | "user"
35 )
36}
37
38pub fn convert_expr(expr: &Expr) -> Result<SqlExpr> {
40 convert_expr_depth(expr, &mut 0)
41}
42
43pub(super) fn convert_expr_depth(expr: &Expr, depth: &mut usize) -> Result<SqlExpr> {
46 *depth += 1;
47 if *depth > MAX_CONVERT_DEPTH {
48 return Err(SqlError::Unsupported {
49 detail: format!("expression nesting depth exceeds maximum of {MAX_CONVERT_DEPTH}"),
50 });
51 }
52 let result = convert_expr_inner(expr, depth);
53 *depth -= 1;
54 result
55}
56
57fn convert_expr_inner(expr: &Expr, depth: &mut usize) -> Result<SqlExpr> {
58 match expr {
59 Expr::Identifier(ident) => {
60 let name = normalize_ident(ident);
61 if is_zero_arg_keyword_function(&name) {
66 return Ok(SqlExpr::Function {
67 name,
68 args: vec![],
69 distinct: false,
70 });
71 }
72 Ok(SqlExpr::Column { table: None, name })
73 }
74 Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
75 let qualified: String = parts
76 .iter()
77 .map(normalize_ident)
78 .collect::<Vec<_>>()
79 .join(".");
80 Err(SqlError::Unsupported {
81 detail: format!(
82 "schema-qualified column reference '{qualified}': {SCHEMA_QUALIFIED_MSG}"
83 ),
84 })
85 }
86 Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(SqlExpr::Column {
87 table: Some(normalize_ident(&parts[0])),
88 name: normalize_ident(&parts[1]),
89 }),
90 Expr::Value(val) => Ok(SqlExpr::Literal(convert_value(&val.value)?)),
91 Expr::BinaryOp { left, op, right } => {
92 use ast::BinaryOperator;
96 let json_fn: Option<&str> = match op {
97 BinaryOperator::Arrow => Some("pg_json_get"),
98 BinaryOperator::LongArrow => Some("pg_json_get_text"),
99 BinaryOperator::HashArrow => Some("pg_json_path_get"),
100 BinaryOperator::HashLongArrow => Some("pg_json_path_get_text"),
101 BinaryOperator::AtArrow => Some("pg_json_contains"),
102 BinaryOperator::ArrowAt => Some("pg_json_contained_by"),
103 BinaryOperator::Question => Some("pg_json_has_key"),
104 BinaryOperator::QuestionAnd => Some("pg_json_has_all_keys"),
105 BinaryOperator::QuestionPipe => Some("pg_json_has_any_key"),
106 _ => None,
107 };
108 if let Some(name) = json_fn {
109 return Ok(SqlExpr::Function {
110 name: name.into(),
111 args: vec![
112 convert_expr_depth(left, depth)?,
113 convert_expr_depth(right, depth)?,
114 ],
115 distinct: false,
116 });
117 }
118 if matches!(op, BinaryOperator::AtAt) {
120 let col_expr = convert_expr_depth(left, depth)?;
121 let query_expr = convert_expr_depth(right, depth)?;
122 return Ok(crate::functions::fts_ops::pg_fts_funcs::lower_pg_fts_match(
123 col_expr, query_expr,
124 ));
125 }
126 Ok(SqlExpr::BinaryOp {
127 left: Box::new(convert_expr_depth(left, depth)?),
128 op: convert_binary_op(op)?,
129 right: Box::new(convert_expr_depth(right, depth)?),
130 })
131 }
132 Expr::UnaryOp { op, expr } => Ok(SqlExpr::UnaryOp {
133 op: convert_unary_op(op)?,
134 expr: Box::new(convert_expr_depth(expr, depth)?),
135 }),
136 Expr::Function(func) => convert_function_depth(func, depth),
137 Expr::Nested(inner) => convert_expr_depth(inner, depth),
138 Expr::IsNull(inner) => Ok(SqlExpr::IsNull {
139 expr: Box::new(convert_expr_depth(inner, depth)?),
140 negated: false,
141 }),
142 Expr::IsNotNull(inner) => Ok(SqlExpr::IsNull {
143 expr: Box::new(convert_expr_depth(inner, depth)?),
144 negated: true,
145 }),
146 Expr::InList {
147 expr,
148 list,
149 negated,
150 } => Ok(SqlExpr::InList {
151 expr: Box::new(convert_expr_depth(expr, depth)?),
152 list: list
153 .iter()
154 .map(|e| convert_expr_depth(e, depth))
155 .collect::<Result<_>>()?,
156 negated: *negated,
157 }),
158 Expr::Between {
159 expr,
160 low,
161 high,
162 negated,
163 } => Ok(SqlExpr::Between {
164 expr: Box::new(convert_expr_depth(expr, depth)?),
165 low: Box::new(convert_expr_depth(low, depth)?),
166 high: Box::new(convert_expr_depth(high, depth)?),
167 negated: *negated,
168 }),
169 Expr::Like {
170 expr,
171 pattern,
172 negated,
173 ..
174 } => Ok(SqlExpr::Like {
175 expr: Box::new(convert_expr_depth(expr, depth)?),
176 pattern: Box::new(convert_expr_depth(pattern, depth)?),
177 negated: *negated,
178 case_insensitive: false,
179 }),
180 Expr::ILike {
181 expr,
182 pattern,
183 negated,
184 ..
185 } => Ok(SqlExpr::Like {
186 expr: Box::new(convert_expr_depth(expr, depth)?),
187 pattern: Box::new(convert_expr_depth(pattern, depth)?),
188 negated: *negated,
189 case_insensitive: true,
190 }),
191 Expr::Case {
192 operand,
193 conditions,
194 else_result,
195 ..
196 } => {
197 let when_then = conditions
198 .iter()
199 .map(|cw| {
200 Ok((
201 convert_expr_depth(&cw.condition, depth)?,
202 convert_expr_depth(&cw.result, depth)?,
203 ))
204 })
205 .collect::<Result<Vec<_>>>()?;
206 Ok(SqlExpr::Case {
207 operand: operand
208 .as_ref()
209 .map(|e| convert_expr_depth(e, depth).map(Box::new))
210 .transpose()?,
211 when_then,
212 else_expr: else_result
213 .as_ref()
214 .map(|e| convert_expr_depth(e, depth).map(Box::new))
215 .transpose()?,
216 })
217 }
218 Expr::TypedString(ts) => {
219 let type_str = format!("{}", ts.data_type).to_ascii_uppercase();
221 let raw = match &ts.value.value {
222 Value::SingleQuotedString(s) => s.clone(),
223 other => {
224 return Err(SqlError::Unsupported {
225 detail: format!("typed string value: {other}"),
226 });
227 }
228 };
229 match type_str.as_str() {
230 "TIMESTAMP" => {
231 let dt =
232 nodedb_types::NdbDateTime::parse(&raw).ok_or_else(|| SqlError::Parse {
233 detail: format!("cannot parse TIMESTAMP literal: '{raw}'"),
234 })?;
235 return Ok(SqlExpr::Literal(SqlValue::Timestamp(dt)));
236 }
237 "TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => {
238 let dt =
239 nodedb_types::NdbDateTime::parse(&raw).ok_or_else(|| SqlError::Parse {
240 detail: format!("cannot parse TIMESTAMPTZ literal: '{raw}'"),
241 })?;
242 return Ok(SqlExpr::Literal(SqlValue::Timestamptz(dt)));
243 }
244 _ => {}
245 }
246 Ok(SqlExpr::Literal(SqlValue::String(raw)))
248 }
249 Expr::Cast {
250 expr, data_type, ..
251 } => {
252 let type_str = format!("{data_type}").to_ascii_lowercase();
257 if type_str == "tsvector" || type_str == "tsquery" {
258 return convert_expr_depth(expr, depth);
259 }
260 let upper = type_str.to_uppercase();
263 if (upper == "TIMESTAMP"
264 || upper == "TIMESTAMPTZ"
265 || upper == "TIMESTAMP WITH TIME ZONE")
266 && let Expr::Value(v) = expr.as_ref()
267 && let Value::SingleQuotedString(s) = &v.value
268 {
269 let dt = nodedb_types::NdbDateTime::parse(s).ok_or_else(|| SqlError::Parse {
270 detail: format!("cannot parse timestamp cast: '{s}'"),
271 })?;
272 return Ok(SqlExpr::Literal(if upper == "TIMESTAMP" {
273 SqlValue::Timestamp(dt)
274 } else {
275 SqlValue::Timestamptz(dt)
276 }));
277 }
278 Ok(SqlExpr::Cast {
279 expr: Box::new(convert_expr_depth(expr, depth)?),
280 to_type: format!("{data_type}"),
281 })
282 }
283 Expr::Array(ast::Array { elem, .. }) => {
284 let elems = elem
285 .iter()
286 .map(|e| convert_expr_depth(e, depth))
287 .collect::<Result<_>>()?;
288 Ok(SqlExpr::ArrayLiteral(elems))
289 }
290 Expr::Wildcard(_) => Ok(SqlExpr::Wildcard),
291 Expr::Trim { expr, .. } => Ok(SqlExpr::Function {
293 name: "trim".into(),
294 args: vec![convert_expr_depth(expr, depth)?],
295 distinct: false,
296 }),
297 Expr::Ceil { expr, .. } => Ok(SqlExpr::Function {
299 name: "ceil".into(),
300 args: vec![convert_expr_depth(expr, depth)?],
301 distinct: false,
302 }),
303 Expr::Floor { expr, .. } => Ok(SqlExpr::Function {
304 name: "floor".into(),
305 args: vec![convert_expr_depth(expr, depth)?],
306 distinct: false,
307 }),
308 Expr::Substring {
310 expr,
311 substring_from,
312 substring_for,
313 ..
314 } => {
315 let mut args = vec![convert_expr_depth(expr, depth)?];
316 if let Some(from) = substring_from {
317 args.push(convert_expr_depth(from, depth)?);
318 }
319 if let Some(len) = substring_for {
320 args.push(convert_expr_depth(len, depth)?);
321 }
322 Ok(SqlExpr::Function {
323 name: "substring".into(),
324 args,
325 distinct: false,
326 })
327 }
328 Expr::Interval(interval) => {
329 let interval_str = match interval.value.as_ref() {
332 Expr::Value(v) => match &v.value {
333 Value::SingleQuotedString(s) => s.clone(),
334 Value::Number(n, _) => {
335 if let Some(ref field) = interval.leading_field {
337 format!("{n} {field}")
338 } else {
339 n.clone()
340 }
341 }
342 _ => {
343 return Err(SqlError::Unsupported {
344 detail: format!("INTERVAL value: {}", interval.value),
345 });
346 }
347 },
348 _ => {
349 return Err(SqlError::Unsupported {
350 detail: format!("INTERVAL expression: {}", interval.value),
351 });
352 }
353 };
354
355 let full_str = if interval_str.chars().all(|c| c.is_ascii_digit())
357 && let Some(ref field) = interval.leading_field
358 {
359 format!("{interval_str} {field}")
360 } else {
361 interval_str
362 };
363
364 let micros = parse_interval_to_micros(&full_str).ok_or_else(|| SqlError::Parse {
365 detail: format!("cannot parse INTERVAL '{full_str}'"),
366 })?;
367
368 Ok(SqlExpr::Literal(SqlValue::Int(micros)))
369 }
370 _ => Err(SqlError::Unsupported {
371 detail: format!("expression: {expr}"),
372 }),
373 }
374}