nodedb_sql/parser/preprocess/
vector_ops.rs1use super::lex::{SqlSegment, find_operator_positions, has_operator_outside_literals, segments};
7
8pub(super) fn rewrite_arrow_distance(sql: &str) -> Option<String> {
14 rewrite_binary_op(sql, "<->", "vector_distance")
15}
16
17fn extract_left_operand(before: &str) -> Option<String> {
19 let text_before = last_text_segment(before);
22 let trimmed = text_before.trim_end();
23 let start = trimmed
24 .rfind(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
25 .map(|p| p + 1)
26 .unwrap_or(0);
27 let ident = &trimmed[start..];
28 if ident.is_empty() {
29 return None;
30 }
31 Some(ident.to_string())
35}
36
37fn last_text_segment(sql: &str) -> &str {
40 let segs = segments(sql);
41 for seg in segs.iter().rev() {
42 if let SqlSegment::Text(t) = seg {
43 return t;
44 }
45 }
46 ""
47}
48
49fn extract_right_operand(after: &str) -> Option<(String, usize)> {
52 let trimmed = after.trim_start();
53 let upper = trimmed.to_uppercase();
54
55 if upper.starts_with("ARRAY[") {
56 let mut depth = 0;
57 for (i, c) in trimmed.char_indices() {
58 match c {
59 '[' => depth += 1,
60 ']' => {
61 depth -= 1;
62 if depth == 0 {
63 return Some((trimmed[..=i].to_string(), i + 1));
64 }
65 }
66 _ => {}
67 }
68 }
69 None
70 } else if trimmed.starts_with('$') {
71 let end = trimmed
72 .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '$')
73 .unwrap_or(trimmed.len());
74 Some((trimmed[..end].to_string(), end))
75 } else {
76 let end = trimmed
77 .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.')
78 .unwrap_or(trimmed.len());
79 if end == 0 {
80 return None;
81 }
82 Some((trimmed[..end].to_string(), end))
83 }
84}
85
86pub(super) fn rewrite_cosine_distance(sql: &str) -> Option<String> {
92 rewrite_binary_op(sql, "<=>", "vector_cosine_distance")
93}
94
95pub(super) fn rewrite_neg_inner_product(sql: &str) -> Option<String> {
101 rewrite_binary_op(sql, "<#>", "vector_neg_inner_product")
102}
103
104fn rewrite_binary_op(sql: &str, op: &str, func_name: &str) -> Option<String> {
107 if !has_operator_outside_literals(sql, op) {
108 return None;
109 }
110
111 let positions = find_operator_positions(sql, op);
112 if positions.is_empty() {
113 return None;
114 }
115
116 let mut result = String::with_capacity(sql.len());
117 let mut consumed = 0usize;
118 let mut found = false;
119
120 for op_pos in positions {
121 if op_pos < consumed {
122 continue;
123 }
124
125 let before = &sql[consumed..op_pos];
126 let left = extract_left_operand(before)?;
127 let trimmed_end_len = before.trim_end().len();
134 let left_start = consumed + (trimmed_end_len - left.len());
135
136 let after_op = &sql[op_pos + op.len()..];
137 let (right, right_len) = extract_right_operand(after_op.trim_start())?;
138 let ws_skip = after_op.len() - after_op.trim_start().len();
139
140 result.push_str(&sql[consumed..left_start]);
141 result.push_str(&format!("{func_name}({left}, {right})"));
142 consumed = op_pos + op.len() + ws_skip + right_len;
143 found = true;
144 }
145
146 if !found {
147 return None;
148 }
149
150 result.push_str(&sql[consumed..]);
151 Some(result)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn no_match_returns_none() {
160 assert!(rewrite_arrow_distance("SELECT * FROM t").is_none());
161 }
162
163 #[test]
164 fn arrow_in_string_literal_ignored() {
165 assert!(rewrite_arrow_distance("SELECT '<->'").is_none());
167 }
168
169 #[test]
170 fn arrow_in_line_comment_ignored() {
171 assert!(rewrite_arrow_distance("SELECT col -- has <-> in comment\nFROM t").is_none());
172 }
173
174 #[test]
175 fn arrow_in_block_comment_ignored() {
176 assert!(rewrite_arrow_distance("SELECT /* <-> */ x").is_none());
177 }
178
179 #[test]
180 fn arrow_in_quoted_ident_ignored() {
181 assert!(rewrite_arrow_distance(r#"SELECT "col_<->""#).is_none());
182 }
183
184 #[test]
185 fn nested_block_comment_arrow_ignored() {
186 assert!(rewrite_arrow_distance("SELECT /* /* nested */ <-> */ x").is_none());
187 }
188
189 #[test]
190 fn basic_array_rewrite() {
191 let rewritten = rewrite_arrow_distance(
192 "SELECT title FROM articles ORDER BY embedding <-> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
193 )
194 .unwrap();
195 assert!(
196 rewritten.contains("vector_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
197 "got: {rewritten}"
198 );
199 assert!(!rewritten.contains("<->"));
200 }
201
202 #[test]
203 fn rewrite_with_param() {
204 let rewritten =
205 rewrite_arrow_distance("SELECT * FROM docs WHERE embedding <-> $1 < 0.5").unwrap();
206 assert!(
207 rewritten.contains("vector_distance(embedding, $1)"),
208 "got: {rewritten}"
209 );
210 }
211
212 #[test]
215 fn cosine_basic_array_rewrite() {
216 let rewritten = rewrite_cosine_distance(
217 "SELECT title FROM articles ORDER BY embedding <=> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
218 )
219 .unwrap();
220 assert!(
221 rewritten.contains("vector_cosine_distance(embedding, ARRAY[0.1, 0.2, 0.3])"),
222 "got: {rewritten}"
223 );
224 assert!(!rewritten.contains("<=>"));
225 }
226
227 #[test]
228 fn cosine_in_string_literal_ignored() {
229 assert!(rewrite_cosine_distance("SELECT '<=>'").is_none());
230 }
231
232 #[test]
233 fn cosine_in_line_comment_ignored() {
234 assert!(rewrite_cosine_distance("SELECT col -- has <=> in comment\nFROM t").is_none());
235 }
236
237 #[test]
238 fn cosine_in_quoted_ident_ignored() {
239 assert!(rewrite_cosine_distance(r#"SELECT "col_<=>""#).is_none());
240 }
241
242 #[test]
243 fn cosine_no_match_returns_none() {
244 assert!(rewrite_cosine_distance("SELECT * FROM t").is_none());
245 }
246
247 #[test]
250 fn nip_basic_array_rewrite() {
251 let rewritten = rewrite_neg_inner_product(
252 "SELECT title FROM articles ORDER BY embedding <#> ARRAY[0.1, 0.2, 0.3] LIMIT 5",
253 )
254 .unwrap();
255 assert!(
256 rewritten.contains("vector_neg_inner_product(embedding, ARRAY[0.1, 0.2, 0.3])"),
257 "got: {rewritten}"
258 );
259 assert!(!rewritten.contains("<#>"));
260 }
261
262 #[test]
263 fn nip_in_string_literal_ignored() {
264 assert!(rewrite_neg_inner_product("SELECT '<#>'").is_none());
265 }
266
267 #[test]
268 fn nip_in_line_comment_ignored() {
269 assert!(rewrite_neg_inner_product("SELECT col -- has <#> in comment\nFROM t").is_none());
270 }
271
272 #[test]
273 fn nip_in_quoted_ident_ignored() {
274 assert!(rewrite_neg_inner_product(r#"SELECT "col_<#>""#).is_none());
275 }
276
277 #[test]
278 fn nip_no_match_returns_none() {
279 assert!(rewrite_neg_inner_product("SELECT * FROM t").is_none());
280 }
281}