use crate::core::KnowledgeGraph;
use crate::Result;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum QueryOp {
EntityScan {
entity_type: String,
},
Filter {
property: String,
value: String,
},
Join {
left: Box<QueryOp>,
right: Box<QueryOp>,
join_type: JoinType,
},
Neighbors {
source: Box<QueryOp>,
relation_type: Option<String>,
max_hops: usize,
},
Union {
left: Box<QueryOp>,
right: Box<QueryOp>,
},
Limit {
source: Box<QueryOp>,
count: usize,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum JoinType {
Inner,
LeftOuter,
Cross,
}
#[derive(Debug, Clone)]
pub struct OperationCost {
pub cardinality: usize,
pub cost: f64,
pub selectivity: f64,
}
#[derive(Debug, Clone)]
pub struct GraphStatistics {
pub total_entities: usize,
pub entities_by_type: HashMap<String, usize>,
pub total_relationships: usize,
pub relationships_by_type: HashMap<String, usize>,
pub average_degree: f64,
}
impl GraphStatistics {
pub fn from_graph(graph: &KnowledgeGraph) -> Self {
let entities: Vec<_> = graph.entities().collect();
let total_entities = entities.len();
let mut entities_by_type: HashMap<String, usize> = HashMap::new();
for entity in &entities {
*entities_by_type
.entry(entity.entity_type.clone())
.or_insert(0) += 1;
}
let relationships = graph.get_all_relationships();
let total_relationships = relationships.len();
let mut relationships_by_type: HashMap<String, usize> = HashMap::new();
for rel in &relationships {
*relationships_by_type
.entry(rel.relation_type.clone())
.or_insert(0) += 1;
}
let average_degree = if total_entities > 0 {
(total_relationships as f64 * 2.0) / total_entities as f64
} else {
0.0
};
Self {
total_entities,
entities_by_type,
total_relationships,
relationships_by_type,
average_degree,
}
}
}
pub struct QueryOptimizer {
stats: GraphStatistics,
}
impl QueryOptimizer {
pub fn new(stats: GraphStatistics) -> Self {
Self { stats }
}
pub fn optimize(&self, query: QueryOp) -> Result<QueryOp> {
let rewritten = self.rewrite_query(query)?;
let optimized = self.optimize_joins(rewritten)?;
Ok(optimized)
}
fn rewrite_query(&self, query: QueryOp) -> Result<QueryOp> {
match query {
QueryOp::Filter { property, value } => Ok(QueryOp::Filter { property, value }),
QueryOp::Join {
left,
right,
join_type,
} => {
let left_opt = self.rewrite_query(*left)?;
let right_opt = self.rewrite_query(*right)?;
let left_cost = self.estimate_cost(&left_opt)?;
let right_cost = self.estimate_cost(&right_opt)?;
if left_cost.cardinality > right_cost.cardinality {
Ok(QueryOp::Join {
left: Box::new(right_opt),
right: Box::new(left_opt),
join_type,
})
} else {
Ok(QueryOp::Join {
left: Box::new(left_opt),
right: Box::new(right_opt),
join_type,
})
}
},
QueryOp::Neighbors {
source,
relation_type,
max_hops,
} => {
let source_opt = self.rewrite_query(*source)?;
Ok(QueryOp::Neighbors {
source: Box::new(source_opt),
relation_type,
max_hops,
})
},
QueryOp::Union { left, right } => {
let left_opt = self.rewrite_query(*left)?;
let right_opt = self.rewrite_query(*right)?;
Ok(QueryOp::Union {
left: Box::new(left_opt),
right: Box::new(right_opt),
})
},
QueryOp::Limit { source, count } => {
let source_opt = self.rewrite_query(*source)?;
Ok(QueryOp::Limit {
source: Box::new(source_opt),
count,
})
},
QueryOp::EntityScan { entity_type } => Ok(QueryOp::EntityScan { entity_type }),
}
}
fn optimize_joins(&self, query: QueryOp) -> Result<QueryOp> {
match query {
QueryOp::Join {
left,
right,
join_type,
} => {
let left_opt = self.optimize_joins(*left)?;
let right_opt = self.optimize_joins(*right)?;
let mut operands = Vec::new();
Self::collect_join_operands(&left_opt, &mut operands);
Self::collect_join_operands(&right_opt, &mut operands);
if operands.len() > 2 {
self.find_optimal_join_order(operands, join_type)
} else {
Ok(QueryOp::Join {
left: Box::new(left_opt),
right: Box::new(right_opt),
join_type,
})
}
},
QueryOp::Neighbors {
source,
relation_type,
max_hops,
} => {
let source_opt = self.optimize_joins(*source)?;
Ok(QueryOp::Neighbors {
source: Box::new(source_opt),
relation_type,
max_hops,
})
},
QueryOp::Union { left, right } => {
let left_opt = self.optimize_joins(*left)?;
let right_opt = self.optimize_joins(*right)?;
Ok(QueryOp::Union {
left: Box::new(left_opt),
right: Box::new(right_opt),
})
},
QueryOp::Limit { source, count } => {
let source_opt = self.optimize_joins(*source)?;
Ok(QueryOp::Limit {
source: Box::new(source_opt),
count,
})
},
_ => Ok(query),
}
}
fn collect_join_operands(op: &QueryOp, operands: &mut Vec<QueryOp>) {
match op {
QueryOp::Join { left, right, .. } => {
Self::collect_join_operands(left, operands);
Self::collect_join_operands(right, operands);
},
_ => {
operands.push(op.clone());
},
}
}
fn find_optimal_join_order(
&self,
mut operands: Vec<QueryOp>,
join_type: JoinType,
) -> Result<QueryOp> {
if operands.is_empty() {
return Err(crate::core::GraphRAGError::Validation {
message: "No operands for join".to_string(),
});
}
if operands.len() == 1 {
return Ok(operands.pop().expect("non-empty"));
}
while operands.len() > 1 {
let mut min_cost = f64::MAX;
let mut best_i = 0;
let mut best_j = 1;
for i in 0..operands.len() {
for j in (i + 1)..operands.len() {
let cost_i = self.estimate_cost(&operands[i])?;
let cost_j = self.estimate_cost(&operands[j])?;
let join_cost = (cost_i.cardinality as f64) * (cost_j.cardinality as f64);
if join_cost < min_cost {
min_cost = join_cost;
best_i = i;
best_j = j;
}
}
}
let left = operands.remove(best_i);
let right = operands.remove(if best_j > best_i { best_j - 1 } else { best_j });
let joined = QueryOp::Join {
left: Box::new(left),
right: Box::new(right),
join_type: join_type.clone(),
};
operands.push(joined);
}
Ok(operands.pop().expect("non-empty"))
}
pub fn estimate_cost(&self, op: &QueryOp) -> Result<OperationCost> {
match op {
QueryOp::EntityScan { entity_type } => {
let cardinality = self
.stats
.entities_by_type
.get(entity_type)
.copied()
.unwrap_or(0);
Ok(OperationCost {
cardinality,
cost: cardinality as f64,
selectivity: if self.stats.total_entities > 0 {
cardinality as f64 / self.stats.total_entities as f64
} else {
0.0
},
})
},
QueryOp::Filter {
property: _,
value: _,
} => {
let selectivity = 0.1;
let cardinality = (self.stats.total_entities as f64 * selectivity) as usize;
Ok(OperationCost {
cardinality,
cost: self.stats.total_entities as f64, selectivity,
})
},
QueryOp::Join {
left,
right,
join_type,
} => {
let left_cost = self.estimate_cost(left)?;
let right_cost = self.estimate_cost(right)?;
let cardinality = match join_type {
JoinType::Inner => {
((left_cost.cardinality as f64) * (right_cost.cardinality as f64)).sqrt()
as usize
},
JoinType::LeftOuter => left_cost.cardinality,
JoinType::Cross => left_cost.cardinality * right_cost.cardinality,
};
let cost = left_cost.cost
+ right_cost.cost
+ (left_cost.cardinality as f64 * right_cost.cardinality as f64);
Ok(OperationCost {
cardinality,
cost,
selectivity: left_cost.selectivity * right_cost.selectivity,
})
},
QueryOp::Neighbors {
source,
relation_type: _,
max_hops,
} => {
let source_cost = self.estimate_cost(source)?;
let expansion_factor = self.stats.average_degree.powi(*max_hops as i32);
let cardinality = (source_cost.cardinality as f64 * expansion_factor)
.min(self.stats.total_entities as f64)
as usize;
Ok(OperationCost {
cardinality,
cost: source_cost.cost + (cardinality as f64),
selectivity: cardinality as f64 / self.stats.total_entities as f64,
})
},
QueryOp::Union { left, right } => {
let left_cost = self.estimate_cost(left)?;
let right_cost = self.estimate_cost(right)?;
let cardinality = (left_cost.cardinality + right_cost.cardinality) * 9 / 10;
Ok(OperationCost {
cardinality,
cost: left_cost.cost + right_cost.cost,
selectivity: (left_cost.selectivity + right_cost.selectivity).min(1.0),
})
},
QueryOp::Limit { source, count } => {
let source_cost = self.estimate_cost(source)?;
Ok(OperationCost {
cardinality: (*count).min(source_cost.cardinality),
cost: source_cost.cost,
selectivity: (*count as f64 / self.stats.total_entities as f64).min(1.0),
})
},
}
}
pub fn explain(&self, op: &QueryOp) -> Result<String> {
let cost = self.estimate_cost(op)?;
let mut plan = String::new();
self.explain_recursive(op, 0, &mut plan)?;
plan.push_str(&format!(
"\nEstimated Cost: {:.2}\nEstimated Cardinality: {}\nSelectivity: {:.2}%\n",
cost.cost,
cost.cardinality,
cost.selectivity * 100.0
));
Ok(plan)
}
fn explain_recursive(&self, op: &QueryOp, depth: usize, plan: &mut String) -> Result<()> {
let indent = " ".repeat(depth);
let cost = self.estimate_cost(op)?;
match op {
QueryOp::EntityScan { entity_type } => {
plan.push_str(&format!(
"{}EntityScan({}) [cost={:.0}, rows={}]\n",
indent, entity_type, cost.cost, cost.cardinality
));
},
QueryOp::Filter { property, value } => {
plan.push_str(&format!(
"{}Filter({}={}) [cost={:.0}, rows={}]\n",
indent, property, value, cost.cost, cost.cardinality
));
},
QueryOp::Join {
left,
right,
join_type,
} => {
plan.push_str(&format!(
"{}Join({:?}) [cost={:.0}, rows={}]\n",
indent, join_type, cost.cost, cost.cardinality
));
self.explain_recursive(left, depth + 1, plan)?;
self.explain_recursive(right, depth + 1, plan)?;
},
QueryOp::Neighbors {
source,
relation_type,
max_hops,
} => {
let rel_str = relation_type.as_deref().unwrap_or("*");
plan.push_str(&format!(
"{}Neighbors({}, hops={}) [cost={:.0}, rows={}]\n",
indent, rel_str, max_hops, cost.cost, cost.cardinality
));
self.explain_recursive(source, depth + 1, plan)?;
},
QueryOp::Union { left, right } => {
plan.push_str(&format!(
"{}Union [cost={:.0}, rows={}]\n",
indent, cost.cost, cost.cardinality
));
self.explain_recursive(left, depth + 1, plan)?;
self.explain_recursive(right, depth + 1, plan)?;
},
QueryOp::Limit { source, count } => {
plan.push_str(&format!(
"{}Limit({}) [cost={:.0}, rows={}]\n",
indent, count, cost.cost, cost.cardinality
));
self.explain_recursive(source, depth + 1, plan)?;
},
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_stats() -> GraphStatistics {
let mut entities_by_type = HashMap::new();
entities_by_type.insert("PERSON".to_string(), 100);
entities_by_type.insert("ORGANIZATION".to_string(), 50);
entities_by_type.insert("LOCATION".to_string(), 30);
let mut relationships_by_type = HashMap::new();
relationships_by_type.insert("WORKS_FOR".to_string(), 80);
relationships_by_type.insert("LOCATED_IN".to_string(), 60);
GraphStatistics {
total_entities: 180,
entities_by_type,
total_relationships: 140,
relationships_by_type,
average_degree: 1.56,
}
}
#[test]
fn test_cost_estimation_scan() {
let stats = create_test_stats();
let optimizer = QueryOptimizer::new(stats);
let query = QueryOp::EntityScan {
entity_type: "PERSON".to_string(),
};
let cost = optimizer.estimate_cost(&query).unwrap();
assert_eq!(cost.cardinality, 100);
assert_eq!(cost.cost, 100.0);
}
#[test]
fn test_cost_estimation_join() {
let stats = create_test_stats();
let optimizer = QueryOptimizer::new(stats);
let query = QueryOp::Join {
left: Box::new(QueryOp::EntityScan {
entity_type: "PERSON".to_string(),
}),
right: Box::new(QueryOp::EntityScan {
entity_type: "ORGANIZATION".to_string(),
}),
join_type: JoinType::Inner,
};
let cost = optimizer.estimate_cost(&query).unwrap();
assert!(cost.cardinality > 60 && cost.cardinality < 80);
}
#[test]
fn test_join_reordering() {
let stats = create_test_stats();
let optimizer = QueryOptimizer::new(stats);
let query = QueryOp::Join {
left: Box::new(QueryOp::EntityScan {
entity_type: "PERSON".to_string(),
}),
right: Box::new(QueryOp::EntityScan {
entity_type: "LOCATION".to_string(),
}),
join_type: JoinType::Inner,
};
let optimized = optimizer.optimize(query).unwrap();
if let QueryOp::Join { left, .. } = optimized {
if let QueryOp::EntityScan { entity_type } = &*left {
assert_eq!(entity_type, "LOCATION", "Smaller table should be first");
}
}
}
#[test]
fn test_neighbors_cost() {
let stats = create_test_stats();
let optimizer = QueryOptimizer::new(stats);
let query = QueryOp::Neighbors {
source: Box::new(QueryOp::EntityScan {
entity_type: "PERSON".to_string(),
}),
relation_type: Some("WORKS_FOR".to_string()),
max_hops: 2,
};
let cost = optimizer.estimate_cost(&query).unwrap();
assert!(cost.cardinality > 100);
}
#[test]
fn test_explain_plan() {
let stats = create_test_stats();
let optimizer = QueryOptimizer::new(stats);
let query = QueryOp::Join {
left: Box::new(QueryOp::EntityScan {
entity_type: "PERSON".to_string(),
}),
right: Box::new(QueryOp::EntityScan {
entity_type: "ORGANIZATION".to_string(),
}),
join_type: JoinType::Inner,
};
let plan = optimizer.explain(&query).unwrap();
assert!(plan.contains("Join"));
assert!(plan.contains("EntityScan"));
assert!(plan.contains("Estimated Cost"));
}
}