1use super::error::ParseError;
4use super::Parser;
5use crate::ast::{QueryExpr, VectorQuery, VectorSource};
6use crate::lexer::Token;
7use reddb_types::distance::DistanceMetric;
8use reddb_types::vector_metadata::{MetadataFilter, MetadataValue};
9
10impl<'a> Parser<'a> {
11 pub fn parse_vector_query(&mut self) -> Result<QueryExpr, ParseError> {
24 self.expect(Token::Vector)?;
25 self.expect(Token::Search)?;
26
27 let collection = self.expect_ident()?;
29
30 self.expect(Token::Similar)?;
32 self.expect(Token::To)?;
33
34 let query_vector = self.parse_vector_source()?;
35
36 let mut filter: Option<MetadataFilter> = None;
38 let mut metric: Option<DistanceMetric> = None;
39 let mut threshold: Option<f32> = None;
40 let mut include_vectors = false;
41 let mut include_metadata = false;
42 let mut k: usize = 10; loop {
46 if self.consume(&Token::Where)? {
47 filter = Some(self.parse_metadata_filter()?);
48 } else if self.consume(&Token::Metric)? {
49 metric = Some(self.parse_distance_metric()?);
50 } else if self.consume(&Token::Threshold)? {
51 threshold = Some(self.parse_float()? as f32);
52 } else if self.consume(&Token::Include)? {
53 if self.consume(&Token::Vectors)? {
54 include_vectors = true;
55 } else if self.consume(&Token::Metadata)? {
56 include_metadata = true;
57 } else {
58 return Err(ParseError::expected(
59 vec!["VECTORS", "METADATA"],
60 self.peek(),
61 self.position(),
62 ));
63 }
64 } else if self.consume(&Token::Limit)? {
65 k = self.parse_integer()? as usize;
66 } else if self.consume(&Token::K)? {
67 self.expect(Token::Eq)?;
69 k = self.parse_integer()? as usize;
70 } else {
71 break;
72 }
73 }
74
75 Ok(QueryExpr::Vector(VectorQuery {
76 alias: None,
77 collection,
78 query_vector,
79 k,
80 filter,
81 metric,
82 include_vectors,
83 include_metadata,
84 threshold,
85 }))
86 }
87
88 pub fn parse_vector_source(&mut self) -> Result<VectorSource, ParseError> {
90 match self.peek() {
91 Token::LBracket => {
93 self.advance()?;
94 let mut values = Vec::new();
95 loop {
96 let value = self.parse_float()?;
97 values.push(value as f32);
98 if !self.consume(&Token::Comma)? {
99 break;
100 }
101 }
102 self.expect(Token::RBracket)?;
103 Ok(VectorSource::Literal(values))
104 }
105 Token::String(_) => {
107 let text = self.parse_string()?;
108 Ok(VectorSource::Text(text))
109 }
110 Token::LParen => {
112 self.advance()?;
113 if self.vector_source_starts_subquery() {
114 let expr = self.parse_query_expr()?;
115 self.expect(Token::RParen)?;
116 Ok(VectorSource::Subquery(Box::new(expr)))
117 } else {
118 let collection = self.expect_ident()?;
120 self.expect(Token::Comma)?;
121 let vector_id = self.parse_integer()? as u64;
122 self.expect(Token::RParen)?;
123 Ok(VectorSource::Reference {
124 collection,
125 vector_id,
126 })
127 }
128 }
129 Token::Ident(_) => {
131 let name = self.expect_ident()?;
132 if self.consume(&Token::LParen)? {
134 let vector_id = self.parse_integer()? as u64;
135 self.expect(Token::RParen)?;
136 Ok(VectorSource::Reference {
137 collection: name,
138 vector_id,
139 })
140 } else {
141 Ok(VectorSource::Text(name))
143 }
144 }
145 other => Err(ParseError::expected(
146 vec!["vector literal [...]", "string", "reference"],
147 other,
148 self.position(),
149 )),
150 }
151 }
152
153 fn vector_source_starts_subquery(&self) -> bool {
154 matches!(
155 self.peek(),
156 Token::Select
157 | Token::Match
158 | Token::Path
159 | Token::From
160 | Token::Vector
161 | Token::Hybrid
162 )
163 }
164
165 pub fn parse_metadata_filter(&mut self) -> Result<MetadataFilter, ParseError> {
167 self.parse_metadata_or_expr()
168 }
169
170 fn parse_metadata_or_expr(&mut self) -> Result<MetadataFilter, ParseError> {
172 let mut left = self.parse_metadata_and_expr()?;
173
174 while self.consume(&Token::Or)? {
175 let right = self.parse_metadata_and_expr()?;
176 left = MetadataFilter::Or(vec![left, right]);
177 }
178
179 Ok(left)
180 }
181
182 fn parse_metadata_and_expr(&mut self) -> Result<MetadataFilter, ParseError> {
184 let mut left = self.parse_metadata_primary()?;
185
186 while self.consume(&Token::And)? {
187 let right = self.parse_metadata_primary()?;
188 left = MetadataFilter::And(vec![left, right]);
189 }
190
191 Ok(left)
192 }
193
194 fn parse_metadata_primary(&mut self) -> Result<MetadataFilter, ParseError> {
196 if self.consume(&Token::LParen)? {
198 let expr = self.parse_metadata_filter()?;
199 self.expect(Token::RParen)?;
200 return Ok(expr);
201 }
202
203 let field = self.expect_ident()?;
205
206 if self.consume(&Token::Eq)? {
208 let value = self.parse_metadata_value()?;
209 Ok(MetadataFilter::Eq(field, value))
210 } else if self.consume(&Token::Ne)? {
211 let value = self.parse_metadata_value()?;
212 Ok(MetadataFilter::Ne(field, value))
213 } else if self.consume(&Token::Lt)? {
214 let value = self.parse_metadata_value()?;
215 Ok(MetadataFilter::Lt(field, value))
216 } else if self.consume(&Token::Le)? {
217 let value = self.parse_metadata_value()?;
218 Ok(MetadataFilter::Lte(field, value))
219 } else if self.consume(&Token::Gt)? {
220 let value = self.parse_metadata_value()?;
221 Ok(MetadataFilter::Gt(field, value))
222 } else if self.consume(&Token::Ge)? {
223 let value = self.parse_metadata_value()?;
224 Ok(MetadataFilter::Gte(field, value))
225 } else if self.consume(&Token::In)? {
226 self.expect(Token::LParen)?;
227 let values = self.parse_metadata_value_list()?;
228 self.expect(Token::RParen)?;
229 Ok(MetadataFilter::In(field, values))
230 } else if self.consume(&Token::Not)? {
231 self.expect(Token::In)?;
232 self.expect(Token::LParen)?;
233 let values = self.parse_metadata_value_list()?;
234 self.expect(Token::RParen)?;
235 Ok(MetadataFilter::NotIn(field, values))
236 } else if self.consume(&Token::Contains)? {
237 let value = self.parse_string()?;
238 Ok(MetadataFilter::Contains(field, value))
239 } else {
240 Err(ParseError::expected(
241 vec!["=", "<>", "<", "<=", ">", ">=", "IN", "NOT IN", "CONTAINS"],
242 self.peek(),
243 self.position(),
244 ))
245 }
246 }
247
248 fn parse_metadata_value(&mut self) -> Result<MetadataValue, ParseError> {
250 match self.peek() {
251 Token::String(_) => {
252 let s = self.parse_string()?;
253 Ok(MetadataValue::String(s))
254 }
255 Token::Integer(_) => {
256 let n = self.parse_integer()?;
257 Ok(MetadataValue::Integer(n))
258 }
259 Token::Float(_) => {
260 let n = self.parse_float()?;
261 Ok(MetadataValue::Float(n))
262 }
263 Token::True => {
264 self.advance()?;
265 Ok(MetadataValue::Bool(true))
266 }
267 Token::False => {
268 self.advance()?;
269 Ok(MetadataValue::Bool(false))
270 }
271 other => Err(ParseError::expected(
272 vec!["string", "number", "true", "false"],
273 other,
274 self.position(),
275 )),
276 }
277 }
278
279 fn parse_metadata_value_list(&mut self) -> Result<Vec<MetadataValue>, ParseError> {
281 let mut values = Vec::new();
282 loop {
283 values.push(self.parse_metadata_value()?);
284 if !self.consume(&Token::Comma)? {
285 break;
286 }
287 }
288 Ok(values)
289 }
290
291 pub fn parse_distance_metric(&mut self) -> Result<DistanceMetric, ParseError> {
293 match self.peek() {
294 Token::L2 => {
295 self.advance()?;
296 Ok(DistanceMetric::L2)
297 }
298 Token::Cosine => {
299 self.advance()?;
300 Ok(DistanceMetric::Cosine)
301 }
302 Token::InnerProduct => {
303 self.advance()?;
304 Ok(DistanceMetric::InnerProduct)
305 }
306 Token::Ident(name) => {
307 let name_upper = name.to_uppercase();
308 let name_clone = name.clone();
309 self.advance()?;
310 match name_upper.as_str() {
311 "L2" | "EUCLIDEAN" => Ok(DistanceMetric::L2),
312 "COSINE" | "COS" => Ok(DistanceMetric::Cosine),
313 "INNER_PRODUCT" | "IP" | "DOT" => Ok(DistanceMetric::InnerProduct),
314 _ => Err(ParseError::new(
315 format!(
316 "Unknown distance metric: {}. Valid: L2, COSINE, INNER_PRODUCT",
317 name_clone
318 ),
319 self.position(),
320 )),
321 }
322 }
323 other => Err(ParseError::expected(
324 vec!["L2", "COSINE", "INNER_PRODUCT"],
325 other,
326 self.position(),
327 )),
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 fn parse_query(input: &str) -> Result<QueryExpr, ParseError> {
337 crate::parser::parse(input).map(|query| query.query)
338 }
339
340 #[test]
341 fn vector_query_uses_defaults_for_bare_identifier_source() {
342 let query = parse_query("VECTOR SEARCH embeddings SIMILAR TO nearest_neighbor").unwrap();
343
344 let QueryExpr::Vector(vector) = query else {
345 panic!("expected vector query");
346 };
347 assert_eq!(vector.collection, "embeddings");
348 assert_eq!(vector.k, 10);
349 assert!(vector.filter.is_none());
350 assert_eq!(vector.metric, None);
351 assert_eq!(vector.threshold, None);
352 assert!(!vector.include_vectors);
353 assert!(!vector.include_metadata);
354 assert!(matches!(
355 vector.query_vector,
356 VectorSource::Text(text) if text == "nearest_neighbor"
357 ));
358 }
359
360 #[test]
361 fn vector_query_parses_reference_sources_and_k_alias() {
362 let query =
363 parse_query("VECTOR SEARCH embeddings SIMILAR TO docs(42) INCLUDE METADATA K = 7")
364 .unwrap();
365 let QueryExpr::Vector(vector) = query else {
366 panic!("expected vector query");
367 };
368 assert_eq!(vector.k, 7);
369 assert!(vector.include_metadata);
370 assert!(matches!(
371 vector.query_vector,
372 VectorSource::Reference {
373 collection,
374 vector_id,
375 } if collection == "docs" && vector_id == 42
376 ));
377
378 let query =
379 parse_query("VECTOR SEARCH embeddings SIMILAR TO (archive, 99) LIMIT 4").unwrap();
380 let QueryExpr::Vector(vector) = query else {
381 panic!("expected vector query");
382 };
383 assert_eq!(vector.k, 4);
384 assert!(matches!(
385 vector.query_vector,
386 VectorSource::Reference {
387 collection,
388 vector_id,
389 } if collection == "archive" && vector_id == 99
390 ));
391 }
392
393 #[test]
394 fn vector_query_parses_subquery_source() {
395 let query =
396 parse_query("VECTOR SEARCH docs SIMILAR TO (SELECT id FROM seeds) LIMIT 2").unwrap();
397
398 let QueryExpr::Vector(vector) = query else {
399 panic!("expected vector query");
400 };
401 assert_eq!(vector.collection, "docs");
402 assert_eq!(vector.k, 2);
403 match vector.query_vector {
404 VectorSource::Subquery(expr) => match *expr {
405 QueryExpr::Table(table) => assert_eq!(table.table, "seeds"),
406 other => panic!("expected table subquery, got {other:?}"),
407 },
408 other => panic!("expected subquery source, got {other:?}"),
409 }
410 }
411
412 #[test]
413 fn vector_query_parses_filter_sets_metric_threshold_and_includes() {
414 let query = parse_query(
415 "VECTOR SEARCH docs SIMILAR TO [0.1, 0.2] \
416 WHERE (source IN ('nmap', 'nessus') OR severity NOT IN (1, 2)) \
417 AND archived = false METRIC DOT THRESHOLD 0.25 INCLUDE VECTORS LIMIT 3",
418 )
419 .unwrap();
420
421 let QueryExpr::Vector(vector) = query else {
422 panic!("expected vector query");
423 };
424 assert_eq!(vector.k, 3);
425 assert_eq!(vector.metric, Some(DistanceMetric::InnerProduct));
426 assert_eq!(vector.threshold, Some(0.25));
427 assert!(vector.include_vectors);
428 assert!(
429 matches!(vector.query_vector, VectorSource::Literal(values) if values == vec![0.1, 0.2])
430 );
431
432 let Some(MetadataFilter::And(and_parts)) = vector.filter else {
433 panic!("expected AND filter");
434 };
435 assert_eq!(and_parts.len(), 2);
436 match &and_parts[0] {
437 MetadataFilter::Or(or_parts) => {
438 assert_eq!(or_parts.len(), 2);
439 assert!(matches!(
440 &or_parts[0],
441 MetadataFilter::In(field, values)
442 if field == "source"
443 && values == &vec![
444 MetadataValue::String("nmap".to_string()),
445 MetadataValue::String("nessus".to_string())
446 ]
447 ));
448 assert!(matches!(
449 &or_parts[1],
450 MetadataFilter::NotIn(field, values)
451 if field == "severity"
452 && values == &vec![MetadataValue::Integer(1), MetadataValue::Integer(2)]
453 ));
454 }
455 other => panic!("expected OR filter, got {other:?}"),
456 }
457 assert!(matches!(
458 &and_parts[1],
459 MetadataFilter::Eq(field, MetadataValue::Bool(false)) if field == "archived"
460 ));
461 }
462
463 #[test]
464 fn metadata_filter_parses_comparisons_and_contains() {
465 let query = parse_query(
466 "VECTOR SEARCH docs SIMILAR TO [0.3] \
467 WHERE score < 0.7 OR rank >= 10 AND title CONTAINS 'redis'",
468 )
469 .unwrap();
470
471 let QueryExpr::Vector(vector) = query else {
472 panic!("expected vector query");
473 };
474 let Some(MetadataFilter::Or(or_parts)) = vector.filter else {
475 panic!("expected OR filter");
476 };
477 assert_eq!(or_parts.len(), 2);
478 assert!(matches!(
479 &or_parts[0],
480 MetadataFilter::Lt(field, MetadataValue::Float(value))
481 if field == "score" && (*value - 0.7).abs() < f64::EPSILON
482 ));
483 match &or_parts[1] {
484 MetadataFilter::And(and_parts) => {
485 assert_eq!(and_parts.len(), 2);
486 assert!(matches!(
487 &and_parts[0],
488 MetadataFilter::Gte(field, MetadataValue::Integer(10)) if field == "rank"
489 ));
490 assert!(matches!(
491 &and_parts[1],
492 MetadataFilter::Contains(field, value)
493 if field == "title" && value == "redis"
494 ));
495 }
496 other => panic!("expected AND filter, got {other:?}"),
497 }
498 }
499
500 #[test]
501 fn vector_parser_reports_malformed_queries() {
502 for sql in [
503 "VECTOR SEARCH docs SIMILAR TO []",
504 "VECTOR SEARCH docs SIMILAR TO [0.1] INCLUDE SCORES",
505 "VECTOR SEARCH docs SIMILAR TO [0.1] METRIC MANHATTAN",
506 "VECTOR SEARCH docs SIMILAR TO [0.1] WHERE source",
507 "VECTOR SEARCH docs SIMILAR TO (docs)",
508 ] {
509 assert!(parse_query(sql).is_err(), "{sql} should not parse");
510 }
511 }
512}