Skip to main content

nodedb_sql/parser/preprocess/
vector_ops.rs

1//! Rewrite pgvector's `<->` distance operator into a `vector_distance()`
2//! function call that standard sqlparser can parse.
3
4/// Rewrite all occurrences of `expr <-> expr` to `vector_distance(expr, expr)`.
5///
6/// Handles: `column_name <-> ARRAY[...]`, `column <-> $param`, etc.
7/// Returns `None` if no valid `<->` patterns are found.
8pub(super) fn rewrite_arrow_distance(sql: &str) -> Option<String> {
9    let mut result = String::with_capacity(sql.len());
10    let mut remaining = sql;
11    let mut found = false;
12
13    while let Some(arrow_pos) = remaining.find("<->") {
14        let before = &remaining[..arrow_pos];
15        let left = extract_left_operand(before)?;
16        let left_start = arrow_pos - left.len();
17
18        let after = &remaining[arrow_pos + 3..];
19        let (right, right_len) = extract_right_operand(after.trim_start())?;
20        let ws_skip = after.len() - after.trim_start().len();
21
22        result.push_str(&remaining[..left_start]);
23        result.push_str(&format!("vector_distance({left}, {right})"));
24        remaining = &remaining[arrow_pos + 3 + ws_skip + right_len..];
25        found = true;
26    }
27
28    if !found {
29        return None;
30    }
31
32    result.push_str(remaining);
33    Some(result)
34}
35
36/// Extract the left operand before `<->`: a column name or dotted path.
37fn extract_left_operand(before: &str) -> Option<String> {
38    let trimmed = before.trim_end();
39    let start = trimmed
40        .rfind(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
41        .map(|p| p + 1)
42        .unwrap_or(0);
43    let ident = &trimmed[start..];
44    if ident.is_empty() {
45        return None;
46    }
47    Some(ident.to_string())
48}
49
50/// Extract the right operand after `<->`: ARRAY[...], $param, or identifier.
51/// Returns (operand_text, consumed_length).
52fn extract_right_operand(after: &str) -> Option<(String, usize)> {
53    let trimmed = after.trim_start();
54    let upper = trimmed.to_uppercase();
55
56    if upper.starts_with("ARRAY[") {
57        let mut depth = 0;
58        for (i, c) in trimmed.char_indices() {
59            match c {
60                '[' => depth += 1,
61                ']' => {
62                    depth -= 1;
63                    if depth == 0 {
64                        return Some((trimmed[..=i].to_string(), i + 1));
65                    }
66                }
67                _ => {}
68            }
69        }
70        None
71    } else if trimmed.starts_with('$') {
72        let end = trimmed
73            .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '$')
74            .unwrap_or(trimmed.len());
75        Some((trimmed[..end].to_string(), end))
76    } else {
77        let end = trimmed
78            .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
79            .unwrap_or(trimmed.len());
80        if end == 0 {
81            return None;
82        }
83        Some((trimmed[..end].to_string(), end))
84    }
85}