1use std::collections::HashMap;
2use crate::error::{GraphError, Result};
3
4#[derive(Debug, Clone)]
7pub struct GqlParser;
8
9#[derive(Debug, Clone, PartialEq)]
11pub struct GqlNode {
12 pub variable: String,
13 pub label: Option<String>,
14 pub properties: HashMap<String, String>,
15}
16
17#[derive(Debug, Clone, PartialEq)]
19pub struct GqlEdge {
20 pub variable: Option<String>,
21 pub label: Option<String>,
22 pub direction: EdgeDirection,
23 pub properties: HashMap<String, String>,
24}
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum EdgeDirection {
29 Outgoing, Incoming, Undirected, }
33
34#[derive(Debug, Clone)]
36pub struct GqlPattern {
37 pub nodes: Vec<GqlNode>,
38 pub edges: Vec<GqlEdge>,
39 pub where_conditions: Vec<WhereCondition>,
40}
41
42#[derive(Debug, Clone, PartialEq)]
44pub struct WhereCondition {
45 pub variable: String,
46 pub property: String,
47 pub operator: ComparisonOperator,
48 pub value: String,
49}
50
51#[derive(Debug, Clone, PartialEq)]
52pub enum ComparisonOperator {
53 Equal,
54 NotEqual,
55 GreaterThan,
56 LessThan,
57 GreaterThanOrEqual,
58 LessThanOrEqual,
59 Contains,
60}
61
62impl GqlParser {
63 pub fn new() -> Self {
64 Self
65 }
66
67 pub fn parse_match_pattern(&self, pattern: &str) -> Result<GqlPattern> {
70 let pattern = pattern.trim();
71
72 let (match_part, where_part) = if pattern.to_uppercase().contains("WHERE") {
74 let parts: Vec<&str> = pattern.splitn(2, "WHERE").collect();
75 if parts.len() == 2 {
76 (parts[0].trim(), Some(parts[1].trim()))
77 } else {
78 (pattern, None)
79 }
80 } else {
81 (pattern, None)
82 };
83
84 let match_part = if match_part.to_uppercase().starts_with("MATCH") {
86 match_part[5..].trim()
87 } else {
88 match_part
89 };
90
91 let (nodes, edges) = self.parse_graph_pattern(match_part)?;
93
94 let where_conditions = if let Some(where_str) = where_part {
96 self.parse_where_conditions(where_str)?
97 } else {
98 Vec::new()
99 };
100
101 Ok(GqlPattern {
102 nodes,
103 edges,
104 where_conditions,
105 })
106 }
107
108 fn parse_graph_pattern(&self, pattern: &str) -> Result<(Vec<GqlNode>, Vec<GqlEdge>)> {
110 let mut nodes = Vec::new();
111 let mut edges = Vec::new();
112
113 let pattern = pattern.trim();
117
118 if pattern.starts_with('(') {
122 let chars: Vec<char> = pattern.chars().collect();
124 let mut i = 0;
125
126 while i < chars.len() {
127 if chars[i] == '(' {
128 let (node, end_pos) = self.parse_node_pattern(&chars, i)?;
130 nodes.push(node);
131 i = end_pos + 1;
132 } else if chars[i] == '-' || chars[i] == '<' {
133 let (edge, end_pos) = self.parse_edge_pattern(&chars, i)?;
135 edges.push(edge);
136 i = end_pos + 1;
137 } else {
138 i += 1;
139 }
140 }
141 } else {
142 return Err(GraphError::sql_parsing("Invalid GQL pattern: must start with node"));
143 }
144
145 Ok((nodes, edges))
146 }
147
148 fn parse_node_pattern(&self, chars: &[char], start: usize) -> Result<(GqlNode, usize)> {
150 if start >= chars.len() || chars[start] != '(' {
151 return Err(GraphError::sql_parsing("Node pattern must start with ("));
152 }
153
154 let mut i = start + 1;
155 let mut content = String::new();
156 let mut paren_count = 1;
157
158 while i < chars.len() && paren_count > 0 {
160 if chars[i] == '(' {
161 paren_count += 1;
162 } else if chars[i] == ')' {
163 paren_count -= 1;
164 }
165
166 if paren_count > 0 {
167 content.push(chars[i]);
168 }
169 i += 1;
170 }
171
172 if paren_count > 0 {
173 return Err(GraphError::sql_parsing("Unclosed node pattern"));
174 }
175
176 let (variable, label, properties) = self.parse_node_content(&content)?;
178
179 Ok((GqlNode {
180 variable,
181 label,
182 properties,
183 }, i - 1))
184 }
185
186 fn parse_node_content(&self, content: &str) -> Result<(String, Option<String>, HashMap<String, String>)> {
188 let content = content.trim();
189
190 let (var_label_part, properties_part) = if content.contains('{') {
192 let parts: Vec<&str> = content.splitn(2, '{').collect();
193 (parts[0].trim(), Some(parts[1].trim_end_matches('}').trim()))
194 } else {
195 (content, None)
196 };
197
198 let (variable, label) = if var_label_part.contains(':') {
200 let parts: Vec<&str> = var_label_part.splitn(2, ':').collect();
201 (parts[0].trim().to_string(), Some(parts[1].trim().to_string()))
202 } else {
203 (var_label_part.to_string(), None)
204 };
205
206 let properties = if let Some(props_str) = properties_part {
208 self.parse_properties(props_str)?
209 } else {
210 HashMap::new()
211 };
212
213 Ok((variable, label, properties))
214 }
215
216 fn parse_edge_pattern(&self, chars: &[char], start: usize) -> Result<(GqlEdge, usize)> {
218 let mut i = start;
219 let mut direction = EdgeDirection::Undirected;
220 let mut edge_content = String::new();
221
222 if i < chars.len() && chars[i] == '<' {
224 direction = EdgeDirection::Incoming;
225 i += 1;
226 }
227
228 while i < chars.len() && chars[i] == '-' {
230 i += 1;
231 }
232
233 if i < chars.len() && chars[i] == '[' {
235 i += 1; while i < chars.len() && chars[i] != ']' {
237 edge_content.push(chars[i]);
238 i += 1;
239 }
240 if i < chars.len() {
241 i += 1; }
243 }
244
245 while i < chars.len() && chars[i] == '-' {
247 i += 1;
248 }
249
250 if i < chars.len() && chars[i] == '>' {
252 if direction == EdgeDirection::Incoming {
253 return Err(GraphError::sql_parsing("Invalid edge direction: cannot be both incoming and outgoing"));
254 }
255 direction = EdgeDirection::Outgoing;
256 i += 1;
257 }
258
259 let (variable, label, properties) = if edge_content.is_empty() {
261 (None, None, HashMap::new())
262 } else {
263 let (var, lab, props) = self.parse_node_content(&edge_content)?;
264 (Some(var), lab, props)
265 };
266
267 Ok((GqlEdge {
268 variable,
269 label,
270 direction,
271 properties,
272 }, i - 1))
273 }
274
275 fn parse_where_conditions(&self, where_str: &str) -> Result<Vec<WhereCondition>> {
277 let mut conditions = Vec::new();
278
279 let condition_parts: Vec<&str> = where_str.split(" AND ").collect();
281
282 for condition_str in condition_parts {
283 let condition = self.parse_single_condition(condition_str.trim())?;
284 conditions.push(condition);
285 }
286
287 Ok(conditions)
288 }
289
290 fn parse_single_condition(&self, condition: &str) -> Result<WhereCondition> {
292 let operators = [">=", "<=", "!=", "=", ">", "<"];
294
295 for op_str in &operators {
296 if condition.contains(op_str) {
297 let parts: Vec<&str> = condition.splitn(2, op_str).collect();
298 if parts.len() == 2 {
299 let left = parts[0].trim();
300 let right = parts[1].trim().trim_matches('\'').trim_matches('"');
301
302 if let Some(dot_pos) = left.find('.') {
304 let variable = left[..dot_pos].to_string();
305 let property = left[dot_pos + 1..].to_string();
306
307 let operator = match *op_str {
308 "=" => ComparisonOperator::Equal,
309 "!=" => ComparisonOperator::NotEqual,
310 ">" => ComparisonOperator::GreaterThan,
311 "<" => ComparisonOperator::LessThan,
312 ">=" => ComparisonOperator::GreaterThanOrEqual,
313 "<=" => ComparisonOperator::LessThanOrEqual,
314 _ => return Err(GraphError::sql_parsing(&format!("Unknown operator: {}", op_str))),
315 };
316
317 return Ok(WhereCondition {
318 variable,
319 property,
320 operator,
321 value: right.to_string(),
322 });
323 }
324 }
325 }
326 }
327
328 Err(GraphError::sql_parsing(&format!("Invalid WHERE condition: {}", condition)))
329 }
330
331 fn parse_properties(&self, props_str: &str) -> Result<HashMap<String, String>> {
333 let mut properties = HashMap::new();
334
335 let prop_parts: Vec<&str> = props_str.split(',').collect();
337
338 for prop_str in prop_parts {
339 let prop_str = prop_str.trim();
340 if prop_str.contains(':') {
341 let parts: Vec<&str> = prop_str.splitn(2, ':').collect();
342 if parts.len() == 2 {
343 let key = parts[0].trim().to_string();
344 let value = parts[1].trim().trim_matches('\'').trim_matches('"').to_string();
345 properties.insert(key, value);
346 }
347 }
348 }
349
350 Ok(properties)
351 }
352}
353
354impl Default for GqlParser {
355 fn default() -> Self {
356 Self::new()
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_parse_simple_node_pattern() {
366 let parser = GqlParser::new();
367 let pattern = "MATCH (a:User)";
368 let result = parser.parse_match_pattern(pattern).unwrap();
369
370 assert_eq!(result.nodes.len(), 1);
371 assert_eq!(result.nodes[0].variable, "a");
372 assert_eq!(result.nodes[0].label, Some("User".to_string()));
373 assert!(result.edges.is_empty());
374 assert!(result.where_conditions.is_empty());
375 }
376
377 #[test]
378 fn test_parse_simple_edge_pattern() {
379 let parser = GqlParser::new();
380 let pattern = "MATCH (a)-[r]->(b)";
381 let result = parser.parse_match_pattern(pattern).unwrap();
382
383 assert_eq!(result.nodes.len(), 2);
384 assert_eq!(result.edges.len(), 1);
385 assert_eq!(result.edges[0].variable, Some("r".to_string()));
386 assert_eq!(result.edges[0].direction, EdgeDirection::Outgoing);
387 }
388
389 #[test]
390 fn test_parse_with_where_condition() {
391 let parser = GqlParser::new();
392 let pattern = "MATCH (a:User) WHERE a.name = 'Alice'";
393 let result = parser.parse_match_pattern(pattern).unwrap();
394
395 assert_eq!(result.where_conditions.len(), 1);
396 assert_eq!(result.where_conditions[0].variable, "a");
397 assert_eq!(result.where_conditions[0].property, "name");
398 assert_eq!(result.where_conditions[0].value, "Alice");
399 assert_eq!(result.where_conditions[0].operator, ComparisonOperator::Equal);
400 }
401
402 #[test]
403 fn test_parse_complex_pattern() {
404 let parser = GqlParser::new();
405 let pattern = "MATCH (a:User)-[r:FOLLOWS]->(b:User) WHERE a.name = 'Alice' AND b.age > 25";
406 let result = parser.parse_match_pattern(pattern).unwrap();
407
408 assert_eq!(result.nodes.len(), 2);
409 assert_eq!(result.edges.len(), 1);
410 assert_eq!(result.where_conditions.len(), 2);
411
412 assert_eq!(result.nodes[0].variable, "a");
414 assert_eq!(result.nodes[0].label, Some("User".to_string()));
415 assert_eq!(result.nodes[1].variable, "b");
416 assert_eq!(result.nodes[1].label, Some("User".to_string()));
417
418 assert_eq!(result.edges[0].variable, Some("r".to_string()));
420 assert_eq!(result.edges[0].label, Some("FOLLOWS".to_string()));
421 assert_eq!(result.edges[0].direction, EdgeDirection::Outgoing);
422 }
423
424 #[test]
425 fn test_parse_invalid_pattern() {
426 let parser = GqlParser::new();
427 let pattern = "INVALID PATTERN";
428 let result = parser.parse_match_pattern(pattern);
429 assert!(result.is_err());
430 }
431}