nodedb_sql/parser/preprocess/
pipeline.rs1use super::function_args::rewrite_object_literal_args;
14use super::lex::{first_sql_word, has_brace_outside_literals, has_operator_outside_literals};
15use super::object_literal_stmt::try_rewrite_object_literal;
16use super::search_vector::try_rewrite_search_using_vector;
17use super::temporal::extract as extract_temporal;
18use super::vector_ops::{
19 rewrite_arrow_distance, rewrite_cosine_distance, rewrite_neg_inner_product,
20};
21use crate::error::SqlError;
22use crate::temporal::TemporalScope;
23
24#[derive(Debug)]
26pub struct PreprocessedSql {
27 pub sql: String,
29 pub is_upsert: bool,
31 pub temporal: TemporalScope,
35}
36
37pub fn preprocess(sql: &str) -> Result<Option<PreprocessedSql>, SqlError> {
43 let trimmed = sql.trim();
44
45 let (temporal_sql, temporal) =
49 match extract_temporal(trimmed).map_err(|e| SqlError::Parse { detail: e.0 })? {
50 Some(ex) => (ex.sql, ex.temporal),
51 None => (trimmed.to_string(), TemporalScope::default()),
52 };
53 let any_temporal = temporal != TemporalScope::default();
54
55 let (temporal_sql, search_vector_rewritten) =
60 match try_rewrite_search_using_vector(&temporal_sql) {
61 Some(rewritten) => (rewritten, true),
62 None => (temporal_sql, false),
63 };
64
65 let first_word = first_sql_word(&temporal_sql)
66 .map(|w| w.to_uppercase())
67 .unwrap_or_default();
68 let is_upsert = first_word == "UPSERT";
69
70 if is_upsert {
71 let rewritten = format!("INSERT INTO {}", &temporal_sql["UPSERT INTO ".len()..]);
72 if let Some(result) = try_rewrite_object_literal(&rewritten) {
73 return Ok(Some(PreprocessedSql {
74 sql: result,
75 is_upsert: true,
76 temporal,
77 }));
78 }
79 return Ok(Some(PreprocessedSql {
80 sql: rewritten,
81 is_upsert: true,
82 temporal,
83 }));
84 }
85
86 if first_word == "INSERT"
87 && let Some(result) = try_rewrite_object_literal(&temporal_sql)
88 {
89 return Ok(Some(PreprocessedSql {
90 sql: result,
91 is_upsert: false,
92 temporal,
93 }));
94 }
95
96 let mut sql_buf = temporal_sql;
97 let mut any_rewrite = any_temporal || search_vector_rewritten;
98
99 if has_operator_outside_literals(&sql_buf, "<->")
100 && let Some(rewritten) = rewrite_arrow_distance(&sql_buf)
101 {
102 sql_buf = rewritten;
103 any_rewrite = true;
104 }
105
106 if has_operator_outside_literals(&sql_buf, "<=>")
107 && let Some(rewritten) = rewrite_cosine_distance(&sql_buf)
108 {
109 sql_buf = rewritten;
110 any_rewrite = true;
111 }
112
113 if has_operator_outside_literals(&sql_buf, "<#>")
114 && let Some(rewritten) = rewrite_neg_inner_product(&sql_buf)
115 {
116 sql_buf = rewritten;
117 any_rewrite = true;
118 }
119
120 if has_brace_outside_literals(&sql_buf)
121 && let Some(rewritten) = rewrite_object_literal_args(&sql_buf)
122 {
123 sql_buf = rewritten;
124 any_rewrite = true;
125 }
126
127 if any_rewrite {
128 return Ok(Some(PreprocessedSql {
129 sql: sql_buf,
130 is_upsert: false,
131 temporal,
132 }));
133 }
134
135 Ok(None)
136}
137
138#[cfg(test)]
139mod tests {
140 use super::super::function_args::rewrite_object_literal_args;
141 use super::*;
142
143 fn pp(sql: &str) -> Option<PreprocessedSql> {
144 super::preprocess(sql).unwrap()
145 }
146
147 #[test]
148 fn passthrough_standard_sql() {
149 assert!(pp("SELECT * FROM users").is_none());
150 assert!(pp("INSERT INTO users (name) VALUES ('alice')").is_none());
151 assert!(pp("DELETE FROM users WHERE id = 1").is_none());
152 }
153
154 #[test]
155 fn upsert_rewrite() {
156 let result = pp("UPSERT INTO users (name) VALUES ('alice')").unwrap();
157 assert!(result.is_upsert);
158 assert_eq!(result.sql, "INSERT INTO users (name) VALUES ('alice')");
159 }
160
161 #[test]
162 fn object_literal_insert() {
163 let result = pp("INSERT INTO users { name: 'alice', age: 30 }").unwrap();
164 assert!(!result.is_upsert);
165 assert!(result.sql.starts_with("INSERT INTO users ("));
166 assert!(result.sql.contains("'alice'"));
167 assert!(result.sql.contains("30"));
168 }
169
170 #[test]
171 fn object_literal_upsert() {
172 let result = pp("UPSERT INTO users { name: 'bob' }").unwrap();
173 assert!(result.is_upsert);
174 assert!(result.sql.starts_with("INSERT INTO users ("));
175 assert!(result.sql.contains("'bob'"));
176 }
177
178 #[test]
179 fn batch_array_insert() {
180 let result =
181 pp("INSERT INTO users [{ name: 'alice', age: 30 }, { name: 'bob', age: 25 }]").unwrap();
182 assert!(!result.is_upsert);
183 assert!(result.sql.contains("VALUES"));
184 assert!(result.sql.contains("'alice'"));
185 assert!(result.sql.contains("'bob'"));
186 assert!(result.sql.contains("30"));
187 assert!(result.sql.contains("25"));
188 let values_part = result.sql.split("VALUES").nth(1).unwrap();
189 let row_count = values_part.matches('(').count();
190 assert_eq!(row_count, 2, "should have 2 row groups: {}", result.sql);
191 }
192
193 #[test]
194 fn batch_array_heterogeneous_keys() {
195 let result =
196 pp("INSERT INTO docs [{ id: 'a', name: 'Alice' }, { id: 'b', role: 'admin' }]")
197 .unwrap();
198 assert!(result.sql.contains("NULL"));
199 assert!(result.sql.contains("'Alice'"));
200 assert!(result.sql.contains("'admin'"));
201 }
202
203 #[test]
204 fn batch_array_upsert() {
205 let result =
206 pp("UPSERT INTO users [{ id: 'u1', name: 'a' }, { id: 'u2', name: 'b' }]").unwrap();
207 assert!(result.is_upsert);
208 assert!(result.sql.contains("VALUES"));
209 }
210
211 #[test]
212 fn arrow_distance_operator_select() {
213 let result =
214 pp("SELECT title FROM articles ORDER BY embedding <-> ARRAY[0.1, 0.2, 0.3] LIMIT 5")
215 .unwrap();
216 assert!(
217 result
218 .sql
219 .contains("vector_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
220 "got: {}",
221 result.sql
222 );
223 assert!(!result.sql.contains("<->"));
224 }
225
226 #[test]
227 fn arrow_distance_operator_where() {
228 let result = pp("SELECT * FROM docs WHERE embedding <-> ARRAY[1.0, 2.0] < 0.5").unwrap();
229 assert!(
230 result
231 .sql
232 .contains("vector_distance(embedding, ARRAY[1.0, 2.0])"),
233 "got: {}",
234 result.sql
235 );
236 assert!(
242 !result.sql.contains("evector_distance"),
243 "rewriter must not leave the column's leading char in the prefix; got: {}",
244 result.sql
245 );
246 assert!(
247 result.sql.contains("WHERE vector_distance("),
248 "rewritten WHERE clause must start `WHERE vector_distance(`; got: {}",
249 result.sql
250 );
251 }
252
253 #[test]
254 fn arrow_distance_operator_where_no_whitespace_before_op() {
255 let result = pp("SELECT id FROM docs WHERE embedding<->ARRAY[1.0, 2.0]").unwrap();
259 assert!(
260 result
261 .sql
262 .contains("vector_distance(embedding, ARRAY[1.0, 2.0])"),
263 "got: {}",
264 result.sql
265 );
266 assert!(
267 !result.sql.contains("evector_distance"),
268 "got: {}",
269 result.sql
270 );
271 }
272
273 #[test]
274 fn arrow_distance_no_match() {
275 assert!(pp("SELECT * FROM users WHERE age > 30").is_none());
276 }
277
278 #[test]
279 fn arrow_distance_with_alias() {
280 let result = pp("SELECT embedding <-> ARRAY[0.1, 0.2] AS dist FROM articles").unwrap();
281 assert!(
282 result
283 .sql
284 .contains("vector_distance(embedding, ARRAY[0.1, 0.2]) AS dist"),
285 "got: {}",
286 result.sql
287 );
288 }
289
290 #[test]
291 fn fuzzy_object_literal_in_function() {
292 let direct = rewrite_object_literal_args(
293 "SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })",
294 );
295 assert!(direct.is_some(), "rewrite_object_literal_args should match");
296 let rewritten = direct.unwrap();
297 assert!(
298 rewritten.contains("\"fuzzy\""),
299 "direct rewrite should contain JSON, got: {}",
300 rewritten
301 );
302
303 let result =
304 pp("SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })").unwrap();
305 assert!(
306 !result.sql.contains("{ fuzzy"),
307 "should not contain object literal, got: {}",
308 result.sql
309 );
310 }
311
312 #[test]
313 fn fuzzy_object_literal_with_distance() {
314 let result = pp(
315 "SELECT * FROM articles WHERE text_match(title, 'test', { fuzzy: true, distance: 2 })",
316 )
317 .unwrap();
318 assert!(result.sql.contains("\"fuzzy\""), "got: {}", result.sql);
319 assert!(result.sql.contains("\"distance\""), "got: {}", result.sql);
320 }
321
322 #[test]
323 fn object_literal_not_rewritten_outside_function() {
324 let result = pp("INSERT INTO docs { name: 'Alice' }").unwrap();
325 assert!(result.sql.contains("VALUES"), "got: {}", result.sql);
326 }
327
328 #[test]
329 fn comment_prefix_insert_routes_correctly() {
330 let result = pp("/* hint */ INSERT INTO t VALUES ({})");
332 if let Some(r) = result {
335 assert!(!r.is_upsert);
336 }
337 }
338
339 #[test]
340 fn comment_prefix_upsert_routes_correctly() {
341 let result = pp("/* hint */ UPSERT INTO t (name) VALUES ('a')").unwrap();
343 assert!(result.is_upsert);
344 }
345
346 #[test]
347 fn line_comment_before_insert_does_not_trigger_insert() {
348 let result = pp("-- INSERT INTO t\nSELECT 1");
351 assert!(
352 result.is_none(),
353 "line-commented INSERT must pass through, got: {result:?}"
354 );
355 }
356
357 #[test]
358 fn string_literal_brace_not_rewritten_as_object_literal() {
359 let result = pp("SELECT func('{ foo }')");
361 assert!(
362 result.is_none() || !result.unwrap().sql.contains("\"foo\""),
363 "string literal brace must not be rewritten"
364 );
365 }
366}