Skip to main content

nodedb_sql/parser/preprocess/
vector_ops.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Rewrite pgvector's `<->` distance operator into a `vector_distance()`
4//! function call that standard sqlparser can parse.
5
6use super::lex::{SqlSegment, find_operator_positions, has_operator_outside_literals, segments};
7
8/// Rewrite all occurrences of `expr <-> expr` to `vector_distance(expr, expr)`.
9///
10/// Handles: `column_name <-> ARRAY[...]`, `column <-> $param`, etc.
11/// Returns `None` if no valid `<->` patterns are found outside string
12/// literals, quoted identifiers, or comments.
13pub(super) fn rewrite_arrow_distance(sql: &str) -> Option<String> {
14    rewrite_binary_op(sql, "<->", "vector_distance")
15}
16
17/// Extract the left operand before `<->`: a column name or dotted path.
18fn extract_left_operand(before: &str) -> Option<String> {
19    // Determine what the effective "before" text is by looking only at the
20    // last Text segment — the operand cannot span across a quoted literal.
21    let text_before = last_text_segment(before);
22    let trimmed = text_before.trim_end();
23    let start = trimmed
24        .rfind(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
25        .map(|p| p + 1)
26        .unwrap_or(0);
27    let ident = &trimmed[start..];
28    if ident.is_empty() {
29        return None;
30    }
31    // Reconstruct including the prefix text up to the ident within `before`.
32    // We need the ident string only (the caller uses its length for offset math
33    // and only the text content for the rewrite).
34    Some(ident.to_string())
35}
36
37/// Return the content of the last `Text` segment in `sql`, or the whole
38/// string if there are no non-text segments.
39fn last_text_segment(sql: &str) -> &str {
40    let segs = segments(sql);
41    for seg in segs.iter().rev() {
42        if let SqlSegment::Text(t) = seg {
43            return t;
44        }
45    }
46    ""
47}
48
49/// Extract the right operand after `<->`: ARRAY[...], $param, or identifier.
50/// Returns (operand_text, consumed_length).
51fn extract_right_operand(after: &str) -> Option<(String, usize)> {
52    let trimmed = after.trim_start();
53    let upper = trimmed.to_uppercase();
54
55    if upper.starts_with("ARRAY[") {
56        let mut depth = 0;
57        for (i, c) in trimmed.char_indices() {
58            match c {
59                '[' => depth += 1,
60                ']' => {
61                    depth -= 1;
62                    if depth == 0 {
63                        return Some((trimmed[..=i].to_string(), i + 1));
64                    }
65                }
66                _ => {}
67            }
68        }
69        None
70    } else if trimmed.starts_with('$') {
71        let end = trimmed
72            .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '$')
73            .unwrap_or(trimmed.len());
74        Some((trimmed[..end].to_string(), end))
75    } else {
76        let end = trimmed
77            .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
78            .unwrap_or(trimmed.len());
79        if end == 0 {
80            return None;
81        }
82        Some((trimmed[..end].to_string(), end))
83    }
84}
85
86/// Rewrite all occurrences of `expr <=> expr` to `vector_cosine_distance(expr, expr)`.
87///
88/// Handles: `column_name <=> ARRAY[...]`, `column <=> $param`, etc.
89/// Returns `None` if no valid `<=>` patterns are found outside string
90/// literals, quoted identifiers, or comments.
91pub(super) fn rewrite_cosine_distance(sql: &str) -> Option<String> {
92    rewrite_binary_op(sql, "<=>", "vector_cosine_distance")
93}
94
95/// Rewrite all occurrences of `expr <#> expr` to `vector_neg_inner_product(expr, expr)`.
96///
97/// Handles: `column_name <#> ARRAY[...]`, `column <#> $param`, etc.
98/// Returns `None` if no valid `<#>` patterns are found outside string
99/// literals, quoted identifiers, or comments.
100pub(super) fn rewrite_neg_inner_product(sql: &str) -> Option<String> {
101    rewrite_binary_op(sql, "<#>", "vector_neg_inner_product")
102}
103
104/// Generic binary-operator rewriter: replaces all occurrences of `op`
105/// (outside literals/comments/quoted idents) with `func_name(left, right)`.
106fn rewrite_binary_op(sql: &str, op: &str, func_name: &str) -> Option<String> {
107    if !has_operator_outside_literals(sql, op) {
108        return None;
109    }
110
111    let positions = find_operator_positions(sql, op);
112    if positions.is_empty() {
113        return None;
114    }
115
116    let mut result = String::with_capacity(sql.len());
117    let mut consumed = 0usize;
118    let mut found = false;
119
120    for op_pos in positions {
121        if op_pos < consumed {
122            continue;
123        }
124
125        let before = &sql[consumed..op_pos];
126        let left = extract_left_operand(before)?;
127        // The left operand is the trailing identifier of `before`. Subtract
128        // its length from the *trimmed* end of `before` — using `before.len()`
129        // directly is off-by-one whenever there is whitespace between the
130        // column and the operator (e.g. `embedding <-> ARRAY[...]`), which
131        // would leave the column's first character behind in the result and
132        // corrupt the rewritten function name (`evector_distance(...)`).
133        let trimmed_end_len = before.trim_end().len();
134        let left_start = consumed + (trimmed_end_len - left.len());
135
136        let after_op = &sql[op_pos + op.len()..];
137        let (right, right_len) = extract_right_operand(after_op.trim_start())?;
138        let ws_skip = after_op.len() - after_op.trim_start().len();
139
140        result.push_str(&sql[consumed..left_start]);
141        result.push_str(&format!("{func_name}({left}, {right})"));
142        consumed = op_pos + op.len() + ws_skip + right_len;
143        found = true;
144    }
145
146    if !found {
147        return None;
148    }
149
150    result.push_str(&sql[consumed..]);
151    Some(result)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn no_match_returns_none() {
160        assert!(rewrite_arrow_distance("SELECT * FROM t").is_none());
161    }
162
163    #[test]
164    fn arrow_in_string_literal_ignored() {
165        // The `<->` is inside a string literal — no rewrite should happen.
166        assert!(rewrite_arrow_distance("SELECT '<->'").is_none());
167    }
168
169    #[test]
170    fn arrow_in_line_comment_ignored() {
171        assert!(rewrite_arrow_distance("SELECT col -- has <-> in comment\nFROM t").is_none());
172    }
173
174    #[test]
175    fn arrow_in_block_comment_ignored() {
176        assert!(rewrite_arrow_distance("SELECT /* <-> */ x").is_none());
177    }
178
179    #[test]
180    fn arrow_in_quoted_ident_ignored() {
181        assert!(rewrite_arrow_distance(r#"SELECT "col_<->""#).is_none());
182    }
183
184    #[test]
185    fn nested_block_comment_arrow_ignored() {
186        assert!(rewrite_arrow_distance("SELECT /* /* nested */ <-> */ x").is_none());
187    }
188
189    #[test]
190    fn basic_array_rewrite() {
191        let rewritten = rewrite_arrow_distance(
192            "SELECT title FROM articles ORDER BY embedding <-> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
193        )
194        .unwrap();
195        assert!(
196            rewritten.contains("vector_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
197            "got: {rewritten}"
198        );
199        assert!(!rewritten.contains("<->"));
200    }
201
202    #[test]
203    fn rewrite_with_param() {
204        let rewritten =
205            rewrite_arrow_distance("SELECT * FROM docs WHERE embedding <-> $1 < 0.5").unwrap();
206        assert!(
207            rewritten.contains("vector_distance(embedding, $1)"),
208            "got: {rewritten}"
209        );
210    }
211
212    // ── <=> (cosine) ──────────────────────────────────────────────────────
213
214    #[test]
215    fn cosine_basic_array_rewrite() {
216        let rewritten = rewrite_cosine_distance(
217            "SELECT title FROM articles ORDER BY embedding <=> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
218        )
219        .unwrap();
220        assert!(
221            rewritten.contains("vector_cosine_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
222            "got: {rewritten}"
223        );
224        assert!(!rewritten.contains("<=>"));
225    }
226
227    #[test]
228    fn cosine_in_string_literal_ignored() {
229        assert!(rewrite_cosine_distance("SELECT '<=>'").is_none());
230    }
231
232    #[test]
233    fn cosine_in_line_comment_ignored() {
234        assert!(rewrite_cosine_distance("SELECT col -- has <=> in comment\nFROM t").is_none());
235    }
236
237    #[test]
238    fn cosine_in_quoted_ident_ignored() {
239        assert!(rewrite_cosine_distance(r#"SELECT "col_<=>""#).is_none());
240    }
241
242    #[test]
243    fn cosine_no_match_returns_none() {
244        assert!(rewrite_cosine_distance("SELECT * FROM t").is_none());
245    }
246
247    // ── <#> (neg-inner-product) ───────────────────────────────────────────
248
249    #[test]
250    fn nip_basic_array_rewrite() {
251        let rewritten = rewrite_neg_inner_product(
252            "SELECT title FROM articles ORDER BY embedding <#> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
253        )
254        .unwrap();
255        assert!(
256            rewritten.contains("vector_neg_inner_product(embedding, ARRAY[0.1, 0.2, 0.3])"),
257            "got: {rewritten}"
258        );
259        assert!(!rewritten.contains("<#>"));
260    }
261
262    #[test]
263    fn nip_in_string_literal_ignored() {
264        assert!(rewrite_neg_inner_product("SELECT '<#>'").is_none());
265    }
266
267    #[test]
268    fn nip_in_line_comment_ignored() {
269        assert!(rewrite_neg_inner_product("SELECT col -- has <#> in comment\nFROM t").is_none());
270    }
271
272    #[test]
273    fn nip_in_quoted_ident_ignored() {
274        assert!(rewrite_neg_inner_product(r#"SELECT "col_<#>""#).is_none());
275    }
276
277    #[test]
278    fn nip_no_match_returns_none() {
279        assert!(rewrite_neg_inner_product("SELECT * FROM t").is_none());
280    }
281}