1use sqlparser::ast;
4
5use crate::error::{Result, SqlError};
6use crate::parser::normalize::{normalize_ident, normalize_object_name_checked};
7use crate::resolver::expr::convert_value;
8use crate::types::*;
9
10pub(super) fn convert_value_rows(
11 columns: &[String],
12 rows: &[Vec<ast::Expr>],
13) -> Result<Vec<Vec<(String, SqlValue)>>> {
14 rows.iter()
15 .map(|row| {
16 row.iter()
17 .enumerate()
18 .map(|(i, expr)| {
19 let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
20 let val = expr_to_sql_value(expr)?;
21 Ok((col, val))
22 })
23 .collect::<Result<Vec<_>>>()
24 })
25 .collect()
26}
27
28pub(super) fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
29 match expr {
30 ast::Expr::Value(v) => convert_value(&v.value),
31 ast::Expr::UnaryOp {
32 op: ast::UnaryOperator::Minus,
33 expr: inner,
34 } => {
35 let val = expr_to_sql_value(inner)?;
36 match val {
37 SqlValue::Int(n) => Ok(SqlValue::Int(-n)),
38 SqlValue::Float(f) => Ok(SqlValue::Float(-f)),
39 SqlValue::Decimal(d) => Ok(SqlValue::Decimal(-d)),
40 _ => Err(SqlError::TypeMismatch {
41 detail: "cannot negate non-numeric value".into(),
42 }),
43 }
44 }
45 ast::Expr::Array(ast::Array { elem, .. }) => {
46 let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
47 Ok(SqlValue::Array(vals))
48 }
49 ast::Expr::Function(func) => {
50 let func_name = func
51 .name
52 .0
53 .iter()
54 .map(|p| match p {
55 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
56 _ => String::new(),
57 })
58 .collect::<Vec<_>>()
59 .join(".")
60 .to_lowercase();
61 match func_name.as_str() {
62 "st_point" => {
63 let args = super::select::extract_func_args(func)?;
64 if args.len() >= 2 {
65 let lon = super::select::extract_float(&args[0])?;
66 let lat = super::select::extract_float(&args[1])?;
67 Ok(SqlValue::String(format!(
68 r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
69 )))
70 } else {
71 Ok(SqlValue::String(format!("{expr}")))
72 }
73 }
74 "st_geomfromgeojson" => {
75 let args = super::select::extract_func_args(func)?;
76 if !args.is_empty() {
77 let s = super::select::extract_string_literal(&args[0])?;
78 Ok(SqlValue::String(s))
79 } else {
80 Ok(SqlValue::String(format!("{expr}")))
81 }
82 }
83 _ => {
84 if let Ok(sql_expr) = crate::resolver::expr::convert_expr(expr)
85 && let Some(v) = super::const_fold::fold_constant_default(&sql_expr)
86 {
87 Ok(v)
88 } else {
89 Ok(SqlValue::String(format!("{expr}")))
90 }
91 }
92 }
93 }
94 _ => Err(SqlError::Unsupported {
95 detail: format!("value expression: {expr}"),
96 }),
97 }
98}
99
100pub(super) fn extract_table_name_from_table_with_joins(
101 table: &ast::TableWithJoins,
102) -> Result<String> {
103 match &table.relation {
104 ast::TableFactor::Table { name, .. } => Ok(normalize_object_name_checked(name)?),
105 _ => Err(SqlError::Unsupported {
106 detail: "non-table target in DML".into(),
107 }),
108 }
109}
110
111pub fn extract_point_keys(selection: Option<&ast::Expr>, info: &CollectionInfo) -> Vec<SqlValue> {
113 let pk = match &info.primary_key {
114 Some(pk) => pk.clone(),
115 None => return Vec::new(),
116 };
117
118 let expr = match selection {
119 Some(e) => e,
120 None => return Vec::new(),
121 };
122
123 let mut keys = Vec::new();
124 collect_pk_equalities(expr, &pk, &mut keys);
125 keys
126}
127
128fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
129 match expr {
130 ast::Expr::BinaryOp {
131 left,
132 op: ast::BinaryOperator::Eq,
133 right,
134 } => {
135 if is_column(left, pk)
136 && let Ok(v) = expr_to_sql_value(right)
137 {
138 keys.push(v);
139 } else if is_column(right, pk)
140 && let Ok(v) = expr_to_sql_value(left)
141 {
142 keys.push(v);
143 }
144 }
145 ast::Expr::BinaryOp {
146 left,
147 op: ast::BinaryOperator::Or,
148 right,
149 } => {
150 collect_pk_equalities(left, pk, keys);
151 collect_pk_equalities(right, pk, keys);
152 }
153 ast::Expr::InList {
154 expr: inner,
155 list,
156 negated: false,
157 } if is_column(inner, pk) => {
158 for item in list {
159 if let Ok(v) = expr_to_sql_value(item) {
160 keys.push(v);
161 }
162 }
163 }
164 _ => {}
165 }
166}
167
168fn is_column(expr: &ast::Expr, name: &str) -> bool {
169 match expr {
170 ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
171 ast::Expr::CompoundIdentifier(parts) if parts.len() >= 3 => false,
173 ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
174 normalize_ident(&parts[1]) == name
175 }
176 _ => false,
177 }
178}
179
180pub(super) fn build_vector_primary_insert_plan(
186 collection: &str,
187 vpc: &nodedb_types::VectorPrimaryConfig,
188 _columns: &[String],
189 rows: Vec<Vec<(String, SqlValue)>>,
190) -> Result<Vec<SqlPlan>> {
191 let mut result_rows = Vec::with_capacity(rows.len());
192 for row in rows {
193 let mut vector: Option<Vec<f32>> = None;
194 let mut payload_fields = std::collections::HashMap::new();
195
196 for (col, val) in row {
197 if col == vpc.vector_field {
198 match val {
199 SqlValue::Array(items) => {
200 let floats: Result<Vec<f32>> = items
201 .iter()
202 .map(|v| match v {
203 SqlValue::Float(f) => Ok(*f as f32),
204 SqlValue::Int(i) => Ok(*i as f32),
205 SqlValue::Decimal(d) => {
206 use rust_decimal::prelude::ToPrimitive;
207 d.to_f32().ok_or_else(|| SqlError::Parse {
208 detail: format!(
209 "vector element decimal '{d}' is out of f32 range"
210 ),
211 })
212 }
213 other => Err(SqlError::Parse {
214 detail: format!(
215 "vector field must contain numbers, got {other:?}"
216 ),
217 }),
218 })
219 .collect();
220 vector = Some(floats?);
221 }
222 other => {
223 return Err(SqlError::Parse {
224 detail: format!(
225 "vector field '{}' must be an array literal, got {other:?}",
226 vpc.vector_field
227 ),
228 });
229 }
230 }
231 } else {
232 payload_fields.insert(col, val);
233 }
234 }
235
236 let vector = vector.ok_or_else(|| SqlError::Parse {
237 detail: format!(
238 "vector-primary INSERT missing required vector field '{}'",
239 vpc.vector_field
240 ),
241 })?;
242
243 result_rows.push(VectorPrimaryRow {
244 surrogate: nodedb_types::Surrogate::ZERO,
245 vector,
246 payload_fields,
247 });
248 }
249
250 Ok(vec![SqlPlan::VectorPrimaryInsert {
251 collection: collection.to_string(),
252 field: vpc.vector_field.clone(),
253 quantization: vpc.quantization,
254 payload_indexes: vpc.payload_indexes.clone(),
255 rows: result_rows,
256 }])
257}
258
259pub(super) fn build_kv_insert_plan(
270 table_name: String,
271 columns: &[String],
272 rows_ast: &[Vec<ast::Expr>],
273 intent: KvInsertIntent,
274 on_conflict_updates: Vec<(String, SqlExpr)>,
275 pk_col: Option<&str>,
276) -> Result<Vec<SqlPlan>> {
277 let key_col_name = pk_col.unwrap_or("key");
278 let key_idx = columns.iter().position(|c| c == key_col_name);
279 let ttl_idx = columns.iter().position(|c| c == "ttl");
280 let exclude_from_value: std::collections::HashSet<usize> = {
288 let mut s = std::collections::HashSet::new();
289 if key_col_name == "key"
291 && let Some(idx) = key_idx
292 {
293 s.insert(idx);
294 }
295 if let Some(idx) = ttl_idx {
296 s.insert(idx);
297 }
298 s
299 };
300 let mut entries = Vec::with_capacity(rows_ast.len());
301 let mut ttl_secs: u64 = 0;
302 for row_exprs in rows_ast {
303 let key_val = match key_idx {
304 Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
305 None => SqlValue::String(String::new()),
306 };
307 if let Some(idx) = ttl_idx {
308 match expr_to_sql_value(&row_exprs[idx]) {
309 Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
310 Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
311 _ => {}
312 }
313 }
314 let value_cols: Vec<(String, SqlValue)> = columns
315 .iter()
316 .enumerate()
317 .filter(|(i, _)| !exclude_from_value.contains(i))
318 .map(|(i, col)| {
319 let val = expr_to_sql_value(&row_exprs[i])?;
320 Ok((col.clone(), val))
321 })
322 .collect::<Result<Vec<_>>>()?;
323 entries.push((key_val, value_cols));
324 }
325 Ok(vec![SqlPlan::KvInsert {
326 collection: table_name,
327 entries,
328 ttl_secs,
329 intent,
330 on_conflict_updates,
331 }])
332}