arrow_graph/sql/
gql_parser.rs

1use std::collections::HashMap;
2use crate::error::{GraphError, Result};
3
4/// GQL pattern parser for basic graph pattern matching
5/// Supports patterns like: MATCH (a)-[r]->(b) WHERE a.type = 'user'
6#[derive(Debug, Clone)]
7pub struct GqlParser;
8
9/// Represents a node in a GQL pattern
10#[derive(Debug, Clone, PartialEq)]
11pub struct GqlNode {
12    pub variable: String,
13    pub label: Option<String>,
14    pub properties: HashMap<String, String>,
15}
16
17/// Represents an edge in a GQL pattern
18#[derive(Debug, Clone, PartialEq)]
19pub struct GqlEdge {
20    pub variable: Option<String>,
21    pub label: Option<String>,
22    pub direction: EdgeDirection,
23    pub properties: HashMap<String, String>,
24}
25
26/// Edge direction in GQL patterns
27#[derive(Debug, Clone, PartialEq)]
28pub enum EdgeDirection {
29    Outgoing,  // -->
30    Incoming,  // <--
31    Undirected, // --
32}
33
34/// Represents a complete GQL pattern
35#[derive(Debug, Clone)]
36pub struct GqlPattern {
37    pub nodes: Vec<GqlNode>,
38    pub edges: Vec<GqlEdge>,
39    pub where_conditions: Vec<WhereCondition>,
40}
41
42/// WHERE clause conditions
43#[derive(Debug, Clone, PartialEq)]
44pub struct WhereCondition {
45    pub variable: String,
46    pub property: String,
47    pub operator: ComparisonOperator,
48    pub value: String,
49}
50
51#[derive(Debug, Clone, PartialEq)]
52pub enum ComparisonOperator {
53    Equal,
54    NotEqual,
55    GreaterThan,
56    LessThan,
57    GreaterThanOrEqual,
58    LessThanOrEqual,
59    Contains,
60}
61
62impl GqlParser {
63    pub fn new() -> Self {
64        Self
65    }
66
67    /// Parse a GQL MATCH pattern
68    /// Example: "MATCH (a:User)-[r:FOLLOWS]->(b:User) WHERE a.name = 'Alice'"
69    pub fn parse_match_pattern(&self, pattern: &str) -> Result<GqlPattern> {
70        let pattern = pattern.trim();
71        
72        // Split MATCH and WHERE clauses
73        let (match_part, where_part) = if pattern.to_uppercase().contains("WHERE") {
74            let parts: Vec<&str> = pattern.splitn(2, "WHERE").collect();
75            if parts.len() == 2 {
76                (parts[0].trim(), Some(parts[1].trim()))
77            } else {
78                (pattern, None)
79            }
80        } else {
81            (pattern, None)
82        };
83
84        // Remove MATCH keyword
85        let match_part = if match_part.to_uppercase().starts_with("MATCH") {
86            match_part[5..].trim()
87        } else {
88            match_part
89        };
90
91        // Parse the graph pattern
92        let (nodes, edges) = self.parse_graph_pattern(match_part)?;
93        
94        // Parse WHERE conditions
95        let where_conditions = if let Some(where_str) = where_part {
96            self.parse_where_conditions(where_str)?
97        } else {
98            Vec::new()
99        };
100
101        Ok(GqlPattern {
102            nodes,
103            edges,
104            where_conditions,
105        })
106    }
107
108    /// Parse the graph pattern part: (a:User)-[r:FOLLOWS]->(b:User)
109    fn parse_graph_pattern(&self, pattern: &str) -> Result<(Vec<GqlNode>, Vec<GqlEdge>)> {
110        let mut nodes = Vec::new();
111        let mut edges = Vec::new();
112        
113        // Simple regex-like parsing for basic patterns
114        // This is a simplified implementation - a full parser would use a proper grammar
115        
116        let pattern = pattern.trim();
117        
118        // Example pattern: (a:User)-[r:FOLLOWS]->(b:User)
119        // We'll parse this step by step
120        
121        if pattern.starts_with('(') {
122            // Find node patterns and edge patterns
123            let chars: Vec<char> = pattern.chars().collect();
124            let mut i = 0;
125            
126            while i < chars.len() {
127                if chars[i] == '(' {
128                    // Parse node
129                    let (node, end_pos) = self.parse_node_pattern(&chars, i)?;
130                    nodes.push(node);
131                    i = end_pos + 1;
132                } else if chars[i] == '-' || chars[i] == '<' {
133                    // Parse edge
134                    let (edge, end_pos) = self.parse_edge_pattern(&chars, i)?;
135                    edges.push(edge);
136                    i = end_pos + 1;
137                } else {
138                    i += 1;
139                }
140            }
141        } else {
142            return Err(GraphError::sql_parsing("Invalid GQL pattern: must start with node"));
143        }
144
145        Ok((nodes, edges))
146    }
147
148    /// Parse a node pattern: (a:User {name: 'Alice'})
149    fn parse_node_pattern(&self, chars: &[char], start: usize) -> Result<(GqlNode, usize)> {
150        if start >= chars.len() || chars[start] != '(' {
151            return Err(GraphError::sql_parsing("Node pattern must start with ("));
152        }
153
154        let mut i = start + 1;
155        let mut content = String::new();
156        let mut paren_count = 1;
157
158        // Find the matching closing parenthesis
159        while i < chars.len() && paren_count > 0 {
160            if chars[i] == '(' {
161                paren_count += 1;
162            } else if chars[i] == ')' {
163                paren_count -= 1;
164            }
165            
166            if paren_count > 0 {
167                content.push(chars[i]);
168            }
169            i += 1;
170        }
171
172        if paren_count > 0 {
173            return Err(GraphError::sql_parsing("Unclosed node pattern"));
174        }
175
176        // Parse node content: variable:label {properties}
177        let (variable, label, properties) = self.parse_node_content(&content)?;
178
179        Ok((GqlNode {
180            variable,
181            label,
182            properties,
183        }, i - 1))
184    }
185
186    /// Parse node content: a:User {name: 'Alice'}
187    fn parse_node_content(&self, content: &str) -> Result<(String, Option<String>, HashMap<String, String>)> {
188        let content = content.trim();
189        
190        // Split by { to separate variable:label from properties
191        let (var_label_part, properties_part) = if content.contains('{') {
192            let parts: Vec<&str> = content.splitn(2, '{').collect();
193            (parts[0].trim(), Some(parts[1].trim_end_matches('}').trim()))
194        } else {
195            (content, None)
196        };
197
198        // Parse variable and label
199        let (variable, label) = if var_label_part.contains(':') {
200            let parts: Vec<&str> = var_label_part.splitn(2, ':').collect();
201            (parts[0].trim().to_string(), Some(parts[1].trim().to_string()))
202        } else {
203            (var_label_part.to_string(), None)
204        };
205
206        // Parse properties (simplified)
207        let properties = if let Some(props_str) = properties_part {
208            self.parse_properties(props_str)?
209        } else {
210            HashMap::new()
211        };
212
213        Ok((variable, label, properties))
214    }
215
216    /// Parse edge pattern: -[r:FOLLOWS]->
217    fn parse_edge_pattern(&self, chars: &[char], start: usize) -> Result<(GqlEdge, usize)> {
218        let mut i = start;
219        let mut direction = EdgeDirection::Undirected;
220        let mut edge_content = String::new();
221
222        // Determine direction
223        if i < chars.len() && chars[i] == '<' {
224            direction = EdgeDirection::Incoming;
225            i += 1;
226        }
227
228        // Skip dashes
229        while i < chars.len() && chars[i] == '-' {
230            i += 1;
231        }
232
233        // Parse edge content if present
234        if i < chars.len() && chars[i] == '[' {
235            i += 1; // Skip [
236            while i < chars.len() && chars[i] != ']' {
237                edge_content.push(chars[i]);
238                i += 1;
239            }
240            if i < chars.len() {
241                i += 1; // Skip ]
242            }
243        }
244
245        // Skip more dashes
246        while i < chars.len() && chars[i] == '-' {
247            i += 1;
248        }
249
250        // Check for outgoing direction
251        if i < chars.len() && chars[i] == '>' {
252            if direction == EdgeDirection::Incoming {
253                return Err(GraphError::sql_parsing("Invalid edge direction: cannot be both incoming and outgoing"));
254            }
255            direction = EdgeDirection::Outgoing;
256            i += 1;
257        }
258
259        // Parse edge content
260        let (variable, label, properties) = if edge_content.is_empty() {
261            (None, None, HashMap::new())
262        } else {
263            let (var, lab, props) = self.parse_node_content(&edge_content)?;
264            (Some(var), lab, props)
265        };
266
267        Ok((GqlEdge {
268            variable,
269            label,
270            direction,
271            properties,
272        }, i - 1))
273    }
274
275    /// Parse WHERE conditions: a.name = 'Alice' AND b.age > 25
276    fn parse_where_conditions(&self, where_str: &str) -> Result<Vec<WhereCondition>> {
277        let mut conditions = Vec::new();
278        
279        // Split by AND/OR (simplified - just AND for now)
280        let condition_parts: Vec<&str> = where_str.split(" AND ").collect();
281        
282        for condition_str in condition_parts {
283            let condition = self.parse_single_condition(condition_str.trim())?;
284            conditions.push(condition);
285        }
286
287        Ok(conditions)
288    }
289
290    /// Parse a single condition: a.name = 'Alice'
291    fn parse_single_condition(&self, condition: &str) -> Result<WhereCondition> {
292        // Find the operator
293        let operators = [">=", "<=", "!=", "=", ">", "<"];
294        
295        for op_str in &operators {
296            if condition.contains(op_str) {
297                let parts: Vec<&str> = condition.splitn(2, op_str).collect();
298                if parts.len() == 2 {
299                    let left = parts[0].trim();
300                    let right = parts[1].trim().trim_matches('\'').trim_matches('"');
301                    
302                    // Parse left side (variable.property)
303                    if let Some(dot_pos) = left.find('.') {
304                        let variable = left[..dot_pos].to_string();
305                        let property = left[dot_pos + 1..].to_string();
306                        
307                        let operator = match *op_str {
308                            "=" => ComparisonOperator::Equal,
309                            "!=" => ComparisonOperator::NotEqual,
310                            ">" => ComparisonOperator::GreaterThan,
311                            "<" => ComparisonOperator::LessThan,
312                            ">=" => ComparisonOperator::GreaterThanOrEqual,
313                            "<=" => ComparisonOperator::LessThanOrEqual,
314                            _ => return Err(GraphError::sql_parsing(&format!("Unknown operator: {}", op_str))),
315                        };
316                        
317                        return Ok(WhereCondition {
318                            variable,
319                            property,
320                            operator,
321                            value: right.to_string(),
322                        });
323                    }
324                }
325            }
326        }
327
328        Err(GraphError::sql_parsing(&format!("Invalid WHERE condition: {}", condition)))
329    }
330
331    /// Parse properties: name: 'Alice', age: 25
332    fn parse_properties(&self, props_str: &str) -> Result<HashMap<String, String>> {
333        let mut properties = HashMap::new();
334        
335        // Split by comma
336        let prop_parts: Vec<&str> = props_str.split(',').collect();
337        
338        for prop_str in prop_parts {
339            let prop_str = prop_str.trim();
340            if prop_str.contains(':') {
341                let parts: Vec<&str> = prop_str.splitn(2, ':').collect();
342                if parts.len() == 2 {
343                    let key = parts[0].trim().to_string();
344                    let value = parts[1].trim().trim_matches('\'').trim_matches('"').to_string();
345                    properties.insert(key, value);
346                }
347            }
348        }
349
350        Ok(properties)
351    }
352}
353
354impl Default for GqlParser {
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_parse_simple_node_pattern() {
366        let parser = GqlParser::new();
367        let pattern = "MATCH (a:User)";
368        let result = parser.parse_match_pattern(pattern).unwrap();
369        
370        assert_eq!(result.nodes.len(), 1);
371        assert_eq!(result.nodes[0].variable, "a");
372        assert_eq!(result.nodes[0].label, Some("User".to_string()));
373        assert!(result.edges.is_empty());
374        assert!(result.where_conditions.is_empty());
375    }
376
377    #[test]
378    fn test_parse_simple_edge_pattern() {
379        let parser = GqlParser::new();
380        let pattern = "MATCH (a)-[r]->(b)";
381        let result = parser.parse_match_pattern(pattern).unwrap();
382        
383        assert_eq!(result.nodes.len(), 2);
384        assert_eq!(result.edges.len(), 1);
385        assert_eq!(result.edges[0].variable, Some("r".to_string()));
386        assert_eq!(result.edges[0].direction, EdgeDirection::Outgoing);
387    }
388
389    #[test]
390    fn test_parse_with_where_condition() {
391        let parser = GqlParser::new();
392        let pattern = "MATCH (a:User) WHERE a.name = 'Alice'";
393        let result = parser.parse_match_pattern(pattern).unwrap();
394        
395        assert_eq!(result.where_conditions.len(), 1);
396        assert_eq!(result.where_conditions[0].variable, "a");
397        assert_eq!(result.where_conditions[0].property, "name");
398        assert_eq!(result.where_conditions[0].value, "Alice");
399        assert_eq!(result.where_conditions[0].operator, ComparisonOperator::Equal);
400    }
401
402    #[test]
403    fn test_parse_complex_pattern() {
404        let parser = GqlParser::new();
405        let pattern = "MATCH (a:User)-[r:FOLLOWS]->(b:User) WHERE a.name = 'Alice' AND b.age > 25";
406        let result = parser.parse_match_pattern(pattern).unwrap();
407        
408        assert_eq!(result.nodes.len(), 2);
409        assert_eq!(result.edges.len(), 1);
410        assert_eq!(result.where_conditions.len(), 2);
411        
412        // Check nodes
413        assert_eq!(result.nodes[0].variable, "a");
414        assert_eq!(result.nodes[0].label, Some("User".to_string()));
415        assert_eq!(result.nodes[1].variable, "b");
416        assert_eq!(result.nodes[1].label, Some("User".to_string()));
417        
418        // Check edge
419        assert_eq!(result.edges[0].variable, Some("r".to_string()));
420        assert_eq!(result.edges[0].label, Some("FOLLOWS".to_string()));
421        assert_eq!(result.edges[0].direction, EdgeDirection::Outgoing);
422    }
423
424    #[test]
425    fn test_parse_invalid_pattern() {
426        let parser = GqlParser::new();
427        let pattern = "INVALID PATTERN";
428        let result = parser.parse_match_pattern(pattern);
429        assert!(result.is_err());
430    }
431}