use super::ast::{MatchClause, MatchQuery, PatternChain, PatternTriple};
use crate::engine::graph::csr::CsrIndex;
pub fn optimize(query: &mut MatchQuery, csr: &CsrIndex) {
for clause in &mut query.clauses {
optimize_clause(clause, csr);
}
}
fn optimize_clause(clause: &mut MatchClause, csr: &CsrIndex) {
for chain in &mut clause.patterns {
optimize_chain(chain, csr);
}
}
fn optimize_chain(chain: &mut PatternChain, csr: &CsrIndex) {
let n = chain.triples.len();
if n <= 1 {
return; }
let mut placed: Vec<PatternTriple> = Vec::with_capacity(n);
let mut remaining: Vec<PatternTriple> = chain.triples.drain(..).collect();
let mut bound_vars: std::collections::HashSet<String> = std::collections::HashSet::new();
for _ in 0..n {
let mut best_idx = 0;
let mut best_cost = f64::INFINITY;
for (idx, triple) in remaining.iter().enumerate() {
let cost = score_triple(triple, csr, &bound_vars);
if cost < best_cost {
best_cost = cost;
best_idx = idx;
}
}
let triple = remaining.swap_remove(best_idx);
if let Some(ref name) = triple.src.name {
bound_vars.insert(name.clone());
}
if let Some(ref name) = triple.dst.name {
bound_vars.insert(name.clone());
}
if let Some(ref name) = triple.edge.name {
bound_vars.insert(name.clone());
}
placed.push(triple);
}
chain.triples = placed;
}
fn score_triple(
triple: &PatternTriple,
csr: &CsrIndex,
bound_vars: &std::collections::HashSet<String>,
) -> f64 {
let label_count = triple
.edge
.edge_type
.as_ref()
.map_or(csr.edge_count(), |label| csr.label_edge_count(label));
if label_count == 0 {
return 0.0;
}
let base_cost = label_count as f64;
let src_bound = triple
.src
.name
.as_ref()
.is_some_and(|n| bound_vars.contains(n));
let dst_bound = triple
.dst
.name
.as_ref()
.is_some_and(|n| bound_vars.contains(n));
let factor = match (src_bound, dst_bound) {
(true, true) => 0.01, (true, false) | (false, true) => 0.1, (false, false) => 1.0, };
let hop_factor = if triple.edge.is_variable_length() {
triple.edge.max_hops as f64
} else {
1.0
};
base_cost * factor * hop_factor
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::graph::pattern::ast::*;
use crate::engine::graph::pattern::compiler;
fn social_graph() -> CsrIndex {
let mut csr = CsrIndex::new();
for i in 0..100 {
csr.add_edge(&format!("p{i}"), "KNOWS", &format!("p{}", (i + 1) % 100));
}
for i in 0..5 {
csr.add_edge(&format!("p{i}"), "CREATED", &format!("doc{i}"));
}
csr.compact();
csr
}
#[test]
fn optimize_reorders_by_selectivity() {
let csr = social_graph();
let mut query =
compiler::parse("MATCH (a)-[:KNOWS]->(b), (b)-[:CREATED]->(c) RETURN a, b, c").unwrap();
assert_eq!(
query.clauses[0].patterns[0].triples[0]
.edge
.edge_type
.as_deref(),
Some("KNOWS")
);
assert_eq!(
query.clauses[0].patterns[1].triples[0]
.edge
.edge_type
.as_deref(),
Some("CREATED")
);
optimize(&mut query, &csr);
assert_eq!(query.clauses[0].patterns.len(), 2);
}
#[test]
fn optimize_prefers_bound_variables() {
let csr = social_graph();
let mut chain = PatternChain {
triples: vec![
PatternTriple {
src: NodeBinding {
name: Some("a".into()),
label: None,
},
edge: EdgeBinding {
name: None,
edge_type: Some("KNOWS".into()),
direction: EdgeDirection::Right,
min_hops: 1,
max_hops: 1,
},
dst: NodeBinding {
name: Some("b".into()),
label: None,
},
},
PatternTriple {
src: NodeBinding {
name: Some("b".into()),
label: None,
},
edge: EdgeBinding {
name: None,
edge_type: Some("CREATED".into()),
direction: EdgeDirection::Right,
min_hops: 1,
max_hops: 1,
},
dst: NodeBinding {
name: Some("c".into()),
label: None,
},
},
],
};
optimize_chain(&mut chain, &csr);
assert_eq!(chain.triples[0].edge.edge_type.as_deref(), Some("CREATED"));
}
#[test]
fn optimize_single_triple_unchanged() {
let csr = social_graph();
let mut chain = PatternChain {
triples: vec![PatternTriple {
src: NodeBinding {
name: Some("a".into()),
label: None,
},
edge: EdgeBinding {
name: None,
edge_type: Some("KNOWS".into()),
direction: EdgeDirection::Right,
min_hops: 1,
max_hops: 1,
},
dst: NodeBinding {
name: Some("b".into()),
label: None,
},
}],
};
optimize_chain(&mut chain, &csr);
assert_eq!(chain.triples.len(), 1);
}
#[test]
fn score_bound_vs_unbound() {
let csr = social_graph();
let triple = PatternTriple {
src: NodeBinding {
name: Some("a".into()),
label: None,
},
edge: EdgeBinding {
name: None,
edge_type: Some("KNOWS".into()),
direction: EdgeDirection::Right,
min_hops: 1,
max_hops: 1,
},
dst: NodeBinding {
name: Some("b".into()),
label: None,
},
};
let mut bound = std::collections::HashSet::new();
let unbound_score = score_triple(&triple, &csr, &bound);
bound.insert("a".to_string());
let one_bound_score = score_triple(&triple, &csr, &bound);
bound.insert("b".to_string());
let both_bound_score = score_triple(&triple, &csr, &bound);
assert!(one_bound_score < unbound_score);
assert!(both_bound_score < one_bound_score);
}
#[test]
fn score_variable_length_penalized() {
let csr = social_graph();
let bound = std::collections::HashSet::new();
let fixed = PatternTriple {
src: NodeBinding {
name: Some("a".into()),
label: None,
},
edge: EdgeBinding {
name: None,
edge_type: Some("KNOWS".into()),
direction: EdgeDirection::Right,
min_hops: 1,
max_hops: 1,
},
dst: NodeBinding {
name: Some("b".into()),
label: None,
},
};
let variable = PatternTriple {
src: NodeBinding {
name: Some("a".into()),
label: None,
},
edge: EdgeBinding {
name: None,
edge_type: Some("KNOWS".into()),
direction: EdgeDirection::Right,
min_hops: 1,
max_hops: 3,
},
dst: NodeBinding {
name: Some("b".into()),
label: None,
},
};
let fixed_score = score_triple(&fixed, &csr, &bound);
let variable_score = score_triple(&variable, &csr, &bound);
assert!(
variable_score > fixed_score,
"variable-length should be more expensive"
);
}
}