nodedb_sql/parser/preprocess/
pipeline.rs1use super::function_args::rewrite_object_literal_args;
12use super::object_literal_stmt::try_rewrite_object_literal;
13use super::vector_ops::rewrite_arrow_distance;
14
15pub struct PreprocessedSql {
17 pub sql: String,
19 pub is_upsert: bool,
21}
22
23pub fn preprocess(sql: &str) -> Option<PreprocessedSql> {
27 let trimmed = sql.trim();
28 let upper = trimmed.to_uppercase();
29
30 let is_upsert = upper.starts_with("UPSERT INTO ");
31
32 if is_upsert {
33 let rewritten = format!("INSERT INTO {}", &trimmed["UPSERT INTO ".len()..]);
34 if let Some(result) = try_rewrite_object_literal(&rewritten) {
35 return Some(PreprocessedSql {
36 sql: result,
37 is_upsert: true,
38 });
39 }
40 return Some(PreprocessedSql {
41 sql: rewritten,
42 is_upsert: true,
43 });
44 }
45
46 if upper.starts_with("INSERT INTO ")
47 && let Some(result) = try_rewrite_object_literal(trimmed)
48 {
49 return Some(PreprocessedSql {
50 sql: result,
51 is_upsert: false,
52 });
53 }
54
55 let mut sql_buf = trimmed.to_string();
56 let mut any_rewrite = false;
57
58 if sql_buf.contains("<->")
59 && let Some(rewritten) = rewrite_arrow_distance(&sql_buf)
60 {
61 sql_buf = rewritten;
62 any_rewrite = true;
63 }
64
65 if (sql_buf.contains("{ ") || sql_buf.contains("{f") || sql_buf.contains("{d"))
66 && let Some(rewritten) = rewrite_object_literal_args(&sql_buf)
67 {
68 sql_buf = rewritten;
69 any_rewrite = true;
70 }
71
72 if any_rewrite {
73 return Some(PreprocessedSql {
74 sql: sql_buf,
75 is_upsert: false,
76 });
77 }
78
79 None
80}
81
82#[cfg(test)]
83mod tests {
84 use super::super::function_args::rewrite_object_literal_args;
85 use super::*;
86
87 #[test]
88 fn passthrough_standard_sql() {
89 assert!(preprocess("SELECT * FROM users").is_none());
90 assert!(preprocess("INSERT INTO users (name) VALUES ('alice')").is_none());
91 assert!(preprocess("DELETE FROM users WHERE id = 1").is_none());
92 }
93
94 #[test]
95 fn upsert_rewrite() {
96 let result = preprocess("UPSERT INTO users (name) VALUES ('alice')").unwrap();
97 assert!(result.is_upsert);
98 assert_eq!(result.sql, "INSERT INTO users (name) VALUES ('alice')");
99 }
100
101 #[test]
102 fn object_literal_insert() {
103 let result = preprocess("INSERT INTO users { name: 'alice', age: 30 }").unwrap();
104 assert!(!result.is_upsert);
105 assert!(result.sql.starts_with("INSERT INTO users ("));
106 assert!(result.sql.contains("'alice'"));
107 assert!(result.sql.contains("30"));
108 }
109
110 #[test]
111 fn object_literal_upsert() {
112 let result = preprocess("UPSERT INTO users { name: 'bob' }").unwrap();
113 assert!(result.is_upsert);
114 assert!(result.sql.starts_with("INSERT INTO users ("));
115 assert!(result.sql.contains("'bob'"));
116 }
117
118 #[test]
119 fn batch_array_insert() {
120 let result =
121 preprocess("INSERT INTO users [{ name: 'alice', age: 30 }, { name: 'bob', age: 25 }]")
122 .unwrap();
123 assert!(!result.is_upsert);
124 assert!(result.sql.contains("VALUES"));
125 assert!(result.sql.contains("'alice'"));
126 assert!(result.sql.contains("'bob'"));
127 assert!(result.sql.contains("30"));
128 assert!(result.sql.contains("25"));
129 let values_part = result.sql.split("VALUES").nth(1).unwrap();
130 let row_count = values_part.matches('(').count();
131 assert_eq!(row_count, 2, "should have 2 row groups: {}", result.sql);
132 }
133
134 #[test]
135 fn batch_array_heterogeneous_keys() {
136 let result =
137 preprocess("INSERT INTO docs [{ id: 'a', name: 'Alice' }, { id: 'b', role: 'admin' }]")
138 .unwrap();
139 assert!(result.sql.contains("NULL"));
140 assert!(result.sql.contains("'Alice'"));
141 assert!(result.sql.contains("'admin'"));
142 }
143
144 #[test]
145 fn batch_array_upsert() {
146 let result =
147 preprocess("UPSERT INTO users [{ id: 'u1', name: 'a' }, { id: 'u2', name: 'b' }]")
148 .unwrap();
149 assert!(result.is_upsert);
150 assert!(result.sql.contains("VALUES"));
151 }
152
153 #[test]
154 fn arrow_distance_operator_select() {
155 let result = preprocess(
156 "SELECT title FROM articles ORDER BY embedding <-> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
157 )
158 .unwrap();
159 assert!(
160 result
161 .sql
162 .contains("vector_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
163 "got: {}",
164 result.sql
165 );
166 assert!(!result.sql.contains("<->"));
167 }
168
169 #[test]
170 fn arrow_distance_operator_where() {
171 let result =
172 preprocess("SELECT * FROM docs WHERE embedding <-> ARRAY[1.0, 2.0] < 0.5").unwrap();
173 assert!(
174 result
175 .sql
176 .contains("vector_distance(embedding, ARRAY[1.0, 2.0])"),
177 "got: {}",
178 result.sql
179 );
180 }
181
182 #[test]
183 fn arrow_distance_no_match() {
184 assert!(preprocess("SELECT * FROM users WHERE age > 30").is_none());
185 }
186
187 #[test]
188 fn arrow_distance_with_alias() {
189 let result =
190 preprocess("SELECT embedding <-> ARRAY[0.1, 0.2] AS dist FROM articles").unwrap();
191 assert!(
192 result
193 .sql
194 .contains("vector_distance(embedding, ARRAY[0.1, 0.2]) AS dist"),
195 "got: {}",
196 result.sql
197 );
198 }
199
200 #[test]
201 fn fuzzy_object_literal_in_function() {
202 let direct = rewrite_object_literal_args(
203 "SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })",
204 );
205 assert!(direct.is_some(), "rewrite_object_literal_args should match");
206 let rewritten = direct.unwrap();
207 assert!(
208 rewritten.contains("\"fuzzy\""),
209 "direct rewrite should contain JSON, got: {}",
210 rewritten
211 );
212
213 let result =
214 preprocess("SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })")
215 .unwrap();
216 assert!(
217 !result.sql.contains("{ fuzzy"),
218 "should not contain object literal, got: {}",
219 result.sql
220 );
221 }
222
223 #[test]
224 fn fuzzy_object_literal_with_distance() {
225 let result = preprocess(
226 "SELECT * FROM articles WHERE text_match(title, 'test', { fuzzy: true, distance: 2 })",
227 )
228 .unwrap();
229 assert!(result.sql.contains("\"fuzzy\""), "got: {}", result.sql);
230 assert!(result.sql.contains("\"distance\""), "got: {}", result.sql);
231 }
232
233 #[test]
234 fn object_literal_not_rewritten_outside_function() {
235 let result = preprocess("INSERT INTO docs { name: 'Alice' }").unwrap();
236 assert!(result.sql.contains("VALUES"), "got: {}", result.sql);
237 }
238}