nodedb_sql/parser/preprocess/
search_vector.rs1const SEARCH_KEYWORD: &str = "SEARCH";
11const USING_KEYWORD: &str = "USING";
12const VECTOR_KEYWORD: &str = "VECTOR";
13
14pub fn try_rewrite_search_using_vector(sql: &str) -> Option<String> {
15 let trimmed = sql.trim_end_matches(|c: char| c == ';' || c.is_whitespace());
16 let upper = trimmed.to_uppercase();
17 let stripped = upper.trim_start();
18 if !stripped.starts_with(SEARCH_KEYWORD) {
19 return None;
20 }
21 let leading = trimmed.len() - trimmed.trim_start().len();
22 let after_search = trimmed[leading + SEARCH_KEYWORD.len()..].trim_start();
23 if after_search.is_empty() {
24 return None;
25 }
26
27 let (collection, rest) = take_identifier(after_search)?;
28 let rest = rest.trim_start();
29 let rest_upper = rest.to_uppercase();
30 if !rest_upper.starts_with(USING_KEYWORD) {
31 return None;
32 }
33 let after_using = rest[USING_KEYWORD.len()..].trim_start();
34 let after_using_upper = after_using.to_uppercase();
35 if !after_using_upper.starts_with(VECTOR_KEYWORD) {
36 return None;
37 }
38 let after_vec = after_using[VECTOR_KEYWORD.len()..].trim_start();
39 let body = strip_parentheses(after_vec)?;
40 let (field, vector_expr, limit) = split_vector_args(body)?;
41
42 let trailing = sql[leading + (trimmed.len() - leading)..].to_string();
43 let order_by = match field {
44 Some(name) => format!("vector_distance({name}, {vector_expr})"),
45 None => format!("vector_distance({vector_expr})"),
46 };
47 Some(format!(
48 "SELECT * FROM {collection} ORDER BY {order_by} LIMIT {limit}{trailing}"
49 ))
50}
51
52fn take_identifier(input: &str) -> Option<(&str, &str)> {
53 let end = input
54 .char_indices()
55 .find(|(_, c)| !is_ident_char(*c))
56 .map(|(i, _)| i)
57 .unwrap_or(input.len());
58 if end == 0 {
59 return None;
60 }
61 Some((&input[..end], &input[end..]))
62}
63
64fn is_ident_char(c: char) -> bool {
65 c.is_ascii_alphanumeric() || c == '_'
66}
67
68fn strip_parentheses(input: &str) -> Option<&str> {
69 let trimmed = input.trim();
70 let bytes = trimmed.as_bytes();
71 if bytes.first() != Some(&b'(') || bytes.last() != Some(&b')') {
72 return None;
73 }
74 Some(trimmed[1..trimmed.len() - 1].trim())
75}
76
77fn split_vector_args(body: &str) -> Option<(Option<String>, String, String)> {
78 let parts = split_top_level_commas(body);
79 match parts.as_slice() {
80 [field, vec, k] => {
81 let field = field.trim();
82 let trimmed = if field.is_empty() {
83 None
84 } else {
85 Some(field.to_string())
86 };
87 Some((trimmed, vec.trim().to_string(), k.trim().to_string()))
88 }
89 [vec, k] => Some((None, vec.trim().to_string(), k.trim().to_string())),
90 _ => None,
91 }
92}
93
94fn split_top_level_commas(body: &str) -> Vec<String> {
95 let mut depth_paren = 0i32;
96 let mut depth_bracket = 0i32;
97 let mut in_single = false;
98 let mut in_double = false;
99 let mut current = String::new();
100 let mut out = Vec::new();
101 for c in body.chars() {
102 match c {
103 '\'' if !in_double => in_single = !in_single,
104 '"' if !in_single => in_double = !in_double,
105 '(' if !in_single && !in_double => depth_paren += 1,
106 ')' if !in_single && !in_double => depth_paren -= 1,
107 '[' if !in_single && !in_double => depth_bracket += 1,
108 ']' if !in_single && !in_double => depth_bracket -= 1,
109 ',' if !in_single && !in_double && depth_paren == 0 && depth_bracket == 0 => {
110 out.push(std::mem::take(&mut current));
111 continue;
112 }
113 _ => {}
114 }
115 current.push(c);
116 }
117 if !current.trim().is_empty() {
118 out.push(current);
119 }
120 out
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn rewrites_three_arg_form() {
129 let out = try_rewrite_search_using_vector(
130 "SEARCH articles USING VECTOR(embedding, ARRAY[0.1, 0.3, -0.2], 10)",
131 )
132 .unwrap();
133 assert_eq!(
134 out,
135 "SELECT * FROM articles ORDER BY vector_distance(embedding, ARRAY[0.1, 0.3, -0.2]) LIMIT 10"
136 );
137 }
138
139 #[test]
140 fn rewrites_two_arg_form_when_field_omitted() {
141 let out =
142 try_rewrite_search_using_vector("SEARCH articles USING VECTOR(ARRAY[0.1, 0.3], 5)")
143 .unwrap();
144 assert_eq!(
145 out,
146 "SELECT * FROM articles ORDER BY vector_distance(ARRAY[0.1, 0.3]) LIMIT 5"
147 );
148 }
149
150 #[test]
151 fn returns_none_when_not_search() {
152 assert!(try_rewrite_search_using_vector("SELECT * FROM t").is_none());
153 }
154
155 #[test]
156 fn returns_none_when_using_fusion() {
157 assert!(try_rewrite_search_using_vector("SEARCH c USING FUSION(ARRAY[0.5])").is_none());
158 }
159
160 #[test]
161 fn handles_trailing_semicolon() {
162 let out =
163 try_rewrite_search_using_vector("SEARCH t USING VECTOR(emb, ARRAY[1.0], 3);").unwrap();
164 assert!(
165 out.starts_with("SELECT * FROM t ORDER BY vector_distance(emb, ARRAY[1.0]) LIMIT 3")
166 );
167 }
168}