1use super::object_literal::{parse_object_literal, parse_object_literal_array};
10
11pub struct PreprocessedSql {
13 pub sql: String,
15 pub is_upsert: bool,
17}
18
19pub fn preprocess(sql: &str) -> Option<PreprocessedSql> {
23 let trimmed = sql.trim();
24 let upper = trimmed.to_uppercase();
25
26 let is_upsert = upper.starts_with("UPSERT INTO ");
28
29 if is_upsert {
30 let rewritten = format!("INSERT INTO {}", &trimmed["UPSERT INTO ".len()..]);
32 if let Some(result) = try_rewrite_object_literal(&rewritten) {
33 return Some(PreprocessedSql {
34 sql: result,
35 is_upsert: true,
36 });
37 }
38 return Some(PreprocessedSql {
39 sql: rewritten,
40 is_upsert: true,
41 });
42 }
43
44 if upper.starts_with("INSERT INTO ")
46 && let Some(result) = try_rewrite_object_literal(trimmed)
47 {
48 return Some(PreprocessedSql {
49 sql: result,
50 is_upsert: false,
51 });
52 }
53
54 let mut sql_buf = trimmed.to_string();
56 let mut any_rewrite = false;
57
58 if sql_buf.contains("<->")
60 && let Some(rewritten) = rewrite_arrow_distance(&sql_buf)
61 {
62 sql_buf = rewritten;
63 any_rewrite = true;
64 }
65
66 if (sql_buf.contains("{ ") || sql_buf.contains("{f") || sql_buf.contains("{d"))
69 && let Some(rewritten) = rewrite_object_literal_args(&sql_buf)
70 {
71 sql_buf = rewritten;
72 any_rewrite = true;
73 }
74
75 if any_rewrite {
76 return Some(PreprocessedSql {
77 sql: sql_buf,
78 is_upsert: false,
79 });
80 }
81
82 None
83}
84
85fn try_rewrite_object_literal(sql: &str) -> Option<String> {
90 let after_into = sql["INSERT INTO ".len()..].trim_start();
92 let coll_end = after_into.find(|c: char| c.is_whitespace())?;
93 let coll_name = &after_into[..coll_end];
94 let rest = after_into[coll_end..].trim_start();
95
96 let obj_str = rest.trim_end_matches(';').trim_end();
98
99 if obj_str.starts_with('[') {
100 return rewrite_array_form(coll_name, obj_str);
102 }
103
104 if !obj_str.starts_with('{') {
105 return None;
106 }
107
108 let fields = parse_object_literal(obj_str)?.ok()?;
110 if fields.is_empty() {
111 return None;
112 }
113 Some(fields_to_values_sql(coll_name, &[fields]))
114}
115
116fn rewrite_array_form(coll_name: &str, obj_str: &str) -> Option<String> {
118 let objects = parse_object_literal_array(obj_str)?.ok()?;
119 if objects.is_empty() {
120 return None;
121 }
122 Some(fields_to_values_sql(coll_name, &objects))
123}
124
125fn fields_to_values_sql(
129 coll_name: &str,
130 rows: &[std::collections::HashMap<String, nodedb_types::Value>],
131) -> String {
132 let mut all_keys: Vec<String> = rows
134 .iter()
135 .flat_map(|r| r.keys().cloned())
136 .collect::<std::collections::BTreeSet<_>>()
137 .into_iter()
138 .collect();
139 all_keys.sort();
140
141 let col_list = all_keys.join(", ");
142
143 let row_strs: Vec<String> = rows
144 .iter()
145 .map(|row| {
146 let vals: Vec<String> = all_keys
147 .iter()
148 .map(|k| match row.get(k) {
149 Some(v) => value_to_sql_literal(v),
150 None => "NULL".to_string(),
151 })
152 .collect();
153 format!("({})", vals.join(", "))
154 })
155 .collect();
156
157 format!(
158 "INSERT INTO {} ({}) VALUES {}",
159 coll_name,
160 col_list,
161 row_strs.join(", ")
162 )
163}
164
165fn rewrite_object_literal_args(sql: &str) -> Option<String> {
172 let mut result = String::with_capacity(sql.len());
173 let chars: Vec<char> = sql.chars().collect();
174 let mut i = 0;
175 let mut found = false;
176 let mut paren_depth: i32 = 0;
177
178 while i < chars.len() {
179 match chars[i] {
180 '(' => {
181 paren_depth += 1;
182 result.push('(');
183 i += 1;
184 }
185 ')' => {
186 paren_depth = paren_depth.saturating_sub(1);
187 result.push(')');
188 i += 1;
189 }
190 '\'' => {
191 result.push('\'');
193 i += 1;
194 while i < chars.len() {
195 result.push(chars[i]);
196 if chars[i] == '\'' {
197 if i + 1 < chars.len() && chars[i + 1] == '\'' {
199 i += 1;
200 result.push(chars[i]);
201 } else {
202 break;
203 }
204 }
205 i += 1;
206 }
207 i += 1;
208 }
209 '{' if paren_depth > 0 => {
210 let remaining: String = chars[i..].iter().collect();
212 if let Some(Ok(fields)) = parse_object_literal(&remaining) {
213 let obj_end = find_matching_brace(&chars, i);
215 if let Some(end) = obj_end {
216 let json = value_map_to_json(&fields);
217 result.push('\'');
218 result.push_str(&json);
219 result.push('\'');
220 i = end + 1;
221 found = true;
222 continue;
223 }
224 }
225 result.push('{');
227 i += 1;
228 }
229 _ => {
230 result.push(chars[i]);
231 i += 1;
232 }
233 }
234 }
235
236 if found { Some(result) } else { None }
237}
238
239fn value_map_to_json(fields: &std::collections::HashMap<String, nodedb_types::Value>) -> String {
241 let mut parts = Vec::with_capacity(fields.len());
242 let mut entries: Vec<_> = fields.iter().collect();
243 entries.sort_by_key(|(k, _)| k.as_str());
244 for (key, val) in entries {
245 parts.push(format!("\"{}\":{}", key, value_to_json(val)));
246 }
247 format!("{{{}}}", parts.join(","))
248}
249
250fn value_to_json(value: &nodedb_types::Value) -> String {
252 match value {
253 nodedb_types::Value::String(s) => {
254 format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\""))
255 }
256 nodedb_types::Value::Integer(n) => n.to_string(),
257 nodedb_types::Value::Float(f) => format!("{f}"),
258 nodedb_types::Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
259 nodedb_types::Value::Null => "null".to_string(),
260 nodedb_types::Value::Array(items) => {
261 let inner: Vec<String> = items.iter().map(value_to_json).collect();
262 format!("[{}]", inner.join(","))
263 }
264 nodedb_types::Value::Object(map) => value_map_to_json(map),
265 _ => format!("\"{}\"", format!("{value:?}").replace('"', "\\\"")),
266 }
267}
268
269fn find_matching_brace(chars: &[char], start: usize) -> Option<usize> {
271 let mut depth = 0;
272 let mut in_string = false;
273 for i in start..chars.len() {
274 match chars[i] {
275 '\'' if !in_string => in_string = true,
276 '\'' if in_string => {
277 if i + 1 < chars.len() && chars[i + 1] == '\'' {
278 continue;
280 }
281 in_string = false;
282 }
283 '{' if !in_string => depth += 1,
284 '}' if !in_string => {
285 depth -= 1;
286 if depth == 0 {
287 return Some(i);
288 }
289 }
290 _ => {}
291 }
292 }
293 None
294}
295
296fn rewrite_arrow_distance(sql: &str) -> Option<String> {
301 let mut result = String::with_capacity(sql.len());
302 let mut remaining = sql;
303 let mut found = false;
304
305 while let Some(arrow_pos) = remaining.find("<->") {
306 let before = &remaining[..arrow_pos];
308 let left = extract_left_operand(before)?;
309 let left_start = arrow_pos - left.len();
310
311 let after = &remaining[arrow_pos + 3..];
313 let (right, right_len) = extract_right_operand(after.trim_start())?;
314 let ws_skip = after.len() - after.trim_start().len();
315
316 result.push_str(&remaining[..left_start]);
318 result.push_str(&format!("vector_distance({left}, {right})"));
319 remaining = &remaining[arrow_pos + 3 + ws_skip + right_len..];
320 found = true;
321 }
322
323 if !found {
324 return None;
325 }
326
327 result.push_str(remaining);
328 Some(result)
329}
330
331fn extract_left_operand(before: &str) -> Option<String> {
333 let trimmed = before.trim_end();
334 let start = trimmed
336 .rfind(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
337 .map(|p| p + 1)
338 .unwrap_or(0);
339 let ident = &trimmed[start..];
340 if ident.is_empty() {
341 return None;
342 }
343 Some(ident.to_string())
344}
345
346fn extract_right_operand(after: &str) -> Option<(String, usize)> {
349 let trimmed = after.trim_start();
350 let upper = trimmed.to_uppercase();
351
352 if upper.starts_with("ARRAY[") {
353 let mut depth = 0;
355 for (i, c) in trimmed.char_indices() {
356 match c {
357 '[' => depth += 1,
358 ']' => {
359 depth -= 1;
360 if depth == 0 {
361 return Some((trimmed[..=i].to_string(), i + 1));
362 }
363 }
364 _ => {}
365 }
366 }
367 None } else if trimmed.starts_with('$') {
369 let end = trimmed
371 .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '$')
372 .unwrap_or(trimmed.len());
373 Some((trimmed[..end].to_string(), end))
374 } else {
375 let end = trimmed
377 .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
378 .unwrap_or(trimmed.len());
379 if end == 0 {
380 return None;
381 }
382 Some((trimmed[..end].to_string(), end))
383 }
384}
385
386pub fn value_to_sql_literal(value: &nodedb_types::Value) -> String {
391 match value {
392 nodedb_types::Value::String(s) => format!("'{}'", s.replace('\'', "''")),
393 nodedb_types::Value::Integer(n) => n.to_string(),
394 nodedb_types::Value::Float(f) => format!("{f}"),
395 nodedb_types::Value::Bool(b) => if *b { "TRUE" } else { "FALSE" }.to_string(),
396 nodedb_types::Value::Null => "NULL".to_string(),
397 nodedb_types::Value::Array(items) => {
398 let inner: Vec<String> = items.iter().map(value_to_sql_literal).collect();
399 format!("ARRAY[{}]", inner.join(", "))
400 }
401 nodedb_types::Value::Bytes(b) => {
402 let hex: String = b.iter().map(|byte| format!("{byte:02x}")).collect();
403 format!("'\\x{hex}'")
404 }
405 nodedb_types::Value::Object(_) => "NULL".to_string(),
406 nodedb_types::Value::Uuid(u) => format!("'{u}'"),
407 nodedb_types::Value::Ulid(u) => format!("'{u}'"),
408 nodedb_types::Value::DateTime(dt) => format!("'{dt}'"),
409 nodedb_types::Value::Duration(d) => format!("'{d}'"),
410 nodedb_types::Value::Decimal(d) => d.to_string(),
411 other => format!("'{}'", format!("{other:?}").replace('\'', "''")),
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn passthrough_standard_sql() {
422 assert!(preprocess("SELECT * FROM users").is_none());
423 assert!(preprocess("INSERT INTO users (name) VALUES ('alice')").is_none());
424 assert!(preprocess("DELETE FROM users WHERE id = 1").is_none());
425 }
426
427 #[test]
428 fn upsert_rewrite() {
429 let result = preprocess("UPSERT INTO users (name) VALUES ('alice')").unwrap();
430 assert!(result.is_upsert);
431 assert_eq!(result.sql, "INSERT INTO users (name) VALUES ('alice')");
432 }
433
434 #[test]
435 fn object_literal_insert() {
436 let result = preprocess("INSERT INTO users { name: 'alice', age: 30 }").unwrap();
437 assert!(!result.is_upsert);
438 assert!(result.sql.starts_with("INSERT INTO users ("));
439 assert!(result.sql.contains("'alice'"));
440 assert!(result.sql.contains("30"));
441 }
442
443 #[test]
444 fn object_literal_upsert() {
445 let result = preprocess("UPSERT INTO users { name: 'bob' }").unwrap();
446 assert!(result.is_upsert);
447 assert!(result.sql.starts_with("INSERT INTO users ("));
448 assert!(result.sql.contains("'bob'"));
449 }
450
451 #[test]
452 fn batch_array_insert() {
453 let result =
454 preprocess("INSERT INTO users [{ name: 'alice', age: 30 }, { name: 'bob', age: 25 }]")
455 .unwrap();
456 assert!(!result.is_upsert);
457 assert!(result.sql.contains("VALUES"));
459 assert!(result.sql.contains("'alice'"));
460 assert!(result.sql.contains("'bob'"));
461 assert!(result.sql.contains("30"));
462 assert!(result.sql.contains("25"));
463 let values_part = result.sql.split("VALUES").nth(1).unwrap();
465 let row_count = values_part.matches('(').count();
466 assert_eq!(row_count, 2, "should have 2 row groups: {}", result.sql);
467 }
468
469 #[test]
470 fn batch_array_heterogeneous_keys() {
471 let result =
472 preprocess("INSERT INTO docs [{ id: 'a', name: 'Alice' }, { id: 'b', role: 'admin' }]")
473 .unwrap();
474 assert!(result.sql.contains("NULL"));
476 assert!(result.sql.contains("'Alice'"));
477 assert!(result.sql.contains("'admin'"));
478 }
479
480 #[test]
481 fn batch_array_upsert() {
482 let result =
483 preprocess("UPSERT INTO users [{ id: 'u1', name: 'a' }, { id: 'u2', name: 'b' }]")
484 .unwrap();
485 assert!(result.is_upsert);
486 assert!(result.sql.contains("VALUES"));
487 }
488
489 #[test]
490 fn arrow_distance_operator_select() {
491 let result = preprocess(
492 "SELECT title FROM articles ORDER BY embedding <-> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
493 )
494 .unwrap();
495 assert!(
496 result
497 .sql
498 .contains("vector_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
499 "got: {}",
500 result.sql
501 );
502 assert!(!result.sql.contains("<->"));
503 }
504
505 #[test]
506 fn arrow_distance_operator_where() {
507 let result =
508 preprocess("SELECT * FROM docs WHERE embedding <-> ARRAY[1.0, 2.0] < 0.5").unwrap();
509 assert!(
510 result
511 .sql
512 .contains("vector_distance(embedding, ARRAY[1.0, 2.0])"),
513 "got: {}",
514 result.sql
515 );
516 }
517
518 #[test]
519 fn arrow_distance_no_match() {
520 assert!(preprocess("SELECT * FROM users WHERE age > 30").is_none());
522 }
523
524 #[test]
525 fn arrow_distance_with_alias() {
526 let result =
527 preprocess("SELECT embedding <-> ARRAY[0.1, 0.2] AS dist FROM articles").unwrap();
528 assert!(
529 result
530 .sql
531 .contains("vector_distance(embedding, ARRAY[0.1, 0.2]) AS dist"),
532 "got: {}",
533 result.sql
534 );
535 }
536
537 #[test]
538 fn fuzzy_object_literal_in_function() {
539 let direct = rewrite_object_literal_args(
541 "SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })",
542 );
543 assert!(direct.is_some(), "rewrite_object_literal_args should match");
544 let rewritten = direct.unwrap();
545 assert!(
546 rewritten.contains("\"fuzzy\""),
547 "direct rewrite should contain JSON, got: {}",
548 rewritten
549 );
550
551 let result =
552 preprocess("SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })")
553 .unwrap();
554 assert!(
555 !result.sql.contains("{ fuzzy"),
556 "should not contain object literal, got: {}",
557 result.sql
558 );
559 }
560
561 #[test]
562 fn fuzzy_object_literal_with_distance() {
563 let result = preprocess(
564 "SELECT * FROM articles WHERE text_match(title, 'test', { fuzzy: true, distance: 2 })",
565 )
566 .unwrap();
567 assert!(result.sql.contains("\"fuzzy\""), "got: {}", result.sql);
568 assert!(result.sql.contains("\"distance\""), "got: {}", result.sql);
569 }
570
571 #[test]
572 fn object_literal_not_rewritten_outside_function() {
573 let result = preprocess("INSERT INTO docs { name: 'Alice' }").unwrap();
576 assert!(result.sql.contains("VALUES"), "got: {}", result.sql);
578 }
579}