Skip to main content

lance_graph/
ast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Abstract Syntax Tree for Cypher queries
5//!
6//! This module defines the AST nodes for representing parsed Cypher queries.
7//! The AST is designed to capture the essential graph patterns while being
8//! simple enough to translate to SQL efficiently.
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// A complete Cypher query
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct CypherQuery {
16    /// READING clauses (MATCH, UNWIND, etc.)
17    pub reading_clauses: Vec<ReadingClause>,
18    /// WHERE clause (optional, before WITH if present)
19    pub where_clause: Option<WhereClause>,
20    /// WITH clause (optional) - intermediate projection/aggregation
21    pub with_clause: Option<WithClause>,
22    /// Post-WITH READING clauses
23    pub post_with_reading_clauses: Vec<ReadingClause>,
24    /// WHERE clause after WITH (optional) - filters the WITH results
25    pub post_with_where_clause: Option<WhereClause>,
26    /// RETURN clause
27    pub return_clause: ReturnClause,
28    /// LIMIT clause (optional)
29    pub limit: Option<u64>,
30    /// ORDER BY clause (optional)
31    pub order_by: Option<OrderByClause>,
32    /// SKIP/OFFSET clause (optional)
33    pub skip: Option<u64>,
34}
35
36impl CypherQuery {
37    /// Extract all node labels referenced in the query
38    pub fn get_node_labels(&self) -> Vec<String> {
39        let mut labels = Vec::new();
40        // Iterate all match clauses directly
41        for clause in &self.reading_clauses {
42            if let ReadingClause::Match(match_clause) = clause {
43                for pattern in &match_clause.patterns {
44                    match pattern {
45                        GraphPattern::Node(node) => {
46                            for label in &node.labels {
47                                if !labels.contains(label) {
48                                    labels.push(label.clone());
49                                }
50                            }
51                        }
52                        GraphPattern::Path(path) => {
53                            for label in &path.start_node.labels {
54                                if !labels.contains(label) {
55                                    labels.push(label.clone());
56                                }
57                            }
58                            for segment in &path.segments {
59                                for label in &segment.end_node.labels {
60                                    if !labels.contains(label) {
61                                        labels.push(label.clone());
62                                    }
63                                }
64                            }
65                        }
66                    }
67                }
68            }
69        }
70        labels
71    }
72
73    /// Extract all relationship types referenced in the query
74    pub fn get_relationship_types(&self) -> Vec<String> {
75        let mut types = Vec::new();
76        for clause in &self.reading_clauses {
77            if let ReadingClause::Match(match_clause) = clause {
78                for pattern in &match_clause.patterns {
79                    self.collect_relationship_types_from_pattern(pattern, &mut types);
80                }
81            }
82        }
83        types
84    }
85
86    fn collect_relationship_types_from_pattern(
87        &self,
88        pattern: &GraphPattern,
89        types: &mut Vec<String>,
90    ) {
91        if let GraphPattern::Path(path) = pattern {
92            for segment in &path.segments {
93                for rel_type in &segment.relationship.types {
94                    if !types.contains(rel_type) {
95                        types.push(rel_type.clone());
96                    }
97                }
98            }
99        }
100    }
101}
102
103/// A clause that reads from the graph (MATCH, UNWIND)
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub enum ReadingClause {
106    Match(MatchClause),
107    Unwind(UnwindClause),
108}
109
110/// A MATCH clause containing graph patterns
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
112pub struct MatchClause {
113    /// Graph patterns to match
114    pub patterns: Vec<GraphPattern>,
115}
116
117/// An UNWIND clause
118#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
119pub struct UnwindClause {
120    /// Expression to unwind
121    pub expression: ValueExpression,
122    /// Alias for the unwound values
123    pub alias: String,
124}
125
126/// A graph pattern (nodes and relationships)
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
128pub enum GraphPattern {
129    /// A single node pattern
130    Node(NodePattern),
131    /// A path pattern (node-relationship-node sequence)
132    Path(PathPattern),
133}
134
135/// A node pattern in a graph query
136#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
137pub struct NodePattern {
138    /// Variable name for the node (e.g., 'n' in (n:Person))
139    pub variable: Option<String>,
140    /// Node labels (e.g., ['Person', 'Employee'])
141    pub labels: Vec<String>,
142    /// Property constraints (e.g., {name: 'John', age: 30})
143    pub properties: HashMap<String, PropertyValue>,
144}
145
146/// A path pattern connecting nodes through relationships
147#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
148pub struct PathPattern {
149    /// Starting node
150    pub start_node: NodePattern,
151    /// Relationships and intermediate nodes
152    pub segments: Vec<PathSegment>,
153}
154
155/// A segment of a path (relationship + end node)
156#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
157pub struct PathSegment {
158    /// The relationship in this segment
159    pub relationship: RelationshipPattern,
160    /// The end node of this segment
161    pub end_node: NodePattern,
162}
163
164/// A relationship pattern
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct RelationshipPattern {
167    /// Variable name for the relationship (e.g., 'r' in [r:KNOWS])
168    pub variable: Option<String>,
169    /// Relationship types (e.g., ['KNOWS', 'FRIEND_OF'])
170    pub types: Vec<String>,
171    /// Direction of the relationship
172    pub direction: RelationshipDirection,
173    /// Property constraints on the relationship
174    pub properties: HashMap<String, PropertyValue>,
175    /// Length constraints (for variable-length paths)
176    pub length: Option<LengthRange>,
177}
178
179/// Direction of a relationship
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
181pub enum RelationshipDirection {
182    /// Outgoing relationship (->)
183    Outgoing,
184    /// Incoming relationship (<-)
185    Incoming,
186    /// Undirected relationship (-)
187    Undirected,
188}
189
190/// Length range for variable-length paths
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct LengthRange {
193    /// Minimum length (inclusive)
194    pub min: Option<u32>,
195    /// Maximum length (inclusive)
196    pub max: Option<u32>,
197}
198
199/// Property value in patterns and expressions
200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
201pub enum PropertyValue {
202    /// String literal
203    String(String),
204    /// Integer literal
205    Integer(i64),
206    /// Float literal
207    Float(f64),
208    /// Boolean literal
209    Boolean(bool),
210    /// Null value
211    Null,
212    /// Parameter reference (e.g., $param)
213    Parameter(String),
214    /// Property reference (e.g., node.property)
215    Property(PropertyRef),
216}
217
218/// Reference to a property of a node or relationship
219#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
220pub struct PropertyRef {
221    /// Variable name (e.g., 'n' in n.name)
222    pub variable: String,
223    /// Property name
224    pub property: String,
225}
226
227/// WHERE clause for filtering
228#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
229pub struct WhereClause {
230    /// Boolean expression for filtering
231    pub expression: BooleanExpression,
232}
233
234/// Boolean expressions in WHERE clauses
235#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
236pub enum BooleanExpression {
237    /// Comparison operation (=, <>, <, >, <=, >=)
238    Comparison {
239        left: ValueExpression,
240        operator: ComparisonOperator,
241        right: ValueExpression,
242    },
243    /// Logical AND
244    And(Box<BooleanExpression>, Box<BooleanExpression>),
245    /// Logical OR
246    Or(Box<BooleanExpression>, Box<BooleanExpression>),
247    /// Logical NOT
248    Not(Box<BooleanExpression>),
249    /// Property existence check
250    Exists(PropertyRef),
251    /// IN clause
252    In {
253        expression: ValueExpression,
254        list: Vec<ValueExpression>,
255    },
256    /// LIKE pattern matching (case-sensitive)
257    Like {
258        expression: ValueExpression,
259        pattern: String,
260    },
261    /// ILIKE pattern matching (case-insensitive)
262    ILike {
263        expression: ValueExpression,
264        pattern: String,
265    },
266    /// CONTAINS substring matching
267    Contains {
268        expression: ValueExpression,
269        substring: String,
270    },
271    /// STARTS WITH prefix matching
272    StartsWith {
273        expression: ValueExpression,
274        prefix: String,
275    },
276    /// ENDS WITH suffix matching
277    EndsWith {
278        expression: ValueExpression,
279        suffix: String,
280    },
281    /// IS NULL pattern matching
282    IsNull(ValueExpression),
283    /// IS NOT NULL pattern matching
284    IsNotNull(ValueExpression),
285}
286
287/// Distance metric for vector similarity
288#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
289pub enum DistanceMetric {
290    /// Euclidean distance (L2)
291    L2,
292    /// Cosine similarity (1 - cosine distance)
293    #[default]
294    Cosine,
295    /// Dot product
296    Dot,
297}
298
299/// Comparison operators
300#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
301pub enum ComparisonOperator {
302    Equal,
303    NotEqual,
304    LessThan,
305    LessThanOrEqual,
306    GreaterThan,
307    GreaterThanOrEqual,
308}
309
310/// Value expressions (for comparisons and return values)
311#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
312pub enum ValueExpression {
313    /// Variable reference
314    Variable(String),
315    /// Property reference
316    Property(PropertyRef),
317    /// Literal value
318    Literal(PropertyValue),
319    /// Scalar function call (toLower, upper, etc.)
320    /// These are row-level functions that operate on individual values
321    ScalarFunction {
322        name: String,
323        args: Vec<ValueExpression>,
324    },
325    /// Aggregate function call (COUNT, SUM, AVG, MIN, MAX, COLLECT)
326    /// These functions operate across multiple rows and support DISTINCT
327    AggregateFunction {
328        name: String,
329        args: Vec<ValueExpression>,
330        /// Whether DISTINCT keyword was specified (e.g., COUNT(DISTINCT x))
331        distinct: bool,
332    },
333    /// Arithmetic operation
334    Arithmetic {
335        left: Box<ValueExpression>,
336        operator: ArithmeticOperator,
337        right: Box<ValueExpression>,
338    },
339    /// Vector distance function: vector_distance(left, right, metric)
340    /// Returns the distance as a float (lower = more similar for L2/Cosine)
341    VectorDistance {
342        left: Box<ValueExpression>,
343        right: Box<ValueExpression>,
344        metric: DistanceMetric,
345    },
346    /// Vector similarity function: vector_similarity(left, right, metric)
347    /// Returns the similarity score as a float (higher = more similar)
348    VectorSimilarity {
349        left: Box<ValueExpression>,
350        right: Box<ValueExpression>,
351        metric: DistanceMetric,
352    },
353    /// Parameter reference for query parameters (e.g., $query_vector)
354    Parameter(String),
355    /// Vector literal: [0.1, 0.2, 0.3]
356    /// Represents an inline vector for similarity search
357    VectorLiteral(Vec<f32>),
358}
359
360/// Function type classification
361#[derive(Debug, Clone, PartialEq)]
362pub enum FunctionType {
363    /// Aggregate function (operates across multiple rows)
364    Aggregate,
365    /// Scalar function (operates on individual values)
366    Scalar,
367    /// Unknown function type
368    Unknown,
369}
370
371/// Classify a function by name
372pub fn classify_function(name: &str) -> FunctionType {
373    match name.to_lowercase().as_str() {
374        "count" | "sum" | "avg" | "min" | "max" | "collect" => FunctionType::Aggregate,
375        "tolower" | "lower" | "toupper" | "upper" => FunctionType::Scalar,
376        // Vector functions are handled separately as special variants
377        _ => FunctionType::Unknown,
378    }
379}
380
381/// Arithmetic operators
382#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
383pub enum ArithmeticOperator {
384    Add,
385    Subtract,
386    Multiply,
387    Divide,
388    Modulo,
389}
390
391/// WITH clause for intermediate projections/aggregations
392///
393/// WITH acts as a query stage boundary, projecting results that become
394/// the input for subsequent clauses.
395#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
396pub struct WithClause {
397    /// Items to project (similar to RETURN)
398    pub items: Vec<ReturnItem>,
399    /// Optional ORDER BY within WITH
400    pub order_by: Option<OrderByClause>,
401    /// Optional LIMIT within WITH
402    pub limit: Option<u64>,
403}
404
405/// RETURN clause specifying what to return
406#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
407pub struct ReturnClause {
408    /// Whether DISTINCT was specified
409    pub distinct: bool,
410    /// Items to return
411    pub items: Vec<ReturnItem>,
412}
413
414/// An item in the RETURN clause
415#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
416pub struct ReturnItem {
417    /// The expression to return
418    pub expression: ValueExpression,
419    /// Alias for the returned value
420    pub alias: Option<String>,
421}
422
423/// ORDER BY clause
424#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
425pub struct OrderByClause {
426    /// Items to order by
427    pub items: Vec<OrderByItem>,
428}
429
430/// An item in the ORDER BY clause
431#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
432pub struct OrderByItem {
433    /// Expression to order by
434    pub expression: ValueExpression,
435    /// Sort direction
436    pub direction: SortDirection,
437}
438
439/// Sort direction for ORDER BY
440#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
441pub enum SortDirection {
442    Ascending,
443    Descending,
444}
445
446impl NodePattern {
447    /// Create a new node pattern
448    pub fn new(variable: Option<String>) -> Self {
449        Self {
450            variable,
451            labels: Vec::new(),
452            properties: HashMap::new(),
453        }
454    }
455
456    /// Add a label to the node pattern
457    pub fn with_label<S: Into<String>>(mut self, label: S) -> Self {
458        self.labels.push(label.into());
459        self
460    }
461
462    /// Add a property constraint to the node pattern
463    pub fn with_property<S: Into<String>>(mut self, key: S, value: PropertyValue) -> Self {
464        self.properties.insert(key.into(), value);
465        self
466    }
467}
468
469impl RelationshipPattern {
470    /// Create a new relationship pattern
471    pub fn new(direction: RelationshipDirection) -> Self {
472        Self {
473            variable: None,
474            types: Vec::new(),
475            direction,
476            properties: HashMap::new(),
477            length: None,
478        }
479    }
480
481    /// Set the variable name for the relationship
482    pub fn with_variable<S: Into<String>>(mut self, variable: S) -> Self {
483        self.variable = Some(variable.into());
484        self
485    }
486
487    /// Add a type to the relationship pattern
488    pub fn with_type<S: Into<String>>(mut self, rel_type: S) -> Self {
489        self.types.push(rel_type.into());
490        self
491    }
492
493    /// Add a property constraint to the relationship pattern
494    pub fn with_property<S: Into<String>>(mut self, key: S, value: PropertyValue) -> Self {
495        self.properties.insert(key.into(), value);
496        self
497    }
498}
499
500impl PropertyRef {
501    /// Create a new property reference
502    pub fn new<S: Into<String>>(variable: S, property: S) -> Self {
503        Self {
504            variable: variable.into(),
505            property: property.into(),
506        }
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_node_pattern_creation() {
516        let node = NodePattern::new(Some("n".to_string()))
517            .with_label("Person")
518            .with_property("name", PropertyValue::String("John".to_string()));
519
520        assert_eq!(node.variable, Some("n".to_string()));
521        assert_eq!(node.labels, vec!["Person"]);
522        assert_eq!(node.properties.len(), 1);
523    }
524
525    #[test]
526    fn test_relationship_pattern_creation() {
527        let rel = RelationshipPattern::new(RelationshipDirection::Outgoing)
528            .with_variable("r")
529            .with_type("KNOWS");
530
531        assert_eq!(rel.variable, Some("r".to_string()));
532        assert_eq!(rel.types, vec!["KNOWS"]);
533        assert_eq!(rel.direction, RelationshipDirection::Outgoing);
534    }
535
536    #[test]
537    fn test_property_ref() {
538        let prop_ref = PropertyRef::new("n", "name");
539        assert_eq!(prop_ref.variable, "n");
540        assert_eq!(prop_ref.property, "name");
541    }
542}