1use crate::core::KnowledgeGraph;
12use crate::Result;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum QueryOp {
18 EntityScan {
20 entity_type: String,
22 },
23 Filter {
25 property: String,
27 value: String,
29 },
30 Join {
32 left: Box<QueryOp>,
34 right: Box<QueryOp>,
36 join_type: JoinType,
38 },
39 Neighbors {
41 source: Box<QueryOp>,
43 relation_type: Option<String>,
45 max_hops: usize,
47 },
48 Union {
50 left: Box<QueryOp>,
52 right: Box<QueryOp>,
54 },
55 Limit {
57 source: Box<QueryOp>,
59 count: usize,
61 },
62}
63
64#[derive(Debug, Clone, PartialEq)]
66pub enum JoinType {
67 Inner,
69 LeftOuter,
71 Cross,
73}
74
75#[derive(Debug, Clone)]
77pub struct OperationCost {
78 pub cardinality: usize,
80 pub cost: f64,
82 pub selectivity: f64,
84}
85
86#[derive(Debug, Clone)]
88pub struct GraphStatistics {
89 pub total_entities: usize,
91 pub entities_by_type: HashMap<String, usize>,
93 pub total_relationships: usize,
95 pub relationships_by_type: HashMap<String, usize>,
97 pub average_degree: f64,
99}
100
101impl GraphStatistics {
102 pub fn from_graph(graph: &KnowledgeGraph) -> Self {
104 let entities: Vec<_> = graph.entities().collect();
105 let total_entities = entities.len();
106
107 let mut entities_by_type: HashMap<String, usize> = HashMap::new();
108 for entity in &entities {
109 *entities_by_type
110 .entry(entity.entity_type.clone())
111 .or_insert(0) += 1;
112 }
113
114 let relationships = graph.get_all_relationships();
115 let total_relationships = relationships.len();
116
117 let mut relationships_by_type: HashMap<String, usize> = HashMap::new();
118 for rel in &relationships {
119 *relationships_by_type
120 .entry(rel.relation_type.clone())
121 .or_insert(0) += 1;
122 }
123
124 let average_degree = if total_entities > 0 {
125 (total_relationships as f64 * 2.0) / total_entities as f64
126 } else {
127 0.0
128 };
129
130 Self {
131 total_entities,
132 entities_by_type,
133 total_relationships,
134 relationships_by_type,
135 average_degree,
136 }
137 }
138}
139
140pub struct QueryOptimizer {
142 stats: GraphStatistics,
143}
144
145impl QueryOptimizer {
146 pub fn new(stats: GraphStatistics) -> Self {
148 Self { stats }
149 }
150
151 pub fn optimize(&self, query: QueryOp) -> Result<QueryOp> {
153 let rewritten = self.rewrite_query(query)?;
155 let optimized = self.optimize_joins(rewritten)?;
156 Ok(optimized)
157 }
158
159 fn rewrite_query(&self, query: QueryOp) -> Result<QueryOp> {
161 match query {
162 QueryOp::Filter { property, value } => Ok(QueryOp::Filter { property, value }),
164
165 QueryOp::Join {
167 left,
168 right,
169 join_type,
170 } => {
171 let left_opt = self.rewrite_query(*left)?;
172 let right_opt = self.rewrite_query(*right)?;
173
174 let left_cost = self.estimate_cost(&left_opt)?;
176 let right_cost = self.estimate_cost(&right_opt)?;
177
178 if left_cost.cardinality > right_cost.cardinality {
180 Ok(QueryOp::Join {
181 left: Box::new(right_opt),
182 right: Box::new(left_opt),
183 join_type,
184 })
185 } else {
186 Ok(QueryOp::Join {
187 left: Box::new(left_opt),
188 right: Box::new(right_opt),
189 join_type,
190 })
191 }
192 },
193
194 QueryOp::Neighbors {
196 source,
197 relation_type,
198 max_hops,
199 } => {
200 let source_opt = self.rewrite_query(*source)?;
201 Ok(QueryOp::Neighbors {
202 source: Box::new(source_opt),
203 relation_type,
204 max_hops,
205 })
206 },
207
208 QueryOp::Union { left, right } => {
209 let left_opt = self.rewrite_query(*left)?;
210 let right_opt = self.rewrite_query(*right)?;
211 Ok(QueryOp::Union {
212 left: Box::new(left_opt),
213 right: Box::new(right_opt),
214 })
215 },
216
217 QueryOp::Limit { source, count } => {
218 let source_opt = self.rewrite_query(*source)?;
219 Ok(QueryOp::Limit {
220 source: Box::new(source_opt),
221 count,
222 })
223 },
224
225 QueryOp::EntityScan { entity_type } => Ok(QueryOp::EntityScan { entity_type }),
227 }
228 }
229
230 fn optimize_joins(&self, query: QueryOp) -> Result<QueryOp> {
232 match query {
233 QueryOp::Join {
234 left,
235 right,
236 join_type,
237 } => {
238 let left_opt = self.optimize_joins(*left)?;
240 let right_opt = self.optimize_joins(*right)?;
241
242 let mut operands = Vec::new();
244 Self::collect_join_operands(&left_opt, &mut operands);
245 Self::collect_join_operands(&right_opt, &mut operands);
246
247 if operands.len() > 2 {
248 self.find_optimal_join_order(operands, join_type)
250 } else {
251 Ok(QueryOp::Join {
253 left: Box::new(left_opt),
254 right: Box::new(right_opt),
255 join_type,
256 })
257 }
258 },
259
260 QueryOp::Neighbors {
262 source,
263 relation_type,
264 max_hops,
265 } => {
266 let source_opt = self.optimize_joins(*source)?;
267 Ok(QueryOp::Neighbors {
268 source: Box::new(source_opt),
269 relation_type,
270 max_hops,
271 })
272 },
273
274 QueryOp::Union { left, right } => {
275 let left_opt = self.optimize_joins(*left)?;
276 let right_opt = self.optimize_joins(*right)?;
277 Ok(QueryOp::Union {
278 left: Box::new(left_opt),
279 right: Box::new(right_opt),
280 })
281 },
282
283 QueryOp::Limit { source, count } => {
284 let source_opt = self.optimize_joins(*source)?;
285 Ok(QueryOp::Limit {
286 source: Box::new(source_opt),
287 count,
288 })
289 },
290
291 _ => Ok(query),
293 }
294 }
295
296 fn collect_join_operands(op: &QueryOp, operands: &mut Vec<QueryOp>) {
298 match op {
299 QueryOp::Join { left, right, .. } => {
300 Self::collect_join_operands(left, operands);
301 Self::collect_join_operands(right, operands);
302 },
303 _ => {
304 operands.push(op.clone());
305 },
306 }
307 }
308
309 fn find_optimal_join_order(
311 &self,
312 mut operands: Vec<QueryOp>,
313 join_type: JoinType,
314 ) -> Result<QueryOp> {
315 if operands.is_empty() {
316 return Err(crate::core::GraphRAGError::Validation {
317 message: "No operands for join".to_string(),
318 });
319 }
320
321 if operands.len() == 1 {
322 return Ok(operands.pop().unwrap());
323 }
324
325 while operands.len() > 1 {
327 let mut min_cost = f64::MAX;
328 let mut best_i = 0;
329 let mut best_j = 1;
330
331 for i in 0..operands.len() {
333 for j in (i + 1)..operands.len() {
334 let cost_i = self.estimate_cost(&operands[i])?;
335 let cost_j = self.estimate_cost(&operands[j])?;
336
337 let join_cost = (cost_i.cardinality as f64) * (cost_j.cardinality as f64);
339
340 if join_cost < min_cost {
341 min_cost = join_cost;
342 best_i = i;
343 best_j = j;
344 }
345 }
346 }
347
348 let left = operands.remove(best_i);
350 let right = operands.remove(if best_j > best_i { best_j - 1 } else { best_j });
351
352 let joined = QueryOp::Join {
353 left: Box::new(left),
354 right: Box::new(right),
355 join_type: join_type.clone(),
356 };
357
358 operands.push(joined);
359 }
360
361 Ok(operands.pop().unwrap())
362 }
363
364 pub fn estimate_cost(&self, op: &QueryOp) -> Result<OperationCost> {
366 match op {
367 QueryOp::EntityScan { entity_type } => {
368 let cardinality = self
369 .stats
370 .entities_by_type
371 .get(entity_type)
372 .copied()
373 .unwrap_or(0);
374
375 Ok(OperationCost {
376 cardinality,
377 cost: cardinality as f64,
378 selectivity: if self.stats.total_entities > 0 {
379 cardinality as f64 / self.stats.total_entities as f64
380 } else {
381 0.0
382 },
383 })
384 },
385
386 QueryOp::Filter {
387 property: _,
388 value: _,
389 } => {
390 let selectivity = 0.1;
392 let cardinality = (self.stats.total_entities as f64 * selectivity) as usize;
393
394 Ok(OperationCost {
395 cardinality,
396 cost: self.stats.total_entities as f64, selectivity,
398 })
399 },
400
401 QueryOp::Join {
402 left,
403 right,
404 join_type,
405 } => {
406 let left_cost = self.estimate_cost(left)?;
407 let right_cost = self.estimate_cost(right)?;
408
409 let cardinality = match join_type {
410 JoinType::Inner => {
411 ((left_cost.cardinality as f64) * (right_cost.cardinality as f64)).sqrt()
413 as usize
414 },
415 JoinType::LeftOuter => left_cost.cardinality,
416 JoinType::Cross => left_cost.cardinality * right_cost.cardinality,
417 };
418
419 let cost = left_cost.cost
420 + right_cost.cost
421 + (left_cost.cardinality as f64 * right_cost.cardinality as f64);
422
423 Ok(OperationCost {
424 cardinality,
425 cost,
426 selectivity: left_cost.selectivity * right_cost.selectivity,
427 })
428 },
429
430 QueryOp::Neighbors {
431 source,
432 relation_type: _,
433 max_hops,
434 } => {
435 let source_cost = self.estimate_cost(source)?;
436
437 let expansion_factor = self.stats.average_degree.powi(*max_hops as i32);
439 let cardinality = (source_cost.cardinality as f64 * expansion_factor)
440 .min(self.stats.total_entities as f64)
441 as usize;
442
443 Ok(OperationCost {
444 cardinality,
445 cost: source_cost.cost + (cardinality as f64),
446 selectivity: cardinality as f64 / self.stats.total_entities as f64,
447 })
448 },
449
450 QueryOp::Union { left, right } => {
451 let left_cost = self.estimate_cost(left)?;
452 let right_cost = self.estimate_cost(right)?;
453
454 let cardinality = (left_cost.cardinality + right_cost.cardinality) * 9 / 10;
456
457 Ok(OperationCost {
458 cardinality,
459 cost: left_cost.cost + right_cost.cost,
460 selectivity: (left_cost.selectivity + right_cost.selectivity).min(1.0),
461 })
462 },
463
464 QueryOp::Limit { source, count } => {
465 let source_cost = self.estimate_cost(source)?;
466
467 Ok(OperationCost {
468 cardinality: (*count).min(source_cost.cardinality),
469 cost: source_cost.cost,
470 selectivity: (*count as f64 / self.stats.total_entities as f64).min(1.0),
471 })
472 },
473 }
474 }
475
476 pub fn explain(&self, op: &QueryOp) -> Result<String> {
478 let cost = self.estimate_cost(op)?;
479 let mut plan = String::new();
480
481 self.explain_recursive(op, 0, &mut plan)?;
482
483 plan.push_str(&format!(
484 "\nEstimated Cost: {:.2}\nEstimated Cardinality: {}\nSelectivity: {:.2}%\n",
485 cost.cost,
486 cost.cardinality,
487 cost.selectivity * 100.0
488 ));
489
490 Ok(plan)
491 }
492
493 fn explain_recursive(&self, op: &QueryOp, depth: usize, plan: &mut String) -> Result<()> {
495 let indent = " ".repeat(depth);
496 let cost = self.estimate_cost(op)?;
497
498 match op {
499 QueryOp::EntityScan { entity_type } => {
500 plan.push_str(&format!(
501 "{}EntityScan({}) [cost={:.0}, rows={}]\n",
502 indent, entity_type, cost.cost, cost.cardinality
503 ));
504 },
505 QueryOp::Filter { property, value } => {
506 plan.push_str(&format!(
507 "{}Filter({}={}) [cost={:.0}, rows={}]\n",
508 indent, property, value, cost.cost, cost.cardinality
509 ));
510 },
511 QueryOp::Join {
512 left,
513 right,
514 join_type,
515 } => {
516 plan.push_str(&format!(
517 "{}Join({:?}) [cost={:.0}, rows={}]\n",
518 indent, join_type, cost.cost, cost.cardinality
519 ));
520 self.explain_recursive(left, depth + 1, plan)?;
521 self.explain_recursive(right, depth + 1, plan)?;
522 },
523 QueryOp::Neighbors {
524 source,
525 relation_type,
526 max_hops,
527 } => {
528 let rel_str = relation_type.as_deref().unwrap_or("*");
529 plan.push_str(&format!(
530 "{}Neighbors({}, hops={}) [cost={:.0}, rows={}]\n",
531 indent, rel_str, max_hops, cost.cost, cost.cardinality
532 ));
533 self.explain_recursive(source, depth + 1, plan)?;
534 },
535 QueryOp::Union { left, right } => {
536 plan.push_str(&format!(
537 "{}Union [cost={:.0}, rows={}]\n",
538 indent, cost.cost, cost.cardinality
539 ));
540 self.explain_recursive(left, depth + 1, plan)?;
541 self.explain_recursive(right, depth + 1, plan)?;
542 },
543 QueryOp::Limit { source, count } => {
544 plan.push_str(&format!(
545 "{}Limit({}) [cost={:.0}, rows={}]\n",
546 indent, count, cost.cost, cost.cardinality
547 ));
548 self.explain_recursive(source, depth + 1, plan)?;
549 },
550 }
551
552 Ok(())
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 fn create_test_stats() -> GraphStatistics {
561 let mut entities_by_type = HashMap::new();
562 entities_by_type.insert("PERSON".to_string(), 100);
563 entities_by_type.insert("ORGANIZATION".to_string(), 50);
564 entities_by_type.insert("LOCATION".to_string(), 30);
565
566 let mut relationships_by_type = HashMap::new();
567 relationships_by_type.insert("WORKS_FOR".to_string(), 80);
568 relationships_by_type.insert("LOCATED_IN".to_string(), 60);
569
570 GraphStatistics {
571 total_entities: 180,
572 entities_by_type,
573 total_relationships: 140,
574 relationships_by_type,
575 average_degree: 1.56,
576 }
577 }
578
579 #[test]
580 fn test_cost_estimation_scan() {
581 let stats = create_test_stats();
582 let optimizer = QueryOptimizer::new(stats);
583
584 let query = QueryOp::EntityScan {
585 entity_type: "PERSON".to_string(),
586 };
587
588 let cost = optimizer.estimate_cost(&query).unwrap();
589
590 assert_eq!(cost.cardinality, 100);
591 assert_eq!(cost.cost, 100.0);
592 }
593
594 #[test]
595 fn test_cost_estimation_join() {
596 let stats = create_test_stats();
597 let optimizer = QueryOptimizer::new(stats);
598
599 let query = QueryOp::Join {
600 left: Box::new(QueryOp::EntityScan {
601 entity_type: "PERSON".to_string(),
602 }),
603 right: Box::new(QueryOp::EntityScan {
604 entity_type: "ORGANIZATION".to_string(),
605 }),
606 join_type: JoinType::Inner,
607 };
608
609 let cost = optimizer.estimate_cost(&query).unwrap();
610
611 assert!(cost.cardinality > 60 && cost.cardinality < 80);
613 }
614
615 #[test]
616 fn test_join_reordering() {
617 let stats = create_test_stats();
618 let optimizer = QueryOptimizer::new(stats);
619
620 let query = QueryOp::Join {
622 left: Box::new(QueryOp::EntityScan {
623 entity_type: "PERSON".to_string(),
624 }),
625 right: Box::new(QueryOp::EntityScan {
626 entity_type: "LOCATION".to_string(),
627 }),
628 join_type: JoinType::Inner,
629 };
630
631 let optimized = optimizer.optimize(query).unwrap();
632
633 if let QueryOp::Join { left, .. } = optimized {
635 if let QueryOp::EntityScan { entity_type } = &*left {
636 assert_eq!(entity_type, "LOCATION", "Smaller table should be first");
637 }
638 }
639 }
640
641 #[test]
642 fn test_neighbors_cost() {
643 let stats = create_test_stats();
644 let optimizer = QueryOptimizer::new(stats);
645
646 let query = QueryOp::Neighbors {
647 source: Box::new(QueryOp::EntityScan {
648 entity_type: "PERSON".to_string(),
649 }),
650 relation_type: Some("WORKS_FOR".to_string()),
651 max_hops: 2,
652 };
653
654 let cost = optimizer.estimate_cost(&query).unwrap();
655
656 assert!(cost.cardinality > 100);
658 }
659
660 #[test]
661 fn test_explain_plan() {
662 let stats = create_test_stats();
663 let optimizer = QueryOptimizer::new(stats);
664
665 let query = QueryOp::Join {
666 left: Box::new(QueryOp::EntityScan {
667 entity_type: "PERSON".to_string(),
668 }),
669 right: Box::new(QueryOp::EntityScan {
670 entity_type: "ORGANIZATION".to_string(),
671 }),
672 join_type: JoinType::Inner,
673 };
674
675 let plan = optimizer.explain(&query).unwrap();
676
677 assert!(plan.contains("Join"));
678 assert!(plan.contains("EntityScan"));
679 assert!(plan.contains("Estimated Cost"));
680 }
681}