use petgraph::algo::{connected_components, is_cyclic_directed};
use petgraph::graph::{Graph, NodeIndex, UnGraph};
use petgraph::visit::EdgeRef;
use std::collections::{HashMap, HashSet};
use crate::ast::ast::{Edge, PathPattern, PatternElement};
use crate::plan::pattern_optimization::pattern_analysis::{
LinearPath, PatternConnectivity, TraversalStep,
};
#[allow(dead_code)]
#[derive(Debug)]
pub struct PatternAnalyzer {
debug_mode: bool,
}
impl PatternAnalyzer {
pub fn new() -> Self {
PatternAnalyzer { debug_mode: false }
}
#[allow(dead_code)] pub fn new_debug() -> Self {
PatternAnalyzer { debug_mode: true }
}
pub fn analyze_patterns(&self, patterns: Vec<PathPattern>) -> PatternConnectivity {
if self.debug_mode {
log::debug!("Analyzing {} patterns for connectivity", patterns.len());
}
let shared_vars = self.find_shared_variables(&patterns);
let connectivity_graph = self.build_connectivity_graph(&patterns, &shared_vars);
if self.debug_mode {
log::debug!(
"Found {} shared variables: {:?}",
shared_vars.len(),
shared_vars.keys().collect::<Vec<_>>()
);
}
PatternConnectivity {
patterns,
shared_variables: shared_vars,
connectivity_graph,
}
}
fn find_shared_variables(&self, patterns: &[PathPattern]) -> HashMap<String, Vec<usize>> {
let mut var_usage: HashMap<String, Vec<usize>> = HashMap::new();
for (pattern_idx, pattern) in patterns.iter().enumerate() {
let vars = self.extract_pattern_variables(pattern);
if self.debug_mode {
log::debug!("Pattern {}: variables {:?}", pattern_idx, vars);
}
for var in vars {
var_usage
.entry(var)
.or_insert_with(Vec::new)
.push(pattern_idx);
}
}
var_usage.retain(|_, indices| indices.len() > 1);
var_usage
}
fn extract_pattern_variables(&self, pattern: &PathPattern) -> HashSet<String> {
let mut variables = HashSet::new();
for element in &pattern.elements {
match element {
PatternElement::Node(node) => {
if let Some(ref identifier) = node.identifier {
variables.insert(identifier.clone());
}
}
PatternElement::Edge(edge) => {
if let Some(ref identifier) = edge.identifier {
variables.insert(identifier.clone());
}
}
}
}
variables
}
fn build_connectivity_graph(
&self,
patterns: &[PathPattern],
shared_vars: &HashMap<String, Vec<usize>>,
) -> Graph<usize, String> {
let mut graph = Graph::new();
let mut node_indices: HashMap<usize, NodeIndex> = HashMap::new();
for (pattern_idx, _) in patterns.iter().enumerate() {
let node_index = graph.add_node(pattern_idx);
node_indices.insert(pattern_idx, node_index);
}
for (var_name, pattern_indices) in shared_vars {
for i in 0..pattern_indices.len() {
for j in i + 1..pattern_indices.len() {
let from_pattern = pattern_indices[i];
let to_pattern = pattern_indices[j];
if let (Some(&from_node), Some(&to_node)) = (
node_indices.get(&from_pattern),
node_indices.get(&to_pattern),
) {
graph.add_edge(from_node, to_node, var_name.clone());
if self.debug_mode {
log::debug!(
"Connected patterns {} and {} via variable '{}'",
from_pattern,
to_pattern,
var_name
);
}
}
}
}
}
graph
}
pub fn detect_linear_path(&self, connectivity: &PatternConnectivity) -> Option<LinearPath> {
if connectivity.patterns.len() < 2 {
return None;
}
if self.debug_mode {
log::debug!(
"Checking for linear path in {} patterns",
connectivity.patterns.len()
);
}
let undirected = self.to_undirected_graph(&connectivity.connectivity_graph);
if !self.is_simple_path(&undirected) {
if self.debug_mode {
log::debug!("Not a simple path - has cycles or branching");
}
return None;
}
let endpoints = self.find_path_endpoints(&undirected);
if endpoints.len() != 2 {
if self.debug_mode {
log::debug!(
"Path should have exactly 2 endpoints, found {}",
endpoints.len()
);
}
return None;
}
let start_pattern_idx = connectivity.connectivity_graph[endpoints[0]];
self.build_linear_path_from_start(connectivity, start_pattern_idx)
}
fn to_undirected_graph(&self, directed: &Graph<usize, String>) -> UnGraph<usize, String> {
let mut undirected = UnGraph::new_undirected();
let mut node_map = HashMap::new();
for node_idx in directed.node_indices() {
let pattern_idx = directed[node_idx];
let new_node = undirected.add_node(pattern_idx);
node_map.insert(node_idx, new_node);
}
for edge_idx in directed.edge_indices() {
if let Some((from, to)) = directed.edge_endpoints(edge_idx) {
let edge_data = &directed[edge_idx];
if let (Some(&from_new), Some(&to_new)) = (node_map.get(&from), node_map.get(&to)) {
undirected.add_edge(from_new, to_new, edge_data.clone());
}
}
}
undirected
}
fn is_simple_path(&self, graph: &UnGraph<usize, String>) -> bool {
if graph.node_count() < 2 {
return false;
}
let branching_nodes = graph
.node_indices()
.filter(|&node| graph.edges(node).count() > 2)
.count();
branching_nodes == 0
}
fn find_path_endpoints(&self, graph: &UnGraph<usize, String>) -> Vec<NodeIndex> {
graph
.node_indices()
.filter(|&node| graph.edges(node).count() == 1)
.collect()
}
fn build_linear_path_from_start(
&self,
connectivity: &PatternConnectivity,
start_pattern_idx: usize,
) -> Option<LinearPath> {
let mut steps = Vec::new();
let mut visited = HashSet::new();
let mut current_idx = start_pattern_idx;
visited.insert(current_idx);
while let Some(next_idx) =
self.find_next_connected_pattern(connectivity, current_idx, &visited)
{
let shared_var =
self.find_shared_variable_between_patterns(connectivity, current_idx, next_idx)?;
let step = self.create_traversal_step(
&connectivity.patterns[current_idx],
&connectivity.patterns[next_idx],
&shared_var,
next_idx,
)?;
steps.push(step);
visited.insert(next_idx);
current_idx = next_idx;
}
if steps.is_empty() {
return None;
}
let start_pattern = connectivity.patterns[start_pattern_idx].clone();
Some(LinearPath::new(start_pattern, steps))
}
fn find_next_connected_pattern(
&self,
connectivity: &PatternConnectivity,
current_idx: usize,
visited: &HashSet<usize>,
) -> Option<usize> {
let current_node = connectivity
.connectivity_graph
.node_indices()
.find(|&node| connectivity.connectivity_graph[node] == current_idx)?;
for edge in connectivity.connectivity_graph.edges(current_node) {
let target_node = edge.target();
let target_pattern_idx = connectivity.connectivity_graph[target_node];
if !visited.contains(&target_pattern_idx) {
return Some(target_pattern_idx);
}
}
None
}
fn find_shared_variable_between_patterns(
&self,
connectivity: &PatternConnectivity,
pattern1_idx: usize,
pattern2_idx: usize,
) -> Option<String> {
for (var_name, pattern_indices) in &connectivity.shared_variables {
if pattern_indices.contains(&pattern1_idx) && pattern_indices.contains(&pattern2_idx) {
return Some(var_name.clone());
}
}
None
}
fn create_traversal_step(
&self,
from_pattern: &PathPattern,
to_pattern: &PathPattern,
shared_var: &str,
to_pattern_idx: usize,
) -> Option<TraversalStep> {
let relationship =
self.find_connecting_relationship(from_pattern, to_pattern, shared_var)?;
let (from_var, to_var) =
self.determine_traversal_direction(from_pattern, to_pattern, shared_var)?;
Some(TraversalStep {
from_var,
relationship,
to_var,
selectivity: 0.1, pattern_index: to_pattern_idx,
})
}
fn find_connecting_relationship(
&self,
from_pattern: &PathPattern,
to_pattern: &PathPattern,
_shared_var: &str,
) -> Option<Edge> {
for element in &from_pattern.elements {
if let PatternElement::Edge(edge) = element {
return Some(edge.clone());
}
}
for element in &to_pattern.elements {
if let PatternElement::Edge(edge) = element {
return Some(edge.clone());
}
}
None
}
fn determine_traversal_direction(
&self,
from_pattern: &PathPattern,
to_pattern: &PathPattern,
shared_var: &str,
) -> Option<(String, String)> {
let from_vars = self.extract_pattern_variables(from_pattern);
let to_vars = self.extract_pattern_variables(to_pattern);
let from_var = from_vars.iter().find(|&var| var != shared_var).cloned()?;
let to_var = to_vars.iter().find(|&var| var != shared_var).cloned()?;
Some((from_var, to_var))
}
#[allow(dead_code)] pub fn is_star_pattern(&self, connectivity: &PatternConnectivity) -> bool {
if connectivity.patterns.len() < 3 {
return false;
}
for (var_name, pattern_indices) in &connectivity.shared_variables {
if pattern_indices.len() >= connectivity.patterns.len() - 1 {
if self.debug_mode {
log::debug!("Star pattern detected with center variable '{}'", var_name);
}
return true;
}
}
let max_degree = connectivity
.connectivity_graph
.node_indices()
.map(|node| connectivity.connectivity_graph.edges(node).count())
.max()
.unwrap_or(0);
max_degree >= connectivity.patterns.len() - 1
}
#[allow(dead_code)] pub fn has_cycle(&self, connectivity: &PatternConnectivity) -> bool {
is_cyclic_directed(&connectivity.connectivity_graph)
}
#[allow(dead_code)] pub fn count_connected_components(&self, connectivity: &PatternConnectivity) -> usize {
let undirected = self.to_undirected_graph(&connectivity.connectivity_graph);
connected_components(&undirected)
}
}
impl Default for PatternAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::ast::{EdgeDirection, Location, Node};
fn create_test_node(id: &str, label: &str) -> PatternElement {
PatternElement::Node(Node {
identifier: Some(id.to_string()),
labels: if label.is_empty() {
vec![]
} else {
vec![label.to_string()]
},
properties: None,
location: Location::default(),
})
}
fn create_test_edge(id: Option<&str>, label: &str) -> PatternElement {
PatternElement::Edge(Edge {
identifier: id.map(|s| s.to_string()),
labels: vec![label.to_string()],
properties: None,
direction: EdgeDirection::Outgoing,
quantifier: None,
location: Location::default(),
})
}
fn create_test_pattern(elements: Vec<PatternElement>) -> PathPattern {
PathPattern {
assignment: None,
path_type: None,
elements,
location: Location::default(),
}
}
#[test]
fn test_variable_extraction() {
let pattern = create_test_pattern(vec![
create_test_node("a", "Person"),
create_test_edge(Some("r"), "KNOWS"),
create_test_node("b", "Person"),
]);
let analyzer = PatternAnalyzer::new();
let vars = analyzer.extract_pattern_variables(&pattern);
assert!(vars.contains("a"));
assert!(vars.contains("b"));
assert!(vars.contains("r"));
assert_eq!(vars.len(), 3);
}
#[test]
fn test_shared_variable_detection() {
let patterns = vec![
create_test_pattern(vec![
create_test_node("a", "Person"),
create_test_edge(None, "KNOWS"),
create_test_node("b", "Person"),
]),
create_test_pattern(vec![
create_test_node("b", "Person"),
create_test_edge(None, "WORKS_IN"),
create_test_node("c", "Department"),
]),
];
let analyzer = PatternAnalyzer::new();
let shared = analyzer.find_shared_variables(&patterns);
assert!(shared.contains_key("b"));
assert_eq!(shared.get("b").unwrap(), &vec![0, 1]);
assert_eq!(shared.len(), 1);
}
#[test]
fn test_linear_path_detection() {
let patterns = vec![
create_test_pattern(vec![
create_test_node("a", "Person"),
create_test_edge(None, "KNOWS"),
create_test_node("b", "Person"),
]),
create_test_pattern(vec![
create_test_node("b", "Person"),
create_test_edge(None, "WORKS_IN"),
create_test_node("c", "Department"),
]),
];
let analyzer = PatternAnalyzer::new_debug();
let connectivity = analyzer.analyze_patterns(patterns);
let path = analyzer.detect_linear_path(&connectivity);
assert!(path.is_some(), "Should detect linear path");
let path = path.unwrap();
assert_eq!(path.length(), 1);
}
#[test]
fn test_star_pattern_detection() {
let patterns = vec![
create_test_pattern(vec![
create_test_node("a", "Person"),
create_test_edge(None, "KNOWS"),
create_test_node("b", "Person"),
]),
create_test_pattern(vec![
create_test_node("a", "Person"),
create_test_edge(None, "WORKS_IN"),
create_test_node("d", "Department"),
]),
create_test_pattern(vec![
create_test_node("a", "Person"),
create_test_edge(None, "LIVES_IN"),
create_test_node("c", "City"),
]),
];
let analyzer = PatternAnalyzer::new_debug();
let connectivity = analyzer.analyze_patterns(patterns);
assert!(analyzer.is_star_pattern(&connectivity));
assert!(!analyzer.has_cycle(&connectivity));
}
#[test]
fn test_no_shared_variables() {
let patterns = vec![
create_test_pattern(vec![
create_test_node("a", "Person"),
create_test_edge(None, "KNOWS"),
create_test_node("b", "Person"),
]),
create_test_pattern(vec![
create_test_node("c", "Department"),
create_test_edge(None, "HAS_EMPLOYEE"),
create_test_node("d", "Person"),
]),
];
let analyzer = PatternAnalyzer::new();
let connectivity = analyzer.analyze_patterns(patterns);
assert!(!connectivity.has_shared_variables());
assert!(analyzer.detect_linear_path(&connectivity).is_none());
assert!(!analyzer.is_star_pattern(&connectivity));
assert_eq!(analyzer.count_connected_components(&connectivity), 2);
}
}