use super::error::SqlError;
use crate::query::ir::{QueryOp, TraversalDepth};
const DEFAULT_MAX_DEPTH: usize = 10;
#[derive(Debug, Clone)]
pub struct GraphPattern {
#[allow(dead_code)]
pub source_alias: String,
#[allow(dead_code)]
pub target_alias: String,
pub edge_type: Option<String>,
pub direction: PatternDirection,
pub depth: TraversalDepth,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PatternDirection {
Outgoing,
Incoming,
Both,
}
#[derive(Debug)]
pub struct ExtractedMatch {
pub cleaned_sql: String,
pub patterns: Vec<GraphPattern>,
}
impl GraphPattern {
pub fn to_query_op(&self) -> QueryOp {
match self.direction {
PatternDirection::Outgoing => QueryOp::TraverseOut {
label: self.edge_type.clone(),
depth: self.depth,
},
PatternDirection::Incoming => QueryOp::TraverseIn {
label: self.edge_type.clone(),
depth: self.depth,
},
PatternDirection::Both => QueryOp::TraverseBoth {
label: self.edge_type.clone(),
depth: self.depth,
},
}
}
}
pub fn extract_match_clauses(sql: &str) -> Result<ExtractedMatch, SqlError> {
let mut cleaned = sql.to_string();
let mut patterns = Vec::new();
while let Some(match_pos) = find_keyword_outside_strings(&cleaned, "MATCH") {
let after_match = match_pos + "MATCH".len();
let match_end = find_match_end(&cleaned, after_match);
let pattern_text = cleaned[after_match..match_end].trim().to_string();
if pattern_text.is_empty() {
return Err(SqlError::ParseError("Empty MATCH pattern".to_string()));
}
let parsed = parse_graph_patterns(&pattern_text)?;
patterns.extend(parsed);
cleaned = format!(
"{} {}",
cleaned[..match_pos].trim_end(),
cleaned[match_end..].trim_start()
);
}
Ok(ExtractedMatch {
cleaned_sql: cleaned,
patterns,
})
}
fn find_keyword_outside_strings(sql: &str, keyword: &str) -> Option<usize> {
let sql_upper = sql.to_uppercase();
let keyword_upper = keyword.to_uppercase();
let keyword_len = keyword_upper.chars().count();
let mut i = 0;
let chars: Vec<char> = sql.chars().collect();
let upper_chars: Vec<char> = sql_upper.chars().collect();
while i < chars.len() {
if chars[i] == '\'' {
i += 1;
while i < chars.len() {
if chars[i] == '\'' {
if i + 1 < chars.len() && chars[i + 1] == '\'' {
i += 2;
} else {
i += 1;
break;
}
} else {
i += 1;
}
}
continue;
}
if i + keyword_len <= upper_chars.len() {
let candidate: String = upper_chars[i..i + keyword_len].iter().collect();
if candidate == keyword_upper {
let before_ok = i == 0 || !chars[i - 1].is_alphanumeric() && chars[i - 1] != '_';
let after_ok = i + keyword_len >= chars.len()
|| !chars[i + keyword_len].is_alphanumeric() && chars[i + keyword_len] != '_';
if before_ok && after_ok {
let byte_offset: usize = sql.chars().take(i).map(|c| c.len_utf8()).sum();
return Some(byte_offset);
}
}
}
i += 1;
}
None
}
fn find_match_end(sql: &str, start: usize) -> usize {
let remainder = &sql[start..];
let keywords = [
"WHERE", "ORDER", "LIMIT", "OFFSET", "GROUP", "HAVING", "MATCH",
];
let mut earliest: Option<usize> = None;
for kw in &keywords {
if let Some(pos) = find_keyword_outside_strings(remainder, kw) {
let absolute = start + pos;
if earliest.is_none_or(|e| absolute < e) {
earliest = Some(absolute);
}
}
}
earliest.unwrap_or(sql.len())
}
fn parse_graph_patterns(text: &str) -> Result<Vec<GraphPattern>, SqlError> {
let text = text.trim();
if text.is_empty() {
return Err(SqlError::ParseError("Empty MATCH pattern".to_string()));
}
let pattern = parse_single_pattern(text)?;
Ok(vec![pattern])
}
fn parse_single_pattern(text: &str) -> Result<GraphPattern, SqlError> {
let bracket_start = text.find('[').ok_or_else(|| {
SqlError::ParseError(format!("Invalid MATCH pattern: missing '[' in '{}'", text))
})?;
let bracket_end = text.find(']').ok_or_else(|| {
SqlError::ParseError(format!("Invalid MATCH pattern: missing ']' in '{}'", text))
})?;
if bracket_end <= bracket_start {
return Err(SqlError::ParseError(format!(
"Invalid MATCH pattern: ']' before '[' in '{}'",
text
)));
}
let before_bracket = &text[..bracket_start];
let after_bracket = &text[bracket_end + 1..];
let direction = determine_direction(before_bracket, after_bracket)?;
let source_alias = extract_alias_from_start(text)?;
let target_alias = extract_alias_from_end(text)?;
let bracket_content = &text[bracket_start + 1..bracket_end];
let (edge_type, depth) = parse_edge_spec(bracket_content)?;
Ok(GraphPattern {
source_alias,
target_alias,
edge_type,
direction,
depth,
})
}
fn determine_direction(
before_bracket: &str,
after_bracket: &str,
) -> Result<PatternDirection, SqlError> {
let before = before_bracket.trim_end();
let after = after_bracket.trim_start();
if before.ends_with("<-") {
return Ok(PatternDirection::Incoming);
}
if after.starts_with("->") {
return Ok(PatternDirection::Outgoing);
}
if after.starts_with('-') {
return Ok(PatternDirection::Both);
}
Err(SqlError::ParseError(format!(
"Cannot determine direction in MATCH pattern: before='{}', after='{}'",
before_bracket, after_bracket
)))
}
fn extract_alias_from_start(text: &str) -> Result<String, SqlError> {
let open = text.find('(').ok_or_else(|| {
SqlError::ParseError(format!(
"Invalid MATCH pattern: missing '(' for source alias in '{}'",
text
))
})?;
let close = text.find(')').ok_or_else(|| {
SqlError::ParseError(format!(
"Invalid MATCH pattern: missing ')' for source alias in '{}'",
text
))
})?;
if close <= open {
return Err(SqlError::ParseError(format!(
"Invalid MATCH pattern: ')' before '(' in '{}'",
text
)));
}
let alias = text[open + 1..close].trim().to_string();
if alias.is_empty() {
return Err(SqlError::ParseError(
"Empty source alias in MATCH pattern".to_string(),
));
}
Ok(alias)
}
fn extract_alias_from_end(text: &str) -> Result<String, SqlError> {
let open = text.rfind('(').ok_or_else(|| {
SqlError::ParseError(format!(
"Invalid MATCH pattern: missing '(' for target alias in '{}'",
text
))
})?;
let close = text.rfind(')').ok_or_else(|| {
SqlError::ParseError(format!(
"Invalid MATCH pattern: missing ')' for target alias in '{}'",
text
))
})?;
if close <= open {
return Err(SqlError::ParseError(format!(
"Invalid MATCH pattern: ')' before '(' for target in '{}'",
text
)));
}
let alias = text[open + 1..close].trim().to_string();
if alias.is_empty() {
return Err(SqlError::ParseError(
"Empty target alias in MATCH pattern".to_string(),
));
}
Ok(alias)
}
fn parse_edge_spec(content: &str) -> Result<(Option<String>, TraversalDepth), SqlError> {
let content = content.trim();
if content.is_empty() {
return Ok((None, TraversalDepth::Exact(1)));
}
if !content.starts_with(':') {
return Err(SqlError::ParseError(format!(
"Edge type must start with ':' in '[{}]'",
content
)));
}
let after_colon = &content[1..];
if let Some(star_pos) = after_colon.find('*') {
let edge_type = after_colon[..star_pos].trim().to_string();
let depth_spec = &after_colon[star_pos + 1..];
let edge_type = if edge_type.is_empty() {
None
} else {
Some(edge_type)
};
let depth = parse_depth_spec(depth_spec)?;
Ok((edge_type, depth))
} else {
let edge_type = after_colon.trim().to_string();
let edge_type = if edge_type.is_empty() {
None
} else {
Some(edge_type)
};
Ok((edge_type, TraversalDepth::Exact(1)))
}
}
fn parse_depth_spec(spec: &str) -> Result<TraversalDepth, SqlError> {
let spec = spec.trim();
if spec.is_empty() {
return Ok(TraversalDepth::Variable);
}
if let Some(dot_pos) = spec.find("..") {
let min_str = &spec[..dot_pos];
let max_str = &spec[dot_pos + 2..];
let min = if min_str.is_empty() {
1 } else {
min_str.parse::<usize>().map_err(|_| {
SqlError::ParseError(format!(
"Invalid minimum depth '{}' in traversal spec",
min_str
))
})?
};
let max = if max_str.is_empty() {
DEFAULT_MAX_DEPTH } else {
max_str.parse::<usize>().map_err(|_| {
SqlError::ParseError(format!(
"Invalid maximum depth '{}' in traversal spec",
max_str
))
})?
};
if min > max {
return Err(SqlError::ParseError(format!(
"Invalid depth range: min ({}) > max ({})",
min, max
)));
}
Ok(TraversalDepth::Range { min, max })
} else {
let n = spec.parse::<usize>().map_err(|_| {
SqlError::ParseError(format!("Invalid depth '{}' in traversal spec", spec))
})?;
Ok(TraversalDepth::Exact(n))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_no_match() {
let result = extract_match_clauses("SELECT * FROM nodes WHERE x = 1").unwrap();
assert!(result.patterns.is_empty());
assert_eq!(result.cleaned_sql, "SELECT * FROM nodes WHERE x = 1");
}
#[test]
fn test_extract_simple_outgoing() {
let result = extract_match_clauses(
"SELECT * FROM nodes AS a MATCH (a)-[:KNOWS]->(b) WHERE a.name = 'Alice'",
)
.unwrap();
assert_eq!(result.patterns.len(), 1);
assert_eq!(result.patterns[0].source_alias, "a");
assert_eq!(result.patterns[0].target_alias, "b");
assert_eq!(result.patterns[0].edge_type, Some("KNOWS".to_string()));
assert_eq!(result.patterns[0].direction, PatternDirection::Outgoing);
assert_eq!(result.patterns[0].depth, TraversalDepth::Exact(1));
assert!(!result.cleaned_sql.to_uppercase().contains("MATCH"));
assert!(result.cleaned_sql.contains("WHERE"));
}
#[test]
fn test_extract_incoming() {
let result = extract_match_clauses(
"SELECT * FROM nodes AS child MATCH (parent)<-[:PARENT_OF]-(child)",
)
.unwrap();
assert_eq!(result.patterns.len(), 1);
assert_eq!(result.patterns[0].direction, PatternDirection::Incoming);
assert_eq!(result.patterns[0].edge_type, Some("PARENT_OF".to_string()));
}
#[test]
fn test_extract_bidirectional() {
let result =
extract_match_clauses("SELECT * FROM nodes AS n MATCH (n)-[:RELATED]-(related)")
.unwrap();
assert_eq!(result.patterns.len(), 1);
assert_eq!(result.patterns[0].direction, PatternDirection::Both);
}
#[test]
fn test_extract_with_variable_depth() {
let result =
extract_match_clauses("SELECT * FROM nodes AS a MATCH (a)-[:KNOWS*1..3]->(b)").unwrap();
assert_eq!(
result.patterns[0].depth,
TraversalDepth::Range { min: 1, max: 3 }
);
}
#[test]
fn test_extract_exact_depth() {
let result =
extract_match_clauses("SELECT * FROM nodes AS a MATCH (a)-[:KNOWS*2]->(b)").unwrap();
assert_eq!(result.patterns[0].depth, TraversalDepth::Exact(2));
}
#[test]
fn test_extract_unbounded() {
let result =
extract_match_clauses("SELECT * FROM nodes AS a MATCH (a)-[:KNOWS*]->(b)").unwrap();
assert_eq!(result.patterns[0].depth, TraversalDepth::Variable);
}
#[test]
fn test_extract_open_min() {
let result =
extract_match_clauses("SELECT * FROM nodes AS a MATCH (a)-[:KNOWS*2..]->(b)").unwrap();
assert_eq!(
result.patterns[0].depth,
TraversalDepth::Range {
min: 2,
max: DEFAULT_MAX_DEPTH
}
);
}
#[test]
fn test_extract_open_max() {
let result =
extract_match_clauses("SELECT * FROM nodes AS a MATCH (a)-[:KNOWS*..3]->(b)").unwrap();
assert_eq!(
result.patterns[0].depth,
TraversalDepth::Range { min: 1, max: 3 }
);
}
#[test]
fn test_match_inside_string_ignored() {
let result =
extract_match_clauses("SELECT * FROM nodes WHERE name = 'MATCH (a)-[:X]->(b)'")
.unwrap();
assert!(result.patterns.is_empty());
}
#[test]
fn test_case_insensitive_match() {
let result =
extract_match_clauses("SELECT * FROM nodes AS a match (a)-[:KNOWS]->(b)").unwrap();
assert_eq!(result.patterns.len(), 1);
}
#[test]
fn test_to_query_op_outgoing() {
let pattern = GraphPattern {
source_alias: "a".to_string(),
target_alias: "b".to_string(),
edge_type: Some("KNOWS".to_string()),
direction: PatternDirection::Outgoing,
depth: TraversalDepth::Exact(1),
};
let op = pattern.to_query_op();
assert!(matches!(
op,
QueryOp::TraverseOut {
label: Some(ref l),
depth: TraversalDepth::Exact(1),
} if l == "KNOWS"
));
}
#[test]
fn test_to_query_op_incoming() {
let pattern = GraphPattern {
source_alias: "a".to_string(),
target_alias: "b".to_string(),
edge_type: Some("PARENT_OF".to_string()),
direction: PatternDirection::Incoming,
depth: TraversalDepth::Exact(1),
};
let op = pattern.to_query_op();
assert!(matches!(op, QueryOp::TraverseIn { .. }));
}
#[test]
fn test_to_query_op_both() {
let pattern = GraphPattern {
source_alias: "a".to_string(),
target_alias: "b".to_string(),
edge_type: Some("RELATED".to_string()),
direction: PatternDirection::Both,
depth: TraversalDepth::Exact(1),
};
let op = pattern.to_query_op();
assert!(matches!(op, QueryOp::TraverseBoth { .. }));
}
#[test]
fn test_cleaned_sql_valid_for_sqlparser() {
let result = extract_match_clauses(
"SELECT * FROM nodes AS source MATCH (source)-[:KNOWS]->(target) WHERE source.name = 'Alice' ORDER BY target.name LIMIT 10",
)
.unwrap();
assert!(result.cleaned_sql.contains("SELECT"));
assert!(result.cleaned_sql.contains("FROM"));
assert!(result.cleaned_sql.contains("WHERE"));
assert!(result.cleaned_sql.contains("ORDER BY"));
assert!(result.cleaned_sql.contains("LIMIT"));
assert!(!result.cleaned_sql.to_uppercase().contains("MATCH"));
}
#[test]
fn test_extract_multiple_match_clauses() {
let result = extract_match_clauses(
"SELECT * FROM nodes AS a MATCH (a)-[:KNOWS]->(b) MATCH (b)-[:WORKS_AT]->(c) WHERE a.name = 'Alice'",
)
.unwrap();
assert_eq!(result.patterns.len(), 2);
assert_eq!(result.patterns[0].edge_type, Some("KNOWS".to_string()));
assert_eq!(result.patterns[0].direction, PatternDirection::Outgoing);
assert_eq!(result.patterns[1].edge_type, Some("WORKS_AT".to_string()));
assert_eq!(result.patterns[1].direction, PatternDirection::Outgoing);
assert!(!result.cleaned_sql.to_uppercase().contains("MATCH"));
assert!(result.cleaned_sql.contains("WHERE"));
}
#[test]
fn test_extract_empty_edge_spec() {
let result = extract_match_clauses("SELECT * FROM nodes AS a MATCH (a)-[]->(b)").unwrap();
assert_eq!(result.patterns.len(), 1);
assert_eq!(result.patterns[0].edge_type, None);
assert_eq!(result.patterns[0].direction, PatternDirection::Outgoing);
assert_eq!(result.patterns[0].depth, TraversalDepth::Exact(1));
}
}