lance_graph/
ast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Abstract Syntax Tree for Cypher queries
5//!
6//! This module defines the AST nodes for representing parsed Cypher queries.
7//! The AST is designed to capture the essential graph patterns while being
8//! simple enough to translate to SQL efficiently.
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// A complete Cypher query
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct CypherQuery {
16    /// MATCH clauses
17    pub match_clauses: Vec<MatchClause>,
18    /// WHERE clause (optional)
19    pub where_clause: Option<WhereClause>,
20    /// RETURN clause
21    pub return_clause: ReturnClause,
22    /// LIMIT clause (optional)
23    pub limit: Option<u64>,
24    /// ORDER BY clause (optional)
25    pub order_by: Option<OrderByClause>,
26    /// SKIP/OFFSET clause (optional)
27    pub skip: Option<u64>,
28}
29
30impl CypherQuery {
31    /// Extract all node labels referenced in the query
32    pub fn get_node_labels(&self) -> Vec<String> {
33        let mut labels = Vec::new();
34        for match_clause in &self.match_clauses {
35            for pattern in &match_clause.patterns {
36                match pattern {
37                    GraphPattern::Node(node) => {
38                        for label in &node.labels {
39                            if !labels.contains(label) {
40                                labels.push(label.clone());
41                            }
42                        }
43                    }
44                    GraphPattern::Path(path) => {
45                        for label in &path.start_node.labels {
46                            if !labels.contains(label) {
47                                labels.push(label.clone());
48                            }
49                        }
50                        for segment in &path.segments {
51                            for label in &segment.end_node.labels {
52                                if !labels.contains(label) {
53                                    labels.push(label.clone());
54                                }
55                            }
56                        }
57                    }
58                }
59            }
60        }
61        labels
62    }
63
64    /// Extract all relationship types referenced in the query
65    pub fn get_relationship_types(&self) -> Vec<String> {
66        let mut types = Vec::new();
67        for match_clause in &self.match_clauses {
68            for pattern in &match_clause.patterns {
69                if let GraphPattern::Path(path) = pattern {
70                    for segment in &path.segments {
71                        for rel_type in &segment.relationship.types {
72                            if !types.contains(rel_type) {
73                                types.push(rel_type.clone());
74                            }
75                        }
76                    }
77                }
78            }
79        }
80        types
81    }
82}
83
84/// A MATCH clause containing graph patterns
85#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct MatchClause {
87    /// Graph patterns to match
88    pub patterns: Vec<GraphPattern>,
89}
90
91/// A graph pattern (nodes and relationships)
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub enum GraphPattern {
94    /// A single node pattern
95    Node(NodePattern),
96    /// A path pattern (node-relationship-node sequence)
97    Path(PathPattern),
98}
99
100/// A node pattern in a graph query
101#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
102pub struct NodePattern {
103    /// Variable name for the node (e.g., 'n' in (n:Person))
104    pub variable: Option<String>,
105    /// Node labels (e.g., ['Person', 'Employee'])
106    pub labels: Vec<String>,
107    /// Property constraints (e.g., {name: 'John', age: 30})
108    pub properties: HashMap<String, PropertyValue>,
109}
110
111/// A path pattern connecting nodes through relationships
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
113pub struct PathPattern {
114    /// Starting node
115    pub start_node: NodePattern,
116    /// Relationships and intermediate nodes
117    pub segments: Vec<PathSegment>,
118}
119
120/// A segment of a path (relationship + end node)
121#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
122pub struct PathSegment {
123    /// The relationship in this segment
124    pub relationship: RelationshipPattern,
125    /// The end node of this segment
126    pub end_node: NodePattern,
127}
128
129/// A relationship pattern
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131pub struct RelationshipPattern {
132    /// Variable name for the relationship (e.g., 'r' in [r:KNOWS])
133    pub variable: Option<String>,
134    /// Relationship types (e.g., ['KNOWS', 'FRIEND_OF'])
135    pub types: Vec<String>,
136    /// Direction of the relationship
137    pub direction: RelationshipDirection,
138    /// Property constraints on the relationship
139    pub properties: HashMap<String, PropertyValue>,
140    /// Length constraints (for variable-length paths)
141    pub length: Option<LengthRange>,
142}
143
144/// Direction of a relationship
145#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
146pub enum RelationshipDirection {
147    /// Outgoing relationship (->)
148    Outgoing,
149    /// Incoming relationship (<-)
150    Incoming,
151    /// Undirected relationship (-)
152    Undirected,
153}
154
155/// Length range for variable-length paths
156#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
157pub struct LengthRange {
158    /// Minimum length (inclusive)
159    pub min: Option<u32>,
160    /// Maximum length (inclusive)
161    pub max: Option<u32>,
162}
163
164/// Property value in patterns and expressions
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub enum PropertyValue {
167    /// String literal
168    String(String),
169    /// Integer literal
170    Integer(i64),
171    /// Float literal
172    Float(f64),
173    /// Boolean literal
174    Boolean(bool),
175    /// Null value
176    Null,
177    /// Parameter reference (e.g., $param)
178    Parameter(String),
179    /// Property reference (e.g., node.property)
180    Property(PropertyRef),
181}
182
183/// Reference to a property of a node or relationship
184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
185pub struct PropertyRef {
186    /// Variable name (e.g., 'n' in n.name)
187    pub variable: String,
188    /// Property name
189    pub property: String,
190}
191
192/// WHERE clause for filtering
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
194pub struct WhereClause {
195    /// Boolean expression for filtering
196    pub expression: BooleanExpression,
197}
198
199/// Boolean expressions in WHERE clauses
200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
201pub enum BooleanExpression {
202    /// Comparison operation (=, <>, <, >, <=, >=)
203    Comparison {
204        left: ValueExpression,
205        operator: ComparisonOperator,
206        right: ValueExpression,
207    },
208    /// Logical AND
209    And(Box<BooleanExpression>, Box<BooleanExpression>),
210    /// Logical OR
211    Or(Box<BooleanExpression>, Box<BooleanExpression>),
212    /// Logical NOT
213    Not(Box<BooleanExpression>),
214    /// Property existence check
215    Exists(PropertyRef),
216    /// IN clause
217    In {
218        expression: ValueExpression,
219        list: Vec<ValueExpression>,
220    },
221    /// LIKE pattern matching
222    Like {
223        expression: ValueExpression,
224        pattern: String,
225    },
226    /// IS NULL pattern matching
227    IsNull(ValueExpression),
228    /// IS NOT NULL pattern matching
229    IsNotNull(ValueExpression),
230}
231
232/// Comparison operators
233#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
234pub enum ComparisonOperator {
235    Equal,
236    NotEqual,
237    LessThan,
238    LessThanOrEqual,
239    GreaterThan,
240    GreaterThanOrEqual,
241}
242
243/// Value expressions (for comparisons and return values)
244#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
245pub enum ValueExpression {
246    /// Variable reference
247    Variable(String),
248    /// Property reference
249    Property(PropertyRef),
250    /// Literal value
251    Literal(PropertyValue),
252    /// Function call
253    Function {
254        name: String,
255        args: Vec<ValueExpression>,
256    },
257    /// Arithmetic operation
258    Arithmetic {
259        left: Box<ValueExpression>,
260        operator: ArithmeticOperator,
261        right: Box<ValueExpression>,
262    },
263}
264
265/// Arithmetic operators
266#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
267pub enum ArithmeticOperator {
268    Add,
269    Subtract,
270    Multiply,
271    Divide,
272    Modulo,
273}
274
275/// RETURN clause specifying what to return
276#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
277pub struct ReturnClause {
278    /// Whether DISTINCT was specified
279    pub distinct: bool,
280    /// Items to return
281    pub items: Vec<ReturnItem>,
282}
283
284/// An item in the RETURN clause
285#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
286pub struct ReturnItem {
287    /// The expression to return
288    pub expression: ValueExpression,
289    /// Alias for the returned value
290    pub alias: Option<String>,
291}
292
293/// ORDER BY clause
294#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
295pub struct OrderByClause {
296    /// Items to order by
297    pub items: Vec<OrderByItem>,
298}
299
300/// An item in the ORDER BY clause
301#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
302pub struct OrderByItem {
303    /// Expression to order by
304    pub expression: ValueExpression,
305    /// Sort direction
306    pub direction: SortDirection,
307}
308
309/// Sort direction for ORDER BY
310#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
311pub enum SortDirection {
312    Ascending,
313    Descending,
314}
315
316impl NodePattern {
317    /// Create a new node pattern
318    pub fn new(variable: Option<String>) -> Self {
319        Self {
320            variable,
321            labels: Vec::new(),
322            properties: HashMap::new(),
323        }
324    }
325
326    /// Add a label to the node pattern
327    pub fn with_label<S: Into<String>>(mut self, label: S) -> Self {
328        self.labels.push(label.into());
329        self
330    }
331
332    /// Add a property constraint to the node pattern
333    pub fn with_property<S: Into<String>>(mut self, key: S, value: PropertyValue) -> Self {
334        self.properties.insert(key.into(), value);
335        self
336    }
337}
338
339impl RelationshipPattern {
340    /// Create a new relationship pattern
341    pub fn new(direction: RelationshipDirection) -> Self {
342        Self {
343            variable: None,
344            types: Vec::new(),
345            direction,
346            properties: HashMap::new(),
347            length: None,
348        }
349    }
350
351    /// Set the variable name for the relationship
352    pub fn with_variable<S: Into<String>>(mut self, variable: S) -> Self {
353        self.variable = Some(variable.into());
354        self
355    }
356
357    /// Add a type to the relationship pattern
358    pub fn with_type<S: Into<String>>(mut self, rel_type: S) -> Self {
359        self.types.push(rel_type.into());
360        self
361    }
362
363    /// Add a property constraint to the relationship pattern
364    pub fn with_property<S: Into<String>>(mut self, key: S, value: PropertyValue) -> Self {
365        self.properties.insert(key.into(), value);
366        self
367    }
368}
369
370impl PropertyRef {
371    /// Create a new property reference
372    pub fn new<S: Into<String>>(variable: S, property: S) -> Self {
373        Self {
374            variable: variable.into(),
375            property: property.into(),
376        }
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_node_pattern_creation() {
386        let node = NodePattern::new(Some("n".to_string()))
387            .with_label("Person")
388            .with_property("name", PropertyValue::String("John".to_string()));
389
390        assert_eq!(node.variable, Some("n".to_string()));
391        assert_eq!(node.labels, vec!["Person"]);
392        assert_eq!(node.properties.len(), 1);
393    }
394
395    #[test]
396    fn test_relationship_pattern_creation() {
397        let rel = RelationshipPattern::new(RelationshipDirection::Outgoing)
398            .with_variable("r")
399            .with_type("KNOWS");
400
401        assert_eq!(rel.variable, Some("r".to_string()));
402        assert_eq!(rel.types, vec!["KNOWS"]);
403        assert_eq!(rel.direction, RelationshipDirection::Outgoing);
404    }
405
406    #[test]
407    fn test_property_ref() {
408        let prop_ref = PropertyRef::new("n", "name");
409        assert_eq!(prop_ref.variable, "n");
410        assert_eq!(prop_ref.property, "name");
411    }
412}