lance_graph/
ast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Abstract Syntax Tree for Cypher queries
5//!
6//! This module defines the AST nodes for representing parsed Cypher queries.
7//! The AST is designed to capture the essential graph patterns while being
8//! simple enough to translate to SQL efficiently.
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// A complete Cypher query
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct CypherQuery {
16    /// MATCH clauses
17    pub match_clauses: Vec<MatchClause>,
18    /// WHERE clause (optional)
19    pub where_clause: Option<WhereClause>,
20    /// RETURN clause
21    pub return_clause: ReturnClause,
22    /// LIMIT clause (optional)
23    pub limit: Option<u64>,
24    /// ORDER BY clause (optional)
25    pub order_by: Option<OrderByClause>,
26    /// SKIP/OFFSET clause (optional)
27    pub skip: Option<u64>,
28}
29
30impl CypherQuery {
31    /// Extract all node labels referenced in the query
32    pub fn get_node_labels(&self) -> Vec<String> {
33        let mut labels = Vec::new();
34        for match_clause in &self.match_clauses {
35            for pattern in &match_clause.patterns {
36                match pattern {
37                    GraphPattern::Node(node) => {
38                        for label in &node.labels {
39                            if !labels.contains(label) {
40                                labels.push(label.clone());
41                            }
42                        }
43                    }
44                    GraphPattern::Path(path) => {
45                        for label in &path.start_node.labels {
46                            if !labels.contains(label) {
47                                labels.push(label.clone());
48                            }
49                        }
50                        for segment in &path.segments {
51                            for label in &segment.end_node.labels {
52                                if !labels.contains(label) {
53                                    labels.push(label.clone());
54                                }
55                            }
56                        }
57                    }
58                }
59            }
60        }
61        labels
62    }
63
64    /// Extract all relationship types referenced in the query
65    pub fn get_relationship_types(&self) -> Vec<String> {
66        let mut types = Vec::new();
67        for match_clause in &self.match_clauses {
68            for pattern in &match_clause.patterns {
69                if let GraphPattern::Path(path) = pattern {
70                    for segment in &path.segments {
71                        for rel_type in &segment.relationship.types {
72                            if !types.contains(rel_type) {
73                                types.push(rel_type.clone());
74                            }
75                        }
76                    }
77                }
78            }
79        }
80        types
81    }
82}
83
84/// A MATCH clause containing graph patterns
85#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct MatchClause {
87    /// Graph patterns to match
88    pub patterns: Vec<GraphPattern>,
89}
90
91/// A graph pattern (nodes and relationships)
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub enum GraphPattern {
94    /// A single node pattern
95    Node(NodePattern),
96    /// A path pattern (node-relationship-node sequence)
97    Path(PathPattern),
98}
99
100/// A node pattern in a graph query
101#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
102pub struct NodePattern {
103    /// Variable name for the node (e.g., 'n' in (n:Person))
104    pub variable: Option<String>,
105    /// Node labels (e.g., ['Person', 'Employee'])
106    pub labels: Vec<String>,
107    /// Property constraints (e.g., {name: 'John', age: 30})
108    pub properties: HashMap<String, PropertyValue>,
109}
110
111/// A path pattern connecting nodes through relationships
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
113pub struct PathPattern {
114    /// Starting node
115    pub start_node: NodePattern,
116    /// Relationships and intermediate nodes
117    pub segments: Vec<PathSegment>,
118}
119
120/// A segment of a path (relationship + end node)
121#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
122pub struct PathSegment {
123    /// The relationship in this segment
124    pub relationship: RelationshipPattern,
125    /// The end node of this segment
126    pub end_node: NodePattern,
127}
128
129/// A relationship pattern
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131pub struct RelationshipPattern {
132    /// Variable name for the relationship (e.g., 'r' in [r:KNOWS])
133    pub variable: Option<String>,
134    /// Relationship types (e.g., ['KNOWS', 'FRIEND_OF'])
135    pub types: Vec<String>,
136    /// Direction of the relationship
137    pub direction: RelationshipDirection,
138    /// Property constraints on the relationship
139    pub properties: HashMap<String, PropertyValue>,
140    /// Length constraints (for variable-length paths)
141    pub length: Option<LengthRange>,
142}
143
144/// Direction of a relationship
145#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
146pub enum RelationshipDirection {
147    /// Outgoing relationship (->)
148    Outgoing,
149    /// Incoming relationship (<-)
150    Incoming,
151    /// Undirected relationship (-)
152    Undirected,
153}
154
155/// Length range for variable-length paths
156#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
157pub struct LengthRange {
158    /// Minimum length (inclusive)
159    pub min: Option<u32>,
160    /// Maximum length (inclusive)
161    pub max: Option<u32>,
162}
163
164/// Property value in patterns and expressions
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub enum PropertyValue {
167    /// String literal
168    String(String),
169    /// Integer literal
170    Integer(i64),
171    /// Float literal
172    Float(f64),
173    /// Boolean literal
174    Boolean(bool),
175    /// Null value
176    Null,
177    /// Parameter reference (e.g., $param)
178    Parameter(String),
179    /// Property reference (e.g., node.property)
180    Property(PropertyRef),
181}
182
183/// Reference to a property of a node or relationship
184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
185pub struct PropertyRef {
186    /// Variable name (e.g., 'n' in n.name)
187    pub variable: String,
188    /// Property name
189    pub property: String,
190}
191
192/// WHERE clause for filtering
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
194pub struct WhereClause {
195    /// Boolean expression for filtering
196    pub expression: BooleanExpression,
197}
198
199/// Boolean expressions in WHERE clauses
200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
201pub enum BooleanExpression {
202    /// Comparison operation (=, <>, <, >, <=, >=)
203    Comparison {
204        left: ValueExpression,
205        operator: ComparisonOperator,
206        right: ValueExpression,
207    },
208    /// Logical AND
209    And(Box<BooleanExpression>, Box<BooleanExpression>),
210    /// Logical OR
211    Or(Box<BooleanExpression>, Box<BooleanExpression>),
212    /// Logical NOT
213    Not(Box<BooleanExpression>),
214    /// Property existence check
215    Exists(PropertyRef),
216    /// IN clause
217    In {
218        expression: ValueExpression,
219        list: Vec<ValueExpression>,
220    },
221    /// LIKE pattern matching (case-sensitive)
222    Like {
223        expression: ValueExpression,
224        pattern: String,
225    },
226    /// ILIKE pattern matching (case-insensitive)
227    ILike {
228        expression: ValueExpression,
229        pattern: String,
230    },
231    /// CONTAINS substring matching
232    Contains {
233        expression: ValueExpression,
234        substring: String,
235    },
236    /// STARTS WITH prefix matching
237    StartsWith {
238        expression: ValueExpression,
239        prefix: String,
240    },
241    /// ENDS WITH suffix matching
242    EndsWith {
243        expression: ValueExpression,
244        suffix: String,
245    },
246    /// IS NULL pattern matching
247    IsNull(ValueExpression),
248    /// IS NOT NULL pattern matching
249    IsNotNull(ValueExpression),
250}
251
252/// Distance metric for vector similarity
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
254pub enum DistanceMetric {
255    /// Euclidean distance (L2)
256    L2,
257    /// Cosine similarity (1 - cosine distance)
258    #[default]
259    Cosine,
260    /// Dot product
261    Dot,
262}
263
264/// Comparison operators
265#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
266pub enum ComparisonOperator {
267    Equal,
268    NotEqual,
269    LessThan,
270    LessThanOrEqual,
271    GreaterThan,
272    GreaterThanOrEqual,
273}
274
275/// Value expressions (for comparisons and return values)
276#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
277pub enum ValueExpression {
278    /// Variable reference
279    Variable(String),
280    /// Property reference
281    Property(PropertyRef),
282    /// Literal value
283    Literal(PropertyValue),
284    /// Function call
285    Function {
286        name: String,
287        args: Vec<ValueExpression>,
288    },
289    /// Arithmetic operation
290    Arithmetic {
291        left: Box<ValueExpression>,
292        operator: ArithmeticOperator,
293        right: Box<ValueExpression>,
294    },
295    /// Vector distance function: vector_distance(left, right, metric)
296    /// Returns the distance as a float (lower = more similar for L2/Cosine)
297    VectorDistance {
298        left: Box<ValueExpression>,
299        right: Box<ValueExpression>,
300        metric: DistanceMetric,
301    },
302    /// Vector similarity function: vector_similarity(left, right, metric)
303    /// Returns the similarity score as a float (higher = more similar)
304    VectorSimilarity {
305        left: Box<ValueExpression>,
306        right: Box<ValueExpression>,
307        metric: DistanceMetric,
308    },
309    /// Parameter reference for query parameters (e.g., $query_vector)
310    Parameter(String),
311    /// Vector literal: [0.1, 0.2, 0.3]
312    /// Represents an inline vector for similarity search
313    VectorLiteral(Vec<f32>),
314}
315
316/// Arithmetic operators
317#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
318pub enum ArithmeticOperator {
319    Add,
320    Subtract,
321    Multiply,
322    Divide,
323    Modulo,
324}
325
326/// RETURN clause specifying what to return
327#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
328pub struct ReturnClause {
329    /// Whether DISTINCT was specified
330    pub distinct: bool,
331    /// Items to return
332    pub items: Vec<ReturnItem>,
333}
334
335/// An item in the RETURN clause
336#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
337pub struct ReturnItem {
338    /// The expression to return
339    pub expression: ValueExpression,
340    /// Alias for the returned value
341    pub alias: Option<String>,
342}
343
344/// ORDER BY clause
345#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
346pub struct OrderByClause {
347    /// Items to order by
348    pub items: Vec<OrderByItem>,
349}
350
351/// An item in the ORDER BY clause
352#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
353pub struct OrderByItem {
354    /// Expression to order by
355    pub expression: ValueExpression,
356    /// Sort direction
357    pub direction: SortDirection,
358}
359
360/// Sort direction for ORDER BY
361#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
362pub enum SortDirection {
363    Ascending,
364    Descending,
365}
366
367impl NodePattern {
368    /// Create a new node pattern
369    pub fn new(variable: Option<String>) -> Self {
370        Self {
371            variable,
372            labels: Vec::new(),
373            properties: HashMap::new(),
374        }
375    }
376
377    /// Add a label to the node pattern
378    pub fn with_label<S: Into<String>>(mut self, label: S) -> Self {
379        self.labels.push(label.into());
380        self
381    }
382
383    /// Add a property constraint to the node pattern
384    pub fn with_property<S: Into<String>>(mut self, key: S, value: PropertyValue) -> Self {
385        self.properties.insert(key.into(), value);
386        self
387    }
388}
389
390impl RelationshipPattern {
391    /// Create a new relationship pattern
392    pub fn new(direction: RelationshipDirection) -> Self {
393        Self {
394            variable: None,
395            types: Vec::new(),
396            direction,
397            properties: HashMap::new(),
398            length: None,
399        }
400    }
401
402    /// Set the variable name for the relationship
403    pub fn with_variable<S: Into<String>>(mut self, variable: S) -> Self {
404        self.variable = Some(variable.into());
405        self
406    }
407
408    /// Add a type to the relationship pattern
409    pub fn with_type<S: Into<String>>(mut self, rel_type: S) -> Self {
410        self.types.push(rel_type.into());
411        self
412    }
413
414    /// Add a property constraint to the relationship pattern
415    pub fn with_property<S: Into<String>>(mut self, key: S, value: PropertyValue) -> Self {
416        self.properties.insert(key.into(), value);
417        self
418    }
419}
420
421impl PropertyRef {
422    /// Create a new property reference
423    pub fn new<S: Into<String>>(variable: S, property: S) -> Self {
424        Self {
425            variable: variable.into(),
426            property: property.into(),
427        }
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_node_pattern_creation() {
437        let node = NodePattern::new(Some("n".to_string()))
438            .with_label("Person")
439            .with_property("name", PropertyValue::String("John".to_string()));
440
441        assert_eq!(node.variable, Some("n".to_string()));
442        assert_eq!(node.labels, vec!["Person"]);
443        assert_eq!(node.properties.len(), 1);
444    }
445
446    #[test]
447    fn test_relationship_pattern_creation() {
448        let rel = RelationshipPattern::new(RelationshipDirection::Outgoing)
449            .with_variable("r")
450            .with_type("KNOWS");
451
452        assert_eq!(rel.variable, Some("r".to_string()));
453        assert_eq!(rel.types, vec!["KNOWS"]);
454        assert_eq!(rel.direction, RelationshipDirection::Outgoing);
455    }
456
457    #[test]
458    fn test_property_ref() {
459        let prop_ref = PropertyRef::new("n", "name");
460        assert_eq!(prop_ref.variable, "n");
461        assert_eq!(prop_ref.property, "name");
462    }
463}