Skip to main content

nodedb_sql/parser/preprocess/
object_literal_stmt.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Rewrite `INSERT/UPSERT INTO coll { ... }` (and `[{ ... }, ...]`) into
4//! standard `INSERT INTO coll (cols) VALUES (row), ...`.
5
6use super::literal::value_to_sql_literal;
7use crate::parser::object_literal::{parse_object_literal, parse_object_literal_array};
8
9/// Try to rewrite `INSERT INTO coll { ... }` or `INSERT INTO coll [{ ... }, { ... }]`
10/// into standard `INSERT INTO coll (cols) VALUES (row1), (row2)`.
11///
12/// Returns `None` if the statement doesn't use object literal syntax.
13pub(super) fn try_rewrite_object_literal(sql: &str) -> Option<String> {
14    let after_into = sql["INSERT INTO ".len()..].trim_start();
15    let coll_end = after_into.find(|c: char| c.is_whitespace())?;
16    let coll_name = &after_into[..coll_end];
17    let rest = after_into[coll_end..].trim_start();
18
19    // Strip trailing semicolon before parsing.
20    let obj_str = rest.trim_end_matches(';').trim_end();
21
22    if obj_str.starts_with('[') {
23        return rewrite_array_form(coll_name, obj_str);
24    }
25
26    if !obj_str.starts_with('{') {
27        return None;
28    }
29
30    let fields = parse_object_literal(obj_str)?.ok()?;
31    if fields.is_empty() {
32        return None;
33    }
34    Some(fields_to_values_sql(coll_name, &[fields]))
35}
36
37/// Rewrite `[{ ... }, { ... }]` → multi-row VALUES.
38fn rewrite_array_form(coll_name: &str, obj_str: &str) -> Option<String> {
39    let objects = parse_object_literal_array(obj_str)?.ok()?;
40    if objects.is_empty() {
41        return None;
42    }
43    Some(fields_to_values_sql(coll_name, &objects))
44}
45
46/// Build `INSERT INTO coll (col_union) VALUES (row1), (row2), ...`
47///
48/// Collects the union of all keys across all rows. Missing keys get NULL.
49fn fields_to_values_sql(
50    coll_name: &str,
51    rows: &[std::collections::HashMap<String, nodedb_types::Value>],
52) -> String {
53    let mut all_keys: Vec<String> = rows
54        .iter()
55        .flat_map(|r| r.keys().cloned())
56        .collect::<std::collections::BTreeSet<_>>()
57        .into_iter()
58        .collect();
59    all_keys.sort();
60
61    let col_list = all_keys.join(", ");
62
63    let row_strs: Vec<String> = rows
64        .iter()
65        .map(|row| {
66            let vals: Vec<String> = all_keys
67                .iter()
68                .map(|k| match row.get(k) {
69                    Some(v) => value_to_sql_literal(v),
70                    None => "NULL".to_string(),
71                })
72                .collect();
73            format!("({})", vals.join(", "))
74        })
75        .collect();
76
77    format!(
78        "INSERT INTO {} ({}) VALUES {}",
79        coll_name,
80        col_list,
81        row_strs.join(", ")
82    )
83}