Skip to main content

nodedb_sql/parser/preprocess/
pipeline.rs

1//! SQL pre-processing orchestrator: rewrite NodeDB-specific syntax into
2//! standard SQL before handing to sqlparser-rs.
3//!
4//! Handles:
5//! - `UPSERT INTO coll (cols) VALUES (vals)` → `INSERT INTO ...` + upsert flag
6//! - `INSERT INTO coll { key: 'val', ... }` → `INSERT INTO coll (key) VALUES ('val')`
7//! - `UPSERT INTO coll { ... }` → both rewrites combined
8//! - `expr <-> expr` → `vector_distance(expr, expr)`
9//! - `{ key: val }` in function args → JSON string literal
10
11use super::function_args::rewrite_object_literal_args;
12use super::object_literal_stmt::try_rewrite_object_literal;
13use super::vector_ops::rewrite_arrow_distance;
14
15/// Result of pre-processing a SQL string.
16pub struct PreprocessedSql {
17    /// The rewritten SQL (standard SQL that sqlparser can handle).
18    pub sql: String,
19    /// Whether the original statement was UPSERT (not INSERT).
20    pub is_upsert: bool,
21}
22
23/// Pre-process a SQL string, rewriting NodeDB-specific syntax.
24///
25/// Returns `None` if no rewriting was needed (pass through to sqlparser as-is).
26pub fn preprocess(sql: &str) -> Option<PreprocessedSql> {
27    let trimmed = sql.trim();
28    let upper = trimmed.to_uppercase();
29
30    let is_upsert = upper.starts_with("UPSERT INTO ");
31
32    if is_upsert {
33        let rewritten = format!("INSERT INTO {}", &trimmed["UPSERT INTO ".len()..]);
34        if let Some(result) = try_rewrite_object_literal(&rewritten) {
35            return Some(PreprocessedSql {
36                sql: result,
37                is_upsert: true,
38            });
39        }
40        return Some(PreprocessedSql {
41            sql: rewritten,
42            is_upsert: true,
43        });
44    }
45
46    if upper.starts_with("INSERT INTO ")
47        && let Some(result) = try_rewrite_object_literal(trimmed)
48    {
49        return Some(PreprocessedSql {
50            sql: result,
51            is_upsert: false,
52        });
53    }
54
55    let mut sql_buf = trimmed.to_string();
56    let mut any_rewrite = false;
57
58    if sql_buf.contains("<->")
59        && let Some(rewritten) = rewrite_arrow_distance(&sql_buf)
60    {
61        sql_buf = rewritten;
62        any_rewrite = true;
63    }
64
65    if (sql_buf.contains("{ ") || sql_buf.contains("{f") || sql_buf.contains("{d"))
66        && let Some(rewritten) = rewrite_object_literal_args(&sql_buf)
67    {
68        sql_buf = rewritten;
69        any_rewrite = true;
70    }
71
72    if any_rewrite {
73        return Some(PreprocessedSql {
74            sql: sql_buf,
75            is_upsert: false,
76        });
77    }
78
79    None
80}
81
82#[cfg(test)]
83mod tests {
84    use super::super::function_args::rewrite_object_literal_args;
85    use super::*;
86
87    #[test]
88    fn passthrough_standard_sql() {
89        assert!(preprocess("SELECT * FROM users").is_none());
90        assert!(preprocess("INSERT INTO users (name) VALUES ('alice')").is_none());
91        assert!(preprocess("DELETE FROM users WHERE id = 1").is_none());
92    }
93
94    #[test]
95    fn upsert_rewrite() {
96        let result = preprocess("UPSERT INTO users (name) VALUES ('alice')").unwrap();
97        assert!(result.is_upsert);
98        assert_eq!(result.sql, "INSERT INTO users (name) VALUES ('alice')");
99    }
100
101    #[test]
102    fn object_literal_insert() {
103        let result = preprocess("INSERT INTO users { name: 'alice', age: 30 }").unwrap();
104        assert!(!result.is_upsert);
105        assert!(result.sql.starts_with("INSERT INTO users ("));
106        assert!(result.sql.contains("'alice'"));
107        assert!(result.sql.contains("30"));
108    }
109
110    #[test]
111    fn object_literal_upsert() {
112        let result = preprocess("UPSERT INTO users { name: 'bob' }").unwrap();
113        assert!(result.is_upsert);
114        assert!(result.sql.starts_with("INSERT INTO users ("));
115        assert!(result.sql.contains("'bob'"));
116    }
117
118    #[test]
119    fn batch_array_insert() {
120        let result =
121            preprocess("INSERT INTO users [{ name: 'alice', age: 30 }, { name: 'bob', age: 25 }]")
122                .unwrap();
123        assert!(!result.is_upsert);
124        assert!(result.sql.contains("VALUES"));
125        assert!(result.sql.contains("'alice'"));
126        assert!(result.sql.contains("'bob'"));
127        assert!(result.sql.contains("30"));
128        assert!(result.sql.contains("25"));
129        let values_part = result.sql.split("VALUES").nth(1).unwrap();
130        let row_count = values_part.matches('(').count();
131        assert_eq!(row_count, 2, "should have 2 row groups: {}", result.sql);
132    }
133
134    #[test]
135    fn batch_array_heterogeneous_keys() {
136        let result =
137            preprocess("INSERT INTO docs [{ id: 'a', name: 'Alice' }, { id: 'b', role: 'admin' }]")
138                .unwrap();
139        assert!(result.sql.contains("NULL"));
140        assert!(result.sql.contains("'Alice'"));
141        assert!(result.sql.contains("'admin'"));
142    }
143
144    #[test]
145    fn batch_array_upsert() {
146        let result =
147            preprocess("UPSERT INTO users [{ id: 'u1', name: 'a' }, { id: 'u2', name: 'b' }]")
148                .unwrap();
149        assert!(result.is_upsert);
150        assert!(result.sql.contains("VALUES"));
151    }
152
153    #[test]
154    fn arrow_distance_operator_select() {
155        let result = preprocess(
156            "SELECT title FROM articles ORDER BY embedding <-> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
157        )
158        .unwrap();
159        assert!(
160            result
161                .sql
162                .contains("vector_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
163            "got: {}",
164            result.sql
165        );
166        assert!(!result.sql.contains("<->"));
167    }
168
169    #[test]
170    fn arrow_distance_operator_where() {
171        let result =
172            preprocess("SELECT * FROM docs WHERE embedding <-> ARRAY[1.0, 2.0] < 0.5").unwrap();
173        assert!(
174            result
175                .sql
176                .contains("vector_distance(embedding, ARRAY[1.0, 2.0])"),
177            "got: {}",
178            result.sql
179        );
180    }
181
182    #[test]
183    fn arrow_distance_no_match() {
184        assert!(preprocess("SELECT * FROM users WHERE age > 30").is_none());
185    }
186
187    #[test]
188    fn arrow_distance_with_alias() {
189        let result =
190            preprocess("SELECT embedding <-> ARRAY[0.1, 0.2] AS dist FROM articles").unwrap();
191        assert!(
192            result
193                .sql
194                .contains("vector_distance(embedding, ARRAY[0.1, 0.2]) AS dist"),
195            "got: {}",
196            result.sql
197        );
198    }
199
200    #[test]
201    fn fuzzy_object_literal_in_function() {
202        let direct = rewrite_object_literal_args(
203            "SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })",
204        );
205        assert!(direct.is_some(), "rewrite_object_literal_args should match");
206        let rewritten = direct.unwrap();
207        assert!(
208            rewritten.contains("\"fuzzy\""),
209            "direct rewrite should contain JSON, got: {}",
210            rewritten
211        );
212
213        let result =
214            preprocess("SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })")
215                .unwrap();
216        assert!(
217            !result.sql.contains("{ fuzzy"),
218            "should not contain object literal, got: {}",
219            result.sql
220        );
221    }
222
223    #[test]
224    fn fuzzy_object_literal_with_distance() {
225        let result = preprocess(
226            "SELECT * FROM articles WHERE text_match(title, 'test', { fuzzy: true, distance: 2 })",
227        )
228        .unwrap();
229        assert!(result.sql.contains("\"fuzzy\""), "got: {}", result.sql);
230        assert!(result.sql.contains("\"distance\""), "got: {}", result.sql);
231    }
232
233    #[test]
234    fn object_literal_not_rewritten_outside_function() {
235        let result = preprocess("INSERT INTO docs { name: 'Alice' }").unwrap();
236        assert!(result.sql.contains("VALUES"), "got: {}", result.sql);
237    }
238}