use sqlparser::ast::{self, Expr, Value};
use crate::error::{Result, SqlError};
use crate::parser::normalize::{SCHEMA_QUALIFIED_MSG, normalize_ident};
use crate::types::*;
use super::binary_ops::{convert_binary_op, convert_unary_op};
use super::functions::convert_function_depth;
use super::value::{convert_value, parse_interval_to_micros};
const MAX_CONVERT_DEPTH: usize = 128;
fn is_zero_arg_keyword_function(name: &str) -> bool {
matches!(
name,
"current_timestamp"
| "current_date"
| "current_time"
| "localtime"
| "localtimestamp"
| "current_user"
| "current_role"
| "current_schema"
| "session_user"
| "user"
)
}
pub fn convert_expr(expr: &Expr) -> Result<SqlExpr> {
convert_expr_depth(expr, &mut 0)
}
pub(super) fn convert_expr_depth(expr: &Expr, depth: &mut usize) -> Result<SqlExpr> {
*depth += 1;
if *depth > MAX_CONVERT_DEPTH {
return Err(SqlError::Unsupported {
detail: format!("expression nesting depth exceeds maximum of {MAX_CONVERT_DEPTH}"),
});
}
let result = convert_expr_inner(expr, depth);
*depth -= 1;
result
}
fn convert_expr_inner(expr: &Expr, depth: &mut usize) -> Result<SqlExpr> {
match expr {
Expr::Identifier(ident) => {
let name = normalize_ident(ident);
if is_zero_arg_keyword_function(&name) {
return Ok(SqlExpr::Function {
name,
args: vec![],
distinct: false,
});
}
Ok(SqlExpr::Column { table: None, name })
}
Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
let qualified: String = parts
.iter()
.map(normalize_ident)
.collect::<Vec<_>>()
.join(".");
Err(SqlError::Unsupported {
detail: format!(
"schema-qualified column reference '{qualified}': {SCHEMA_QUALIFIED_MSG}"
),
})
}
Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(SqlExpr::Column {
table: Some(normalize_ident(&parts[0])),
name: normalize_ident(&parts[1]),
}),
Expr::Value(val) => Ok(SqlExpr::Literal(convert_value(&val.value)?)),
Expr::BinaryOp { left, op, right } => {
use ast::BinaryOperator;
let json_fn: Option<&str> = match op {
BinaryOperator::Arrow => Some("pg_json_get"),
BinaryOperator::LongArrow => Some("pg_json_get_text"),
BinaryOperator::HashArrow => Some("pg_json_path_get"),
BinaryOperator::HashLongArrow => Some("pg_json_path_get_text"),
BinaryOperator::AtArrow => Some("pg_json_contains"),
BinaryOperator::ArrowAt => Some("pg_json_contained_by"),
BinaryOperator::Question => Some("pg_json_has_key"),
BinaryOperator::QuestionAnd => Some("pg_json_has_all_keys"),
BinaryOperator::QuestionPipe => Some("pg_json_has_any_key"),
_ => None,
};
if let Some(name) = json_fn {
return Ok(SqlExpr::Function {
name: name.into(),
args: vec![
convert_expr_depth(left, depth)?,
convert_expr_depth(right, depth)?,
],
distinct: false,
});
}
if matches!(op, BinaryOperator::AtAt) {
let col_expr = convert_expr_depth(left, depth)?;
let query_expr = convert_expr_depth(right, depth)?;
return Ok(crate::functions::fts_ops::pg_fts_funcs::lower_pg_fts_match(
col_expr, query_expr,
));
}
Ok(SqlExpr::BinaryOp {
left: Box::new(convert_expr_depth(left, depth)?),
op: convert_binary_op(op)?,
right: Box::new(convert_expr_depth(right, depth)?),
})
}
Expr::UnaryOp { op, expr } => Ok(SqlExpr::UnaryOp {
op: convert_unary_op(op)?,
expr: Box::new(convert_expr_depth(expr, depth)?),
}),
Expr::Function(func) => convert_function_depth(func, depth),
Expr::Nested(inner) => convert_expr_depth(inner, depth),
Expr::IsNull(inner) => Ok(SqlExpr::IsNull {
expr: Box::new(convert_expr_depth(inner, depth)?),
negated: false,
}),
Expr::IsNotNull(inner) => Ok(SqlExpr::IsNull {
expr: Box::new(convert_expr_depth(inner, depth)?),
negated: true,
}),
Expr::InList {
expr,
list,
negated,
} => Ok(SqlExpr::InList {
expr: Box::new(convert_expr_depth(expr, depth)?),
list: list
.iter()
.map(|e| convert_expr_depth(e, depth))
.collect::<Result<_>>()?,
negated: *negated,
}),
Expr::Between {
expr,
low,
high,
negated,
} => Ok(SqlExpr::Between {
expr: Box::new(convert_expr_depth(expr, depth)?),
low: Box::new(convert_expr_depth(low, depth)?),
high: Box::new(convert_expr_depth(high, depth)?),
negated: *negated,
}),
Expr::Like {
expr,
pattern,
negated,
..
} => Ok(SqlExpr::Like {
expr: Box::new(convert_expr_depth(expr, depth)?),
pattern: Box::new(convert_expr_depth(pattern, depth)?),
negated: *negated,
case_insensitive: false,
}),
Expr::ILike {
expr,
pattern,
negated,
..
} => Ok(SqlExpr::Like {
expr: Box::new(convert_expr_depth(expr, depth)?),
pattern: Box::new(convert_expr_depth(pattern, depth)?),
negated: *negated,
case_insensitive: true,
}),
Expr::Case {
operand,
conditions,
else_result,
..
} => {
let when_then = conditions
.iter()
.map(|cw| {
Ok((
convert_expr_depth(&cw.condition, depth)?,
convert_expr_depth(&cw.result, depth)?,
))
})
.collect::<Result<Vec<_>>>()?;
Ok(SqlExpr::Case {
operand: operand
.as_ref()
.map(|e| convert_expr_depth(e, depth).map(Box::new))
.transpose()?,
when_then,
else_expr: else_result
.as_ref()
.map(|e| convert_expr_depth(e, depth).map(Box::new))
.transpose()?,
})
}
Expr::TypedString(ts) => {
let type_str = format!("{}", ts.data_type).to_ascii_uppercase();
let raw = match &ts.value.value {
Value::SingleQuotedString(s) => s.clone(),
other => {
return Err(SqlError::Unsupported {
detail: format!("typed string value: {other}"),
});
}
};
match type_str.as_str() {
"TIMESTAMP" => {
let dt =
nodedb_types::NdbDateTime::parse(&raw).ok_or_else(|| SqlError::Parse {
detail: format!("cannot parse TIMESTAMP literal: '{raw}'"),
})?;
return Ok(SqlExpr::Literal(SqlValue::Timestamp(dt)));
}
"TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => {
let dt =
nodedb_types::NdbDateTime::parse(&raw).ok_or_else(|| SqlError::Parse {
detail: format!("cannot parse TIMESTAMPTZ literal: '{raw}'"),
})?;
return Ok(SqlExpr::Literal(SqlValue::Timestamptz(dt)));
}
_ => {}
}
Ok(SqlExpr::Literal(SqlValue::String(raw)))
}
Expr::Cast {
expr, data_type, ..
} => {
let type_str = format!("{data_type}").to_ascii_lowercase();
if type_str == "tsvector" || type_str == "tsquery" {
return convert_expr_depth(expr, depth);
}
let upper = type_str.to_uppercase();
if (upper == "TIMESTAMP"
|| upper == "TIMESTAMPTZ"
|| upper == "TIMESTAMP WITH TIME ZONE")
&& let Expr::Value(v) = expr.as_ref()
&& let Value::SingleQuotedString(s) = &v.value
{
let dt = nodedb_types::NdbDateTime::parse(s).ok_or_else(|| SqlError::Parse {
detail: format!("cannot parse timestamp cast: '{s}'"),
})?;
return Ok(SqlExpr::Literal(if upper == "TIMESTAMP" {
SqlValue::Timestamp(dt)
} else {
SqlValue::Timestamptz(dt)
}));
}
Ok(SqlExpr::Cast {
expr: Box::new(convert_expr_depth(expr, depth)?),
to_type: format!("{data_type}"),
})
}
Expr::Array(ast::Array { elem, .. }) => {
let elems = elem
.iter()
.map(|e| convert_expr_depth(e, depth))
.collect::<Result<_>>()?;
Ok(SqlExpr::ArrayLiteral(elems))
}
Expr::Wildcard(_) => Ok(SqlExpr::Wildcard),
Expr::Trim { expr, .. } => Ok(SqlExpr::Function {
name: "trim".into(),
args: vec![convert_expr_depth(expr, depth)?],
distinct: false,
}),
Expr::Ceil { expr, .. } => Ok(SqlExpr::Function {
name: "ceil".into(),
args: vec![convert_expr_depth(expr, depth)?],
distinct: false,
}),
Expr::Floor { expr, .. } => Ok(SqlExpr::Function {
name: "floor".into(),
args: vec![convert_expr_depth(expr, depth)?],
distinct: false,
}),
Expr::Substring {
expr,
substring_from,
substring_for,
..
} => {
let mut args = vec![convert_expr_depth(expr, depth)?];
if let Some(from) = substring_from {
args.push(convert_expr_depth(from, depth)?);
}
if let Some(len) = substring_for {
args.push(convert_expr_depth(len, depth)?);
}
Ok(SqlExpr::Function {
name: "substring".into(),
args,
distinct: false,
})
}
Expr::Interval(interval) => {
let interval_str = match interval.value.as_ref() {
Expr::Value(v) => match &v.value {
Value::SingleQuotedString(s) => s.clone(),
Value::Number(n, _) => {
if let Some(ref field) = interval.leading_field {
format!("{n} {field}")
} else {
n.clone()
}
}
_ => {
return Err(SqlError::Unsupported {
detail: format!("INTERVAL value: {}", interval.value),
});
}
},
_ => {
return Err(SqlError::Unsupported {
detail: format!("INTERVAL expression: {}", interval.value),
});
}
};
let full_str = if interval_str.chars().all(|c| c.is_ascii_digit())
&& let Some(ref field) = interval.leading_field
{
format!("{interval_str} {field}")
} else {
interval_str
};
let micros = parse_interval_to_micros(&full_str).ok_or_else(|| SqlError::Parse {
detail: format!("cannot parse INTERVAL '{full_str}'"),
})?;
Ok(SqlExpr::Literal(SqlValue::Int(micros)))
}
_ => Err(SqlError::Unsupported {
detail: format!("expression: {expr}"),
}),
}
}