1use crate::ast::{
49 BinaryOp, CreateCollectionStatement, DataType, DistanceMetric, DropCollectionStatement,
50 EdgeDirection, EdgeLength, EdgePattern, Expr, GraphPattern, Identifier, MatchStatement,
51 NodePattern, OrderByExpr, ParameterRef, PathPattern, PayloadFieldDef, PropertyCondition,
52 QualifiedName, ReturnItem, SelectStatement, ShortestPathPattern, Statement, VectorDef,
53 VectorTypeDef, WeightSpec,
54};
55use crate::error::{ParseError, ParseResult};
56use crate::parser::sql;
57
58pub struct ExtendedParser;
60
61impl ExtendedParser {
62 pub fn parse(input: &str) -> ParseResult<Vec<Statement>> {
75 if input.trim().is_empty() {
76 return Err(ParseError::EmptyQuery);
77 }
78
79 if Self::is_standalone_match(input) {
81 return Self::parse_standalone_match(input);
82 }
83
84 if Self::is_create_collection(input) {
86 return Self::parse_create_collection(input);
87 }
88
89 if Self::is_drop_collection(input) {
91 return Self::parse_drop_collection(input);
92 }
93
94 let (sql_without_match, match_patterns, optional_patterns) =
96 Self::extract_match_clauses(input)?;
97
98 let preprocessed = Self::preprocess_vector_ops(&sql_without_match);
100
101 let mut statements = sql::parse_sql(&preprocessed)?;
103
104 for (i, stmt) in statements.iter_mut().enumerate() {
106 Self::restore_vector_ops(stmt);
107 if let Some(pattern) = match_patterns.get(i) {
108 Self::add_match_clause(stmt, pattern.clone());
109 }
110 if let Some(opt_patterns) = optional_patterns.get(i) {
111 for pattern in opt_patterns {
112 Self::add_optional_match_clause(stmt, pattern.clone());
113 }
114 }
115 }
116
117 Ok(statements)
118 }
119
120 fn is_standalone_match(input: &str) -> bool {
124 let trimmed = input.trim();
125 let upper = trimmed.to_uppercase();
126
127 if !upper.starts_with("MATCH") {
129 return false;
130 }
131
132 upper.contains("RETURN")
134 }
135
136 fn is_create_collection(input: &str) -> bool {
138 let upper = input.trim().to_uppercase();
139 upper.starts_with("CREATE COLLECTION")
140 || upper.starts_with("CREATE IF NOT EXISTS COLLECTION")
141 }
142
143 fn is_drop_collection(input: &str) -> bool {
145 let upper = input.trim().to_uppercase();
146 upper.starts_with("DROP COLLECTION") || upper.starts_with("DROP IF EXISTS COLLECTION")
147 }
148
149 fn parse_drop_collection(input: &str) -> ParseResult<Vec<Statement>> {
156 let input = input.trim().trim_end_matches(';');
157 let upper = input.to_uppercase();
158
159 let (if_exists, after_if_exists) = if upper.starts_with("DROP IF EXISTS") {
161 (true, input[14..].trim_start()) } else if upper.starts_with("DROP") {
163 (false, input[4..].trim_start()) } else {
165 return Err(ParseError::SqlSyntax("expected DROP keyword".to_string()));
166 };
167
168 let upper_rest = after_if_exists.to_uppercase();
170 if !upper_rest.starts_with("COLLECTION") {
171 return Err(ParseError::SqlSyntax(
172 "expected COLLECTION keyword after DROP".to_string(),
173 ));
174 }
175 let after_collection = after_if_exists[10..].trim_start(); let names: Vec<Identifier> =
179 after_collection.split(',').map(|s| Identifier::new(s.trim())).collect();
180
181 if names.is_empty() || names.iter().any(|n| n.name.is_empty()) {
182 return Err(ParseError::SqlSyntax("expected collection name(s)".to_string()));
183 }
184
185 let mut stmt = DropCollectionStatement::new(names);
186 if if_exists {
187 stmt = stmt.if_exists();
188 }
189
190 Ok(vec![Statement::DropCollection(stmt)])
191 }
192
193 fn parse_create_collection(input: &str) -> ParseResult<Vec<Statement>> {
209 let input = input.trim();
210 let upper = input.to_uppercase();
211
212 if !upper.starts_with("CREATE") {
214 return Err(ParseError::SqlSyntax("expected CREATE keyword".to_string()));
215 }
216 let after_create = input[6..].trim_start();
217 let upper_after_create = after_create.to_uppercase();
218
219 let (if_not_exists, after_if_not_exists) =
221 if upper_after_create.starts_with("IF NOT EXISTS") {
222 (true, after_create[13..].trim_start())
223 } else {
224 (false, after_create)
225 };
226
227 let upper_rest = after_if_not_exists.to_uppercase();
229 if !upper_rest.starts_with("COLLECTION") {
230 return Err(ParseError::SqlSyntax(
231 "expected COLLECTION keyword after CREATE".to_string(),
232 ));
233 }
234 let after_collection = after_if_not_exists[10..].trim_start();
235
236 let name_end = after_collection
238 .find(|c: char| c == '(' || c.is_whitespace())
239 .unwrap_or(after_collection.len());
240 let collection_name = &after_collection[..name_end];
241 if collection_name.is_empty() {
242 return Err(ParseError::SqlSyntax("expected collection name".to_string()));
243 }
244 let name = Identifier::new(collection_name.trim());
245
246 let after_name = after_collection[name_end..].trim_start();
248 if !after_name.starts_with('(') {
249 return Err(ParseError::SqlSyntax("expected '(' after collection name".to_string()));
250 }
251
252 let close_paren = Self::find_matching_paren(after_name, 0).ok_or_else(|| {
254 ParseError::SqlSyntax("unclosed parenthesis in CREATE COLLECTION".to_string())
255 })?;
256
257 let defs_str = &after_name[1..close_paren];
258
259 let (vectors, payload_fields) = Self::parse_collection_definitions(defs_str)?;
261
262 if vectors.is_empty() {
263 return Err(ParseError::SqlSyntax(
264 "CREATE COLLECTION requires at least one vector definition".to_string(),
265 ));
266 }
267
268 let stmt = CreateCollectionStatement { if_not_exists, name, vectors, payload_fields };
269
270 Ok(vec![Statement::CreateCollection(Box::new(stmt))])
271 }
272
273 fn parse_collection_definitions(
279 input: &str,
280 ) -> ParseResult<(Vec<VectorDef>, Vec<PayloadFieldDef>)> {
281 let input = input.trim();
282 if input.is_empty() {
283 return Ok((vec![], vec![]));
284 }
285
286 let mut vectors = Vec::new();
287 let mut payload_fields = Vec::new();
288 let mut current = String::new();
289 let mut paren_depth: i32 = 0;
290
291 for c in input.chars() {
293 match c {
294 '(' => {
295 paren_depth += 1;
296 current.push(c);
297 }
298 ')' => {
299 paren_depth = paren_depth.saturating_sub(1);
300 current.push(c);
301 }
302 ',' if paren_depth == 0 => {
303 if !current.trim().is_empty() {
304 Self::parse_collection_item(
305 current.trim(),
306 &mut vectors,
307 &mut payload_fields,
308 )?;
309 }
310 current.clear();
311 }
312 _ => current.push(c),
313 }
314 }
315
316 if !current.trim().is_empty() {
318 Self::parse_collection_item(current.trim(), &mut vectors, &mut payload_fields)?;
319 }
320
321 Ok((vectors, payload_fields))
322 }
323
324 fn parse_collection_item(
326 input: &str,
327 vectors: &mut Vec<VectorDef>,
328 payload_fields: &mut Vec<PayloadFieldDef>,
329 ) -> ParseResult<()> {
330 let input = input.trim();
331 let upper = input.to_uppercase();
332
333 if upper.starts_with("VECTOR ") {
335 let vector = Self::parse_new_style_vector_def(input)?;
336 vectors.push(vector);
337 return Ok(());
338 }
339
340 if Self::is_legacy_vector_def(input) {
342 let vector = Self::parse_single_vector_def(input)?;
343 vectors.push(vector);
344 return Ok(());
345 }
346
347 let field = Self::parse_payload_field(input)?;
349 payload_fields.push(field);
350 Ok(())
351 }
352
353 fn is_legacy_vector_def(input: &str) -> bool {
355 let upper = input.to_uppercase();
356 let name_end = input.find(|c: char| c.is_whitespace()).unwrap_or(input.len());
358 let after_name = upper[name_end..].trim_start();
359 after_name.starts_with("VECTOR")
360 || after_name.starts_with("SPARSE_VECTOR")
361 || after_name.starts_with("MULTI_VECTOR")
362 || after_name.starts_with("BINARY_VECTOR")
363 }
364
365 fn parse_new_style_vector_def(input: &str) -> ParseResult<VectorDef> {
367 let input = input.trim();
368 let upper = input.to_uppercase();
369
370 if !upper.starts_with("VECTOR ") {
372 return Err(ParseError::SqlSyntax("expected VECTOR keyword".to_string()));
373 }
374 let after_vector = input[7..].trim_start();
375
376 let name_end = after_vector.find(|c: char| c.is_whitespace()).unwrap_or(after_vector.len());
378 let name_str = &after_vector[..name_end];
379 if name_str.is_empty() {
380 return Err(ParseError::SqlSyntax("expected vector name after VECTOR".to_string()));
381 }
382 let name = Identifier::new(name_str);
383
384 let after_name = after_vector[name_end..].trim_start();
385 let upper_after_name = after_name.to_uppercase();
386
387 if !upper_after_name.starts_with("DIMENSION") {
389 return Err(ParseError::SqlSyntax(
390 "expected DIMENSION keyword after vector name".to_string(),
391 ));
392 }
393 let after_dimension = after_name[9..].trim_start();
394
395 let dim_end =
397 after_dimension.find(|c: char| c.is_whitespace()).unwrap_or(after_dimension.len());
398 let dim_str = &after_dimension[..dim_end];
399 let dimension = dim_str
400 .trim()
401 .parse::<u32>()
402 .map_err(|_| ParseError::SqlSyntax(format!("invalid DIMENSION value: {dim_str}")))?;
403
404 let vector_type = VectorTypeDef::Vector { dimension };
405
406 let rest = after_dimension[dim_end..].trim_start();
408 let upper_rest = rest.to_uppercase();
409
410 let (using_method, after_using) = if upper_rest.starts_with("USING") {
411 let after_using_kw = rest[5..].trim_start();
412 let method_end =
413 after_using_kw.find(|c: char| c.is_whitespace()).unwrap_or(after_using_kw.len());
414 let method = &after_using_kw[..method_end];
415 (Some(method.to_lowercase()), after_using_kw[method_end..].trim_start())
416 } else {
417 (None, rest)
418 };
419
420 let upper_after_using = after_using.to_uppercase();
421 let with_options = if upper_after_using.starts_with("WITH") {
422 let after_with = after_using[4..].trim_start();
423 Self::parse_with_options(after_with)?
424 } else {
425 vec![]
426 };
427
428 Ok(VectorDef { name, vector_type, using: using_method, with_options })
429 }
430
431 fn parse_payload_field(input: &str) -> ParseResult<PayloadFieldDef> {
433 let input = input.trim();
434
435 let name_end = input.find(|c: char| c.is_whitespace()).unwrap_or(input.len());
437 let name_str = &input[..name_end];
438 if name_str.is_empty() {
439 return Err(ParseError::SqlSyntax("expected field name".to_string()));
440 }
441 let name = Identifier::new(name_str);
442
443 let after_name = input[name_end..].trim_start();
444 let upper_after_name = after_name.to_uppercase();
445
446 let (data_type, after_type) = Self::parse_payload_data_type(&upper_after_name, after_name)?;
448
449 let upper_after_type = after_type.to_uppercase();
451 let indexed = upper_after_type.trim().starts_with("INDEXED");
452
453 Ok(PayloadFieldDef { name, data_type, indexed })
454 }
455
456 fn parse_payload_data_type<'a>(
458 upper: &str,
459 original: &'a str,
460 ) -> ParseResult<(DataType, &'a str)> {
461 if upper.starts_with("TEXT") {
462 return Ok((DataType::Text, &original[4..]));
463 }
464 if upper.starts_with("INTEGER") {
465 return Ok((DataType::Integer, &original[7..]));
466 }
467 if upper.starts_with("INT") {
468 return Ok((DataType::Integer, &original[3..]));
469 }
470 if upper.starts_with("BIGINT") {
471 return Ok((DataType::BigInt, &original[6..]));
472 }
473 if upper.starts_with("FLOAT") {
474 return Ok((DataType::Real, &original[5..]));
475 }
476 if upper.starts_with("REAL") {
477 return Ok((DataType::Real, &original[4..]));
478 }
479 if upper.starts_with("DOUBLE") {
480 return Ok((DataType::DoublePrecision, &original[6..]));
481 }
482 if upper.starts_with("BOOLEAN") || upper.starts_with("BOOL") {
483 let len = if upper.starts_with("BOOLEAN") { 7 } else { 4 };
484 return Ok((DataType::Boolean, &original[len..]));
485 }
486 if upper.starts_with("JSON") {
487 return Ok((DataType::Json, &original[4..]));
488 }
489 if upper.starts_with("JSONB") {
490 return Ok((DataType::Jsonb, &original[5..]));
491 }
492 if upper.starts_with("BYTEA") {
493 return Ok((DataType::Bytea, &original[5..]));
494 }
495 if upper.starts_with("TIMESTAMP") {
496 return Ok((DataType::Timestamp, &original[9..]));
497 }
498 if upper.starts_with("DATE") {
499 return Ok((DataType::Date, &original[4..]));
500 }
501 if upper.starts_with("UUID") {
502 return Ok((DataType::Uuid, &original[4..]));
503 }
504 if upper.starts_with("VARCHAR") {
506 let after_varchar = original[7..].trim_start();
507 if after_varchar.starts_with('(') {
508 let close = after_varchar.find(')').ok_or_else(|| {
509 ParseError::SqlSyntax("unclosed parenthesis in VARCHAR".to_string())
510 })?;
511 let len_str = &after_varchar[1..close];
512 let len = len_str.trim().parse::<u32>().map_err(|_| {
513 ParseError::SqlSyntax(format!("invalid VARCHAR length: {len_str}"))
514 })?;
515 return Ok((DataType::Varchar(Some(len)), &after_varchar[close + 1..]));
516 }
517 return Ok((DataType::Varchar(None), &original[7..]));
518 }
519
520 Err(ParseError::SqlSyntax(format!(
521 "expected data type (TEXT, INTEGER, FLOAT, BOOLEAN, JSON, BYTEA, etc.), found: {}",
522 &original[..original.len().min(20)]
523 )))
524 }
525
526 #[allow(dead_code)]
530 fn parse_vector_definitions(input: &str) -> ParseResult<Vec<VectorDef>> {
531 let input = input.trim();
532 if input.is_empty() {
533 return Ok(vec![]);
534 }
535
536 let mut vectors = Vec::new();
537 let mut current = String::new();
538 let mut paren_depth: i32 = 0;
539
540 for c in input.chars() {
542 match c {
543 '(' => {
544 paren_depth += 1;
545 current.push(c);
546 }
547 ')' => {
548 paren_depth = paren_depth.saturating_sub(1);
549 current.push(c);
550 }
551 ',' if paren_depth == 0 => {
552 if !current.trim().is_empty() {
553 vectors.push(Self::parse_single_vector_def(current.trim())?);
554 }
555 current.clear();
556 }
557 _ => current.push(c),
558 }
559 }
560
561 if !current.trim().is_empty() {
563 vectors.push(Self::parse_single_vector_def(current.trim())?);
564 }
565
566 Ok(vectors)
567 }
568
569 fn parse_single_vector_def(input: &str) -> ParseResult<VectorDef> {
573 let input = input.trim();
574
575 let name_end = input.find(|c: char| c.is_whitespace()).unwrap_or(input.len());
577 let name_str = &input[..name_end];
578 if name_str.is_empty() {
579 return Err(ParseError::SqlSyntax("expected vector name in definition".to_string()));
580 }
581 let name = Identifier::new(name_str);
582
583 let after_name = input[name_end..].trim_start();
584 let upper_after_name = after_name.to_uppercase();
585
586 let (vector_type, after_type) = Self::parse_vector_type(&upper_after_name, after_name)?;
588
589 let after_type_trimmed = after_type.trim_start();
591 let upper_after_type = after_type_trimmed.to_uppercase();
592 let (using_method, after_using) = if upper_after_type.starts_with("USING") {
593 let after_using_kw = after_type_trimmed[5..].trim_start();
594 let upper_after_using_kw = after_using_kw.to_uppercase();
595
596 let method_end =
598 if let Some(with_pos) = Self::find_keyword_pos(&upper_after_using_kw, "WITH") {
599 with_pos
600 } else {
601 after_using_kw.len()
602 };
603
604 let method = after_using_kw[..method_end].trim();
605 (Some(method.to_lowercase()), after_using_kw[method_end..].trim_start())
606 } else {
607 (None, after_type_trimmed)
608 };
609
610 let upper_after_using = after_using.to_uppercase();
612 let with_options = if upper_after_using.starts_with("WITH") {
613 let after_with = after_using[4..].trim_start();
614 Self::parse_with_options(after_with)?
615 } else {
616 vec![]
617 };
618
619 Ok(VectorDef { name, vector_type, using: using_method, with_options })
620 }
621
622 fn parse_vector_type<'a>(
624 upper: &str,
625 original: &'a str,
626 ) -> ParseResult<(VectorTypeDef, &'a str)> {
627 if upper.starts_with("VECTOR") {
629 let after_vector = original[6..].trim_start();
630 if after_vector.starts_with('(') {
631 let close = after_vector.find(')').ok_or_else(|| {
632 ParseError::SqlSyntax("unclosed parenthesis in VECTOR type".to_string())
633 })?;
634 let dim_str = &after_vector[1..close];
635 let dimension = dim_str.trim().parse::<u32>().map_err(|_| {
636 ParseError::SqlSyntax(format!("invalid dimension in VECTOR: {dim_str}"))
637 })?;
638 return Ok((VectorTypeDef::Vector { dimension }, &after_vector[close + 1..]));
639 }
640 return Err(ParseError::SqlSyntax(
641 "VECTOR type requires dimension: VECTOR(dim)".to_string(),
642 ));
643 }
644
645 if upper.starts_with("SPARSE_VECTOR") {
647 let after_sparse = original[13..].trim_start();
648 if after_sparse.starts_with('(') {
649 let close = after_sparse.find(')').ok_or_else(|| {
650 ParseError::SqlSyntax("unclosed parenthesis in SPARSE_VECTOR type".to_string())
651 })?;
652 let dim_str = &after_sparse[1..close];
653 let max_dimension = Some(dim_str.trim().parse::<u32>().map_err(|_| {
654 ParseError::SqlSyntax(format!(
655 "invalid max dimension in SPARSE_VECTOR: {dim_str}"
656 ))
657 })?);
658 return Ok((
659 VectorTypeDef::SparseVector { max_dimension },
660 &after_sparse[close + 1..],
661 ));
662 }
663 return Ok((VectorTypeDef::SparseVector { max_dimension: None }, after_sparse));
664 }
665
666 if upper.starts_with("MULTI_VECTOR") {
668 let after_multi = original[12..].trim_start();
669 if after_multi.starts_with('(') {
670 let close = after_multi.find(')').ok_or_else(|| {
671 ParseError::SqlSyntax("unclosed parenthesis in MULTI_VECTOR type".to_string())
672 })?;
673 let dim_str = &after_multi[1..close];
674 let token_dim = dim_str.trim().parse::<u32>().map_err(|_| {
675 ParseError::SqlSyntax(format!("invalid dimension in MULTI_VECTOR: {dim_str}"))
676 })?;
677 return Ok((VectorTypeDef::MultiVector { token_dim }, &after_multi[close + 1..]));
678 }
679 return Err(ParseError::SqlSyntax(
680 "MULTI_VECTOR type requires dimension: MULTI_VECTOR(dim)".to_string(),
681 ));
682 }
683
684 if upper.starts_with("BINARY_VECTOR") {
686 let after_binary = original[13..].trim_start();
687 if after_binary.starts_with('(') {
688 let close = after_binary.find(')').ok_or_else(|| {
689 ParseError::SqlSyntax("unclosed parenthesis in BINARY_VECTOR type".to_string())
690 })?;
691 let bits_str = &after_binary[1..close];
692 let bits = bits_str.trim().parse::<u32>().map_err(|_| {
693 ParseError::SqlSyntax(format!("invalid bits in BINARY_VECTOR: {bits_str}"))
694 })?;
695 return Ok((VectorTypeDef::BinaryVector { bits }, &after_binary[close + 1..]));
696 }
697 return Err(ParseError::SqlSyntax(
698 "BINARY_VECTOR type requires bit count: BINARY_VECTOR(bits)".to_string(),
699 ));
700 }
701
702 Err(ParseError::SqlSyntax(format!(
703 "expected vector type (VECTOR, SPARSE_VECTOR, MULTI_VECTOR, BINARY_VECTOR), found: {}",
704 &original[..original.len().min(20)]
705 )))
706 }
707
708 fn parse_with_options(input: &str) -> ParseResult<Vec<(String, String)>> {
710 let input = input.trim();
711 if !input.starts_with('(') {
712 return Err(ParseError::SqlSyntax("expected '(' after WITH keyword".to_string()));
713 }
714
715 let close = input.find(')').ok_or_else(|| {
716 ParseError::SqlSyntax("unclosed parenthesis in WITH options".to_string())
717 })?;
718
719 let options_str = &input[1..close];
720 let mut options = Vec::new();
721
722 for pair in options_str.split(',') {
724 let pair = pair.trim();
725 if pair.is_empty() {
726 continue;
727 }
728
729 let eq_pos = pair.find('=').ok_or_else(|| {
730 ParseError::SqlSyntax(format!("expected '=' in WITH option: {pair}"))
731 })?;
732
733 let key = pair[..eq_pos].trim().to_lowercase();
734 let value_part = pair[eq_pos + 1..].trim();
735
736 let value = if (value_part.starts_with('\'') && value_part.ends_with('\''))
738 || (value_part.starts_with('"') && value_part.ends_with('"'))
739 {
740 value_part[1..value_part.len() - 1].to_string()
741 } else {
742 value_part.to_string()
743 };
744
745 options.push((key, value));
746 }
747
748 Ok(options)
749 }
750
751 fn parse_standalone_match(input: &str) -> ParseResult<Vec<Statement>> {
763 let input = input.trim();
764 let upper = input.to_uppercase();
765
766 if !upper.starts_with("MATCH") {
768 return Err(ParseError::InvalidPattern("expected MATCH keyword".to_string()));
769 }
770
771 let after_match = input[5..].trim_start();
773
774 let upper_after_match = after_match.to_uppercase();
776
777 let return_pos = Self::find_keyword_pos(&upper_after_match, "RETURN").ok_or_else(|| {
779 ParseError::InvalidPattern("MATCH requires RETURN clause".to_string())
780 })?;
781
782 let where_pos = Self::find_keyword_pos(&upper_after_match[..return_pos], "WHERE");
784
785 let pattern_end = where_pos.unwrap_or(return_pos);
787 let pattern_str = after_match[..pattern_end].trim();
788 let pattern = Self::parse_graph_pattern(pattern_str)?;
789
790 let where_clause = if let Some(wp) = where_pos {
792 let where_content = &after_match[wp + 5..return_pos]; Some(Self::parse_where_expression(where_content.trim())?)
794 } else {
795 None
796 };
797
798 let after_return = after_match[return_pos + 6..].trim_start(); let upper_after_return = after_return.to_uppercase();
801
802 let (distinct, return_content_start) = if upper_after_return.starts_with("DISTINCT") {
804 (true, 8) } else {
806 (false, 0)
807 };
808
809 let return_and_rest = after_return[return_content_start..].trim_start();
810 let upper_return_rest = return_and_rest.to_uppercase();
811
812 let order_by_pos = Self::find_keyword_pos(&upper_return_rest, "ORDER BY");
814 let skip_pos = Self::find_keyword_pos(&upper_return_rest, "SKIP");
815 let limit_pos = Self::find_keyword_pos(&upper_return_rest, "LIMIT");
816
817 let return_items_end = [order_by_pos, skip_pos, limit_pos]
819 .iter()
820 .filter_map(|&p| p)
821 .min()
822 .unwrap_or(return_and_rest.len());
823
824 let return_items_str = return_and_rest[..return_items_end].trim();
826 let return_items = Self::parse_return_items(return_items_str)?;
827
828 let order_by = if let Some(obp) = order_by_pos {
830 let order_end = [skip_pos, limit_pos]
831 .iter()
832 .filter_map(|&p| p)
833 .min()
834 .unwrap_or(return_and_rest.len());
835
836 if order_end > obp + 8 {
837 let order_content = &return_and_rest[obp + 8..order_end]; Self::parse_order_by(order_content.trim())?
839 } else {
840 vec![]
841 }
842 } else {
843 vec![]
844 };
845
846 let skip = if let Some(sp) = skip_pos {
848 let skip_end = limit_pos.unwrap_or(return_and_rest.len());
849 if skip_end > sp + 4 {
850 let skip_content = &return_and_rest[sp + 4..skip_end]; Some(Self::parse_limit_expr(skip_content.trim())?)
852 } else {
853 None
854 }
855 } else {
856 None
857 };
858
859 let limit = if let Some(lp) = limit_pos {
861 let limit_content = &return_and_rest[lp + 5..]; let limit_end = limit_content.find(';').unwrap_or(limit_content.len());
863 let limit_str = limit_content[..limit_end].trim();
864 if limit_str.is_empty() {
865 None
866 } else {
867 Some(Self::parse_limit_expr(limit_str)?)
868 }
869 } else {
870 None
871 };
872
873 let match_stmt = MatchStatement {
875 pattern,
876 where_clause,
877 return_clause: return_items,
878 distinct,
879 order_by,
880 skip,
881 limit,
882 };
883
884 Ok(vec![Statement::Match(Box::new(match_stmt))])
885 }
886
887 fn find_keyword_pos(input: &str, keyword: &str) -> Option<usize> {
889 let mut search_from = 0;
890
891 while let Some(pos) = input[search_from..].find(keyword) {
892 let absolute_pos = search_from + pos;
893
894 let before_ok =
896 absolute_pos == 0 || !input.as_bytes()[absolute_pos - 1].is_ascii_alphanumeric();
897 let after_ok = absolute_pos + keyword.len() >= input.len()
898 || !input.as_bytes()[absolute_pos + keyword.len()].is_ascii_alphanumeric();
899
900 if before_ok && after_ok {
901 return Some(absolute_pos);
902 }
903
904 search_from = absolute_pos + keyword.len();
905 }
906
907 None
908 }
909
910 fn parse_where_expression(input: &str) -> ParseResult<Expr> {
912 Self::parse_simple_expression(input)
915 }
916
917 fn parse_simple_expression(input: &str) -> ParseResult<Expr> {
919 let input = input.trim();
920
921 if input.is_empty() {
922 return Err(ParseError::InvalidPattern("empty expression".to_string()));
923 }
924
925 if let Some(and_pos) = Self::find_top_level_keyword(input, " AND ") {
927 let left = Self::parse_simple_expression(&input[..and_pos])?;
928 let right = Self::parse_simple_expression(&input[and_pos + 5..])?;
929 return Ok(Expr::BinaryOp {
930 left: Box::new(left),
931 op: crate::ast::BinaryOp::And,
932 right: Box::new(right),
933 });
934 }
935
936 if let Some(or_pos) = Self::find_top_level_keyword(input, " OR ") {
937 let left = Self::parse_simple_expression(&input[..or_pos])?;
938 let right = Self::parse_simple_expression(&input[or_pos + 4..])?;
939 return Ok(Expr::BinaryOp {
940 left: Box::new(left),
941 op: crate::ast::BinaryOp::Or,
942 right: Box::new(right),
943 });
944 }
945
946 let comparisons = [
948 ("<>", crate::ast::BinaryOp::NotEq),
949 ("!=", crate::ast::BinaryOp::NotEq),
950 ("<=", crate::ast::BinaryOp::LtEq),
951 (">=", crate::ast::BinaryOp::GtEq),
952 ("<", crate::ast::BinaryOp::Lt),
953 (">", crate::ast::BinaryOp::Gt),
954 ("=", crate::ast::BinaryOp::Eq),
955 ];
956
957 for (op_str, op) in &comparisons {
958 if let Some(pos) = Self::find_top_level_operator(input, op_str) {
959 let left = Self::parse_simple_expression(&input[..pos])?;
960 let right = Self::parse_simple_expression(&input[pos + op_str.len()..])?;
961 return Ok(Expr::BinaryOp {
962 left: Box::new(left),
963 op: *op,
964 right: Box::new(right),
965 });
966 }
967 }
968
969 Ok(Self::parse_property_value(input))
971 }
972
973 fn find_top_level_keyword(input: &str, keyword: &str) -> Option<usize> {
975 let upper = input.to_uppercase();
976 let keyword_upper = keyword.to_uppercase();
977 let mut depth: i32 = 0;
978 let mut i = 0;
979 let bytes = input.as_bytes();
980
981 while i < input.len() {
982 if bytes[i] == b'(' {
983 depth += 1;
984 } else if bytes[i] == b')' {
985 depth = depth.saturating_sub(1);
986 } else if depth == 0 && upper[i..].starts_with(&keyword_upper) {
987 return Some(i);
988 }
989 i += 1;
990 }
991 None
992 }
993
994 fn find_top_level_operator(input: &str, op: &str) -> Option<usize> {
996 let mut depth: i32 = 0;
997 let mut in_string = false;
998 let mut string_char = '"';
999 let bytes = input.as_bytes();
1000 let op_bytes = op.as_bytes();
1001
1002 if op.len() > input.len() {
1003 return None;
1004 }
1005
1006 let mut i = 0;
1007 while i < bytes.len() {
1008 let c = bytes[i];
1009
1010 if in_string {
1011 if c == string_char as u8 {
1012 in_string = false;
1013 }
1014 i += 1;
1015 continue;
1016 }
1017
1018 match c {
1019 b'\'' | b'"' => {
1020 in_string = true;
1021 string_char = c as char;
1022 }
1023 b'(' => depth += 1,
1024 b')' => depth = depth.saturating_sub(1),
1025 _ if depth == 0 && i + op.len() <= bytes.len() => {
1026 if &bytes[i..i + op.len()] == op_bytes {
1027 return Some(i);
1028 }
1029 }
1030 _ => {}
1031 }
1032 i += 1;
1033 }
1034 None
1035 }
1036
1037 fn parse_return_items(input: &str) -> ParseResult<Vec<ReturnItem>> {
1039 let input = input.trim();
1040
1041 if input.is_empty() {
1042 return Err(ParseError::InvalidPattern("empty RETURN clause".to_string()));
1043 }
1044
1045 if input == "*" {
1047 return Ok(vec![ReturnItem::Wildcard]);
1048 }
1049
1050 let mut items = Vec::new();
1052 let mut current = String::new();
1053 let mut depth: i32 = 0;
1054
1055 for c in input.chars() {
1056 match c {
1057 '(' => {
1058 depth += 1;
1059 current.push(c);
1060 }
1061 ')' => {
1062 depth = depth.saturating_sub(1);
1063 current.push(c);
1064 }
1065 ',' if depth == 0 => {
1066 if !current.trim().is_empty() {
1067 items.push(Self::parse_return_item(current.trim())?);
1068 }
1069 current.clear();
1070 }
1071 _ => current.push(c),
1072 }
1073 }
1074
1075 if !current.trim().is_empty() {
1076 items.push(Self::parse_return_item(current.trim())?);
1077 }
1078
1079 if items.is_empty() {
1080 return Err(ParseError::InvalidPattern("empty RETURN clause".to_string()));
1081 }
1082
1083 Ok(items)
1084 }
1085
1086 fn parse_return_item(input: &str) -> ParseResult<ReturnItem> {
1088 let input = input.trim();
1089
1090 if input == "*" {
1091 return Ok(ReturnItem::Wildcard);
1092 }
1093
1094 let upper = input.to_uppercase();
1096 if let Some(as_pos) = Self::find_top_level_keyword(&upper, " AS ") {
1097 let expr_str = &input[..as_pos];
1098 let alias_str = &input[as_pos + 4..]; let expr = Self::parse_simple_expression(expr_str.trim())?;
1100 let alias = Identifier::new(alias_str.trim());
1101 return Ok(ReturnItem::Expr { expr, alias: Some(alias) });
1102 }
1103
1104 let expr = Self::parse_simple_expression(input)?;
1106 Ok(ReturnItem::Expr { expr, alias: None })
1107 }
1108
1109 fn parse_order_by(input: &str) -> ParseResult<Vec<OrderByExpr>> {
1111 let input = input.trim();
1112
1113 if input.is_empty() {
1114 return Ok(vec![]);
1115 }
1116
1117 let mut orders = Vec::new();
1118 let mut current = String::new();
1119 let mut depth: i32 = 0;
1120
1121 for c in input.chars() {
1122 match c {
1123 '(' => {
1124 depth += 1;
1125 current.push(c);
1126 }
1127 ')' => {
1128 depth = depth.saturating_sub(1);
1129 current.push(c);
1130 }
1131 ',' if depth == 0 => {
1132 if !current.trim().is_empty() {
1133 orders.push(Self::parse_order_item(current.trim())?);
1134 }
1135 current.clear();
1136 }
1137 _ => current.push(c),
1138 }
1139 }
1140
1141 if !current.trim().is_empty() {
1142 orders.push(Self::parse_order_item(current.trim())?);
1143 }
1144
1145 Ok(orders)
1146 }
1147
1148 fn parse_order_item(input: &str) -> ParseResult<OrderByExpr> {
1150 let input = input.trim();
1151 let upper = input.to_uppercase();
1152
1153 let (expr_str, asc) = if upper.ends_with(" DESC") {
1155 (&input[..input.len() - 5], false)
1156 } else if upper.ends_with(" ASC") {
1157 (&input[..input.len() - 4], true)
1158 } else {
1159 (input, true) };
1161
1162 let expr = Self::parse_simple_expression(expr_str.trim())?;
1163
1164 Ok(OrderByExpr { expr: Box::new(expr), asc, nulls_first: None })
1165 }
1166
1167 fn parse_limit_expr(input: &str) -> ParseResult<Expr> {
1169 let input = input.trim();
1170
1171 if let Ok(n) = input.parse::<i64>() {
1173 return Ok(Expr::integer(n));
1174 }
1175
1176 if let Some(rest) = input.strip_prefix('$') {
1178 if let Ok(n) = rest.parse::<u32>() {
1179 return Ok(Expr::Parameter(ParameterRef::Positional(n)));
1180 }
1181 return Ok(Expr::Parameter(ParameterRef::Named(rest.to_string())));
1182 }
1183
1184 Err(ParseError::InvalidPattern(format!("invalid LIMIT/SKIP value: {input}")))
1185 }
1186
1187 pub fn parse_single(input: &str) -> ParseResult<Statement> {
1193 let mut stmts = Self::parse(input)?;
1194 if stmts.len() != 1 {
1195 return Err(ParseError::SqlSyntax(format!(
1196 "expected 1 statement, found {}",
1197 stmts.len()
1198 )));
1199 }
1200 Ok(stmts.remove(0))
1201 }
1202
1203 fn extract_match_clauses(
1214 input: &str,
1215 ) -> ParseResult<(String, Vec<GraphPattern>, Vec<Vec<GraphPattern>>)> {
1216 let mut result = String::with_capacity(input.len());
1217 let mut match_patterns: Vec<GraphPattern> = Vec::new();
1218 let mut optional_patterns: Vec<Vec<GraphPattern>> = Vec::new();
1219 let mut remaining = input;
1220
1221 loop {
1222 let optional_pos = Self::find_optional_match_keyword(remaining);
1224 let match_pos = Self::find_match_keyword(remaining);
1225
1226 match (optional_pos, match_pos) {
1228 (Some(opt_pos), Some(m_pos)) if opt_pos < m_pos => {
1229 result.push_str(&remaining[..opt_pos]);
1232
1233 let after_optional = &remaining[opt_pos + 8..]; let after_optional_trimmed = after_optional.trim_start();
1236 let whitespace_len = after_optional.len() - after_optional_trimmed.len();
1237
1238 let after_match = &after_optional[whitespace_len + 5..];
1240 let end_pos = Self::find_match_end(after_match);
1241
1242 let pattern_str = after_match[..end_pos].trim();
1243 let pattern = Self::parse_graph_pattern(pattern_str)?;
1244
1245 if let Some(last_optionals) = optional_patterns.last_mut() {
1247 last_optionals.push(pattern);
1248 } else {
1249 optional_patterns.push(vec![pattern]);
1251 }
1252
1253 remaining = &after_match[end_pos..];
1254 }
1255 (Some(opt_pos), None) => {
1256 result.push_str(&remaining[..opt_pos]);
1258
1259 let after_optional = &remaining[opt_pos + 8..];
1260 let after_optional_trimmed = after_optional.trim_start();
1261 let whitespace_len = after_optional.len() - after_optional_trimmed.len();
1262
1263 let after_match = &after_optional[whitespace_len + 5..];
1264 let end_pos = Self::find_match_end(after_match);
1265
1266 let pattern_str = after_match[..end_pos].trim();
1267 let pattern = Self::parse_graph_pattern(pattern_str)?;
1268
1269 if let Some(last_optionals) = optional_patterns.last_mut() {
1271 last_optionals.push(pattern);
1272 } else {
1273 optional_patterns.push(vec![pattern]);
1275 }
1276
1277 remaining = &after_match[end_pos..];
1278 }
1279 (_, Some(m_pos)) => {
1280 result.push_str(&remaining[..m_pos]);
1282
1283 let after_match = &remaining[m_pos + 5..]; let end_pos = Self::find_match_end(after_match);
1285
1286 let pattern_str = after_match[..end_pos].trim();
1287 let pattern = Self::parse_graph_pattern(pattern_str)?;
1288
1289 match_patterns.push(pattern);
1291 optional_patterns.push(Vec::new());
1293
1294 remaining = &after_match[end_pos..];
1295 }
1296 (None, None) => {
1297 break;
1299 }
1300 }
1301 }
1302
1303 result.push_str(remaining);
1304
1305 Ok((result, match_patterns, optional_patterns))
1306 }
1307
1308 fn find_optional_match_keyword(input: &str) -> Option<usize> {
1310 let input_upper = input.to_uppercase();
1311 let mut search_from = 0;
1312
1313 while let Some(pos) = input_upper[search_from..].find("OPTIONAL") {
1314 let absolute_pos = search_from + pos;
1315
1316 let before_ok =
1318 absolute_pos == 0 || !input.as_bytes()[absolute_pos - 1].is_ascii_alphanumeric();
1319 let after_ok = absolute_pos + 8 >= input.len()
1320 || !input.as_bytes()[absolute_pos + 8].is_ascii_alphanumeric();
1321
1322 if before_ok && after_ok {
1323 let after_optional = &input_upper[absolute_pos + 8..];
1325 let after_optional_trimmed = after_optional.trim_start();
1326
1327 if after_optional_trimmed.starts_with("MATCH") {
1328 let match_start =
1330 absolute_pos + 8 + (after_optional.len() - after_optional_trimmed.len());
1331 let match_end = match_start + 5;
1332 let match_after_ok = match_end >= input.len()
1333 || !input.as_bytes()[match_end].is_ascii_alphanumeric();
1334
1335 if match_after_ok {
1336 return Some(absolute_pos);
1337 }
1338 }
1339 }
1340
1341 search_from = absolute_pos + 8;
1342 }
1343
1344 None
1345 }
1346
1347 fn find_match_keyword(input: &str) -> Option<usize> {
1351 let input_upper = input.to_uppercase();
1352 let mut search_from = 0;
1353
1354 while let Some(pos) = input_upper[search_from..].find("MATCH") {
1355 let absolute_pos = search_from + pos;
1356
1357 let before_ok =
1359 absolute_pos == 0 || !input.as_bytes()[absolute_pos - 1].is_ascii_alphanumeric();
1360 let after_ok = absolute_pos + 5 >= input.len()
1361 || !input.as_bytes()[absolute_pos + 5].is_ascii_alphanumeric();
1362
1363 if before_ok && after_ok {
1364 let is_optional_match = Self::is_preceded_by_optional(&input_upper, absolute_pos);
1366
1367 if !is_optional_match {
1368 return Some(absolute_pos);
1369 }
1370 }
1371
1372 search_from = absolute_pos + 5;
1373 }
1374
1375 None
1376 }
1377
1378 fn is_preceded_by_optional(input_upper: &str, match_pos: usize) -> bool {
1380 if match_pos < 8 {
1381 return false;
1383 }
1384
1385 let before_match = &input_upper[..match_pos];
1387 let trimmed = before_match.trim_end();
1388
1389 if trimmed.len() >= 8 && trimmed.ends_with("OPTIONAL") {
1391 let optional_start = trimmed.len() - 8;
1393 if optional_start == 0 {
1394 return true;
1395 }
1396 let byte_before = trimmed.as_bytes()[optional_start - 1];
1397 return !byte_before.is_ascii_alphanumeric();
1398 }
1399
1400 false
1401 }
1402
1403 fn find_match_end(input: &str) -> usize {
1408 let input_upper = input.to_uppercase();
1409
1410 let keywords = [
1411 "WHERE",
1412 "ORDER",
1413 "GROUP",
1414 "HAVING",
1415 "LIMIT",
1416 "OFFSET",
1417 "UNION",
1418 "INTERSECT",
1419 "EXCEPT",
1420 "OPTIONAL", "MATCH", ];
1423
1424 let mut min_pos = input.len();
1425
1426 for keyword in &keywords {
1427 if let Some(pos) = input_upper.find(keyword) {
1428 let before_ok = pos == 0 || !input.as_bytes()[pos - 1].is_ascii_alphanumeric();
1430 let after_ok = pos + keyword.len() >= input.len()
1431 || !input.as_bytes()[pos + keyword.len()].is_ascii_alphanumeric();
1432 if before_ok && after_ok && pos < min_pos {
1433 min_pos = pos;
1434 }
1435 }
1436 }
1437
1438 if let Some(pos) = input.find(';') {
1440 if pos < min_pos {
1441 min_pos = pos;
1442 }
1443 }
1444
1445 min_pos
1446 }
1447
1448 fn preprocess_vector_ops(input: &str) -> String {
1452 if !input.contains("<->")
1454 && !input.contains("<=>")
1455 && !input.contains("<#>")
1456 && !input.contains("<##>")
1457 {
1458 return input.to_string();
1459 }
1460
1461 let mut result = input.to_string();
1462 result = Self::replace_vector_op(&result, "<->", "__VEC_EUCLIDEAN__");
1463 result = Self::replace_vector_op(&result, "<=>", "__VEC_COSINE__");
1464 result = Self::replace_vector_op(&result, "<##>", "__VEC_MAXSIM__");
1466 result = Self::replace_vector_op(&result, "<#>", "__VEC_INNER__");
1467 result
1468 }
1469
1470 fn replace_vector_op(input: &str, op: &str, func_name: &str) -> String {
1474 let chars: Vec<char> = input.chars().collect();
1475 let op_chars: Vec<char> = op.chars().collect();
1476
1477 let Some(op_pos) = Self::find_operator(&chars, &op_chars) else {
1479 return input.to_string();
1480 };
1481
1482 let left_end = op_pos;
1484 let left_start = Self::find_expr_start(&chars, left_end);
1485
1486 let right_start = op_pos + op_chars.len();
1488 let right_end = Self::find_expr_end(&chars, right_start);
1489
1490 let mut result = String::with_capacity(input.len() + 64);
1492
1493 result.extend(&chars[..left_start]);
1495
1496 result.push_str(func_name);
1498 result.push('(');
1499 let left_expr: String = chars[left_start..left_end].iter().collect();
1500 result.push_str(left_expr.trim());
1501 result.push_str(", ");
1502 let right_expr: String = chars[right_start..right_end].iter().collect();
1503 result.push_str(right_expr.trim());
1504 result.push(')');
1505
1506 result.extend(&chars[right_end..]);
1508
1509 Self::replace_vector_op(&result, op, func_name)
1511 }
1512
1513 fn find_operator(chars: &[char], op_chars: &[char]) -> Option<usize> {
1515 (0..chars.len()).find(|&i| {
1516 i + op_chars.len() <= chars.len() && chars[i..i + op_chars.len()] == op_chars[..]
1517 })
1518 }
1519
1520 fn find_expr_start(chars: &[char], end: usize) -> usize {
1522 let mut pos = end;
1523 let mut paren_depth = 0;
1524
1525 while pos > 0 && chars[pos - 1].is_whitespace() {
1527 pos -= 1;
1528 }
1529
1530 while pos > 0 {
1532 let c = chars[pos - 1];
1533 match c {
1534 ')' => {
1535 paren_depth += 1;
1536 pos -= 1;
1537 }
1538 '(' => {
1539 if paren_depth > 0 {
1540 paren_depth -= 1;
1541 pos -= 1;
1542 } else {
1543 break;
1544 }
1545 }
1546 ']' => {
1547 let mut bracket_depth = 1;
1549 pos -= 1;
1550 while pos > 0 && bracket_depth > 0 {
1551 match chars[pos - 1] {
1552 ']' => bracket_depth += 1,
1553 '[' => bracket_depth -= 1,
1554 _ => {}
1555 }
1556 pos -= 1;
1557 }
1558 }
1559 _ if c.is_alphanumeric() || c == '_' || c == '.' || c == '$' => {
1560 pos -= 1;
1561 }
1562 _ if paren_depth > 0 => {
1563 pos -= 1;
1564 }
1565 ',' | ';' | '=' | '>' | '<' | '+' | '-' | '*' | '/' => {
1566 break;
1567 }
1568 _ if c.is_whitespace() => {
1569 break;
1572 }
1573 _ => {
1574 break;
1575 }
1576 }
1577 }
1578
1579 pos
1580 }
1581
1582 fn find_expr_end(chars: &[char], start: usize) -> usize {
1584 let mut pos = start;
1585 let mut paren_depth = 0;
1586
1587 while pos < chars.len() && chars[pos].is_whitespace() {
1589 pos += 1;
1590 }
1591
1592 while pos < chars.len() {
1594 let c = chars[pos];
1595 match c {
1596 '(' => {
1597 paren_depth += 1;
1598 pos += 1;
1599 }
1600 ')' => {
1601 if paren_depth > 0 {
1602 paren_depth -= 1;
1603 pos += 1;
1604 } else {
1605 break;
1606 }
1607 }
1608 '[' => {
1609 let mut bracket_depth = 1;
1611 pos += 1;
1612 while pos < chars.len() && bracket_depth > 0 {
1613 match chars[pos] {
1614 '[' => bracket_depth += 1,
1615 ']' => bracket_depth -= 1,
1616 _ => {}
1617 }
1618 pos += 1;
1619 }
1620 }
1621 _ if c.is_alphanumeric() || c == '_' || c == '.' || c == '$' => {
1622 pos += 1;
1623 }
1624 _ if paren_depth > 0 => {
1625 pos += 1;
1626 }
1627 _ if c.is_whitespace() => {
1628 break;
1630 }
1631 ',' | ';' | '<' | '>' | '=' | '+' | '-' | '*' | '/' => {
1632 break;
1633 }
1634 _ => {
1635 break;
1636 }
1637 }
1638 }
1639
1640 pos
1641 }
1642
1643 fn restore_vector_ops(stmt: &mut Statement) {
1645 match stmt {
1646 Statement::Select(select) => Self::restore_vector_ops_in_select(select),
1647 Statement::Update(update) => {
1648 if let Some(ref mut expr) = update.where_clause {
1649 Self::restore_vector_ops_in_expr(expr);
1650 }
1651 }
1652 Statement::Delete(delete) => {
1653 if let Some(ref mut expr) = delete.where_clause {
1654 Self::restore_vector_ops_in_expr(expr);
1655 }
1656 }
1657 Statement::Explain(inner) => Self::restore_vector_ops(inner),
1658 _ => {}
1659 }
1660 }
1661
1662 fn restore_vector_ops_in_select(select: &mut SelectStatement) {
1664 for item in &mut select.projection {
1666 if let crate::ast::SelectItem::Expr { expr, .. } = item {
1667 Self::restore_vector_ops_in_expr(expr);
1668 }
1669 }
1670
1671 if let Some(ref mut expr) = select.where_clause {
1673 Self::restore_vector_ops_in_expr(expr);
1674 }
1675
1676 for order in &mut select.order_by {
1678 Self::restore_vector_ops_in_expr(&mut order.expr);
1679 }
1680
1681 if let Some(ref mut expr) = select.having {
1683 Self::restore_vector_ops_in_expr(expr);
1684 }
1685 }
1686
1687 fn restore_vector_ops_in_expr(expr: &mut Expr) {
1691 match expr {
1693 Expr::BinaryOp { left, right, .. } => {
1694 Self::restore_vector_ops_in_expr(left);
1695 Self::restore_vector_ops_in_expr(right);
1696 }
1697 Expr::UnaryOp { operand, .. } => {
1698 Self::restore_vector_ops_in_expr(operand);
1699 }
1700 Expr::Function(func) => {
1701 for arg in &mut func.args {
1703 Self::restore_vector_ops_in_expr(arg);
1704 }
1705 }
1706 Expr::Case(case) => {
1707 if let Some(ref mut operand) = case.operand {
1708 Self::restore_vector_ops_in_expr(operand);
1709 }
1710 for (cond, result) in &mut case.when_clauses {
1711 Self::restore_vector_ops_in_expr(cond);
1712 Self::restore_vector_ops_in_expr(result);
1713 }
1714 if let Some(ref mut else_result) = case.else_result {
1715 Self::restore_vector_ops_in_expr(else_result);
1716 }
1717 }
1718 Expr::HybridSearch { components, .. } => {
1719 for comp in components {
1720 Self::restore_vector_ops_in_expr(&mut comp.distance_expr);
1721 }
1722 }
1723 _ => {}
1724 }
1725
1726 Self::convert_vector_function(expr);
1728 }
1729
1730 fn convert_vector_function(expr: &mut Expr) {
1739 let replacement = if let Expr::Function(func) = expr {
1740 let func_name = func.name.name().map(|id| id.name.as_str()).unwrap_or("");
1741
1742 if func_name.eq_ignore_ascii_case("HYBRID") || func_name.eq_ignore_ascii_case("RRF") {
1744 Self::parse_hybrid_function(func, func_name.eq_ignore_ascii_case("RRF"))
1745 } else {
1746 let op = Self::parse_distance_function_name(func_name);
1748
1749 if let Some(op) = op {
1750 if func.args.len() == 2 {
1751 let mut args = std::mem::take(&mut func.args);
1752 let right = args.pop().expect("checked len");
1753 let left = args.pop().expect("checked len");
1754 Some(Expr::BinaryOp { left: Box::new(left), op, right: Box::new(right) })
1755 } else {
1756 None
1757 }
1758 } else {
1759 None
1760 }
1761 }
1762 } else {
1763 None
1764 };
1765
1766 if let Some(new_expr) = replacement {
1767 *expr = new_expr;
1768 }
1769 }
1770
1771 fn parse_distance_function_name(func_name: &str) -> Option<BinaryOp> {
1775 if func_name == "__VEC_EUCLIDEAN__" {
1777 return Some(BinaryOp::EuclideanDistance);
1778 }
1779 if func_name == "__VEC_COSINE__" {
1780 return Some(BinaryOp::CosineDistance);
1781 }
1782 if func_name == "__VEC_INNER__" {
1783 return Some(BinaryOp::InnerProduct);
1784 }
1785 if func_name == "__VEC_MAXSIM__" {
1786 return Some(BinaryOp::MaxSim);
1787 }
1788
1789 let upper = func_name.to_uppercase();
1791 match upper.as_str() {
1792 "COSINE_SIMILARITY" | "COSINE_DISTANCE" | "COS_DISTANCE" | "COS_SIM" => {
1794 Some(BinaryOp::CosineDistance)
1795 }
1796 "EUCLIDEAN_DISTANCE" | "L2_DISTANCE" | "EUCLIDEAN" | "L2" => {
1798 Some(BinaryOp::EuclideanDistance)
1799 }
1800 "INNER_PRODUCT" | "DOT_PRODUCT" | "DOT" => Some(BinaryOp::InnerProduct),
1802 "MAXSIM" | "MAX_SIM" => Some(BinaryOp::MaxSim),
1804 _ => None,
1805 }
1806 }
1807
1808 fn parse_hybrid_function(func: &mut crate::ast::FunctionCall, is_rrf: bool) -> Option<Expr> {
1813 use crate::ast::{HybridCombinationMethod, HybridSearchComponent};
1814
1815 if is_rrf {
1816 if func.args.is_empty() {
1818 return None;
1819 }
1820
1821 let components: Vec<HybridSearchComponent> = std::mem::take(&mut func.args)
1822 .into_iter()
1823 .map(|arg| HybridSearchComponent::new(arg, 1.0))
1824 .collect();
1825
1826 Some(Expr::HybridSearch { components, method: HybridCombinationMethod::RRF { k: 60 } })
1827 } else {
1828 if func.args.len() < 4 || func.args.len() % 2 != 0 {
1831 return None;
1832 }
1833
1834 let mut components = Vec::new();
1835 let args = std::mem::take(&mut func.args);
1836
1837 let mut iter = args.into_iter();
1838 while let Some(distance_expr) = iter.next() {
1839 let weight_expr = iter.next()?;
1840
1841 let weight = Self::extract_weight(&weight_expr)?;
1843
1844 components.push(HybridSearchComponent::new(distance_expr, weight));
1845 }
1846
1847 Some(Expr::HybridSearch { components, method: HybridCombinationMethod::WeightedSum })
1848 }
1849 }
1850
1851 fn extract_weight(expr: &Expr) -> Option<f64> {
1853 match expr {
1854 Expr::Literal(crate::ast::Literal::Float(f)) => Some(*f),
1855 Expr::Literal(crate::ast::Literal::Integer(i)) => Some(*i as f64),
1856 _ => None,
1857 }
1858 }
1859
1860 fn add_match_clause(stmt: &mut Statement, pattern: GraphPattern) {
1862 match stmt {
1863 Statement::Select(select) => {
1864 select.match_clause = Some(pattern);
1865 }
1866 Statement::Update(update) => {
1867 update.match_clause = Some(pattern);
1868 }
1869 Statement::Delete(delete) => {
1870 delete.match_clause = Some(pattern);
1871 }
1872 _ => {}
1873 }
1874 }
1875
1876 fn add_optional_match_clause(stmt: &mut Statement, pattern: GraphPattern) {
1880 if let Statement::Select(select) = stmt {
1881 select.optional_match_clauses.push(pattern);
1882 }
1883 }
1885
1886 fn parse_graph_pattern(input: &str) -> ParseResult<GraphPattern> {
1888 let input = input.trim();
1889 if input.is_empty() {
1890 return Err(ParseError::InvalidPattern("empty pattern".to_string()));
1891 }
1892
1893 let mut paths = Vec::with_capacity(2);
1895 let mut current = input;
1896
1897 while !current.is_empty() {
1898 let (path, remaining) = Self::parse_path_pattern(current)?;
1899 paths.push(path);
1900
1901 current = remaining.trim();
1902 if current.starts_with(',') {
1903 current = current[1..].trim();
1904 }
1905 }
1906
1907 if paths.is_empty() {
1908 return Err(ParseError::InvalidPattern("no paths in pattern".to_string()));
1909 }
1910
1911 Ok(GraphPattern::new(paths))
1912 }
1913
1914 fn parse_path_pattern(input: &str) -> ParseResult<(PathPattern, &str)> {
1916 let (start, remaining) = Self::parse_node_pattern(input)?;
1917 let mut path = PathPattern::node(start);
1918 let mut current = remaining;
1919
1920 loop {
1921 current = current.trim_start();
1922
1923 if current.starts_with('-') || current.starts_with('<') {
1925 let (edge, after_edge) = Self::parse_edge_pattern(current)?;
1926 let (node, after_node) = Self::parse_node_pattern(after_edge.trim_start())?;
1927 path = path.then(edge, node);
1928 current = after_node;
1929 } else {
1930 break;
1931 }
1932 }
1933
1934 Ok((path, current))
1935 }
1936
1937 fn parse_node_pattern(input: &str) -> ParseResult<(NodePattern, &str)> {
1939 let input = input.trim_start();
1940
1941 if !input.starts_with('(') {
1942 return Err(ParseError::InvalidPattern(format!(
1943 "expected '(' at start of node pattern, found: {}",
1944 input.chars().next().unwrap_or('?')
1945 )));
1946 }
1947
1948 let close_paren = Self::find_matching_paren(input, 0)
1949 .ok_or_else(|| ParseError::InvalidPattern("unclosed node pattern".to_string()))?;
1950
1951 let inner = &input[1..close_paren];
1952 let remaining = &input[close_paren + 1..];
1953
1954 let node = Self::parse_node_inner(inner)?;
1955 Ok((node, remaining))
1956 }
1957
1958 fn parse_node_inner(input: &str) -> ParseResult<NodePattern> {
1960 let input = input.trim();
1961
1962 if input.is_empty() {
1963 return Ok(NodePattern::anonymous());
1964 }
1965
1966 let mut variable = None;
1967 let mut labels = Vec::with_capacity(2);
1969 let mut properties = Vec::with_capacity(2);
1971
1972 let mut current = input;
1973
1974 if !current.starts_with(':') && !current.starts_with('{') {
1976 let end = current.find([':', '{', ' ']).unwrap_or(current.len());
1977 let var_name = ¤t[..end];
1978 if !var_name.is_empty() {
1979 variable = Some(Identifier::new(var_name));
1980 }
1981 current = ¤t[end..];
1982 }
1983
1984 while current.starts_with(':') {
1986 current = ¤t[1..]; let end = current.find([':', '{', ' ', ')']).unwrap_or(current.len());
1988 let label = ¤t[..end];
1989 if !label.is_empty() {
1990 labels.push(Identifier::new(label));
1991 }
1992 current = current[end..].trim_start();
1993 }
1994
1995 if current.starts_with('{') {
1997 let close_brace = current
1998 .find('}')
1999 .ok_or_else(|| ParseError::InvalidPattern("unclosed properties".to_string()))?;
2000 let props_str = ¤t[1..close_brace];
2001 properties = Self::parse_properties(props_str)?;
2002 }
2003
2004 Ok(NodePattern { variable, labels, properties })
2005 }
2006
2007 fn parse_edge_pattern(input: &str) -> ParseResult<(EdgePattern, &str)> {
2009 let input = input.trim_start();
2010
2011 let (direction, bracket_start) = if input.starts_with("<-[") {
2013 (EdgeDirection::Left, 2)
2014 } else if input.starts_with("-[") {
2015 (EdgeDirection::Right, 1) } else {
2018 return Err(ParseError::InvalidPattern(format!(
2019 "expected edge pattern, found: {}",
2020 &input[..input.len().min(10)]
2021 )));
2022 };
2023
2024 let bracket_end = input[bracket_start + 1..]
2026 .find(']')
2027 .map(|p| p + bracket_start + 1)
2028 .ok_or_else(|| ParseError::InvalidPattern("unclosed edge pattern".to_string()))?;
2029
2030 let inner = &input[bracket_start + 1..bracket_end];
2031 let after_bracket = &input[bracket_end + 1..];
2032
2033 let (actual_direction, remaining) = if let Some(rest) = after_bracket.strip_prefix("->") {
2035 (EdgeDirection::Right, rest)
2036 } else if let Some(rest) = after_bracket.strip_prefix('-') {
2037 if direction == EdgeDirection::Left {
2038 (EdgeDirection::Left, rest)
2039 } else {
2040 (EdgeDirection::Undirected, rest)
2041 }
2042 } else {
2043 return Err(ParseError::InvalidPattern("invalid edge ending".to_string()));
2044 };
2045
2046 let edge = Self::parse_edge_inner(inner, actual_direction)?;
2047 Ok((edge, remaining))
2048 }
2049
2050 fn parse_edge_inner(input: &str, direction: EdgeDirection) -> ParseResult<EdgePattern> {
2052 let input = input.trim();
2053
2054 let mut variable = None;
2055 let mut edge_types = Vec::with_capacity(2);
2057 let mut length = EdgeLength::Single;
2058 let mut properties = Vec::with_capacity(2);
2060
2061 if input.is_empty() {
2062 return Ok(EdgePattern { direction, variable, edge_types, properties, length });
2063 }
2064
2065 let mut current = input;
2066
2067 if !current.starts_with(':') && !current.starts_with('*') && !current.starts_with('{') {
2069 let end = current.find([':', '*', '{', ' ']).unwrap_or(current.len());
2070 let var_name = ¤t[..end];
2071 if !var_name.is_empty() {
2072 variable = Some(Identifier::new(var_name));
2073 }
2074 current = ¤t[end..];
2075 }
2076
2077 while current.starts_with(':') || current.starts_with('|') {
2079 current = ¤t[1..]; let end = current.find(['|', '*', '{', ' ', ']']).unwrap_or(current.len());
2081 let edge_type = ¤t[..end];
2082 if !edge_type.is_empty() {
2083 edge_types.push(Identifier::new(edge_type));
2084 }
2085 current = current[end..].trim_start();
2086 }
2087
2088 if current.starts_with('*') {
2090 current = ¤t[1..];
2091 length = Self::parse_edge_length(current)?;
2092
2093 let end = current.find(['{', ' ', ']']).unwrap_or(current.len());
2095 current = current[end..].trim_start();
2096 }
2097
2098 if current.starts_with('{') {
2100 let close_brace = current.find('}').ok_or_else(|| {
2101 ParseError::InvalidPattern("unclosed edge properties".to_string())
2102 })?;
2103 let props_str = ¤t[1..close_brace];
2104 properties = Self::parse_properties(props_str)?;
2105 }
2106
2107 Ok(EdgePattern { direction, variable, edge_types, properties, length })
2108 }
2109
2110 fn parse_edge_length(input: &str) -> ParseResult<EdgeLength> {
2112 let input = input.trim();
2113
2114 if input.is_empty()
2115 || input.starts_with('{')
2116 || input.starts_with(' ')
2117 || input.starts_with(']')
2118 {
2119 return Ok(EdgeLength::Any);
2120 }
2121
2122 if let Some(range_pos) = input.find("..") {
2124 let before = &input[..range_pos];
2125 let after_start = range_pos + 2;
2126 let after_end = input[after_start..]
2127 .find(|c: char| !c.is_ascii_digit())
2128 .map_or(input.len(), |p| after_start + p);
2129 let after = &input[after_start..after_end];
2130
2131 let min = if before.is_empty() {
2132 None
2133 } else {
2134 Some(before.parse::<u32>().map_err(|_| {
2135 ParseError::InvalidPattern(format!("invalid min in range: {before}"))
2136 })?)
2137 };
2138
2139 let max = if after.is_empty() {
2140 None
2141 } else {
2142 Some(after.parse::<u32>().map_err(|_| {
2143 ParseError::InvalidPattern(format!("invalid max in range: {after}"))
2144 })?)
2145 };
2146
2147 return Ok(EdgeLength::Range { min, max });
2148 }
2149
2150 let num_end = input.find(|c: char| !c.is_ascii_digit()).unwrap_or(input.len());
2152 let num_str = &input[..num_end];
2153
2154 if !num_str.is_empty() {
2155 let n = num_str.parse::<u32>().map_err(|_| {
2156 ParseError::InvalidPattern(format!("invalid edge length: {num_str}"))
2157 })?;
2158 return Ok(EdgeLength::Exact(n));
2159 }
2160
2161 Ok(EdgeLength::Any)
2162 }
2163
2164 fn parse_properties(input: &str) -> ParseResult<Vec<PropertyCondition>> {
2166 let input = input.trim();
2167 if input.is_empty() {
2168 return Ok(Vec::new());
2169 }
2170
2171 let estimated_count = input.matches(',').count() + 1;
2173 let mut properties = Vec::with_capacity(estimated_count);
2174
2175 for pair in input.split(',') {
2176 let pair = pair.trim();
2177 if pair.is_empty() {
2178 continue;
2179 }
2180
2181 let colon_pos = pair
2182 .find(':')
2183 .ok_or_else(|| ParseError::InvalidPattern(format!("invalid property: {pair}")))?;
2184
2185 let name = pair[..colon_pos].trim();
2186 let value_str = pair[colon_pos + 1..].trim();
2187
2188 let value = Self::parse_property_value(value_str);
2189
2190 properties.push(PropertyCondition { name: Identifier::new(name), value });
2191 }
2192
2193 Ok(properties)
2194 }
2195
2196 fn parse_property_value(input: &str) -> Expr {
2198 let input = input.trim();
2199
2200 if (input.starts_with('\'') && input.ends_with('\''))
2202 || (input.starts_with('"') && input.ends_with('"'))
2203 {
2204 let s = &input[1..input.len() - 1];
2205 return Expr::string(s);
2206 }
2207
2208 if input.eq_ignore_ascii_case("true") {
2210 return Expr::boolean(true);
2211 }
2212 if input.eq_ignore_ascii_case("false") {
2213 return Expr::boolean(false);
2214 }
2215
2216 if input.eq_ignore_ascii_case("null") {
2218 return Expr::null();
2219 }
2220
2221 if let Some(name) = input.strip_prefix('$') {
2223 if let Ok(n) = name.parse::<u32>() {
2224 return Expr::Parameter(ParameterRef::Positional(n));
2225 }
2226 return Expr::Parameter(ParameterRef::Named(name.to_string()));
2227 }
2228
2229 if let Ok(i) = input.parse::<i64>() {
2231 return Expr::integer(i);
2232 }
2233
2234 if let Ok(f) = input.parse::<f64>() {
2236 return Expr::float(f);
2237 }
2238
2239 Expr::column(QualifiedName::simple(input))
2241 }
2242
2243 fn find_matching_paren(input: &str, open_pos: usize) -> Option<usize> {
2245 let bytes = input.as_bytes();
2246 let mut depth = 0;
2247 let mut in_string = false;
2248 let mut string_char = b'"';
2249
2250 for (i, &byte) in bytes.iter().enumerate().skip(open_pos) {
2251 if in_string {
2252 if byte == string_char && (i == 0 || bytes[i - 1] != b'\\') {
2253 in_string = false;
2254 }
2255 continue;
2256 }
2257
2258 match byte {
2259 b'\'' | b'"' => {
2260 in_string = true;
2261 string_char = byte;
2262 }
2263 b'(' => depth += 1,
2264 b')' => {
2265 depth -= 1;
2266 if depth == 0 {
2267 return Some(i);
2268 }
2269 }
2270 _ => {}
2271 }
2272 }
2273
2274 None
2275 }
2276}
2277
2278pub fn parse_shortest_path(input: &str) -> ParseResult<(ShortestPathPattern, &str)> {
2290 let input = input.trim();
2291 let input_upper = input.to_uppercase();
2292
2293 let (find_all, remaining) = if input_upper.starts_with("ALL SHORTEST PATHS") {
2295 (true, input[18..].trim_start())
2296 } else if input_upper.starts_with("ALL SHORTEST PATH") {
2297 (true, input[17..].trim_start())
2299 } else if input_upper.starts_with("SHORTEST PATHS") {
2300 (true, input[14..].trim_start())
2301 } else if input_upper.starts_with("SHORTEST PATH") {
2302 (false, input[13..].trim_start())
2303 } else {
2304 return Err(ParseError::InvalidPattern(
2305 "expected SHORTEST PATH or ALL SHORTEST PATHS".to_string(),
2306 ));
2307 };
2308
2309 let (path, remaining) = ExtendedParser::parse_path_pattern(remaining)?;
2311
2312 let remaining = remaining.trim();
2314 let remaining_upper = remaining.to_uppercase();
2315
2316 let (weight, remaining) = if remaining_upper.starts_with("WEIGHTED BY") {
2317 let after_weighted = remaining[11..].trim_start();
2318 let (weight_spec, rest) = parse_weight_spec(after_weighted)?;
2319 (Some(weight_spec), rest)
2320 } else {
2321 (None, remaining)
2322 };
2323
2324 let pattern = ShortestPathPattern { path, find_all, weight };
2325
2326 Ok((pattern, remaining))
2327}
2328
2329fn parse_weight_spec(input: &str) -> ParseResult<(WeightSpec, &str)> {
2336 let input = input.trim();
2337
2338 let num_end =
2340 input.find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-').unwrap_or(input.len());
2341
2342 if num_end > 0 {
2343 let potential_num = &input[..num_end];
2344 if let Ok(value) = potential_num.parse::<f64>() {
2345 if potential_num.chars().next().is_some_and(|c| c.is_ascii_digit() || c == '-') {
2347 return Ok((WeightSpec::Constant(value), &input[num_end..]));
2348 }
2349 }
2350 }
2351
2352 let ident_end = input.find(|c: char| !c.is_alphanumeric() && c != '_').unwrap_or(input.len());
2354
2355 if ident_end == 0 {
2356 return Err(ParseError::InvalidPattern(
2357 "expected property name or number after WEIGHTED BY".to_string(),
2358 ));
2359 }
2360
2361 let name = input[..ident_end].to_string();
2362 let remaining = input[ident_end..].trim_start();
2363 let remaining_upper = remaining.to_uppercase();
2364
2365 let (default, remaining) = if remaining_upper.starts_with("DEFAULT") {
2367 let after_default = remaining[7..].trim_start();
2368
2369 let default_end = after_default
2371 .find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-')
2372 .unwrap_or(after_default.len());
2373
2374 if default_end == 0 {
2375 return Err(ParseError::InvalidPattern("expected number after DEFAULT".to_string()));
2376 }
2377
2378 let default_str = &after_default[..default_end];
2379 let default_value = default_str.parse::<f64>().map_err(|_| {
2380 ParseError::InvalidPattern(format!("invalid default value: {default_str}"))
2381 })?;
2382
2383 (Some(default_value), &after_default[default_end..])
2384 } else {
2385 (None, remaining)
2386 };
2387
2388 Ok((WeightSpec::Property { name, default }, remaining))
2389}
2390
2391pub fn parse_vector_distance(left: Expr, metric: DistanceMetric, right: Expr) -> Expr {
2398 Expr::BinaryOp {
2399 left: Box::new(left),
2400 op: match metric {
2401 DistanceMetric::Cosine => BinaryOp::CosineDistance,
2402 DistanceMetric::InnerProduct => BinaryOp::InnerProduct,
2403 DistanceMetric::Euclidean | DistanceMetric::Manhattan | DistanceMetric::Hamming => {
2405 BinaryOp::EuclideanDistance
2406 }
2407 },
2408 right: Box::new(right),
2409 }
2410}
2411
2412#[cfg(test)]
2413mod tests {
2414 use super::*;
2415 use crate::ast::HybridCombinationMethod;
2416
2417 #[test]
2418 fn parse_simple_node_pattern() {
2419 let (node, remaining) = ExtendedParser::parse_node_pattern("(p)").unwrap();
2420 assert!(remaining.is_empty());
2421 assert_eq!(node.variable.as_ref().map(|i| i.name.as_str()), Some("p"));
2422 assert!(node.labels.is_empty());
2423 }
2424
2425 #[test]
2426 fn parse_node_with_label() {
2427 let (node, _) = ExtendedParser::parse_node_pattern("(p:Person)").unwrap();
2428 assert_eq!(node.variable.as_ref().map(|i| i.name.as_str()), Some("p"));
2429 assert_eq!(node.labels.len(), 1);
2430 assert_eq!(node.labels[0].name, "Person");
2431 }
2432
2433 #[test]
2434 fn parse_node_with_multiple_labels() {
2435 let (node, _) = ExtendedParser::parse_node_pattern("(p:Person:Employee)").unwrap();
2436 assert_eq!(node.labels.len(), 2);
2437 assert_eq!(node.labels[0].name, "Person");
2438 assert_eq!(node.labels[1].name, "Employee");
2439 }
2440
2441 #[test]
2442 fn parse_anonymous_node() {
2443 let (node, _) = ExtendedParser::parse_node_pattern("()").unwrap();
2444 assert!(node.variable.is_none());
2445 assert!(node.labels.is_empty());
2446 }
2447
2448 #[test]
2449 fn parse_directed_edge() {
2450 let (edge, remaining) = ExtendedParser::parse_edge_pattern("-[:FOLLOWS]->").unwrap();
2451 assert!(remaining.is_empty());
2452 assert_eq!(edge.direction, EdgeDirection::Right);
2453 assert_eq!(edge.edge_types.len(), 1);
2454 assert_eq!(edge.edge_types[0].name, "FOLLOWS");
2455 }
2456
2457 #[test]
2458 fn parse_left_edge() {
2459 let (edge, _) = ExtendedParser::parse_edge_pattern("<-[:CREATED_BY]-").unwrap();
2460 assert_eq!(edge.direction, EdgeDirection::Left);
2461 assert_eq!(edge.edge_types[0].name, "CREATED_BY");
2462 }
2463
2464 #[test]
2465 fn parse_undirected_edge() {
2466 let (edge, _) = ExtendedParser::parse_edge_pattern("-[:KNOWS]-").unwrap();
2467 assert_eq!(edge.direction, EdgeDirection::Undirected);
2468 }
2469
2470 #[test]
2471 fn parse_edge_with_variable() {
2472 let (edge, _) = ExtendedParser::parse_edge_pattern("-[r:FOLLOWS]->").unwrap();
2473 assert_eq!(edge.variable.as_ref().map(|i| i.name.as_str()), Some("r"));
2474 }
2475
2476 #[test]
2477 fn parse_edge_with_length() {
2478 let (edge, _) = ExtendedParser::parse_edge_pattern("-[:FOLLOWS*1..3]->").unwrap();
2479 assert_eq!(edge.length, EdgeLength::Range { min: Some(1), max: Some(3) });
2480 }
2481
2482 #[test]
2483 fn parse_edge_any_length() {
2484 let (edge, _) = ExtendedParser::parse_edge_pattern("-[:PATH*]->").unwrap();
2485 assert_eq!(edge.length, EdgeLength::Any);
2486 }
2487
2488 #[test]
2489 fn parse_simple_path() {
2490 let (path, _) = ExtendedParser::parse_path_pattern("(a)-[:FOLLOWS]->(b)").unwrap();
2491 assert_eq!(path.start.variable.as_ref().map(|i| i.name.as_str()), Some("a"));
2492 assert_eq!(path.steps.len(), 1);
2493 }
2494
2495 #[test]
2496 fn parse_long_path() {
2497 let (path, _) =
2498 ExtendedParser::parse_path_pattern("(a)-[:KNOWS]->(b)-[:LIKES]->(c)").unwrap();
2499 assert_eq!(path.steps.len(), 2);
2500 }
2501
2502 #[test]
2503 fn parse_graph_pattern() {
2504 let pattern = ExtendedParser::parse_graph_pattern("(u:User)-[:FOLLOWS]->(f:User)").unwrap();
2505 assert_eq!(pattern.paths.len(), 1);
2506 }
2507
2508 #[test]
2509 fn parse_multiple_paths() {
2510 let pattern =
2511 ExtendedParser::parse_graph_pattern("(a)-[:R1]->(b), (b)-[:R2]->(c)").unwrap();
2512 assert_eq!(pattern.paths.len(), 2);
2513 }
2514
2515 #[test]
2516 fn extract_match_clause() {
2517 let (sql, patterns, optional_patterns) = ExtendedParser::extract_match_clauses(
2518 "SELECT * FROM users MATCH (u)-[:FOLLOWS]->(f) WHERE u.id = 1",
2519 )
2520 .unwrap();
2521
2522 assert!(sql.contains("SELECT * FROM users"));
2523 assert!(sql.contains("WHERE u.id = 1"));
2524 assert!(!sql.to_uppercase().contains("MATCH"));
2525 assert_eq!(patterns.len(), 1);
2526 assert!(optional_patterns.is_empty() || optional_patterns.iter().all(|v| v.is_empty()));
2528 }
2529
2530 #[test]
2531 fn parse_extended_select() {
2532 let stmts =
2533 ExtendedParser::parse("SELECT * FROM users MATCH (u)-[:FOLLOWS]->(f) WHERE u.id = 1")
2534 .unwrap();
2535
2536 assert_eq!(stmts.len(), 1);
2537 if let Statement::Select(select) = &stmts[0] {
2538 assert!(select.match_clause.is_some());
2539 assert!(select.where_clause.is_some());
2540 } else {
2541 panic!("expected SELECT");
2542 }
2543 }
2544
2545 #[test]
2546 fn preprocess_vector_ops() {
2547 let result = ExtendedParser::preprocess_vector_ops("a <-> b");
2548 assert!(result.contains("__VEC_EUCLIDEAN__"));
2549
2550 let result = ExtendedParser::preprocess_vector_ops("a <=> b");
2551 assert!(result.contains("__VEC_COSINE__"));
2552
2553 let result = ExtendedParser::preprocess_vector_ops("a <#> b");
2554 assert!(result.contains("__VEC_INNER__"));
2555
2556 let result = ExtendedParser::preprocess_vector_ops("a <##> b");
2557 assert!(result.contains("__VEC_MAXSIM__"));
2558 }
2559
2560 #[test]
2561 fn preprocess_maxsim_before_inner() {
2562 let result = ExtendedParser::preprocess_vector_ops("a <##> b");
2564 assert!(result.contains("__VEC_MAXSIM__"));
2565 assert!(!result.contains("__VEC_INNER__"));
2566 }
2567
2568 #[test]
2569 fn parse_node_with_properties() {
2570 let (node, _) =
2571 ExtendedParser::parse_node_pattern("(p:Person {name: 'Alice', age: 30})").unwrap();
2572 assert_eq!(node.properties.len(), 2);
2573 assert_eq!(node.properties[0].name.name, "name");
2574 assert_eq!(node.properties[1].name.name, "age");
2575 }
2576
2577 #[test]
2578 fn parse_edge_length_exact() {
2579 let length = ExtendedParser::parse_edge_length("3").unwrap();
2580 assert_eq!(length, EdgeLength::Exact(3));
2581 }
2582
2583 #[test]
2584 fn parse_edge_length_range() {
2585 let length = ExtendedParser::parse_edge_length("1..5").unwrap();
2586 assert_eq!(length, EdgeLength::Range { min: Some(1), max: Some(5) });
2587 }
2588
2589 #[test]
2590 fn parse_edge_length_min_only() {
2591 let length = ExtendedParser::parse_edge_length("2..").unwrap();
2592 assert_eq!(length, EdgeLength::Range { min: Some(2), max: None });
2593 }
2594
2595 #[test]
2596 fn parse_edge_length_max_only() {
2597 let length = ExtendedParser::parse_edge_length("..5").unwrap();
2598 assert_eq!(length, EdgeLength::Range { min: None, max: Some(5) });
2599 }
2600
2601 #[test]
2602 fn parse_shortest_path_unweighted() {
2603 let (sp, remaining) = parse_shortest_path("SHORTEST PATH (a)-[*]->(b)").unwrap();
2604 assert!(!sp.find_all);
2605 assert!(sp.weight.is_none());
2606 assert!(remaining.is_empty());
2607 }
2608
2609 #[test]
2610 fn parse_shortest_path_weighted() {
2611 let (sp, _) = parse_shortest_path("SHORTEST PATH (a)-[*]->(b) WEIGHTED BY cost").unwrap();
2612 assert!(!sp.find_all);
2613 assert!(sp.weight.is_some());
2614 match sp.weight.unwrap() {
2615 WeightSpec::Property { name, default } => {
2616 assert_eq!(name, "cost");
2617 assert!(default.is_none());
2618 }
2619 _ => panic!("expected Property weight spec"),
2620 }
2621 }
2622
2623 #[test]
2624 fn parse_shortest_path_weighted_with_default() {
2625 let (sp, _) =
2626 parse_shortest_path("SHORTEST PATH (a)-[*]->(b) WEIGHTED BY distance DEFAULT 1.0")
2627 .unwrap();
2628 match sp.weight.unwrap() {
2629 WeightSpec::Property { name, default } => {
2630 assert_eq!(name, "distance");
2631 assert_eq!(default, Some(1.0));
2632 }
2633 _ => panic!("expected Property weight spec"),
2634 }
2635 }
2636
2637 #[test]
2638 fn parse_all_shortest_paths() {
2639 let (sp, _) = parse_shortest_path("ALL SHORTEST PATHS (a)-[*]->(b)").unwrap();
2640 assert!(sp.find_all);
2641 assert!(sp.weight.is_none());
2642 }
2643
2644 #[test]
2645 fn parse_shortest_path_constant_weight() {
2646 let (sp, _) = parse_shortest_path("SHORTEST PATH (a)-[*]->(b) WEIGHTED BY 2.5").unwrap();
2647 match sp.weight.unwrap() {
2648 WeightSpec::Constant(v) => assert_eq!(v, 2.5),
2649 _ => panic!("expected Constant weight spec"),
2650 }
2651 }
2652
2653 #[test]
2654 fn parse_shortest_path_with_edge_type() {
2655 let (sp, _) =
2656 parse_shortest_path("SHORTEST PATH (a)-[:ROAD*]->(b) WEIGHTED BY distance").unwrap();
2657 assert_eq!(sp.path.steps.len(), 1);
2658 let (edge, _) = &sp.path.steps[0];
2659 assert_eq!(edge.edge_types.len(), 1);
2660 assert_eq!(edge.edge_types[0].name, "ROAD");
2661 }
2662
2663 #[test]
2664 fn parse_hybrid_function_basic() {
2665 let stmts = ExtendedParser::parse(
2666 "SELECT * FROM docs ORDER BY HYBRID(dense <=> $1, 0.7, sparse <#> $2, 0.3) LIMIT 10",
2667 )
2668 .unwrap();
2669 assert_eq!(stmts.len(), 1);
2670 if let Statement::Select(select) = &stmts[0] {
2671 assert_eq!(select.order_by.len(), 1);
2672 let order_expr = &*select.order_by[0].expr;
2674 assert!(matches!(order_expr, Expr::HybridSearch { .. }));
2675 if let Expr::HybridSearch { components, method } = order_expr {
2676 assert_eq!(components.len(), 2);
2677 assert!((components[0].weight - 0.7).abs() < 0.001);
2678 assert!((components[1].weight - 0.3).abs() < 0.001);
2679 assert!(matches!(method, HybridCombinationMethod::WeightedSum));
2680 }
2681 } else {
2682 panic!("Expected SELECT statement");
2683 }
2684 }
2685
2686 #[test]
2687 fn parse_rrf_function() {
2688 let stmts = ExtendedParser::parse(
2689 "SELECT * FROM docs ORDER BY RRF(dense <=> $1, sparse <#> $2) LIMIT 10",
2690 )
2691 .unwrap();
2692 assert_eq!(stmts.len(), 1);
2693 if let Statement::Select(select) = &stmts[0] {
2694 let order_expr = &*select.order_by[0].expr;
2695 if let Expr::HybridSearch { components, method } = order_expr {
2696 assert_eq!(components.len(), 2);
2697 assert!((components[0].weight - 1.0).abs() < 0.001);
2699 assert!((components[1].weight - 1.0).abs() < 0.001);
2700 assert!(matches!(method, HybridCombinationMethod::RRF { k: 60 }));
2701 } else {
2702 panic!("Expected HybridSearch expression");
2703 }
2704 } else {
2705 panic!("Expected SELECT statement");
2706 }
2707 }
2708
2709 #[test]
2710 fn parse_hybrid_function_preserves_vector_ops() {
2711 let stmts = ExtendedParser::parse(
2712 "SELECT * FROM docs ORDER BY HYBRID(embedding <=> $q1, 0.5, sparse <#> $q2, 0.5)",
2713 )
2714 .unwrap();
2715 if let Statement::Select(select) = &stmts[0] {
2716 if let Expr::HybridSearch { components, .. } = &*select.order_by[0].expr {
2717 if let Expr::BinaryOp { op: BinaryOp::CosineDistance, .. } =
2719 components[0].distance_expr.as_ref()
2720 {
2721 } else {
2723 panic!("Expected CosineDistance operator for first component");
2724 }
2725 if let Expr::BinaryOp { op: BinaryOp::InnerProduct, .. } =
2727 components[1].distance_expr.as_ref()
2728 {
2729 } else {
2731 panic!("Expected InnerProduct operator for second component");
2732 }
2733 } else {
2734 panic!("Expected HybridSearch expression");
2735 }
2736 }
2737 }
2738
2739 #[test]
2742 fn parse_create_collection_basic() {
2743 let stmts = ExtendedParser::parse(
2744 "CREATE COLLECTION documents (dense VECTOR(768) USING hnsw WITH (distance = 'cosine'))",
2745 )
2746 .unwrap();
2747 assert_eq!(stmts.len(), 1);
2748 if let Statement::CreateCollection(create) = &stmts[0] {
2749 assert_eq!(create.name.name, "documents");
2750 assert!(!create.if_not_exists);
2751 assert_eq!(create.vectors.len(), 1);
2752 assert_eq!(create.vectors[0].name.name, "dense");
2753 assert!(matches!(
2754 create.vectors[0].vector_type,
2755 VectorTypeDef::Vector { dimension: 768 }
2756 ));
2757 assert_eq!(create.vectors[0].using, Some("hnsw".to_string()));
2758 assert_eq!(create.vectors[0].with_options.len(), 1);
2759 assert_eq!(
2760 create.vectors[0].with_options[0],
2761 ("distance".to_string(), "cosine".to_string())
2762 );
2763 } else {
2764 panic!("Expected CreateCollection statement");
2765 }
2766 }
2767
2768 #[test]
2769 fn parse_create_collection_if_not_exists() {
2770 let stmts =
2771 ExtendedParser::parse("CREATE IF NOT EXISTS COLLECTION docs (v VECTOR(128))").unwrap();
2772 assert_eq!(stmts.len(), 1);
2773 if let Statement::CreateCollection(create) = &stmts[0] {
2774 assert!(create.if_not_exists);
2775 assert_eq!(create.name.name, "docs");
2776 } else {
2777 panic!("Expected CreateCollection statement");
2778 }
2779 }
2780
2781 #[test]
2782 fn parse_create_collection_multiple_vectors() {
2783 let stmts = ExtendedParser::parse(
2784 "CREATE COLLECTION documents (
2785 dense VECTOR(768) USING hnsw WITH (distance = 'cosine'),
2786 sparse SPARSE_VECTOR USING inverted,
2787 colbert MULTI_VECTOR(128) USING hnsw WITH (aggregation = 'maxsim')
2788 )",
2789 )
2790 .unwrap();
2791 assert_eq!(stmts.len(), 1);
2792 if let Statement::CreateCollection(create) = &stmts[0] {
2793 assert_eq!(create.vectors.len(), 3);
2794
2795 assert_eq!(create.vectors[0].name.name, "dense");
2797 assert!(matches!(
2798 create.vectors[0].vector_type,
2799 VectorTypeDef::Vector { dimension: 768 }
2800 ));
2801
2802 assert_eq!(create.vectors[1].name.name, "sparse");
2804 assert!(matches!(
2805 create.vectors[1].vector_type,
2806 VectorTypeDef::SparseVector { max_dimension: None }
2807 ));
2808 assert_eq!(create.vectors[1].using, Some("inverted".to_string()));
2809
2810 assert_eq!(create.vectors[2].name.name, "colbert");
2812 assert!(matches!(
2813 create.vectors[2].vector_type,
2814 VectorTypeDef::MultiVector { token_dim: 128 }
2815 ));
2816 } else {
2817 panic!("Expected CreateCollection statement");
2818 }
2819 }
2820
2821 #[test]
2822 fn parse_create_collection_sparse_with_max_dim() {
2823 let stmts = ExtendedParser::parse(
2824 "CREATE COLLECTION docs (keywords SPARSE_VECTOR(30522) USING inverted)",
2825 )
2826 .unwrap();
2827 assert_eq!(stmts.len(), 1);
2828 if let Statement::CreateCollection(create) = &stmts[0] {
2829 assert!(matches!(
2830 create.vectors[0].vector_type,
2831 VectorTypeDef::SparseVector { max_dimension: Some(30522) }
2832 ));
2833 } else {
2834 panic!("Expected CreateCollection statement");
2835 }
2836 }
2837
2838 #[test]
2839 fn parse_create_collection_binary_vector() {
2840 let stmts =
2841 ExtendedParser::parse("CREATE COLLECTION docs (hash BINARY_VECTOR(1024))").unwrap();
2842 assert_eq!(stmts.len(), 1);
2843 if let Statement::CreateCollection(create) = &stmts[0] {
2844 assert!(matches!(
2845 create.vectors[0].vector_type,
2846 VectorTypeDef::BinaryVector { bits: 1024 }
2847 ));
2848 } else {
2849 panic!("Expected CreateCollection statement");
2850 }
2851 }
2852
2853 #[test]
2854 fn parse_create_collection_multiple_with_options() {
2855 let stmts = ExtendedParser::parse(
2856 "CREATE COLLECTION docs (vec VECTOR(768) USING hnsw WITH (distance = 'euclidean', m = 16, ef_construction = 200))",
2857 )
2858 .unwrap();
2859 assert_eq!(stmts.len(), 1);
2860 if let Statement::CreateCollection(create) = &stmts[0] {
2861 assert_eq!(create.vectors[0].with_options.len(), 3);
2862 assert!(create.vectors[0]
2863 .with_options
2864 .contains(&("distance".to_string(), "euclidean".to_string())));
2865 assert!(create.vectors[0].with_options.contains(&("m".to_string(), "16".to_string())));
2866 assert!(create.vectors[0]
2867 .with_options
2868 .contains(&("ef_construction".to_string(), "200".to_string())));
2869 } else {
2870 panic!("Expected CreateCollection statement");
2871 }
2872 }
2873
2874 #[test]
2875 fn parse_create_collection_flat_index() {
2876 let stmts =
2877 ExtendedParser::parse("CREATE COLLECTION docs (vec VECTOR(768) USING flat)").unwrap();
2878 assert_eq!(stmts.len(), 1);
2879 if let Statement::CreateCollection(create) = &stmts[0] {
2880 assert_eq!(create.vectors[0].using, Some("flat".to_string()));
2881 } else {
2882 panic!("Expected CreateCollection statement");
2883 }
2884 }
2885
2886 #[test]
2887 fn parse_create_collection_no_using() {
2888 let stmts = ExtendedParser::parse("CREATE COLLECTION docs (vec VECTOR(768))").unwrap();
2889 assert_eq!(stmts.len(), 1);
2890 if let Statement::CreateCollection(create) = &stmts[0] {
2891 assert!(create.vectors[0].using.is_none());
2892 } else {
2893 panic!("Expected CreateCollection statement");
2894 }
2895 }
2896
2897 #[test]
2898 fn is_create_collection_detection() {
2899 assert!(ExtendedParser::is_create_collection("CREATE COLLECTION foo (v VECTOR(10))"));
2900 assert!(ExtendedParser::is_create_collection(" CREATE COLLECTION foo (v VECTOR(10)) "));
2901 assert!(ExtendedParser::is_create_collection(
2902 "CREATE IF NOT EXISTS COLLECTION foo (v VECTOR(10))"
2903 ));
2904 assert!(!ExtendedParser::is_create_collection("CREATE TABLE foo (id INT)"));
2905 assert!(!ExtendedParser::is_create_collection("SELECT * FROM foo"));
2906 }
2907
2908 #[test]
2911 fn parse_create_collection_new_syntax() {
2912 let stmts = ExtendedParser::parse(
2913 "CREATE COLLECTION documents (
2914 title TEXT,
2915 content TEXT,
2916 VECTOR text_embedding DIMENSION 1536,
2917 VECTOR image_embedding DIMENSION 512
2918 )",
2919 )
2920 .unwrap();
2921 assert_eq!(stmts.len(), 1);
2922 if let Statement::CreateCollection(create) = &stmts[0] {
2923 assert_eq!(create.name.name, "documents");
2924 assert_eq!(create.payload_fields.len(), 2);
2926 assert_eq!(create.payload_fields[0].name.name, "title");
2927 assert!(matches!(create.payload_fields[0].data_type, DataType::Text));
2928 assert_eq!(create.payload_fields[1].name.name, "content");
2929 assert!(matches!(create.payload_fields[1].data_type, DataType::Text));
2930 assert_eq!(create.vectors.len(), 2);
2932 assert_eq!(create.vectors[0].name.name, "text_embedding");
2933 assert!(matches!(
2934 create.vectors[0].vector_type,
2935 VectorTypeDef::Vector { dimension: 1536 }
2936 ));
2937 assert_eq!(create.vectors[1].name.name, "image_embedding");
2938 assert!(matches!(
2939 create.vectors[1].vector_type,
2940 VectorTypeDef::Vector { dimension: 512 }
2941 ));
2942 } else {
2943 panic!("Expected CreateCollection statement");
2944 }
2945 }
2946
2947 #[test]
2948 fn parse_create_collection_mixed_syntax() {
2949 let stmts = ExtendedParser::parse(
2950 "CREATE COLLECTION documents (
2951 title TEXT,
2952 category INTEGER INDEXED,
2953 dense VECTOR(768) USING hnsw,
2954 VECTOR summary_embedding DIMENSION 1536 USING hnsw WITH (distance = 'cosine')
2955 )",
2956 )
2957 .unwrap();
2958 assert_eq!(stmts.len(), 1);
2959 if let Statement::CreateCollection(create) = &stmts[0] {
2960 assert_eq!(create.payload_fields.len(), 2);
2962 assert_eq!(create.payload_fields[0].name.name, "title");
2963 assert!(!create.payload_fields[0].indexed);
2964 assert_eq!(create.payload_fields[1].name.name, "category");
2965 assert!(create.payload_fields[1].indexed);
2966 assert_eq!(create.vectors.len(), 2);
2968 assert_eq!(create.vectors[0].name.name, "dense");
2969 assert_eq!(create.vectors[1].name.name, "summary_embedding");
2970 assert!(matches!(
2971 create.vectors[1].vector_type,
2972 VectorTypeDef::Vector { dimension: 1536 }
2973 ));
2974 assert_eq!(create.vectors[1].using, Some("hnsw".to_string()));
2975 } else {
2976 panic!("Expected CreateCollection statement");
2977 }
2978 }
2979
2980 #[test]
2981 fn parse_create_collection_various_types() {
2982 let stmts = ExtendedParser::parse(
2983 "CREATE COLLECTION items (
2984 name VARCHAR(255),
2985 count INTEGER,
2986 price FLOAT,
2987 active BOOLEAN,
2988 metadata JSON,
2989 created_at TIMESTAMP,
2990 id UUID,
2991 VECTOR embedding DIMENSION 768
2992 )",
2993 )
2994 .unwrap();
2995 assert_eq!(stmts.len(), 1);
2996 if let Statement::CreateCollection(create) = &stmts[0] {
2997 assert_eq!(create.payload_fields.len(), 7);
2998 assert!(matches!(create.payload_fields[0].data_type, DataType::Varchar(Some(255))));
2999 assert!(matches!(create.payload_fields[1].data_type, DataType::Integer));
3000 assert!(matches!(create.payload_fields[2].data_type, DataType::Real));
3001 assert!(matches!(create.payload_fields[3].data_type, DataType::Boolean));
3002 assert!(matches!(create.payload_fields[4].data_type, DataType::Json));
3003 assert!(matches!(create.payload_fields[5].data_type, DataType::Timestamp));
3004 assert!(matches!(create.payload_fields[6].data_type, DataType::Uuid));
3005 assert_eq!(create.vectors.len(), 1);
3006 } else {
3007 panic!("Expected CreateCollection statement");
3008 }
3009 }
3010
3011 #[test]
3014 fn parse_cosine_similarity_function() {
3015 let stmts = ExtendedParser::parse(
3016 "SELECT * FROM documents ORDER BY COSINE_SIMILARITY(text_embedding, $query) LIMIT 10",
3017 )
3018 .unwrap();
3019 assert_eq!(stmts.len(), 1);
3020 if let Statement::Select(select) = &stmts[0] {
3021 assert_eq!(select.order_by.len(), 1);
3022 if let Expr::BinaryOp { op: BinaryOp::CosineDistance, left, right } =
3023 &*select.order_by[0].expr
3024 {
3025 if let Expr::Column(col) = left.as_ref() {
3027 assert_eq!(col.name().map(|id| id.name.as_str()), Some("text_embedding"));
3028 } else {
3029 panic!("Expected column reference for left operand");
3030 }
3031 assert!(matches!(right.as_ref(), Expr::Parameter(_)));
3033 } else {
3034 panic!("Expected BinaryOp with CosineDistance");
3035 }
3036 } else {
3037 panic!("Expected SELECT statement");
3038 }
3039 }
3040
3041 #[test]
3042 fn parse_euclidean_distance_function() {
3043 let stmts = ExtendedParser::parse(
3044 "SELECT * FROM docs ORDER BY EUCLIDEAN_DISTANCE(embedding, $1) ASC LIMIT 5",
3045 )
3046 .unwrap();
3047 assert_eq!(stmts.len(), 1);
3048 if let Statement::Select(select) = &stmts[0] {
3049 if let Expr::BinaryOp { op: BinaryOp::EuclideanDistance, .. } =
3050 &*select.order_by[0].expr
3051 {
3052 } else {
3054 panic!("Expected BinaryOp with EuclideanDistance");
3055 }
3056 } else {
3057 panic!("Expected SELECT statement");
3058 }
3059 }
3060
3061 #[test]
3062 fn parse_l2_distance_alias() {
3063 let stmts =
3064 ExtendedParser::parse("SELECT * FROM docs ORDER BY L2_DISTANCE(vec, $q) LIMIT 10")
3065 .unwrap();
3066 assert_eq!(stmts.len(), 1);
3067 if let Statement::Select(select) = &stmts[0] {
3068 assert!(matches!(
3069 &*select.order_by[0].expr,
3070 Expr::BinaryOp { op: BinaryOp::EuclideanDistance, .. }
3071 ));
3072 } else {
3073 panic!("Expected SELECT statement");
3074 }
3075 }
3076
3077 #[test]
3078 fn parse_inner_product_function() {
3079 let stmts = ExtendedParser::parse(
3080 "SELECT * FROM docs ORDER BY INNER_PRODUCT(vec, $q) DESC LIMIT 10",
3081 )
3082 .unwrap();
3083 assert_eq!(stmts.len(), 1);
3084 if let Statement::Select(select) = &stmts[0] {
3085 assert!(matches!(
3086 &*select.order_by[0].expr,
3087 Expr::BinaryOp { op: BinaryOp::InnerProduct, .. }
3088 ));
3089 } else {
3090 panic!("Expected SELECT statement");
3091 }
3092 }
3093
3094 #[test]
3095 fn parse_dot_product_alias() {
3096 let stmts =
3097 ExtendedParser::parse("SELECT * FROM docs ORDER BY DOT_PRODUCT(vec, $q) DESC LIMIT 10")
3098 .unwrap();
3099 assert_eq!(stmts.len(), 1);
3100 if let Statement::Select(select) = &stmts[0] {
3101 assert!(matches!(
3102 &*select.order_by[0].expr,
3103 Expr::BinaryOp { op: BinaryOp::InnerProduct, .. }
3104 ));
3105 } else {
3106 panic!("Expected SELECT statement");
3107 }
3108 }
3109
3110 #[test]
3111 fn parse_maxsim_function() {
3112 let stmts = ExtendedParser::parse(
3113 "SELECT * FROM docs ORDER BY MAXSIM(colbert_vec, $q) DESC LIMIT 10",
3114 )
3115 .unwrap();
3116 assert_eq!(stmts.len(), 1);
3117 if let Statement::Select(select) = &stmts[0] {
3118 assert!(matches!(
3119 &*select.order_by[0].expr,
3120 Expr::BinaryOp { op: BinaryOp::MaxSim, .. }
3121 ));
3122 } else {
3123 panic!("Expected SELECT statement");
3124 }
3125 }
3126
3127 #[test]
3128 fn parse_distance_function_case_insensitive() {
3129 let stmts = ExtendedParser::parse(
3131 "SELECT * FROM docs ORDER BY cosine_similarity(vec, $q) LIMIT 10",
3132 )
3133 .unwrap();
3134 assert_eq!(stmts.len(), 1);
3135 if let Statement::Select(select) = &stmts[0] {
3136 assert!(matches!(
3137 &*select.order_by[0].expr,
3138 Expr::BinaryOp { op: BinaryOp::CosineDistance, .. }
3139 ));
3140 } else {
3141 panic!("Expected SELECT statement");
3142 }
3143 }
3144
3145 #[test]
3146 fn parse_distance_function_name_mapping() {
3147 assert!(matches!(
3149 ExtendedParser::parse_distance_function_name("COSINE_SIMILARITY"),
3150 Some(BinaryOp::CosineDistance)
3151 ));
3152 assert!(matches!(
3153 ExtendedParser::parse_distance_function_name("COS_DISTANCE"),
3154 Some(BinaryOp::CosineDistance)
3155 ));
3156 assert!(matches!(
3157 ExtendedParser::parse_distance_function_name("EUCLIDEAN_DISTANCE"),
3158 Some(BinaryOp::EuclideanDistance)
3159 ));
3160 assert!(matches!(
3161 ExtendedParser::parse_distance_function_name("L2"),
3162 Some(BinaryOp::EuclideanDistance)
3163 ));
3164 assert!(matches!(
3165 ExtendedParser::parse_distance_function_name("INNER_PRODUCT"),
3166 Some(BinaryOp::InnerProduct)
3167 ));
3168 assert!(matches!(
3169 ExtendedParser::parse_distance_function_name("DOT"),
3170 Some(BinaryOp::InnerProduct)
3171 ));
3172 assert!(matches!(
3173 ExtendedParser::parse_distance_function_name("MAXSIM"),
3174 Some(BinaryOp::MaxSim)
3175 ));
3176 assert!(ExtendedParser::parse_distance_function_name("UNKNOWN").is_none());
3177 }
3178
3179 #[test]
3184 fn extract_optional_match_clause() {
3185 let (sql, patterns, optional_patterns) = ExtendedParser::extract_match_clauses(
3186 "SELECT u.name, p.title \
3187 FROM users \
3188 MATCH (u:User) \
3189 OPTIONAL MATCH (u)-[:LIKES]->(p:Post) \
3190 WHERE u.status = 'active'",
3191 )
3192 .unwrap();
3193
3194 assert!(!sql.to_uppercase().contains("MATCH"));
3196 assert!(!sql.to_uppercase().contains("OPTIONAL"));
3197
3198 assert_eq!(patterns.len(), 1);
3200
3201 assert_eq!(optional_patterns.len(), 1);
3203 assert_eq!(optional_patterns[0].len(), 1);
3204 }
3205
3206 #[test]
3207 fn extract_multiple_optional_match_clauses() {
3208 let (sql, patterns, optional_patterns) = ExtendedParser::extract_match_clauses(
3209 "SELECT u.name, p.title, c.text \
3210 FROM entities \
3211 MATCH (u:User) \
3212 OPTIONAL MATCH (u)-[:LIKES]->(p:Post) \
3213 OPTIONAL MATCH (u)-[:WROTE]->(c:Comment) \
3214 WHERE u.active = true",
3215 )
3216 .unwrap();
3217
3218 assert!(!sql.to_uppercase().contains("MATCH"));
3219 assert_eq!(patterns.len(), 1);
3220 assert_eq!(optional_patterns.len(), 1);
3221 assert_eq!(optional_patterns[0].len(), 2);
3223 }
3224
3225 #[test]
3226 fn parse_optional_match_in_select() {
3227 let stmts = ExtendedParser::parse(
3228 "SELECT u.name, p.title \
3229 FROM users \
3230 MATCH (u:User) \
3231 OPTIONAL MATCH (u)-[:LIKES]->(p:Post) \
3232 WHERE u.status = 'active'",
3233 )
3234 .unwrap();
3235
3236 assert_eq!(stmts.len(), 1);
3237 if let Statement::Select(select) = &stmts[0] {
3238 assert!(select.match_clause.is_some());
3240 assert_eq!(select.optional_match_clauses.len(), 1);
3242 assert!(select.where_clause.is_some());
3244 } else {
3245 panic!("Expected SELECT statement");
3246 }
3247 }
3248
3249 #[test]
3250 fn parse_optional_match_pattern_structure() {
3251 let stmts = ExtendedParser::parse(
3252 "SELECT * FROM entities MATCH (u:User) OPTIONAL MATCH (u)-[:FOLLOWS]->(f:User)",
3253 )
3254 .unwrap();
3255
3256 if let Statement::Select(select) = &stmts[0] {
3257 let optional = &select.optional_match_clauses[0];
3258 assert_eq!(optional.paths.len(), 1);
3260 let path = &optional.paths[0];
3261 assert_eq!(path.start.variable.as_ref().map(|v| v.name.as_str()), Some("u"));
3263 assert_eq!(path.steps.len(), 1);
3265 let (edge, node) = &path.steps[0];
3267 assert_eq!(edge.edge_types[0].name, "FOLLOWS");
3268 assert_eq!(node.variable.as_ref().map(|v| v.name.as_str()), Some("f"));
3270 assert_eq!(node.labels[0].name, "User");
3271 } else {
3272 panic!("Expected SELECT statement");
3273 }
3274 }
3275
3276 #[test]
3277 fn find_optional_match_keyword() {
3278 assert_eq!(ExtendedParser::find_optional_match_keyword("OPTIONAL MATCH (a)"), Some(0));
3280
3281 assert_eq!(ExtendedParser::find_optional_match_keyword(" OPTIONAL MATCH (a)"), Some(2));
3283
3284 assert_eq!(ExtendedParser::find_optional_match_keyword("MATCH (a)"), None);
3286
3287 assert_eq!(ExtendedParser::find_optional_match_keyword("OPTIONAL something else"), None);
3289
3290 assert_eq!(ExtendedParser::find_optional_match_keyword("optional match (a)"), Some(0));
3292 }
3293
3294 #[test]
3295 fn find_match_skips_optional_match() {
3296 let input = "OPTIONAL MATCH (a) MATCH (b)";
3298 let pos = ExtendedParser::find_match_keyword(input);
3299 assert_eq!(pos, Some(19)); assert_eq!(ExtendedParser::find_match_keyword("OPTIONAL MATCH (a)"), None);
3304 }
3305
3306 #[test]
3307 fn optional_match_order_of_clauses() {
3308 let (_, patterns, optional_patterns) = ExtendedParser::extract_match_clauses(
3310 "SELECT * FROM entities MATCH (a) OPTIONAL MATCH (b)",
3311 )
3312 .unwrap();
3313
3314 assert_eq!(patterns.len(), 1);
3315 assert_eq!(optional_patterns.len(), 1);
3316 assert_eq!(optional_patterns[0].len(), 1);
3317 }
3318}