1use crate::core::KnowledgeGraph;
12use crate::Result;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum QueryOp {
18 EntityScan {
20 entity_type: String,
22 },
23 Filter {
25 property: String,
27 value: String,
29 },
30 Join {
32 left: Box<QueryOp>,
34 right: Box<QueryOp>,
36 join_type: JoinType,
38 },
39 Neighbors {
41 source: Box<QueryOp>,
43 relation_type: Option<String>,
45 max_hops: usize,
47 },
48 Union {
50 left: Box<QueryOp>,
52 right: Box<QueryOp>,
54 },
55 Limit {
57 source: Box<QueryOp>,
59 count: usize,
61 },
62}
63
64#[derive(Debug, Clone, PartialEq)]
66pub enum JoinType {
67 Inner,
69 LeftOuter,
71 Cross,
73}
74
75#[derive(Debug, Clone)]
77pub struct OperationCost {
78 pub cardinality: usize,
80 pub cost: f64,
82 pub selectivity: f64,
84}
85
86#[derive(Debug, Clone)]
88pub struct GraphStatistics {
89 pub total_entities: usize,
91 pub entities_by_type: HashMap<String, usize>,
93 pub total_relationships: usize,
95 pub relationships_by_type: HashMap<String, usize>,
97 pub average_degree: f64,
99}
100
101impl GraphStatistics {
102 pub fn from_graph(graph: &KnowledgeGraph) -> Self {
104 let entities: Vec<_> = graph.entities().collect();
105 let total_entities = entities.len();
106
107 let mut entities_by_type: HashMap<String, usize> = HashMap::new();
108 for entity in &entities {
109 *entities_by_type
110 .entry(entity.entity_type.clone())
111 .or_insert(0) += 1;
112 }
113
114 let relationships = graph.get_all_relationships();
115 let total_relationships = relationships.len();
116
117 let mut relationships_by_type: HashMap<String, usize> = HashMap::new();
118 for rel in &relationships {
119 *relationships_by_type
120 .entry(rel.relation_type.clone())
121 .or_insert(0) += 1;
122 }
123
124 let average_degree = if total_entities > 0 {
125 (total_relationships as f64 * 2.0) / total_entities as f64
126 } else {
127 0.0
128 };
129
130 Self {
131 total_entities,
132 entities_by_type,
133 total_relationships,
134 relationships_by_type,
135 average_degree,
136 }
137 }
138}
139
140pub struct QueryOptimizer {
142 stats: GraphStatistics,
143}
144
145impl QueryOptimizer {
146 pub fn new(stats: GraphStatistics) -> Self {
148 Self { stats }
149 }
150
151 pub fn optimize(&self, query: QueryOp) -> Result<QueryOp> {
153 let rewritten = self.rewrite_query(query)?;
155 let optimized = self.optimize_joins(rewritten)?;
156 Ok(optimized)
157 }
158
159 fn rewrite_query(&self, query: QueryOp) -> Result<QueryOp> {
161 match query {
162 QueryOp::Filter { property, value } => {
164 Ok(QueryOp::Filter { property, value })
165 }
166
167 QueryOp::Join {
169 left,
170 right,
171 join_type,
172 } => {
173 let left_opt = self.rewrite_query(*left)?;
174 let right_opt = self.rewrite_query(*right)?;
175
176 let left_cost = self.estimate_cost(&left_opt)?;
178 let right_cost = self.estimate_cost(&right_opt)?;
179
180 if left_cost.cardinality > right_cost.cardinality {
182 Ok(QueryOp::Join {
183 left: Box::new(right_opt),
184 right: Box::new(left_opt),
185 join_type,
186 })
187 } else {
188 Ok(QueryOp::Join {
189 left: Box::new(left_opt),
190 right: Box::new(right_opt),
191 join_type,
192 })
193 }
194 }
195
196 QueryOp::Neighbors {
198 source,
199 relation_type,
200 max_hops,
201 } => {
202 let source_opt = self.rewrite_query(*source)?;
203 Ok(QueryOp::Neighbors {
204 source: Box::new(source_opt),
205 relation_type,
206 max_hops,
207 })
208 }
209
210 QueryOp::Union { left, right } => {
211 let left_opt = self.rewrite_query(*left)?;
212 let right_opt = self.rewrite_query(*right)?;
213 Ok(QueryOp::Union {
214 left: Box::new(left_opt),
215 right: Box::new(right_opt),
216 })
217 }
218
219 QueryOp::Limit { source, count } => {
220 let source_opt = self.rewrite_query(*source)?;
221 Ok(QueryOp::Limit {
222 source: Box::new(source_opt),
223 count,
224 })
225 }
226
227 QueryOp::EntityScan { entity_type } => Ok(QueryOp::EntityScan { entity_type }),
229 }
230 }
231
232 fn optimize_joins(&self, query: QueryOp) -> Result<QueryOp> {
234 match query {
235 QueryOp::Join {
236 left,
237 right,
238 join_type,
239 } => {
240 let left_opt = self.optimize_joins(*left)?;
242 let right_opt = self.optimize_joins(*right)?;
243
244 let mut operands = Vec::new();
246 self.collect_join_operands(&left_opt, &mut operands);
247 self.collect_join_operands(&right_opt, &mut operands);
248
249 if operands.len() > 2 {
250 self.find_optimal_join_order(operands, join_type)
252 } else {
253 Ok(QueryOp::Join {
255 left: Box::new(left_opt),
256 right: Box::new(right_opt),
257 join_type,
258 })
259 }
260 }
261
262 QueryOp::Neighbors {
264 source,
265 relation_type,
266 max_hops,
267 } => {
268 let source_opt = self.optimize_joins(*source)?;
269 Ok(QueryOp::Neighbors {
270 source: Box::new(source_opt),
271 relation_type,
272 max_hops,
273 })
274 }
275
276 QueryOp::Union { left, right } => {
277 let left_opt = self.optimize_joins(*left)?;
278 let right_opt = self.optimize_joins(*right)?;
279 Ok(QueryOp::Union {
280 left: Box::new(left_opt),
281 right: Box::new(right_opt),
282 })
283 }
284
285 QueryOp::Limit { source, count } => {
286 let source_opt = self.optimize_joins(*source)?;
287 Ok(QueryOp::Limit {
288 source: Box::new(source_opt),
289 count,
290 })
291 }
292
293 _ => Ok(query),
295 }
296 }
297
298 fn collect_join_operands(&self, op: &QueryOp, operands: &mut Vec<QueryOp>) {
300 match op {
301 QueryOp::Join { left, right, .. } => {
302 self.collect_join_operands(left, operands);
303 self.collect_join_operands(right, operands);
304 }
305 _ => {
306 operands.push(op.clone());
307 }
308 }
309 }
310
311 fn find_optimal_join_order(
313 &self,
314 mut operands: Vec<QueryOp>,
315 join_type: JoinType,
316 ) -> Result<QueryOp> {
317 if operands.is_empty() {
318 return Err(crate::core::GraphRAGError::Validation {
319 message: "No operands for join".to_string(),
320 });
321 }
322
323 if operands.len() == 1 {
324 return Ok(operands.pop().unwrap());
325 }
326
327 while operands.len() > 1 {
329 let mut min_cost = f64::MAX;
330 let mut best_i = 0;
331 let mut best_j = 1;
332
333 for i in 0..operands.len() {
335 for j in (i + 1)..operands.len() {
336 let cost_i = self.estimate_cost(&operands[i])?;
337 let cost_j = self.estimate_cost(&operands[j])?;
338
339 let join_cost = (cost_i.cardinality as f64) * (cost_j.cardinality as f64);
341
342 if join_cost < min_cost {
343 min_cost = join_cost;
344 best_i = i;
345 best_j = j;
346 }
347 }
348 }
349
350 let left = operands.remove(best_i);
352 let right = operands.remove(if best_j > best_i {
353 best_j - 1
354 } else {
355 best_j
356 });
357
358 let joined = QueryOp::Join {
359 left: Box::new(left),
360 right: Box::new(right),
361 join_type: join_type.clone(),
362 };
363
364 operands.push(joined);
365 }
366
367 Ok(operands.pop().unwrap())
368 }
369
370 pub fn estimate_cost(&self, op: &QueryOp) -> Result<OperationCost> {
372 match op {
373 QueryOp::EntityScan { entity_type } => {
374 let cardinality = self
375 .stats
376 .entities_by_type
377 .get(entity_type)
378 .copied()
379 .unwrap_or(0);
380
381 Ok(OperationCost {
382 cardinality,
383 cost: cardinality as f64,
384 selectivity: if self.stats.total_entities > 0 {
385 cardinality as f64 / self.stats.total_entities as f64
386 } else {
387 0.0
388 },
389 })
390 }
391
392 QueryOp::Filter { property: _, value: _ } => {
393 let selectivity = 0.1;
395 let cardinality = (self.stats.total_entities as f64 * selectivity) as usize;
396
397 Ok(OperationCost {
398 cardinality,
399 cost: self.stats.total_entities as f64, selectivity,
401 })
402 }
403
404 QueryOp::Join { left, right, join_type } => {
405 let left_cost = self.estimate_cost(left)?;
406 let right_cost = self.estimate_cost(right)?;
407
408 let cardinality = match join_type {
409 JoinType::Inner => {
410 ((left_cost.cardinality as f64) * (right_cost.cardinality as f64))
412 .sqrt() as usize
413 }
414 JoinType::LeftOuter => left_cost.cardinality,
415 JoinType::Cross => left_cost.cardinality * right_cost.cardinality,
416 };
417
418 let cost = left_cost.cost
419 + right_cost.cost
420 + (left_cost.cardinality as f64 * right_cost.cardinality as f64);
421
422 Ok(OperationCost {
423 cardinality,
424 cost,
425 selectivity: left_cost.selectivity * right_cost.selectivity,
426 })
427 }
428
429 QueryOp::Neighbors {
430 source,
431 relation_type: _,
432 max_hops,
433 } => {
434 let source_cost = self.estimate_cost(source)?;
435
436 let expansion_factor = self.stats.average_degree.powi(*max_hops as i32);
438 let cardinality =
439 (source_cost.cardinality as f64 * expansion_factor).min(self.stats.total_entities as f64) as usize;
440
441 Ok(OperationCost {
442 cardinality,
443 cost: source_cost.cost + (cardinality as f64),
444 selectivity: cardinality as f64 / self.stats.total_entities as f64,
445 })
446 }
447
448 QueryOp::Union { left, right } => {
449 let left_cost = self.estimate_cost(left)?;
450 let right_cost = self.estimate_cost(right)?;
451
452 let cardinality = (left_cost.cardinality + right_cost.cardinality) * 9 / 10;
454
455 Ok(OperationCost {
456 cardinality,
457 cost: left_cost.cost + right_cost.cost,
458 selectivity: (left_cost.selectivity + right_cost.selectivity).min(1.0),
459 })
460 }
461
462 QueryOp::Limit { source, count } => {
463 let source_cost = self.estimate_cost(source)?;
464
465 Ok(OperationCost {
466 cardinality: (*count).min(source_cost.cardinality),
467 cost: source_cost.cost,
468 selectivity: (*count as f64 / self.stats.total_entities as f64).min(1.0),
469 })
470 }
471 }
472 }
473
474 pub fn explain(&self, op: &QueryOp) -> Result<String> {
476 let cost = self.estimate_cost(op)?;
477 let mut plan = String::new();
478
479 self.explain_recursive(op, 0, &mut plan)?;
480
481 plan.push_str(&format!(
482 "\nEstimated Cost: {:.2}\nEstimated Cardinality: {}\nSelectivity: {:.2}%\n",
483 cost.cost,
484 cost.cardinality,
485 cost.selectivity * 100.0
486 ));
487
488 Ok(plan)
489 }
490
491 fn explain_recursive(&self, op: &QueryOp, depth: usize, plan: &mut String) -> Result<()> {
493 let indent = " ".repeat(depth);
494 let cost = self.estimate_cost(op)?;
495
496 match op {
497 QueryOp::EntityScan { entity_type } => {
498 plan.push_str(&format!(
499 "{}EntityScan({}) [cost={:.0}, rows={}]\n",
500 indent, entity_type, cost.cost, cost.cardinality
501 ));
502 }
503 QueryOp::Filter { property, value } => {
504 plan.push_str(&format!(
505 "{}Filter({}={}) [cost={:.0}, rows={}]\n",
506 indent, property, value, cost.cost, cost.cardinality
507 ));
508 }
509 QueryOp::Join {
510 left,
511 right,
512 join_type,
513 } => {
514 plan.push_str(&format!(
515 "{}Join({:?}) [cost={:.0}, rows={}]\n",
516 indent, join_type, cost.cost, cost.cardinality
517 ));
518 self.explain_recursive(left, depth + 1, plan)?;
519 self.explain_recursive(right, depth + 1, plan)?;
520 }
521 QueryOp::Neighbors {
522 source,
523 relation_type,
524 max_hops,
525 } => {
526 let rel_str = relation_type.as_deref().unwrap_or("*");
527 plan.push_str(&format!(
528 "{}Neighbors({}, hops={}) [cost={:.0}, rows={}]\n",
529 indent, rel_str, max_hops, cost.cost, cost.cardinality
530 ));
531 self.explain_recursive(source, depth + 1, plan)?;
532 }
533 QueryOp::Union { left, right } => {
534 plan.push_str(&format!(
535 "{}Union [cost={:.0}, rows={}]\n",
536 indent, cost.cost, cost.cardinality
537 ));
538 self.explain_recursive(left, depth + 1, plan)?;
539 self.explain_recursive(right, depth + 1, plan)?;
540 }
541 QueryOp::Limit { source, count } => {
542 plan.push_str(&format!(
543 "{}Limit({}) [cost={:.0}, rows={}]\n",
544 indent, count, cost.cost, cost.cardinality
545 ));
546 self.explain_recursive(source, depth + 1, plan)?;
547 }
548 }
549
550 Ok(())
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 fn create_test_stats() -> GraphStatistics {
559 let mut entities_by_type = HashMap::new();
560 entities_by_type.insert("PERSON".to_string(), 100);
561 entities_by_type.insert("ORGANIZATION".to_string(), 50);
562 entities_by_type.insert("LOCATION".to_string(), 30);
563
564 let mut relationships_by_type = HashMap::new();
565 relationships_by_type.insert("WORKS_FOR".to_string(), 80);
566 relationships_by_type.insert("LOCATED_IN".to_string(), 60);
567
568 GraphStatistics {
569 total_entities: 180,
570 entities_by_type,
571 total_relationships: 140,
572 relationships_by_type,
573 average_degree: 1.56,
574 }
575 }
576
577 #[test]
578 fn test_cost_estimation_scan() {
579 let stats = create_test_stats();
580 let optimizer = QueryOptimizer::new(stats);
581
582 let query = QueryOp::EntityScan {
583 entity_type: "PERSON".to_string(),
584 };
585
586 let cost = optimizer.estimate_cost(&query).unwrap();
587
588 assert_eq!(cost.cardinality, 100);
589 assert_eq!(cost.cost, 100.0);
590 }
591
592 #[test]
593 fn test_cost_estimation_join() {
594 let stats = create_test_stats();
595 let optimizer = QueryOptimizer::new(stats);
596
597 let query = QueryOp::Join {
598 left: Box::new(QueryOp::EntityScan {
599 entity_type: "PERSON".to_string(),
600 }),
601 right: Box::new(QueryOp::EntityScan {
602 entity_type: "ORGANIZATION".to_string(),
603 }),
604 join_type: JoinType::Inner,
605 };
606
607 let cost = optimizer.estimate_cost(&query).unwrap();
608
609 assert!(cost.cardinality > 60 && cost.cardinality < 80);
611 }
612
613 #[test]
614 fn test_join_reordering() {
615 let stats = create_test_stats();
616 let optimizer = QueryOptimizer::new(stats);
617
618 let query = QueryOp::Join {
620 left: Box::new(QueryOp::EntityScan {
621 entity_type: "PERSON".to_string(),
622 }),
623 right: Box::new(QueryOp::EntityScan {
624 entity_type: "LOCATION".to_string(),
625 }),
626 join_type: JoinType::Inner,
627 };
628
629 let optimized = optimizer.optimize(query).unwrap();
630
631 if let QueryOp::Join { left, .. } = optimized {
633 if let QueryOp::EntityScan { entity_type } = &*left {
634 assert_eq!(entity_type, "LOCATION", "Smaller table should be first");
635 }
636 }
637 }
638
639 #[test]
640 fn test_neighbors_cost() {
641 let stats = create_test_stats();
642 let optimizer = QueryOptimizer::new(stats);
643
644 let query = QueryOp::Neighbors {
645 source: Box::new(QueryOp::EntityScan {
646 entity_type: "PERSON".to_string(),
647 }),
648 relation_type: Some("WORKS_FOR".to_string()),
649 max_hops: 2,
650 };
651
652 let cost = optimizer.estimate_cost(&query).unwrap();
653
654 assert!(cost.cardinality > 100);
656 }
657
658 #[test]
659 fn test_explain_plan() {
660 let stats = create_test_stats();
661 let optimizer = QueryOptimizer::new(stats);
662
663 let query = QueryOp::Join {
664 left: Box::new(QueryOp::EntityScan {
665 entity_type: "PERSON".to_string(),
666 }),
667 right: Box::new(QueryOp::EntityScan {
668 entity_type: "ORGANIZATION".to_string(),
669 }),
670 join_type: JoinType::Inner,
671 };
672
673 let plan = optimizer.explain(&query).unwrap();
674
675 assert!(plan.contains("Join"));
676 assert!(plan.contains("EntityScan"));
677 assert!(plan.contains("Estimated Cost"));
678 }
679}