nodedb_sql/planner/
dml_helpers.rs1use sqlparser::ast;
6
7use crate::error::{Result, SqlError};
8use crate::parser::normalize::{normalize_ident, normalize_object_name};
9use crate::resolver::expr::convert_value;
10use crate::types::*;
11
12pub(super) fn convert_value_rows(
13 columns: &[String],
14 rows: &[Vec<ast::Expr>],
15) -> Result<Vec<Vec<(String, SqlValue)>>> {
16 rows.iter()
17 .map(|row| {
18 row.iter()
19 .enumerate()
20 .map(|(i, expr)| {
21 let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
22 let val = expr_to_sql_value(expr)?;
23 Ok((col, val))
24 })
25 .collect::<Result<Vec<_>>>()
26 })
27 .collect()
28}
29
30pub(super) fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
31 match expr {
32 ast::Expr::Value(v) => convert_value(&v.value),
33 ast::Expr::UnaryOp {
34 op: ast::UnaryOperator::Minus,
35 expr: inner,
36 } => {
37 let val = expr_to_sql_value(inner)?;
38 match val {
39 SqlValue::Int(n) => Ok(SqlValue::Int(-n)),
40 SqlValue::Float(f) => Ok(SqlValue::Float(-f)),
41 _ => Err(SqlError::TypeMismatch {
42 detail: "cannot negate non-numeric value".into(),
43 }),
44 }
45 }
46 ast::Expr::Array(ast::Array { elem, .. }) => {
47 let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
48 Ok(SqlValue::Array(vals))
49 }
50 ast::Expr::Function(func) => {
51 let func_name = func
52 .name
53 .0
54 .iter()
55 .map(|p| match p {
56 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
57 _ => String::new(),
58 })
59 .collect::<Vec<_>>()
60 .join(".")
61 .to_lowercase();
62 match func_name.as_str() {
63 "st_point" => {
64 let args = super::select::extract_func_args(func)?;
65 if args.len() >= 2 {
66 let lon = super::select::extract_float(&args[0])?;
67 let lat = super::select::extract_float(&args[1])?;
68 Ok(SqlValue::String(format!(
69 r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
70 )))
71 } else {
72 Ok(SqlValue::String(format!("{expr}")))
73 }
74 }
75 "st_geomfromgeojson" => {
76 let args = super::select::extract_func_args(func)?;
77 if !args.is_empty() {
78 let s = super::select::extract_string_literal(&args[0])?;
79 Ok(SqlValue::String(s))
80 } else {
81 Ok(SqlValue::String(format!("{expr}")))
82 }
83 }
84 _ => {
85 if let Ok(sql_expr) = crate::resolver::expr::convert_expr(expr)
86 && let Some(v) = super::const_fold::fold_constant_default(&sql_expr)
87 {
88 Ok(v)
89 } else {
90 Ok(SqlValue::String(format!("{expr}")))
91 }
92 }
93 }
94 }
95 _ => Err(SqlError::Unsupported {
96 detail: format!("value expression: {expr}"),
97 }),
98 }
99}
100
101pub(super) fn extract_table_name_from_table_with_joins(
102 table: &ast::TableWithJoins,
103) -> Result<String> {
104 match &table.relation {
105 ast::TableFactor::Table { name, .. } => Ok(normalize_object_name(name)),
106 _ => Err(SqlError::Unsupported {
107 detail: "non-table target in DML".into(),
108 }),
109 }
110}
111
112pub(super) fn extract_point_keys(
114 selection: Option<&ast::Expr>,
115 info: &CollectionInfo,
116) -> Vec<SqlValue> {
117 let pk = match &info.primary_key {
118 Some(pk) => pk.clone(),
119 None => return Vec::new(),
120 };
121
122 let expr = match selection {
123 Some(e) => e,
124 None => return Vec::new(),
125 };
126
127 let mut keys = Vec::new();
128 collect_pk_equalities(expr, &pk, &mut keys);
129 keys
130}
131
132fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
133 match expr {
134 ast::Expr::BinaryOp {
135 left,
136 op: ast::BinaryOperator::Eq,
137 right,
138 } => {
139 if is_column(left, pk)
140 && let Ok(v) = expr_to_sql_value(right)
141 {
142 keys.push(v);
143 } else if is_column(right, pk)
144 && let Ok(v) = expr_to_sql_value(left)
145 {
146 keys.push(v);
147 }
148 }
149 ast::Expr::BinaryOp {
150 left,
151 op: ast::BinaryOperator::Or,
152 right,
153 } => {
154 collect_pk_equalities(left, pk, keys);
155 collect_pk_equalities(right, pk, keys);
156 }
157 ast::Expr::InList {
158 expr: inner,
159 list,
160 negated: false,
161 } if is_column(inner, pk) => {
162 for item in list {
163 if let Ok(v) = expr_to_sql_value(item) {
164 keys.push(v);
165 }
166 }
167 }
168 _ => {}
169 }
170}
171
172fn is_column(expr: &ast::Expr, name: &str) -> bool {
173 match expr {
174 ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
175 ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
176 normalize_ident(&parts[1]) == name
177 }
178 _ => false,
179 }
180}