Skip to main content

nodedb_sql/parser/preprocess/
search_vector.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Rewrite `SEARCH <coll> USING VECTOR(<field>, ARRAY[...], <k>)` to the
4//! canonical `SELECT * FROM <coll> ORDER BY vector_distance(<field>, ARRAY[...]) LIMIT <k>`.
5//!
6//! `<field>` may be omitted (and the third arg becomes the limit). When the
7//! collection has a single declared vector column the planner resolves the
8//! field; otherwise `vector_distance` rejects the call with a typed error.
9
10const SEARCH_KEYWORD: &str = "SEARCH";
11const USING_KEYWORD: &str = "USING";
12const VECTOR_KEYWORD: &str = "VECTOR";
13
14pub fn try_rewrite_search_using_vector(sql: &str) -> Option<String> {
15    let trimmed = sql.trim_end_matches(|c: char| c == ';' || c.is_whitespace());
16    let upper = trimmed.to_uppercase();
17    let stripped = upper.trim_start();
18    if !stripped.starts_with(SEARCH_KEYWORD) {
19        return None;
20    }
21    let leading = trimmed.len() - trimmed.trim_start().len();
22    let after_search = trimmed[leading + SEARCH_KEYWORD.len()..].trim_start();
23    if after_search.is_empty() {
24        return None;
25    }
26
27    let (collection, rest) = take_identifier(after_search)?;
28    let rest = rest.trim_start();
29    let rest_upper = rest.to_uppercase();
30    if !rest_upper.starts_with(USING_KEYWORD) {
31        return None;
32    }
33    let after_using = rest[USING_KEYWORD.len()..].trim_start();
34    let after_using_upper = after_using.to_uppercase();
35    if !after_using_upper.starts_with(VECTOR_KEYWORD) {
36        return None;
37    }
38    let after_vec = after_using[VECTOR_KEYWORD.len()..].trim_start();
39    let body = strip_parentheses(after_vec)?;
40    let (field, vector_expr, limit) = split_vector_args(body)?;
41
42    let trailing = sql[leading + (trimmed.len() - leading)..].to_string();
43    let order_by = match field {
44        Some(name) => format!("vector_distance({name}, {vector_expr})"),
45        None => format!("vector_distance({vector_expr})"),
46    };
47    Some(format!(
48        "SELECT * FROM {collection} ORDER BY {order_by} LIMIT {limit}{trailing}"
49    ))
50}
51
52fn take_identifier(input: &str) -> Option<(&str, &str)> {
53    let end = input
54        .char_indices()
55        .find(|(_, c)| !is_ident_char(*c))
56        .map(|(i, _)| i)
57        .unwrap_or(input.len());
58    if end == 0 {
59        return None;
60    }
61    Some((&input[..end], &input[end..]))
62}
63
64fn is_ident_char(c: char) -> bool {
65    c.is_ascii_alphanumeric() || c == '_'
66}
67
68fn strip_parentheses(input: &str) -> Option<&str> {
69    let trimmed = input.trim();
70    let bytes = trimmed.as_bytes();
71    if bytes.first() != Some(&b'(') || bytes.last() != Some(&b')') {
72        return None;
73    }
74    Some(trimmed[1..trimmed.len() - 1].trim())
75}
76
77fn split_vector_args(body: &str) -> Option<(Option<String>, String, String)> {
78    let parts = split_top_level_commas(body);
79    match parts.as_slice() {
80        [field, vec, k] => {
81            let field = field.trim();
82            let trimmed = if field.is_empty() {
83                None
84            } else {
85                Some(field.to_string())
86            };
87            Some((trimmed, vec.trim().to_string(), k.trim().to_string()))
88        }
89        [vec, k] => Some((None, vec.trim().to_string(), k.trim().to_string())),
90        _ => None,
91    }
92}
93
94fn split_top_level_commas(body: &str) -> Vec<String> {
95    let mut depth_paren = 0i32;
96    let mut depth_bracket = 0i32;
97    let mut in_single = false;
98    let mut in_double = false;
99    let mut current = String::new();
100    let mut out = Vec::new();
101    for c in body.chars() {
102        match c {
103            '\'' if !in_double => in_single = !in_single,
104            '"' if !in_single => in_double = !in_double,
105            '(' if !in_single && !in_double => depth_paren += 1,
106            ')' if !in_single && !in_double => depth_paren -= 1,
107            '[' if !in_single && !in_double => depth_bracket += 1,
108            ']' if !in_single && !in_double => depth_bracket -= 1,
109            ',' if !in_single && !in_double && depth_paren == 0 && depth_bracket == 0 => {
110                out.push(std::mem::take(&mut current));
111                continue;
112            }
113            _ => {}
114        }
115        current.push(c);
116    }
117    if !current.trim().is_empty() {
118        out.push(current);
119    }
120    out
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn rewrites_three_arg_form() {
129        let out = try_rewrite_search_using_vector(
130            "SEARCH articles USING VECTOR(embedding, ARRAY[0.1, 0.3, -0.2], 10)",
131        )
132        .unwrap();
133        assert_eq!(
134            out,
135            "SELECT * FROM articles ORDER BY vector_distance(embedding, ARRAY[0.1, 0.3, -0.2]) LIMIT 10"
136        );
137    }
138
139    #[test]
140    fn rewrites_two_arg_form_when_field_omitted() {
141        let out =
142            try_rewrite_search_using_vector("SEARCH articles USING VECTOR(ARRAY[0.1, 0.3], 5)")
143                .unwrap();
144        assert_eq!(
145            out,
146            "SELECT * FROM articles ORDER BY vector_distance(ARRAY[0.1, 0.3]) LIMIT 5"
147        );
148    }
149
150    #[test]
151    fn returns_none_when_not_search() {
152        assert!(try_rewrite_search_using_vector("SELECT * FROM t").is_none());
153    }
154
155    #[test]
156    fn returns_none_when_using_fusion() {
157        assert!(try_rewrite_search_using_vector("SEARCH c USING FUSION(ARRAY[0.5])").is_none());
158    }
159
160    #[test]
161    fn handles_trailing_semicolon() {
162        let out =
163            try_rewrite_search_using_vector("SEARCH t USING VECTOR(emb, ARRAY[1.0], 3);").unwrap();
164        assert!(
165            out.starts_with("SELECT * FROM t ORDER BY vector_distance(emb, ARRAY[1.0]) LIMIT 3")
166        );
167    }
168}