Skip to main content

nodedb_sql/parser/preprocess/
function_args.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Rewrite `{ key: val }` object literals appearing inside function-call
4//! argument positions to JSON string literals: `'{"key": val}'`.
5
6use super::lex::{SqlSegment, segments};
7use crate::parser::object_literal::parse_object_literal;
8
9/// Detect patterns like `func(arg1, arg2, { key: val })` and rewrite the
10/// `{ }` to a single-quoted JSON string. Only rewrites `{ }` that appear
11/// inside parentheses (function calls), not at statement level (INSERT).
12pub(super) fn rewrite_object_literal_args(sql: &str) -> Option<String> {
13    let mut result = String::with_capacity(sql.len());
14    let mut found = false;
15    let mut paren_depth: i32 = 0;
16
17    for seg in segments(sql) {
18        match seg {
19            // Non-text segments (string literals, quoted idents, comments) are
20            // passed through opaquely — never scanned for `{`.
21            SqlSegment::SingleQuotedString(s)
22            | SqlSegment::QuotedIdent(s)
23            | SqlSegment::LineComment(s)
24            | SqlSegment::BlockComment(s) => {
25                result.push_str(s);
26            }
27            SqlSegment::Text(t) => {
28                // Walk the text segment character-by-character so we can track
29                // paren depth and detect `{` only inside function calls.
30                let chars: Vec<char> = t.chars().collect();
31                let mut i = 0;
32                let mut seg_result = String::with_capacity(t.len());
33
34                while i < chars.len() {
35                    match chars[i] {
36                        '(' => {
37                            paren_depth += 1;
38                            seg_result.push('(');
39                            i += 1;
40                        }
41                        ')' => {
42                            paren_depth = paren_depth.saturating_sub(1);
43                            seg_result.push(')');
44                            i += 1;
45                        }
46                        '{' if paren_depth > 0 => {
47                            // Reconstruct the remaining text from position `i`
48                            // in this segment and attempt to parse an object
49                            // literal.
50                            let remaining: String = chars[i..].iter().collect();
51                            if let Some(Ok(fields)) = parse_object_literal(&remaining)
52                                && let Some(end) = find_matching_brace(&chars, i)
53                            {
54                                let json = value_map_to_json(&fields);
55                                seg_result.push('\'');
56                                seg_result.push_str(&json);
57                                seg_result.push('\'');
58                                i = end + 1;
59                                found = true;
60                                continue;
61                            }
62                            seg_result.push('{');
63                            i += 1;
64                        }
65                        c => {
66                            seg_result.push(c);
67                            i += 1;
68                        }
69                    }
70                }
71
72                result.push_str(&seg_result);
73            }
74        }
75    }
76
77    if found { Some(result) } else { None }
78}
79
80/// Convert a parsed field map to a JSON string without external serializer.
81pub(super) fn value_map_to_json(
82    fields: &std::collections::HashMap<String, nodedb_types::Value>,
83) -> String {
84    let mut parts = Vec::with_capacity(fields.len());
85    let mut entries: Vec<_> = fields.iter().collect();
86    entries.sort_by_key(|(k, _)| k.as_str());
87    for (key, val) in entries {
88        parts.push(format!("\"{}\":{}", key, value_to_json(val)));
89    }
90    format!("{{{}}}", parts.join(","))
91}
92
93/// Convert a single `Value` to JSON text.
94fn value_to_json(value: &nodedb_types::Value) -> String {
95    match value {
96        nodedb_types::Value::String(s) => {
97            format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\""))
98        }
99        nodedb_types::Value::Integer(n) => n.to_string(),
100        nodedb_types::Value::Float(f) => {
101            if f.is_finite() {
102                format!("{f}")
103            } else {
104                // JSON has no representation for NaN / ±inf; serialize as
105                // `null` to keep the output parseable.
106                "null".to_string()
107            }
108        }
109        nodedb_types::Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
110        nodedb_types::Value::Null => "null".to_string(),
111        nodedb_types::Value::Array(items) => {
112            let inner: Vec<String> = items.iter().map(value_to_json).collect();
113            format!("[{}]", inner.join(","))
114        }
115        nodedb_types::Value::Object(map) => value_map_to_json(map),
116        _ => format!("\"{}\"", format!("{value:?}").replace('"', "\\\"")),
117    }
118}
119
120/// Find the index of the matching `}` for a `{` at position `start` within
121/// `chars`. Operates only on the characters supplied; string literals inside
122/// braces are handled by the outer `segments()` pass (the brace matcher here
123/// only runs on `Text` segments, which contain no quoted literals by
124/// construction).
125fn find_matching_brace(chars: &[char], start: usize) -> Option<usize> {
126    let mut depth = 0;
127    let mut i = start;
128    while i < chars.len() {
129        match chars[i] {
130            '{' => depth += 1,
131            '}' => {
132                depth -= 1;
133                if depth == 0 {
134                    return Some(i);
135                }
136            }
137            _ => {}
138        }
139        i += 1;
140    }
141    None
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn string_literal_with_brace_not_rewritten() {
150        // `{ foo }` lives inside a string literal — must pass through unchanged.
151        let sql = "SELECT func('{ foo }')";
152        assert!(rewrite_object_literal_args(sql).is_none());
153    }
154
155    #[test]
156    fn object_literal_in_function_rewritten() {
157        let sql = "SELECT * FROM articles WHERE text_match(body, 'query', { fuzzy: true })";
158        let result = rewrite_object_literal_args(sql).unwrap();
159        assert!(result.contains("\"fuzzy\""), "got: {result}");
160        assert!(!result.contains("{ fuzzy"), "got: {result}");
161    }
162}