Skip to main content

nodedb_sql/parser/
preprocess.rs

1//! SQL pre-processing: rewrite NodeDB-specific syntax into standard SQL
2//! before handing to sqlparser-rs.
3//!
4//! Handles:
5//! - `UPSERT INTO coll (cols) VALUES (vals)` → `INSERT INTO coll (cols) VALUES (vals)` + upsert flag
6//! - `INSERT INTO coll { key: 'val', ... }` → `INSERT INTO coll (key) VALUES ('val')` + object literal flag
7//! - `UPSERT INTO coll { key: 'val', ... }` → both rewrites combined
8
9use super::object_literal::{parse_object_literal, parse_object_literal_array};
10
11/// Result of pre-processing a SQL string.
12pub struct PreprocessedSql {
13    /// The rewritten SQL (standard SQL that sqlparser can handle).
14    pub sql: String,
15    /// Whether the original statement was UPSERT (not INSERT).
16    pub is_upsert: bool,
17}
18
19/// Pre-process a SQL string, rewriting NodeDB-specific syntax.
20///
21/// Returns `None` if no rewriting was needed (pass through to sqlparser as-is).
22pub fn preprocess(sql: &str) -> Option<PreprocessedSql> {
23    let trimmed = sql.trim();
24    let upper = trimmed.to_uppercase();
25
26    // Check for UPSERT INTO.
27    let is_upsert = upper.starts_with("UPSERT INTO ");
28
29    if is_upsert {
30        // Rewrite UPSERT INTO → INSERT INTO, then check for { } literal.
31        let rewritten = format!("INSERT INTO {}", &trimmed["UPSERT INTO ".len()..]);
32        if let Some(result) = try_rewrite_object_literal(&rewritten) {
33            return Some(PreprocessedSql {
34                sql: result,
35                is_upsert: true,
36            });
37        }
38        return Some(PreprocessedSql {
39            sql: rewritten,
40            is_upsert: true,
41        });
42    }
43
44    // Check for INSERT INTO coll { ... } object literal syntax.
45    if upper.starts_with("INSERT INTO ")
46        && let Some(result) = try_rewrite_object_literal(trimmed)
47    {
48        return Some(PreprocessedSql {
49            sql: result,
50            is_upsert: false,
51        });
52    }
53
54    // Apply expression-level rewrites: `<->` operator, `{ }` in function args.
55    let mut sql_buf = trimmed.to_string();
56    let mut any_rewrite = false;
57
58    // Rewrite pgvector `<->` operator to `vector_distance()` function call.
59    if sql_buf.contains("<->")
60        && let Some(rewritten) = rewrite_arrow_distance(&sql_buf)
61    {
62        sql_buf = rewritten;
63        any_rewrite = true;
64    }
65
66    // Rewrite `{ key: val }` inside function args to JSON string literals.
67    // e.g., text_match(body, 'q', { fuzzy: true }) → text_match(body, 'q', '{"fuzzy":true}')
68    if (sql_buf.contains("{ ") || sql_buf.contains("{f") || sql_buf.contains("{d"))
69        && let Some(rewritten) = rewrite_object_literal_args(&sql_buf)
70    {
71        sql_buf = rewritten;
72        any_rewrite = true;
73    }
74
75    if any_rewrite {
76        return Some(PreprocessedSql {
77            sql: sql_buf,
78            is_upsert: false,
79        });
80    }
81
82    None
83}
84
85/// Try to rewrite `INSERT INTO coll { ... }` or `INSERT INTO coll [{ ... }, { ... }]`
86/// into standard `INSERT INTO coll (cols) VALUES (row1), (row2)`.
87///
88/// Returns `None` if the statement doesn't use object literal syntax.
89fn try_rewrite_object_literal(sql: &str) -> Option<String> {
90    // Find collection name after INSERT INTO.
91    let after_into = sql["INSERT INTO ".len()..].trim_start();
92    let coll_end = after_into.find(|c: char| c.is_whitespace())?;
93    let coll_name = &after_into[..coll_end];
94    let rest = after_into[coll_end..].trim_start();
95
96    // Strip trailing semicolon before parsing.
97    let obj_str = rest.trim_end_matches(';').trim_end();
98
99    if obj_str.starts_with('[') {
100        // Array form: INSERT INTO coll [{ ... }, { ... }]
101        return rewrite_array_form(coll_name, obj_str);
102    }
103
104    if !obj_str.starts_with('{') {
105        return None;
106    }
107
108    // Single object form: INSERT INTO coll { ... }
109    let fields = parse_object_literal(obj_str)?.ok()?;
110    if fields.is_empty() {
111        return None;
112    }
113    Some(fields_to_values_sql(coll_name, &[fields]))
114}
115
116/// Rewrite `[{ ... }, { ... }]` → multi-row VALUES.
117fn rewrite_array_form(coll_name: &str, obj_str: &str) -> Option<String> {
118    let objects = parse_object_literal_array(obj_str)?.ok()?;
119    if objects.is_empty() {
120        return None;
121    }
122    Some(fields_to_values_sql(coll_name, &objects))
123}
124
125/// Build `INSERT INTO coll (col_union) VALUES (row1), (row2), ...`
126///
127/// Collects the union of all keys across all rows. Missing keys get NULL.
128fn fields_to_values_sql(
129    coll_name: &str,
130    rows: &[std::collections::HashMap<String, nodedb_types::Value>],
131) -> String {
132    // Collect union of all keys, sorted for deterministic output.
133    let mut all_keys: Vec<String> = rows
134        .iter()
135        .flat_map(|r| r.keys().cloned())
136        .collect::<std::collections::BTreeSet<_>>()
137        .into_iter()
138        .collect();
139    all_keys.sort();
140
141    let col_list = all_keys.join(", ");
142
143    let row_strs: Vec<String> = rows
144        .iter()
145        .map(|row| {
146            let vals: Vec<String> = all_keys
147                .iter()
148                .map(|k| match row.get(k) {
149                    Some(v) => value_to_sql_literal(v),
150                    None => "NULL".to_string(),
151                })
152                .collect();
153            format!("({})", vals.join(", "))
154        })
155        .collect();
156
157    format!(
158        "INSERT INTO {} ({}) VALUES {}",
159        coll_name,
160        col_list,
161        row_strs.join(", ")
162    )
163}
164
165/// Rewrite `{ key: val }` object literals inside function argument positions
166/// to JSON string literals: `'{"key": val}'`.
167///
168/// Detects patterns like `func(arg1, arg2, { key: val })` and rewrites the
169/// `{ }` to a single-quoted JSON string. Only rewrites `{ }` that appear
170/// inside parentheses (function calls), not at statement level (INSERT).
171fn rewrite_object_literal_args(sql: &str) -> Option<String> {
172    let mut result = String::with_capacity(sql.len());
173    let chars: Vec<char> = sql.chars().collect();
174    let mut i = 0;
175    let mut found = false;
176    let mut paren_depth: i32 = 0;
177
178    while i < chars.len() {
179        match chars[i] {
180            '(' => {
181                paren_depth += 1;
182                result.push('(');
183                i += 1;
184            }
185            ')' => {
186                paren_depth = paren_depth.saturating_sub(1);
187                result.push(')');
188                i += 1;
189            }
190            '\'' => {
191                // Skip quoted strings entirely.
192                result.push('\'');
193                i += 1;
194                while i < chars.len() {
195                    result.push(chars[i]);
196                    if chars[i] == '\'' {
197                        // Handle escaped quotes ('').
198                        if i + 1 < chars.len() && chars[i + 1] == '\'' {
199                            i += 1;
200                            result.push(chars[i]);
201                        } else {
202                            break;
203                        }
204                    }
205                    i += 1;
206                }
207                i += 1;
208            }
209            '{' if paren_depth > 0 => {
210                // Object literal inside function args — parse and convert to JSON string.
211                let remaining: String = chars[i..].iter().collect();
212                if let Some(Ok(fields)) = parse_object_literal(&remaining) {
213                    // Find the end of the object literal to know how many chars to skip.
214                    let obj_end = find_matching_brace(&chars, i);
215                    if let Some(end) = obj_end {
216                        let json = value_map_to_json(&fields);
217                        result.push('\'');
218                        result.push_str(&json);
219                        result.push('\'');
220                        i = end + 1;
221                        found = true;
222                        continue;
223                    }
224                }
225                // Not a valid object literal — pass through.
226                result.push('{');
227                i += 1;
228            }
229            _ => {
230                result.push(chars[i]);
231                i += 1;
232            }
233        }
234    }
235
236    if found { Some(result) } else { None }
237}
238
239/// Convert a parsed field map to a JSON string without external serializer.
240fn value_map_to_json(fields: &std::collections::HashMap<String, nodedb_types::Value>) -> String {
241    let mut parts = Vec::with_capacity(fields.len());
242    let mut entries: Vec<_> = fields.iter().collect();
243    entries.sort_by_key(|(k, _)| k.as_str());
244    for (key, val) in entries {
245        parts.push(format!("\"{}\":{}", key, value_to_json(val)));
246    }
247    format!("{{{}}}", parts.join(","))
248}
249
250/// Convert a single `Value` to JSON text.
251fn value_to_json(value: &nodedb_types::Value) -> String {
252    match value {
253        nodedb_types::Value::String(s) => {
254            format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\""))
255        }
256        nodedb_types::Value::Integer(n) => n.to_string(),
257        nodedb_types::Value::Float(f) => format!("{f}"),
258        nodedb_types::Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
259        nodedb_types::Value::Null => "null".to_string(),
260        nodedb_types::Value::Array(items) => {
261            let inner: Vec<String> = items.iter().map(value_to_json).collect();
262            format!("[{}]", inner.join(","))
263        }
264        nodedb_types::Value::Object(map) => value_map_to_json(map),
265        _ => format!("\"{}\"", format!("{value:?}").replace('"', "\\\"")),
266    }
267}
268
269/// Find the index of the matching `}` for a `{` at position `start`.
270fn find_matching_brace(chars: &[char], start: usize) -> Option<usize> {
271    let mut depth = 0;
272    let mut in_string = false;
273    for i in start..chars.len() {
274        match chars[i] {
275            '\'' if !in_string => in_string = true,
276            '\'' if in_string => {
277                if i + 1 < chars.len() && chars[i + 1] == '\'' {
278                    // Skip escaped quote.
279                    continue;
280                }
281                in_string = false;
282            }
283            '{' if !in_string => depth += 1,
284            '}' if !in_string => {
285                depth -= 1;
286                if depth == 0 {
287                    return Some(i);
288                }
289            }
290            _ => {}
291        }
292    }
293    None
294}
295
296/// Rewrite all occurrences of `expr <-> expr` to `vector_distance(expr, expr)`.
297///
298/// Handles: `column_name <-> ARRAY[...]`, `column <-> $param`, etc.
299/// Returns `None` if no valid `<->` patterns are found.
300fn rewrite_arrow_distance(sql: &str) -> Option<String> {
301    let mut result = String::with_capacity(sql.len());
302    let mut remaining = sql;
303    let mut found = false;
304
305    while let Some(arrow_pos) = remaining.find("<->") {
306        // Extract left operand: walk backwards from `<->` to find the identifier/expression.
307        let before = &remaining[..arrow_pos];
308        let left = extract_left_operand(before)?;
309        let left_start = arrow_pos - left.len();
310
311        // Extract right operand: walk forward from after `<->`.
312        let after = &remaining[arrow_pos + 3..];
313        let (right, right_len) = extract_right_operand(after.trim_start())?;
314        let ws_skip = after.len() - after.trim_start().len();
315
316        // Build: everything before left_operand + vector_distance(left, right) + rest
317        result.push_str(&remaining[..left_start]);
318        result.push_str(&format!("vector_distance({left}, {right})"));
319        remaining = &remaining[arrow_pos + 3 + ws_skip + right_len..];
320        found = true;
321    }
322
323    if !found {
324        return None;
325    }
326
327    result.push_str(remaining);
328    Some(result)
329}
330
331/// Extract the left operand before `<->`: a column name or dotted path.
332fn extract_left_operand(before: &str) -> Option<String> {
333    let trimmed = before.trim_end();
334    // Walk backwards to find the start of the identifier.
335    let start = trimmed
336        .rfind(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
337        .map(|p| p + 1)
338        .unwrap_or(0);
339    let ident = &trimmed[start..];
340    if ident.is_empty() {
341        return None;
342    }
343    Some(ident.to_string())
344}
345
346/// Extract the right operand after `<->`: ARRAY[...], $param, or identifier.
347/// Returns (operand_text, consumed_length).
348fn extract_right_operand(after: &str) -> Option<(String, usize)> {
349    let trimmed = after.trim_start();
350    let upper = trimmed.to_uppercase();
351
352    if upper.starts_with("ARRAY[") {
353        // Find matching `]`.
354        let mut depth = 0;
355        for (i, c) in trimmed.char_indices() {
356            match c {
357                '[' => depth += 1,
358                ']' => {
359                    depth -= 1;
360                    if depth == 0 {
361                        return Some((trimmed[..=i].to_string(), i + 1));
362                    }
363                }
364                _ => {}
365            }
366        }
367        None // Unmatched bracket.
368    } else if trimmed.starts_with('$') {
369        // Parameter reference: $1, $query_vec, etc.
370        let end = trimmed
371            .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '$')
372            .unwrap_or(trimmed.len());
373        Some((trimmed[..end].to_string(), end))
374    } else {
375        // Identifier: column name.
376        let end = trimmed
377            .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
378            .unwrap_or(trimmed.len());
379        if end == 0 {
380            return None;
381        }
382        Some((trimmed[..end].to_string(), end))
383    }
384}
385
386/// Convert a `nodedb_types::Value` to a SQL literal string.
387///
388/// Used by pre-processing and by Origin's pgwire handlers to build SQL
389/// from parsed field maps. Handles all Value variants.
390pub fn value_to_sql_literal(value: &nodedb_types::Value) -> String {
391    match value {
392        nodedb_types::Value::String(s) => format!("'{}'", s.replace('\'', "''")),
393        nodedb_types::Value::Integer(n) => n.to_string(),
394        nodedb_types::Value::Float(f) => format!("{f}"),
395        nodedb_types::Value::Bool(b) => if *b { "TRUE" } else { "FALSE" }.to_string(),
396        nodedb_types::Value::Null => "NULL".to_string(),
397        nodedb_types::Value::Array(items) => {
398            let inner: Vec<String> = items.iter().map(value_to_sql_literal).collect();
399            format!("ARRAY[{}]", inner.join(", "))
400        }
401        nodedb_types::Value::Bytes(b) => {
402            let hex: String = b.iter().map(|byte| format!("{byte:02x}")).collect();
403            format!("'\\x{hex}'")
404        }
405        nodedb_types::Value::Object(_) => "NULL".to_string(),
406        nodedb_types::Value::Uuid(u) => format!("'{u}'"),
407        nodedb_types::Value::Ulid(u) => format!("'{u}'"),
408        nodedb_types::Value::DateTime(dt) => format!("'{dt}'"),
409        nodedb_types::Value::Duration(d) => format!("'{d}'"),
410        nodedb_types::Value::Decimal(d) => d.to_string(),
411        // Exotic types: format as string literal for SQL passthrough.
412        other => format!("'{}'", format!("{other:?}").replace('\'', "''")),
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn passthrough_standard_sql() {
422        assert!(preprocess("SELECT * FROM users").is_none());
423        assert!(preprocess("INSERT INTO users (name) VALUES ('alice')").is_none());
424        assert!(preprocess("DELETE FROM users WHERE id = 1").is_none());
425    }
426
427    #[test]
428    fn upsert_rewrite() {
429        let result = preprocess("UPSERT INTO users (name) VALUES ('alice')").unwrap();
430        assert!(result.is_upsert);
431        assert_eq!(result.sql, "INSERT INTO users (name) VALUES ('alice')");
432    }
433
434    #[test]
435    fn object_literal_insert() {
436        let result = preprocess("INSERT INTO users { name: 'alice', age: 30 }").unwrap();
437        assert!(!result.is_upsert);
438        assert!(result.sql.starts_with("INSERT INTO users ("));
439        assert!(result.sql.contains("'alice'"));
440        assert!(result.sql.contains("30"));
441    }
442
443    #[test]
444    fn object_literal_upsert() {
445        let result = preprocess("UPSERT INTO users { name: 'bob' }").unwrap();
446        assert!(result.is_upsert);
447        assert!(result.sql.starts_with("INSERT INTO users ("));
448        assert!(result.sql.contains("'bob'"));
449    }
450
451    #[test]
452    fn batch_array_insert() {
453        let result =
454            preprocess("INSERT INTO users [{ name: 'alice', age: 30 }, { name: 'bob', age: 25 }]")
455                .unwrap();
456        assert!(!result.is_upsert);
457        // Should produce multi-row VALUES: ... VALUES (...), (...)
458        assert!(result.sql.contains("VALUES"));
459        assert!(result.sql.contains("'alice'"));
460        assert!(result.sql.contains("'bob'"));
461        assert!(result.sql.contains("30"));
462        assert!(result.sql.contains("25"));
463        // Two row groups separated by comma
464        let values_part = result.sql.split("VALUES").nth(1).unwrap();
465        let row_count = values_part.matches('(').count();
466        assert_eq!(row_count, 2, "should have 2 row groups: {}", result.sql);
467    }
468
469    #[test]
470    fn batch_array_heterogeneous_keys() {
471        let result =
472            preprocess("INSERT INTO docs [{ id: 'a', name: 'Alice' }, { id: 'b', role: 'admin' }]")
473                .unwrap();
474        // Union of keys: id, name, role — missing keys get NULL.
475        assert!(result.sql.contains("NULL"));
476        assert!(result.sql.contains("'Alice'"));
477        assert!(result.sql.contains("'admin'"));
478    }
479
480    #[test]
481    fn batch_array_upsert() {
482        let result =
483            preprocess("UPSERT INTO users [{ id: 'u1', name: 'a' }, { id: 'u2', name: 'b' }]")
484                .unwrap();
485        assert!(result.is_upsert);
486        assert!(result.sql.contains("VALUES"));
487    }
488
489    #[test]
490    fn arrow_distance_operator_select() {
491        let result = preprocess(
492            "SELECT title FROM articles ORDER BY embedding <-> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
493        )
494        .unwrap();
495        assert!(
496            result
497                .sql
498                .contains("vector_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
499            "got: {}",
500            result.sql
501        );
502        assert!(!result.sql.contains("<->"));
503    }
504
505    #[test]
506    fn arrow_distance_operator_where() {
507        let result =
508            preprocess("SELECT * FROM docs WHERE embedding <-> ARRAY[1.0, 2.0] < 0.5").unwrap();
509        assert!(
510            result
511                .sql
512                .contains("vector_distance(embedding, ARRAY[1.0, 2.0])"),
513            "got: {}",
514            result.sql
515        );
516    }
517
518    #[test]
519    fn arrow_distance_no_match() {
520        // No <-> in SQL — should return None.
521        assert!(preprocess("SELECT * FROM users WHERE age > 30").is_none());
522    }
523
524    #[test]
525    fn arrow_distance_with_alias() {
526        let result =
527            preprocess("SELECT embedding <-> ARRAY[0.1, 0.2] AS dist FROM articles").unwrap();
528        assert!(
529            result
530                .sql
531                .contains("vector_distance(embedding, ARRAY[0.1, 0.2]) AS dist"),
532            "got: {}",
533            result.sql
534        );
535    }
536
537    #[test]
538    fn fuzzy_object_literal_in_function() {
539        // Test the rewriter directly first.
540        let direct = rewrite_object_literal_args(
541            "SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })",
542        );
543        assert!(direct.is_some(), "rewrite_object_literal_args should match");
544        let rewritten = direct.unwrap();
545        assert!(
546            rewritten.contains("\"fuzzy\""),
547            "direct rewrite should contain JSON, got: {}",
548            rewritten
549        );
550
551        let result =
552            preprocess("SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })")
553                .unwrap();
554        assert!(
555            !result.sql.contains("{ fuzzy"),
556            "should not contain object literal, got: {}",
557            result.sql
558        );
559    }
560
561    #[test]
562    fn fuzzy_object_literal_with_distance() {
563        let result = preprocess(
564            "SELECT * FROM articles WHERE text_match(title, 'test', { fuzzy: true, distance: 2 })",
565        )
566        .unwrap();
567        assert!(result.sql.contains("\"fuzzy\""), "got: {}", result.sql);
568        assert!(result.sql.contains("\"distance\""), "got: {}", result.sql);
569    }
570
571    #[test]
572    fn object_literal_not_rewritten_outside_function() {
573        // INSERT { } should NOT be touched by the function-arg rewriter.
574        // It goes through try_rewrite_object_literal instead.
575        let result = preprocess("INSERT INTO docs { name: 'Alice' }").unwrap();
576        // Should be VALUES, not a JSON string.
577        assert!(result.sql.contains("VALUES"), "got: {}", result.sql);
578    }
579}