Skip to main content

reddb_rql/parser/
vector.rs

1//! Vector query parsing (VECTOR SEARCH ... SIMILAR TO ...)
2
3use super::error::ParseError;
4use super::Parser;
5use crate::ast::{QueryExpr, VectorQuery, VectorSource};
6use crate::lexer::Token;
7use reddb_types::distance::DistanceMetric;
8use reddb_types::vector_metadata::{MetadataFilter, MetadataValue};
9
10impl<'a> Parser<'a> {
11    /// Parse VECTOR SEARCH ... SIMILAR TO ... query
12    ///
13    /// Syntax:
14    /// ```text
15    /// VECTOR SEARCH collection
16    /// SIMILAR TO [0.1, 0.2, ...] | 'text query' | (subquery)
17    /// [WHERE metadata conditions]
18    /// [METRIC L2|COSINE|INNER_PRODUCT]
19    /// [THRESHOLD 0.5]
20    /// [INCLUDE VECTORS] [INCLUDE METADATA]
21    /// [LIMIT k]
22    /// ```
23    pub fn parse_vector_query(&mut self) -> Result<QueryExpr, ParseError> {
24        self.expect(Token::Vector)?;
25        self.expect(Token::Search)?;
26
27        // Collection name
28        let collection = self.expect_ident()?;
29
30        // SIMILAR TO clause
31        self.expect(Token::Similar)?;
32        self.expect(Token::To)?;
33
34        let query_vector = self.parse_vector_source()?;
35
36        // Parse optional clauses
37        let mut filter: Option<MetadataFilter> = None;
38        let mut metric: Option<DistanceMetric> = None;
39        let mut threshold: Option<f32> = None;
40        let mut include_vectors = false;
41        let mut include_metadata = false;
42        let mut k: usize = 10; // Default
43
44        // Parse optional clauses in any order
45        loop {
46            if self.consume(&Token::Where)? {
47                filter = Some(self.parse_metadata_filter()?);
48            } else if self.consume(&Token::Metric)? {
49                metric = Some(self.parse_distance_metric()?);
50            } else if self.consume(&Token::Threshold)? {
51                threshold = Some(self.parse_float()? as f32);
52            } else if self.consume(&Token::Include)? {
53                if self.consume(&Token::Vectors)? {
54                    include_vectors = true;
55                } else if self.consume(&Token::Metadata)? {
56                    include_metadata = true;
57                } else {
58                    return Err(ParseError::expected(
59                        vec!["VECTORS", "METADATA"],
60                        self.peek(),
61                        self.position(),
62                    ));
63                }
64            } else if self.consume(&Token::Limit)? {
65                k = self.parse_integer()? as usize;
66            } else if self.consume(&Token::K)? {
67                // Alternative: K = 10
68                self.expect(Token::Eq)?;
69                k = self.parse_integer()? as usize;
70            } else {
71                break;
72            }
73        }
74
75        Ok(QueryExpr::Vector(VectorQuery {
76            alias: None,
77            collection,
78            query_vector,
79            k,
80            filter,
81            metric,
82            include_vectors,
83            include_metadata,
84            threshold,
85        }))
86    }
87
88    /// Parse vector source: literal array, text, reference, or subquery
89    pub fn parse_vector_source(&mut self) -> Result<VectorSource, ParseError> {
90        match self.peek() {
91            // Literal vector: [0.1, 0.2, 0.3]
92            Token::LBracket => {
93                self.advance()?;
94                let mut values = Vec::new();
95                loop {
96                    let value = self.parse_float()?;
97                    values.push(value as f32);
98                    if !self.consume(&Token::Comma)? {
99                        break;
100                    }
101                }
102                self.expect(Token::RBracket)?;
103                Ok(VectorSource::Literal(values))
104            }
105            // Text query: 'find similar vulnerabilities'
106            Token::String(_) => {
107                let text = self.parse_string()?;
108                Ok(VectorSource::Text(text))
109            }
110            // Parenthesized source: subquery or (collection, vector_id) reference
111            Token::LParen => {
112                self.advance()?;
113                if self.vector_source_starts_subquery() {
114                    let expr = self.parse_query_expr()?;
115                    self.expect(Token::RParen)?;
116                    Ok(VectorSource::Subquery(Box::new(expr)))
117                } else {
118                    // Reference: (collection, vector_id)
119                    let collection = self.expect_ident()?;
120                    self.expect(Token::Comma)?;
121                    let vector_id = self.parse_integer()? as u64;
122                    self.expect(Token::RParen)?;
123                    Ok(VectorSource::Reference {
124                        collection,
125                        vector_id,
126                    })
127                }
128            }
129            // Reference by name: embedding_name
130            Token::Ident(_) => {
131                let name = self.expect_ident()?;
132                // Check for (collection, id) format
133                if self.consume(&Token::LParen)? {
134                    let vector_id = self.parse_integer()? as u64;
135                    self.expect(Token::RParen)?;
136                    Ok(VectorSource::Reference {
137                        collection: name,
138                        vector_id,
139                    })
140                } else {
141                    // Just a name reference, treat as text
142                    Ok(VectorSource::Text(name))
143                }
144            }
145            other => Err(ParseError::expected(
146                vec!["vector literal [...]", "string", "reference"],
147                other,
148                self.position(),
149            )),
150        }
151    }
152
153    fn vector_source_starts_subquery(&self) -> bool {
154        matches!(
155            self.peek(),
156            Token::Select
157                | Token::Match
158                | Token::Path
159                | Token::From
160                | Token::Vector
161                | Token::Hybrid
162        )
163    }
164
165    /// Parse metadata filter for vector queries
166    pub fn parse_metadata_filter(&mut self) -> Result<MetadataFilter, ParseError> {
167        self.parse_metadata_or_expr()
168    }
169
170    /// Parse OR expression in metadata filter
171    fn parse_metadata_or_expr(&mut self) -> Result<MetadataFilter, ParseError> {
172        let mut left = self.parse_metadata_and_expr()?;
173
174        while self.consume(&Token::Or)? {
175            let right = self.parse_metadata_and_expr()?;
176            left = MetadataFilter::Or(vec![left, right]);
177        }
178
179        Ok(left)
180    }
181
182    /// Parse AND expression in metadata filter
183    fn parse_metadata_and_expr(&mut self) -> Result<MetadataFilter, ParseError> {
184        let mut left = self.parse_metadata_primary()?;
185
186        while self.consume(&Token::And)? {
187            let right = self.parse_metadata_primary()?;
188            left = MetadataFilter::And(vec![left, right]);
189        }
190
191        Ok(left)
192    }
193
194    /// Parse primary metadata filter
195    fn parse_metadata_primary(&mut self) -> Result<MetadataFilter, ParseError> {
196        // Parenthesized expression
197        if self.consume(&Token::LParen)? {
198            let expr = self.parse_metadata_filter()?;
199            self.expect(Token::RParen)?;
200            return Ok(expr);
201        }
202
203        // field op value
204        let field = self.expect_ident()?;
205
206        // Handle different operators
207        if self.consume(&Token::Eq)? {
208            let value = self.parse_metadata_value()?;
209            Ok(MetadataFilter::Eq(field, value))
210        } else if self.consume(&Token::Ne)? {
211            let value = self.parse_metadata_value()?;
212            Ok(MetadataFilter::Ne(field, value))
213        } else if self.consume(&Token::Lt)? {
214            let value = self.parse_metadata_value()?;
215            Ok(MetadataFilter::Lt(field, value))
216        } else if self.consume(&Token::Le)? {
217            let value = self.parse_metadata_value()?;
218            Ok(MetadataFilter::Lte(field, value))
219        } else if self.consume(&Token::Gt)? {
220            let value = self.parse_metadata_value()?;
221            Ok(MetadataFilter::Gt(field, value))
222        } else if self.consume(&Token::Ge)? {
223            let value = self.parse_metadata_value()?;
224            Ok(MetadataFilter::Gte(field, value))
225        } else if self.consume(&Token::In)? {
226            self.expect(Token::LParen)?;
227            let values = self.parse_metadata_value_list()?;
228            self.expect(Token::RParen)?;
229            Ok(MetadataFilter::In(field, values))
230        } else if self.consume(&Token::Not)? {
231            self.expect(Token::In)?;
232            self.expect(Token::LParen)?;
233            let values = self.parse_metadata_value_list()?;
234            self.expect(Token::RParen)?;
235            Ok(MetadataFilter::NotIn(field, values))
236        } else if self.consume(&Token::Contains)? {
237            let value = self.parse_string()?;
238            Ok(MetadataFilter::Contains(field, value))
239        } else {
240            Err(ParseError::expected(
241                vec!["=", "<>", "<", "<=", ">", ">=", "IN", "NOT IN", "CONTAINS"],
242                self.peek(),
243                self.position(),
244            ))
245        }
246    }
247
248    /// Parse metadata value
249    fn parse_metadata_value(&mut self) -> Result<MetadataValue, ParseError> {
250        match self.peek() {
251            Token::String(_) => {
252                let s = self.parse_string()?;
253                Ok(MetadataValue::String(s))
254            }
255            Token::Integer(_) => {
256                let n = self.parse_integer()?;
257                Ok(MetadataValue::Integer(n))
258            }
259            Token::Float(_) => {
260                let n = self.parse_float()?;
261                Ok(MetadataValue::Float(n))
262            }
263            Token::True => {
264                self.advance()?;
265                Ok(MetadataValue::Bool(true))
266            }
267            Token::False => {
268                self.advance()?;
269                Ok(MetadataValue::Bool(false))
270            }
271            other => Err(ParseError::expected(
272                vec!["string", "number", "true", "false"],
273                other,
274                self.position(),
275            )),
276        }
277    }
278
279    /// Parse list of metadata values
280    fn parse_metadata_value_list(&mut self) -> Result<Vec<MetadataValue>, ParseError> {
281        let mut values = Vec::new();
282        loop {
283            values.push(self.parse_metadata_value()?);
284            if !self.consume(&Token::Comma)? {
285                break;
286            }
287        }
288        Ok(values)
289    }
290
291    /// Parse distance metric
292    pub fn parse_distance_metric(&mut self) -> Result<DistanceMetric, ParseError> {
293        match self.peek() {
294            Token::L2 => {
295                self.advance()?;
296                Ok(DistanceMetric::L2)
297            }
298            Token::Cosine => {
299                self.advance()?;
300                Ok(DistanceMetric::Cosine)
301            }
302            Token::InnerProduct => {
303                self.advance()?;
304                Ok(DistanceMetric::InnerProduct)
305            }
306            Token::Ident(name) => {
307                let name_upper = name.to_uppercase();
308                let name_clone = name.clone();
309                self.advance()?;
310                match name_upper.as_str() {
311                    "L2" | "EUCLIDEAN" => Ok(DistanceMetric::L2),
312                    "COSINE" | "COS" => Ok(DistanceMetric::Cosine),
313                    "INNER_PRODUCT" | "IP" | "DOT" => Ok(DistanceMetric::InnerProduct),
314                    _ => Err(ParseError::new(
315                        format!(
316                            "Unknown distance metric: {}. Valid: L2, COSINE, INNER_PRODUCT",
317                            name_clone
318                        ),
319                        self.position(),
320                    )),
321                }
322            }
323            other => Err(ParseError::expected(
324                vec!["L2", "COSINE", "INNER_PRODUCT"],
325                other,
326                self.position(),
327            )),
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    fn parse_query(input: &str) -> Result<QueryExpr, ParseError> {
337        crate::parser::parse(input).map(|query| query.query)
338    }
339
340    #[test]
341    fn vector_query_uses_defaults_for_bare_identifier_source() {
342        let query = parse_query("VECTOR SEARCH embeddings SIMILAR TO nearest_neighbor").unwrap();
343
344        let QueryExpr::Vector(vector) = query else {
345            panic!("expected vector query");
346        };
347        assert_eq!(vector.collection, "embeddings");
348        assert_eq!(vector.k, 10);
349        assert!(vector.filter.is_none());
350        assert_eq!(vector.metric, None);
351        assert_eq!(vector.threshold, None);
352        assert!(!vector.include_vectors);
353        assert!(!vector.include_metadata);
354        assert!(matches!(
355            vector.query_vector,
356            VectorSource::Text(text) if text == "nearest_neighbor"
357        ));
358    }
359
360    #[test]
361    fn vector_query_parses_reference_sources_and_k_alias() {
362        let query =
363            parse_query("VECTOR SEARCH embeddings SIMILAR TO docs(42) INCLUDE METADATA K = 7")
364                .unwrap();
365        let QueryExpr::Vector(vector) = query else {
366            panic!("expected vector query");
367        };
368        assert_eq!(vector.k, 7);
369        assert!(vector.include_metadata);
370        assert!(matches!(
371            vector.query_vector,
372            VectorSource::Reference {
373                collection,
374                vector_id,
375            } if collection == "docs" && vector_id == 42
376        ));
377
378        let query =
379            parse_query("VECTOR SEARCH embeddings SIMILAR TO (archive, 99) LIMIT 4").unwrap();
380        let QueryExpr::Vector(vector) = query else {
381            panic!("expected vector query");
382        };
383        assert_eq!(vector.k, 4);
384        assert!(matches!(
385            vector.query_vector,
386            VectorSource::Reference {
387                collection,
388                vector_id,
389            } if collection == "archive" && vector_id == 99
390        ));
391    }
392
393    #[test]
394    fn vector_query_parses_subquery_source() {
395        let query =
396            parse_query("VECTOR SEARCH docs SIMILAR TO (SELECT id FROM seeds) LIMIT 2").unwrap();
397
398        let QueryExpr::Vector(vector) = query else {
399            panic!("expected vector query");
400        };
401        assert_eq!(vector.collection, "docs");
402        assert_eq!(vector.k, 2);
403        match vector.query_vector {
404            VectorSource::Subquery(expr) => match *expr {
405                QueryExpr::Table(table) => assert_eq!(table.table, "seeds"),
406                other => panic!("expected table subquery, got {other:?}"),
407            },
408            other => panic!("expected subquery source, got {other:?}"),
409        }
410    }
411
412    #[test]
413    fn vector_query_parses_filter_sets_metric_threshold_and_includes() {
414        let query = parse_query(
415            "VECTOR SEARCH docs SIMILAR TO [0.1, 0.2] \
416             WHERE (source IN ('nmap', 'nessus') OR severity NOT IN (1, 2)) \
417             AND archived = false METRIC DOT THRESHOLD 0.25 INCLUDE VECTORS LIMIT 3",
418        )
419        .unwrap();
420
421        let QueryExpr::Vector(vector) = query else {
422            panic!("expected vector query");
423        };
424        assert_eq!(vector.k, 3);
425        assert_eq!(vector.metric, Some(DistanceMetric::InnerProduct));
426        assert_eq!(vector.threshold, Some(0.25));
427        assert!(vector.include_vectors);
428        assert!(
429            matches!(vector.query_vector, VectorSource::Literal(values) if values == vec![0.1, 0.2])
430        );
431
432        let Some(MetadataFilter::And(and_parts)) = vector.filter else {
433            panic!("expected AND filter");
434        };
435        assert_eq!(and_parts.len(), 2);
436        match &and_parts[0] {
437            MetadataFilter::Or(or_parts) => {
438                assert_eq!(or_parts.len(), 2);
439                assert!(matches!(
440                    &or_parts[0],
441                    MetadataFilter::In(field, values)
442                        if field == "source"
443                            && values == &vec![
444                                MetadataValue::String("nmap".to_string()),
445                                MetadataValue::String("nessus".to_string())
446                            ]
447                ));
448                assert!(matches!(
449                    &or_parts[1],
450                    MetadataFilter::NotIn(field, values)
451                        if field == "severity"
452                            && values == &vec![MetadataValue::Integer(1), MetadataValue::Integer(2)]
453                ));
454            }
455            other => panic!("expected OR filter, got {other:?}"),
456        }
457        assert!(matches!(
458            &and_parts[1],
459            MetadataFilter::Eq(field, MetadataValue::Bool(false)) if field == "archived"
460        ));
461    }
462
463    #[test]
464    fn metadata_filter_parses_comparisons_and_contains() {
465        let query = parse_query(
466            "VECTOR SEARCH docs SIMILAR TO [0.3] \
467             WHERE score < 0.7 OR rank >= 10 AND title CONTAINS 'redis'",
468        )
469        .unwrap();
470
471        let QueryExpr::Vector(vector) = query else {
472            panic!("expected vector query");
473        };
474        let Some(MetadataFilter::Or(or_parts)) = vector.filter else {
475            panic!("expected OR filter");
476        };
477        assert_eq!(or_parts.len(), 2);
478        assert!(matches!(
479            &or_parts[0],
480            MetadataFilter::Lt(field, MetadataValue::Float(value))
481                if field == "score" && (*value - 0.7).abs() < f64::EPSILON
482        ));
483        match &or_parts[1] {
484            MetadataFilter::And(and_parts) => {
485                assert_eq!(and_parts.len(), 2);
486                assert!(matches!(
487                    &and_parts[0],
488                    MetadataFilter::Gte(field, MetadataValue::Integer(10)) if field == "rank"
489                ));
490                assert!(matches!(
491                    &and_parts[1],
492                    MetadataFilter::Contains(field, value)
493                        if field == "title" && value == "redis"
494                ));
495            }
496            other => panic!("expected AND filter, got {other:?}"),
497        }
498    }
499
500    #[test]
501    fn vector_parser_reports_malformed_queries() {
502        for sql in [
503            "VECTOR SEARCH docs SIMILAR TO []",
504            "VECTOR SEARCH docs SIMILAR TO [0.1] INCLUDE SCORES",
505            "VECTOR SEARCH docs SIMILAR TO [0.1] METRIC MANHATTAN",
506            "VECTOR SEARCH docs SIMILAR TO [0.1] WHERE source",
507            "VECTOR SEARCH docs SIMILAR TO (docs)",
508        ] {
509            assert!(parse_query(sql).is_err(), "{sql} should not parse");
510        }
511    }
512}