nodedb_sql/parser/preprocess/
function_args.rs1use super::lex::{SqlSegment, segments};
7use crate::parser::object_literal::parse_object_literal;
8
9pub(super) fn rewrite_object_literal_args(sql: &str) -> Option<String> {
13 let mut result = String::with_capacity(sql.len());
14 let mut found = false;
15 let mut paren_depth: i32 = 0;
16
17 for seg in segments(sql) {
18 match seg {
19 SqlSegment::SingleQuotedString(s)
22 | SqlSegment::QuotedIdent(s)
23 | SqlSegment::LineComment(s)
24 | SqlSegment::BlockComment(s) => {
25 result.push_str(s);
26 }
27 SqlSegment::Text(t) => {
28 let chars: Vec<char> = t.chars().collect();
31 let mut i = 0;
32 let mut seg_result = String::with_capacity(t.len());
33
34 while i < chars.len() {
35 match chars[i] {
36 '(' => {
37 paren_depth += 1;
38 seg_result.push('(');
39 i += 1;
40 }
41 ')' => {
42 paren_depth = paren_depth.saturating_sub(1);
43 seg_result.push(')');
44 i += 1;
45 }
46 '{' if paren_depth > 0 => {
47 let remaining: String = chars[i..].iter().collect();
51 if let Some(Ok(fields)) = parse_object_literal(&remaining)
52 && let Some(end) = find_matching_brace(&chars, i)
53 {
54 let json = value_map_to_json(&fields);
55 seg_result.push('\'');
56 seg_result.push_str(&json);
57 seg_result.push('\'');
58 i = end + 1;
59 found = true;
60 continue;
61 }
62 seg_result.push('{');
63 i += 1;
64 }
65 c => {
66 seg_result.push(c);
67 i += 1;
68 }
69 }
70 }
71
72 result.push_str(&seg_result);
73 }
74 }
75 }
76
77 if found { Some(result) } else { None }
78}
79
80pub(super) fn value_map_to_json(
82 fields: &std::collections::HashMap<String, nodedb_types::Value>,
83) -> String {
84 let mut parts = Vec::with_capacity(fields.len());
85 let mut entries: Vec<_> = fields.iter().collect();
86 entries.sort_by_key(|(k, _)| k.as_str());
87 for (key, val) in entries {
88 parts.push(format!("\"{}\":{}", key, value_to_json(val)));
89 }
90 format!("{{{}}}", parts.join(","))
91}
92
93fn value_to_json(value: &nodedb_types::Value) -> String {
95 match value {
96 nodedb_types::Value::String(s) => {
97 format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\""))
98 }
99 nodedb_types::Value::Integer(n) => n.to_string(),
100 nodedb_types::Value::Float(f) => {
101 if f.is_finite() {
102 format!("{f}")
103 } else {
104 "null".to_string()
107 }
108 }
109 nodedb_types::Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
110 nodedb_types::Value::Null => "null".to_string(),
111 nodedb_types::Value::Array(items) => {
112 let inner: Vec<String> = items.iter().map(value_to_json).collect();
113 format!("[{}]", inner.join(","))
114 }
115 nodedb_types::Value::Object(map) => value_map_to_json(map),
116 _ => format!("\"{}\"", format!("{value:?}").replace('"', "\\\"")),
117 }
118}
119
120fn find_matching_brace(chars: &[char], start: usize) -> Option<usize> {
126 let mut depth = 0;
127 let mut i = start;
128 while i < chars.len() {
129 match chars[i] {
130 '{' => depth += 1,
131 '}' => {
132 depth -= 1;
133 if depth == 0 {
134 return Some(i);
135 }
136 }
137 _ => {}
138 }
139 i += 1;
140 }
141 None
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn string_literal_with_brace_not_rewritten() {
150 let sql = "SELECT func('{ foo }')";
152 assert!(rewrite_object_literal_args(sql).is_none());
153 }
154
155 #[test]
156 fn object_literal_in_function_rewritten() {
157 let sql = "SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })";
158 let result = rewrite_object_literal_args(sql).unwrap();
159 assert!(result.contains("\"fuzzy\""), "got: {result}");
160 assert!(!result.contains("{ fuzzy"), "got: {result}");
161 }
162}