Skip to main content

nodedb_sql/parser/preprocess/
pipeline.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! SQL pre-processing orchestrator: rewrite NodeDB-specific syntax into
4//! standard SQL before handing to sqlparser-rs.
5//!
6//! Handles:
7//! - `UPSERT INTO coll (cols) VALUES (vals)` → `INSERT INTO ...` + upsert flag
8//! - `INSERT INTO coll { key: 'val', ... }` → `INSERT INTO coll (key) VALUES ('val')`
9//! - `UPSERT INTO coll { ... }` → both rewrites combined
10//! - `expr <-> expr` → `vector_distance(expr, expr)`
11//! - `{ key: val }` in function args → JSON string literal
12
13use super::function_args::rewrite_object_literal_args;
14use super::lex::{first_sql_word, has_brace_outside_literals, has_operator_outside_literals};
15use super::object_literal_stmt::try_rewrite_object_literal;
16use super::search_vector::try_rewrite_search_using_vector;
17use super::temporal::extract as extract_temporal;
18use super::vector_ops::{
19    rewrite_arrow_distance, rewrite_cosine_distance, rewrite_neg_inner_product,
20};
21use crate::error::SqlError;
22use crate::temporal::TemporalScope;
23
24/// Result of pre-processing a SQL string.
25#[derive(Debug)]
26pub struct PreprocessedSql {
27    /// The rewritten SQL (standard SQL that sqlparser can handle).
28    pub sql: String,
29    /// Whether the original statement was UPSERT (not INSERT).
30    pub is_upsert: bool,
31    /// Bitemporal qualifier extracted from `FOR SYSTEM_TIME` /
32    /// `FOR VALID_TIME` / `__system_as_of__(...)`. Default when none
33    /// was present.
34    pub temporal: TemporalScope,
35}
36
37/// Pre-process a SQL string, rewriting NodeDB-specific syntax.
38///
39/// Returns `Ok(None)` if no rewriting was needed. Temporal clause parse
40/// errors bubble up as `SqlError::Parse` so they surface to the caller
41/// identically to sqlparser errors.
42pub fn preprocess(sql: &str) -> Result<Option<PreprocessedSql>, SqlError> {
43    let trimmed = sql.trim();
44
45    // Extract temporal clauses first — they can appear in both SELECT and
46    // INSERT...SELECT, and stripping them before the UPSERT/object-literal
47    // rewrites keeps those rewriters pattern-free of NodeDB extensions.
48    let (temporal_sql, temporal) =
49        match extract_temporal(trimmed).map_err(|e| SqlError::Parse { detail: e.0 })? {
50            Some(ex) => (ex.sql, ex.temporal),
51            None => (trimmed.to_string(), TemporalScope::default()),
52        };
53    let any_temporal = temporal != TemporalScope::default();
54
55    // Rewrite `SEARCH <coll> USING VECTOR(...)` to canonical
56    // `SELECT * FROM <coll> ORDER BY vector_distance(...) LIMIT k` before any
57    // first-word dispatch — the rewritten form re-enters the rest of the
58    // pipeline as a plain SELECT.
59    let (temporal_sql, search_vector_rewritten) =
60        match try_rewrite_search_using_vector(&temporal_sql) {
61            Some(rewritten) => (rewritten, true),
62            None => (temporal_sql, false),
63        };
64
65    let first_word = first_sql_word(&temporal_sql)
66        .map(|w| w.to_uppercase())
67        .unwrap_or_default();
68    let is_upsert = first_word == "UPSERT";
69
70    if is_upsert {
71        let rewritten = format!("INSERT INTO {}", &temporal_sql["UPSERT INTO ".len()..]);
72        if let Some(result) = try_rewrite_object_literal(&rewritten) {
73            return Ok(Some(PreprocessedSql {
74                sql: result,
75                is_upsert: true,
76                temporal,
77            }));
78        }
79        return Ok(Some(PreprocessedSql {
80            sql: rewritten,
81            is_upsert: true,
82            temporal,
83        }));
84    }
85
86    if first_word == "INSERT"
87        && let Some(result) = try_rewrite_object_literal(&temporal_sql)
88    {
89        return Ok(Some(PreprocessedSql {
90            sql: result,
91            is_upsert: false,
92            temporal,
93        }));
94    }
95
96    let mut sql_buf = temporal_sql;
97    let mut any_rewrite = any_temporal || search_vector_rewritten;
98
99    if has_operator_outside_literals(&sql_buf, "<->")
100        && let Some(rewritten) = rewrite_arrow_distance(&sql_buf)
101    {
102        sql_buf = rewritten;
103        any_rewrite = true;
104    }
105
106    if has_operator_outside_literals(&sql_buf, "<=>")
107        && let Some(rewritten) = rewrite_cosine_distance(&sql_buf)
108    {
109        sql_buf = rewritten;
110        any_rewrite = true;
111    }
112
113    if has_operator_outside_literals(&sql_buf, "<#>")
114        && let Some(rewritten) = rewrite_neg_inner_product(&sql_buf)
115    {
116        sql_buf = rewritten;
117        any_rewrite = true;
118    }
119
120    if has_brace_outside_literals(&sql_buf)
121        && let Some(rewritten) = rewrite_object_literal_args(&sql_buf)
122    {
123        sql_buf = rewritten;
124        any_rewrite = true;
125    }
126
127    if any_rewrite {
128        return Ok(Some(PreprocessedSql {
129            sql: sql_buf,
130            is_upsert: false,
131            temporal,
132        }));
133    }
134
135    Ok(None)
136}
137
138#[cfg(test)]
139mod tests {
140    use super::super::function_args::rewrite_object_literal_args;
141    use super::*;
142
143    fn pp(sql: &str) -> Option<PreprocessedSql> {
144        super::preprocess(sql).unwrap()
145    }
146
147    #[test]
148    fn passthrough_standard_sql() {
149        assert!(pp("SELECT * FROM users").is_none());
150        assert!(pp("INSERT INTO users (name) VALUES ('alice')").is_none());
151        assert!(pp("DELETE FROM users WHERE id = 1").is_none());
152    }
153
154    #[test]
155    fn upsert_rewrite() {
156        let result = pp("UPSERT INTO users (name) VALUES ('alice')").unwrap();
157        assert!(result.is_upsert);
158        assert_eq!(result.sql, "INSERT INTO users (name) VALUES ('alice')");
159    }
160
161    #[test]
162    fn object_literal_insert() {
163        let result = pp("INSERT INTO users { name: 'alice', age: 30 }").unwrap();
164        assert!(!result.is_upsert);
165        assert!(result.sql.starts_with("INSERT INTO users ("));
166        assert!(result.sql.contains("'alice'"));
167        assert!(result.sql.contains("30"));
168    }
169
170    #[test]
171    fn object_literal_upsert() {
172        let result = pp("UPSERT INTO users { name: 'bob' }").unwrap();
173        assert!(result.is_upsert);
174        assert!(result.sql.starts_with("INSERT INTO users ("));
175        assert!(result.sql.contains("'bob'"));
176    }
177
178    #[test]
179    fn batch_array_insert() {
180        let result =
181            pp("INSERT INTO users [{ name: 'alice', age: 30 }, { name: 'bob', age: 25 }]").unwrap();
182        assert!(!result.is_upsert);
183        assert!(result.sql.contains("VALUES"));
184        assert!(result.sql.contains("'alice'"));
185        assert!(result.sql.contains("'bob'"));
186        assert!(result.sql.contains("30"));
187        assert!(result.sql.contains("25"));
188        let values_part = result.sql.split("VALUES").nth(1).unwrap();
189        let row_count = values_part.matches('(').count();
190        assert_eq!(row_count, 2, "should have 2 row groups: {}", result.sql);
191    }
192
193    #[test]
194    fn batch_array_heterogeneous_keys() {
195        let result =
196            pp("INSERT INTO docs [{ id: 'a', name: 'Alice' }, { id: 'b', role: 'admin' }]")
197                .unwrap();
198        assert!(result.sql.contains("NULL"));
199        assert!(result.sql.contains("'Alice'"));
200        assert!(result.sql.contains("'admin'"));
201    }
202
203    #[test]
204    fn batch_array_upsert() {
205        let result =
206            pp("UPSERT INTO users [{ id: 'u1', name: 'a' }, { id: 'u2', name: 'b' }]").unwrap();
207        assert!(result.is_upsert);
208        assert!(result.sql.contains("VALUES"));
209    }
210
211    #[test]
212    fn arrow_distance_operator_select() {
213        let result =
214            pp("SELECT title FROM articles ORDER BY embedding <-> ARRAY[0.1, 0.2, 0.3] LIMIT 5")
215                .unwrap();
216        assert!(
217            result
218                .sql
219                .contains("vector_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
220            "got: {}",
221            result.sql
222        );
223        assert!(!result.sql.contains("<->"));
224    }
225
226    #[test]
227    fn arrow_distance_operator_where() {
228        let result = pp("SELECT * FROM docs WHERE embedding <-> ARRAY[1.0, 2.0] < 0.5").unwrap();
229        assert!(
230            result
231                .sql
232                .contains("vector_distance(embedding, ARRAY[1.0, 2.0])"),
233            "got: {}",
234            result.sql
235        );
236        // Regression guard: the rewriter used to compute the left-operand
237        // start with `before.len() - left.len()`, which is off-by-one when
238        // there is whitespace between the column and `<->` — the column's
239        // leading character was left in the prefix, producing
240        // `WHERE evector_distance(embedding, ...)`. Lock the corrected slice.
241        assert!(
242            !result.sql.contains("evector_distance"),
243            "rewriter must not leave the column's leading char in the prefix; got: {}",
244            result.sql
245        );
246        assert!(
247            result.sql.contains("WHERE vector_distance("),
248            "rewritten WHERE clause must start `WHERE vector_distance(`; got: {}",
249            result.sql
250        );
251    }
252
253    #[test]
254    fn arrow_distance_operator_where_no_whitespace_before_op() {
255        // No whitespace between column and `<->` — the spacing-sensitive
256        // off-by-one cannot fire here; the rewriter must still produce a
257        // clean `vector_distance(embedding, ARRAY[...])` form.
258        let result = pp("SELECT id FROM docs WHERE embedding<->ARRAY[1.0, 2.0]").unwrap();
259        assert!(
260            result
261                .sql
262                .contains("vector_distance(embedding, ARRAY[1.0, 2.0])"),
263            "got: {}",
264            result.sql
265        );
266        assert!(
267            !result.sql.contains("evector_distance"),
268            "got: {}",
269            result.sql
270        );
271    }
272
273    #[test]
274    fn arrow_distance_no_match() {
275        assert!(pp("SELECT * FROM users WHERE age > 30").is_none());
276    }
277
278    #[test]
279    fn arrow_distance_with_alias() {
280        let result = pp("SELECT embedding <-> ARRAY[0.1, 0.2] AS dist FROM articles").unwrap();
281        assert!(
282            result
283                .sql
284                .contains("vector_distance(embedding, ARRAY[0.1, 0.2]) AS dist"),
285            "got: {}",
286            result.sql
287        );
288    }
289
290    #[test]
291    fn fuzzy_object_literal_in_function() {
292        let direct = rewrite_object_literal_args(
293            "SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })",
294        );
295        assert!(direct.is_some(), "rewrite_object_literal_args should match");
296        let rewritten = direct.unwrap();
297        assert!(
298            rewritten.contains("\"fuzzy\""),
299            "direct rewrite should contain JSON, got: {}",
300            rewritten
301        );
302
303        let result =
304            pp("SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })").unwrap();
305        assert!(
306            !result.sql.contains("{ fuzzy"),
307            "should not contain object literal, got: {}",
308            result.sql
309        );
310    }
311
312    #[test]
313    fn fuzzy_object_literal_with_distance() {
314        let result = pp(
315            "SELECT * FROM articles WHERE text_match(title, 'test', { fuzzy: true, distance: 2 })",
316        )
317        .unwrap();
318        assert!(result.sql.contains("\"fuzzy\""), "got: {}", result.sql);
319        assert!(result.sql.contains("\"distance\""), "got: {}", result.sql);
320    }
321
322    #[test]
323    fn object_literal_not_rewritten_outside_function() {
324        let result = pp("INSERT INTO docs { name: 'Alice' }").unwrap();
325        assert!(result.sql.contains("VALUES"), "got: {}", result.sql);
326    }
327
328    #[test]
329    fn comment_prefix_insert_routes_correctly() {
330        // A block comment before INSERT must still trigger the INSERT path.
331        let result = pp("/* hint */ INSERT INTO t VALUES ({})");
332        // Either rewrites (object literal) or passes through as INSERT — must
333        // not trigger the UPSERT path.
334        if let Some(r) = result {
335            assert!(!r.is_upsert);
336        }
337    }
338
339    #[test]
340    fn comment_prefix_upsert_routes_correctly() {
341        // A block comment before UPSERT must still trigger the UPSERT path.
342        let result = pp("/* hint */ UPSERT INTO t (name) VALUES ('a')").unwrap();
343        assert!(result.is_upsert);
344    }
345
346    #[test]
347    fn line_comment_before_insert_does_not_trigger_insert() {
348        // `-- INSERT INTO t` is a comment; the real statement is SELECT 1.
349        // It must pass through without upsert rewriting.
350        let result = pp("-- INSERT INTO t\nSELECT 1");
351        assert!(
352            result.is_none(),
353            "line-commented INSERT must pass through, got: {result:?}"
354        );
355    }
356
357    #[test]
358    fn string_literal_brace_not_rewritten_as_object_literal() {
359        // `{ foo }` is inside a string literal — must not be touched.
360        let result = pp("SELECT func('{ foo }')");
361        assert!(
362            result.is_none() || !result.unwrap().sql.contains("\"foo\""),
363            "string literal brace must not be rewritten"
364        );
365    }
366}