1use sqlparser::ast;
4
5use crate::error::{Result, SqlError};
6use crate::parser::normalize::{normalize_ident, normalize_object_name_checked};
7use crate::resolver::expr::convert_value;
8use crate::types::*;
9
10pub(super) fn convert_value_rows(
11 columns: &[String],
12 rows: &[Vec<ast::Expr>],
13) -> Result<Vec<Vec<(String, SqlValue)>>> {
14 rows.iter()
15 .map(|row| {
16 row.iter()
17 .enumerate()
18 .map(|(i, expr)| {
19 let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
20 let val = expr_to_sql_value(expr)?;
21 Ok((col, val))
22 })
23 .collect::<Result<Vec<_>>>()
24 })
25 .collect()
26}
27
28pub(super) fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
29 match expr {
30 ast::Expr::Value(v) => convert_value(&v.value),
31 ast::Expr::Array(ast::Array { elem, .. }) => {
34 let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
35 Ok(SqlValue::Array(vals))
36 }
37 ast::Expr::Function(func) => match SpatialConstructor::from_function(func) {
41 Some(ctor) => spatial_constructor_to_value(ctor, func),
42 None => fold_constant_value(expr),
45 },
46 _ => fold_constant_value(expr),
53 }
54}
55
56fn fold_constant_value(expr: &ast::Expr) -> Result<SqlValue> {
57 let sql_expr = crate::resolver::expr::convert_expr(expr)?;
58 super::const_fold::fold_constant_default(&sql_expr).ok_or_else(|| SqlError::Unsupported {
59 detail: format!("value expression: {expr}"),
60 })
61}
62
63#[derive(Copy, Clone)]
68enum SpatialConstructor {
69 Point,
70 GeomFromGeoJson,
71}
72
73impl SpatialConstructor {
74 fn from_function(func: &ast::Function) -> Option<Self> {
75 let name = func
76 .name
77 .0
78 .iter()
79 .map(|p| match p {
80 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
81 _ => String::new(),
82 })
83 .collect::<Vec<_>>()
84 .join(".")
85 .to_lowercase();
86 match name.as_str() {
87 "st_point" => Some(Self::Point),
88 "st_geomfromgeojson" => Some(Self::GeomFromGeoJson),
89 _ => None,
90 }
91 }
92
93 fn display_name(self) -> &'static str {
94 match self {
95 Self::Point => "ST_Point",
96 Self::GeomFromGeoJson => "ST_GeomFromGeoJSON",
97 }
98 }
99}
100
101fn spatial_constructor_to_value(
102 ctor: SpatialConstructor,
103 func: &ast::Function,
104) -> Result<SqlValue> {
105 let args = super::select::extract_func_args(func)?;
106 match ctor {
107 SpatialConstructor::Point => {
108 if args.len() < 2 {
109 return Err(SqlError::InvalidFunction {
110 detail: format!(
111 "{} requires 2 arguments (longitude, latitude), got {}",
112 ctor.display_name(),
113 args.len()
114 ),
115 });
116 }
117 let lon = super::select::extract_float(&args[0])?;
118 let lat = super::select::extract_float(&args[1])?;
119 Ok(SqlValue::String(format!(
120 r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
121 )))
122 }
123 SpatialConstructor::GeomFromGeoJson => {
124 if args.is_empty() {
125 return Err(SqlError::InvalidFunction {
126 detail: format!(
127 "{} requires 1 argument (GeoJSON string)",
128 ctor.display_name()
129 ),
130 });
131 }
132 let s = super::select::extract_string_literal(&args[0])?;
133 Ok(SqlValue::String(s))
134 }
135 }
136}
137
138pub(super) fn extract_table_name_from_table_with_joins(
139 table: &ast::TableWithJoins,
140) -> Result<String> {
141 match &table.relation {
142 ast::TableFactor::Table { name, .. } => Ok(normalize_object_name_checked(name)?),
143 _ => Err(SqlError::Unsupported {
144 detail: "non-table target in DML".into(),
145 }),
146 }
147}
148
149pub fn extract_point_keys(selection: Option<&ast::Expr>, info: &CollectionInfo) -> Vec<SqlValue> {
151 let pk = match &info.primary_key {
152 Some(pk) => pk.clone(),
153 None => return Vec::new(),
154 };
155
156 let expr = match selection {
157 Some(e) => e,
158 None => return Vec::new(),
159 };
160
161 let mut keys = Vec::new();
162 collect_pk_equalities(expr, &pk, &mut keys);
163 keys
164}
165
166fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
167 match expr {
168 ast::Expr::BinaryOp {
169 left,
170 op: ast::BinaryOperator::Eq,
171 right,
172 } => {
173 if is_column(left, pk)
174 && let Ok(v) = expr_to_sql_value(right)
175 {
176 keys.push(v);
177 } else if is_column(right, pk)
178 && let Ok(v) = expr_to_sql_value(left)
179 {
180 keys.push(v);
181 }
182 }
183 ast::Expr::BinaryOp {
184 left,
185 op: ast::BinaryOperator::Or,
186 right,
187 } => {
188 collect_pk_equalities(left, pk, keys);
189 collect_pk_equalities(right, pk, keys);
190 }
191 ast::Expr::InList {
192 expr: inner,
193 list,
194 negated: false,
195 } if is_column(inner, pk) => {
196 for item in list {
197 if let Ok(v) = expr_to_sql_value(item) {
198 keys.push(v);
199 }
200 }
201 }
202 _ => {}
203 }
204}
205
206fn is_column(expr: &ast::Expr, name: &str) -> bool {
207 match expr {
208 ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
209 ast::Expr::CompoundIdentifier(parts) if parts.len() >= 3 => false,
211 ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
212 normalize_ident(&parts[1]) == name
213 }
214 _ => false,
215 }
216}
217
218pub(super) fn build_vector_primary_insert_plan(
224 collection: &str,
225 vpc: &nodedb_types::VectorPrimaryConfig,
226 _columns: &[String],
227 rows: Vec<Vec<(String, SqlValue)>>,
228) -> Result<Vec<SqlPlan>> {
229 let mut result_rows = Vec::with_capacity(rows.len());
230 for row in rows {
231 let mut vector: Option<Vec<f32>> = None;
232 let mut payload_fields = std::collections::HashMap::new();
233
234 for (col, val) in row {
235 if col == vpc.vector_field {
236 match val {
237 SqlValue::Array(items) => {
238 let floats: Result<Vec<f32>> = items
239 .iter()
240 .map(|v| match v {
241 SqlValue::Float(f) => Ok(*f as f32),
242 SqlValue::Int(i) => Ok(*i as f32),
243 SqlValue::Decimal(d) => {
244 use rust_decimal::prelude::ToPrimitive;
245 d.to_f32().ok_or_else(|| SqlError::Parse {
246 detail: format!(
247 "vector element decimal '{d}' is out of f32 range"
248 ),
249 })
250 }
251 other => Err(SqlError::Parse {
252 detail: format!(
253 "vector field must contain numbers, got {other:?}"
254 ),
255 }),
256 })
257 .collect();
258 vector = Some(floats?);
259 }
260 other => {
261 return Err(SqlError::Parse {
262 detail: format!(
263 "vector field '{}' must be an array literal, got {other:?}",
264 vpc.vector_field
265 ),
266 });
267 }
268 }
269 } else {
270 payload_fields.insert(col, val);
271 }
272 }
273
274 let vector = vector.ok_or_else(|| SqlError::Parse {
275 detail: format!(
276 "vector-primary INSERT missing required vector field '{}'",
277 vpc.vector_field
278 ),
279 })?;
280
281 result_rows.push(VectorPrimaryRow {
282 surrogate: nodedb_types::Surrogate::ZERO,
283 vector,
284 payload_fields,
285 });
286 }
287
288 Ok(vec![SqlPlan::VectorPrimaryInsert {
289 collection: collection.to_string(),
290 field: vpc.vector_field.clone(),
291 quantization: vpc.quantization,
292 payload_indexes: vpc.payload_indexes.clone(),
293 rows: result_rows,
294 }])
295}
296
297pub(super) fn build_kv_insert_plan(
308 table_name: String,
309 columns: &[String],
310 rows_ast: &[Vec<ast::Expr>],
311 intent: KvInsertIntent,
312 on_conflict_updates: Vec<(String, SqlExpr)>,
313 pk_col: Option<&str>,
314) -> Result<Vec<SqlPlan>> {
315 let key_col_name = pk_col.unwrap_or("key");
316 let key_idx = columns.iter().position(|c| c == key_col_name);
317 let ttl_idx = columns.iter().position(|c| c == "ttl");
318 let exclude_from_value: std::collections::HashSet<usize> = {
326 let mut s = std::collections::HashSet::new();
327 if key_col_name == "key"
329 && let Some(idx) = key_idx
330 {
331 s.insert(idx);
332 }
333 if let Some(idx) = ttl_idx {
334 s.insert(idx);
335 }
336 s
337 };
338 let mut entries = Vec::with_capacity(rows_ast.len());
339 let mut ttl_secs: u64 = 0;
340 for row_exprs in rows_ast {
341 let key_val = match key_idx {
342 Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
343 None => SqlValue::String(String::new()),
344 };
345 if let Some(idx) = ttl_idx {
346 match expr_to_sql_value(&row_exprs[idx]) {
347 Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
348 Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
349 _ => {}
350 }
351 }
352 let value_cols: Vec<(String, SqlValue)> = columns
353 .iter()
354 .enumerate()
355 .filter(|(i, _)| !exclude_from_value.contains(i))
356 .map(|(i, col)| {
357 let val = expr_to_sql_value(&row_exprs[i])?;
358 Ok((col.clone(), val))
359 })
360 .collect::<Result<Vec<_>>>()?;
361 entries.push((key_val, value_cols));
362 }
363 Ok(vec![SqlPlan::KvInsert {
364 collection: table_name,
365 entries,
366 ttl_secs,
367 intent,
368 on_conflict_updates,
369 }])
370}