Skip to main content

nodedb_sql/planner/
dml_helpers.rs

1//! DML planning helpers extracted from `dml.rs` to keep both files under
2//! the 500-line limit. Visibility is `pub(super)` so only `planner::dml`
3//! can reach these.
4
5use sqlparser::ast;
6
7use crate::error::{Result, SqlError};
8use crate::parser::normalize::{normalize_ident, normalize_object_name};
9use crate::resolver::expr::convert_value;
10use crate::types::*;
11
12pub(super) fn convert_value_rows(
13    columns: &[String],
14    rows: &[Vec<ast::Expr>],
15) -> Result<Vec<Vec<(String, SqlValue)>>> {
16    rows.iter()
17        .map(|row| {
18            row.iter()
19                .enumerate()
20                .map(|(i, expr)| {
21                    let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
22                    let val = expr_to_sql_value(expr)?;
23                    Ok((col, val))
24                })
25                .collect::<Result<Vec<_>>>()
26        })
27        .collect()
28}
29
30pub(super) fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
31    match expr {
32        ast::Expr::Value(v) => convert_value(&v.value),
33        ast::Expr::UnaryOp {
34            op: ast::UnaryOperator::Minus,
35            expr: inner,
36        } => {
37            let val = expr_to_sql_value(inner)?;
38            match val {
39                SqlValue::Int(n) => Ok(SqlValue::Int(-n)),
40                SqlValue::Float(f) => Ok(SqlValue::Float(-f)),
41                _ => Err(SqlError::TypeMismatch {
42                    detail: "cannot negate non-numeric value".into(),
43                }),
44            }
45        }
46        ast::Expr::Array(ast::Array { elem, .. }) => {
47            let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
48            Ok(SqlValue::Array(vals))
49        }
50        ast::Expr::Function(func) => {
51            let func_name = func
52                .name
53                .0
54                .iter()
55                .map(|p| match p {
56                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
57                    _ => String::new(),
58                })
59                .collect::<Vec<_>>()
60                .join(".")
61                .to_lowercase();
62            match func_name.as_str() {
63                "st_point" => {
64                    let args = super::select::extract_func_args(func)?;
65                    if args.len() >= 2 {
66                        let lon = super::select::extract_float(&args[0])?;
67                        let lat = super::select::extract_float(&args[1])?;
68                        Ok(SqlValue::String(format!(
69                            r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
70                        )))
71                    } else {
72                        Ok(SqlValue::String(format!("{expr}")))
73                    }
74                }
75                "st_geomfromgeojson" => {
76                    let args = super::select::extract_func_args(func)?;
77                    if !args.is_empty() {
78                        let s = super::select::extract_string_literal(&args[0])?;
79                        Ok(SqlValue::String(s))
80                    } else {
81                        Ok(SqlValue::String(format!("{expr}")))
82                    }
83                }
84                _ => {
85                    if let Ok(sql_expr) = crate::resolver::expr::convert_expr(expr)
86                        && let Some(v) = super::const_fold::fold_constant_default(&sql_expr)
87                    {
88                        Ok(v)
89                    } else {
90                        Ok(SqlValue::String(format!("{expr}")))
91                    }
92                }
93            }
94        }
95        _ => Err(SqlError::Unsupported {
96            detail: format!("value expression: {expr}"),
97        }),
98    }
99}
100
101pub(super) fn extract_table_name_from_table_with_joins(
102    table: &ast::TableWithJoins,
103) -> Result<String> {
104    match &table.relation {
105        ast::TableFactor::Table { name, .. } => Ok(normalize_object_name(name)),
106        _ => Err(SqlError::Unsupported {
107            detail: "non-table target in DML".into(),
108        }),
109    }
110}
111
112/// Extract point-operation keys from WHERE clause (WHERE pk = literal OR pk IN (...)).
113pub(super) fn extract_point_keys(
114    selection: Option<&ast::Expr>,
115    info: &CollectionInfo,
116) -> Vec<SqlValue> {
117    let pk = match &info.primary_key {
118        Some(pk) => pk.clone(),
119        None => return Vec::new(),
120    };
121
122    let expr = match selection {
123        Some(e) => e,
124        None => return Vec::new(),
125    };
126
127    let mut keys = Vec::new();
128    collect_pk_equalities(expr, &pk, &mut keys);
129    keys
130}
131
132fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
133    match expr {
134        ast::Expr::BinaryOp {
135            left,
136            op: ast::BinaryOperator::Eq,
137            right,
138        } => {
139            if is_column(left, pk)
140                && let Ok(v) = expr_to_sql_value(right)
141            {
142                keys.push(v);
143            } else if is_column(right, pk)
144                && let Ok(v) = expr_to_sql_value(left)
145            {
146                keys.push(v);
147            }
148        }
149        ast::Expr::BinaryOp {
150            left,
151            op: ast::BinaryOperator::Or,
152            right,
153        } => {
154            collect_pk_equalities(left, pk, keys);
155            collect_pk_equalities(right, pk, keys);
156        }
157        ast::Expr::InList {
158            expr: inner,
159            list,
160            negated: false,
161        } if is_column(inner, pk) => {
162            for item in list {
163                if let Ok(v) = expr_to_sql_value(item) {
164                    keys.push(v);
165                }
166            }
167        }
168        _ => {}
169    }
170}
171
172fn is_column(expr: &ast::Expr, name: &str) -> bool {
173    match expr {
174        ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
175        ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
176            normalize_ident(&parts[1]) == name
177        }
178        _ => false,
179    }
180}