Skip to main content

reddb_server/storage/query/parser/
vector.rs

1//! Vector query parsing (VECTOR SEARCH ... SIMILAR TO ...)
2
3use super::super::ast::{QueryExpr, VectorQuery, VectorSource};
4use super::super::lexer::Token;
5use super::error::ParseError;
6use super::Parser;
7use crate::storage::engine::distance::DistanceMetric;
8use crate::storage::engine::vector_metadata::{MetadataFilter, MetadataValue};
9
10impl<'a> Parser<'a> {
11    /// Parse VECTOR SEARCH ... SIMILAR TO ... query
12    ///
13    /// Syntax:
14    /// ```text
15    /// VECTOR SEARCH collection
16    /// SIMILAR TO [0.1, 0.2, ...] | 'text query' | (subquery)
17    /// [WHERE metadata conditions]
18    /// [METRIC L2|COSINE|INNER_PRODUCT]
19    /// [THRESHOLD 0.5]
20    /// [INCLUDE VECTORS] [INCLUDE METADATA]
21    /// [LIMIT k]
22    /// ```
23    pub fn parse_vector_query(&mut self) -> Result<QueryExpr, ParseError> {
24        self.expect(Token::Vector)?;
25        self.expect(Token::Search)?;
26
27        // Collection name
28        let collection = self.expect_ident()?;
29
30        // SIMILAR TO clause
31        self.expect(Token::Similar)?;
32        self.expect(Token::To)?;
33
34        let query_vector = self.parse_vector_source()?;
35
36        // Parse optional clauses
37        let mut filter: Option<MetadataFilter> = None;
38        let mut metric: Option<DistanceMetric> = None;
39        let mut threshold: Option<f32> = None;
40        let mut include_vectors = false;
41        let mut include_metadata = false;
42        let mut k: usize = 10; // Default
43
44        // Parse optional clauses in any order
45        loop {
46            if self.consume(&Token::Where)? {
47                filter = Some(self.parse_metadata_filter()?);
48            } else if self.consume(&Token::Metric)? {
49                metric = Some(self.parse_distance_metric()?);
50            } else if self.consume(&Token::Threshold)? {
51                threshold = Some(self.parse_float()? as f32);
52            } else if self.consume(&Token::Include)? {
53                if self.consume(&Token::Vectors)? {
54                    include_vectors = true;
55                } else if self.consume(&Token::Metadata)? {
56                    include_metadata = true;
57                } else {
58                    return Err(ParseError::expected(
59                        vec!["VECTORS", "METADATA"],
60                        self.peek(),
61                        self.position(),
62                    ));
63                }
64            } else if self.consume(&Token::Limit)? {
65                k = self.parse_integer()? as usize;
66            } else if self.consume(&Token::K)? {
67                // Alternative: K = 10
68                self.expect(Token::Eq)?;
69                k = self.parse_integer()? as usize;
70            } else {
71                break;
72            }
73        }
74
75        Ok(QueryExpr::Vector(VectorQuery {
76            alias: None,
77            collection,
78            query_vector,
79            k,
80            filter,
81            metric,
82            include_vectors,
83            include_metadata,
84            threshold,
85        }))
86    }
87
88    /// Parse vector source: literal array, text, reference, or subquery
89    pub fn parse_vector_source(&mut self) -> Result<VectorSource, ParseError> {
90        match self.peek() {
91            // Literal vector: [0.1, 0.2, 0.3]
92            Token::LBracket => {
93                self.advance()?;
94                let mut values = Vec::new();
95                loop {
96                    let value = self.parse_float()?;
97                    values.push(value as f32);
98                    if !self.consume(&Token::Comma)? {
99                        break;
100                    }
101                }
102                self.expect(Token::RBracket)?;
103                Ok(VectorSource::Literal(values))
104            }
105            // Text query: 'find similar vulnerabilities'
106            Token::String(_) => {
107                let text = self.parse_string()?;
108                Ok(VectorSource::Text(text))
109            }
110            // Parenthesized source: subquery or (collection, vector_id) reference
111            Token::LParen => {
112                self.advance()?;
113                if self.vector_source_starts_subquery() {
114                    let expr = self.parse_query_expr()?;
115                    self.expect(Token::RParen)?;
116                    Ok(VectorSource::Subquery(Box::new(expr)))
117                } else {
118                    // Reference: (collection, vector_id)
119                    let collection = self.expect_ident()?;
120                    self.expect(Token::Comma)?;
121                    let vector_id = self.parse_integer()? as u64;
122                    self.expect(Token::RParen)?;
123                    Ok(VectorSource::Reference {
124                        collection,
125                        vector_id,
126                    })
127                }
128            }
129            // Reference by name: embedding_name
130            Token::Ident(_) => {
131                let name = self.expect_ident()?;
132                // Check for (collection, id) format
133                if self.consume(&Token::LParen)? {
134                    let vector_id = self.parse_integer()? as u64;
135                    self.expect(Token::RParen)?;
136                    Ok(VectorSource::Reference {
137                        collection: name,
138                        vector_id,
139                    })
140                } else {
141                    // Just a name reference, treat as text
142                    Ok(VectorSource::Text(name))
143                }
144            }
145            other => Err(ParseError::expected(
146                vec!["vector literal [...]", "string", "reference"],
147                other,
148                self.position(),
149            )),
150        }
151    }
152
153    fn vector_source_starts_subquery(&self) -> bool {
154        matches!(
155            self.peek(),
156            Token::Select
157                | Token::Match
158                | Token::Path
159                | Token::From
160                | Token::Vector
161                | Token::Hybrid
162        )
163    }
164
165    /// Parse metadata filter for vector queries
166    pub fn parse_metadata_filter(&mut self) -> Result<MetadataFilter, ParseError> {
167        self.parse_metadata_or_expr()
168    }
169
170    /// Parse OR expression in metadata filter
171    fn parse_metadata_or_expr(&mut self) -> Result<MetadataFilter, ParseError> {
172        let mut left = self.parse_metadata_and_expr()?;
173
174        while self.consume(&Token::Or)? {
175            let right = self.parse_metadata_and_expr()?;
176            left = MetadataFilter::Or(vec![left, right]);
177        }
178
179        Ok(left)
180    }
181
182    /// Parse AND expression in metadata filter
183    fn parse_metadata_and_expr(&mut self) -> Result<MetadataFilter, ParseError> {
184        let mut left = self.parse_metadata_primary()?;
185
186        while self.consume(&Token::And)? {
187            let right = self.parse_metadata_primary()?;
188            left = MetadataFilter::And(vec![left, right]);
189        }
190
191        Ok(left)
192    }
193
194    /// Parse primary metadata filter
195    fn parse_metadata_primary(&mut self) -> Result<MetadataFilter, ParseError> {
196        // Parenthesized expression
197        if self.consume(&Token::LParen)? {
198            let expr = self.parse_metadata_filter()?;
199            self.expect(Token::RParen)?;
200            return Ok(expr);
201        }
202
203        // field op value
204        let field = self.expect_ident()?;
205
206        // Handle different operators
207        if self.consume(&Token::Eq)? {
208            let value = self.parse_metadata_value()?;
209            Ok(MetadataFilter::Eq(field, value))
210        } else if self.consume(&Token::Ne)? {
211            let value = self.parse_metadata_value()?;
212            Ok(MetadataFilter::Ne(field, value))
213        } else if self.consume(&Token::Lt)? {
214            let value = self.parse_metadata_value()?;
215            Ok(MetadataFilter::Lt(field, value))
216        } else if self.consume(&Token::Le)? {
217            let value = self.parse_metadata_value()?;
218            Ok(MetadataFilter::Lte(field, value))
219        } else if self.consume(&Token::Gt)? {
220            let value = self.parse_metadata_value()?;
221            Ok(MetadataFilter::Gt(field, value))
222        } else if self.consume(&Token::Ge)? {
223            let value = self.parse_metadata_value()?;
224            Ok(MetadataFilter::Gte(field, value))
225        } else if self.consume(&Token::In)? {
226            self.expect(Token::LParen)?;
227            let values = self.parse_metadata_value_list()?;
228            self.expect(Token::RParen)?;
229            Ok(MetadataFilter::In(field, values))
230        } else if self.consume(&Token::Not)? {
231            self.expect(Token::In)?;
232            self.expect(Token::LParen)?;
233            let values = self.parse_metadata_value_list()?;
234            self.expect(Token::RParen)?;
235            Ok(MetadataFilter::NotIn(field, values))
236        } else if self.consume(&Token::Contains)? {
237            let value = self.parse_string()?;
238            Ok(MetadataFilter::Contains(field, value))
239        } else {
240            Err(ParseError::expected(
241                vec!["=", "<>", "<", "<=", ">", ">=", "IN", "NOT IN", "CONTAINS"],
242                self.peek(),
243                self.position(),
244            ))
245        }
246    }
247
248    /// Parse metadata value
249    fn parse_metadata_value(&mut self) -> Result<MetadataValue, ParseError> {
250        match self.peek() {
251            Token::String(_) => {
252                let s = self.parse_string()?;
253                Ok(MetadataValue::String(s))
254            }
255            Token::Integer(_) => {
256                let n = self.parse_integer()?;
257                Ok(MetadataValue::Integer(n))
258            }
259            Token::Float(_) => {
260                let n = self.parse_float()?;
261                Ok(MetadataValue::Float(n))
262            }
263            Token::True => {
264                self.advance()?;
265                Ok(MetadataValue::Bool(true))
266            }
267            Token::False => {
268                self.advance()?;
269                Ok(MetadataValue::Bool(false))
270            }
271            other => Err(ParseError::expected(
272                vec!["string", "number", "true", "false"],
273                other,
274                self.position(),
275            )),
276        }
277    }
278
279    /// Parse list of metadata values
280    fn parse_metadata_value_list(&mut self) -> Result<Vec<MetadataValue>, ParseError> {
281        let mut values = Vec::new();
282        loop {
283            values.push(self.parse_metadata_value()?);
284            if !self.consume(&Token::Comma)? {
285                break;
286            }
287        }
288        Ok(values)
289    }
290
291    /// Parse distance metric
292    pub fn parse_distance_metric(&mut self) -> Result<DistanceMetric, ParseError> {
293        match self.peek() {
294            Token::L2 => {
295                self.advance()?;
296                Ok(DistanceMetric::L2)
297            }
298            Token::Cosine => {
299                self.advance()?;
300                Ok(DistanceMetric::Cosine)
301            }
302            Token::InnerProduct => {
303                self.advance()?;
304                Ok(DistanceMetric::InnerProduct)
305            }
306            Token::Ident(name) => {
307                let name_upper = name.to_uppercase();
308                let name_clone = name.clone();
309                self.advance()?;
310                match name_upper.as_str() {
311                    "L2" | "EUCLIDEAN" => Ok(DistanceMetric::L2),
312                    "COSINE" | "COS" => Ok(DistanceMetric::Cosine),
313                    "INNER_PRODUCT" | "IP" | "DOT" => Ok(DistanceMetric::InnerProduct),
314                    _ => Err(ParseError::new(
315                        format!(
316                            "Unknown distance metric: {}. Valid: L2, COSINE, INNER_PRODUCT",
317                            name_clone
318                        ),
319                        self.position(),
320                    )),
321                }
322            }
323            other => Err(ParseError::expected(
324                vec!["L2", "COSINE", "INNER_PRODUCT"],
325                other,
326                self.position(),
327            )),
328        }
329    }
330}