Skip to main content

graphrag_core/query/
optimizer.rs

1//! Query Optimizer for Knowledge Graph Queries
2//!
3//! Optimizes query execution plans using:
4//! - Join ordering optimization
5//! - Cost estimation based on graph statistics
6//! - Selectivity estimation
7//! - Query rewriting
8//!
9//! This is a rule-based optimizer without ML dependencies.
10
11use crate::core::KnowledgeGraph;
12use crate::Result;
13use std::collections::HashMap;
14
15/// Query operation types
16#[derive(Debug, Clone, PartialEq)]
17pub enum QueryOp {
18    /// Scan all entities of a type
19    EntityScan {
20        /// The entity type to scan for
21        entity_type: String,
22    },
23    /// Filter entities by property
24    Filter {
25        /// Property name to filter on
26        property: String,
27        /// Property value to match
28        value: String,
29    },
30    /// Join two results on entity relationships
31    Join {
32        /// Left operand of the join
33        left: Box<QueryOp>,
34        /// Right operand of the join
35        right: Box<QueryOp>,
36        /// Type of join operation
37        join_type: JoinType,
38    },
39    /// Get neighbors of entities
40    Neighbors {
41        /// Source entities to find neighbors for
42        source: Box<QueryOp>,
43        /// Optional relationship type filter
44        relation_type: Option<String>,
45        /// Maximum number of hops (graph traversal depth)
46        max_hops: usize,
47    },
48    /// Union of two operations
49    Union {
50        /// Left operand of the union
51        left: Box<QueryOp>,
52        /// Right operand of the union
53        right: Box<QueryOp>,
54    },
55    /// Limit results
56    Limit {
57        /// Source operation to limit
58        source: Box<QueryOp>,
59        /// Maximum number of results
60        count: usize,
61    },
62}
63
64/// Join type
65#[derive(Debug, Clone, PartialEq)]
66pub enum JoinType {
67    /// Inner join (intersection)
68    Inner,
69    /// Left outer join
70    LeftOuter,
71    /// Cross product
72    Cross,
73}
74
75/// Cost statistics for an operation
76#[derive(Debug, Clone)]
77pub struct OperationCost {
78    /// Estimated number of results
79    pub cardinality: usize,
80    /// Estimated cost (abstract units)
81    pub cost: f64,
82    /// Selectivity (0.0-1.0)
83    pub selectivity: f64,
84}
85
86/// Graph statistics for cost estimation
87#[derive(Debug, Clone)]
88pub struct GraphStatistics {
89    /// Total entities
90    pub total_entities: usize,
91    /// Entities per type
92    pub entities_by_type: HashMap<String, usize>,
93    /// Total relationships
94    pub total_relationships: usize,
95    /// Relationships per type
96    pub relationships_by_type: HashMap<String, usize>,
97    /// Average degree (edges per node)
98    pub average_degree: f64,
99}
100
101impl GraphStatistics {
102    /// Compute statistics from a knowledge graph
103    pub fn from_graph(graph: &KnowledgeGraph) -> Self {
104        let entities: Vec<_> = graph.entities().collect();
105        let total_entities = entities.len();
106
107        let mut entities_by_type: HashMap<String, usize> = HashMap::new();
108        for entity in &entities {
109            *entities_by_type
110                .entry(entity.entity_type.clone())
111                .or_insert(0) += 1;
112        }
113
114        let relationships = graph.get_all_relationships();
115        let total_relationships = relationships.len();
116
117        let mut relationships_by_type: HashMap<String, usize> = HashMap::new();
118        for rel in &relationships {
119            *relationships_by_type
120                .entry(rel.relation_type.clone())
121                .or_insert(0) += 1;
122        }
123
124        let average_degree = if total_entities > 0 {
125            (total_relationships as f64 * 2.0) / total_entities as f64
126        } else {
127            0.0
128        };
129
130        Self {
131            total_entities,
132            entities_by_type,
133            total_relationships,
134            relationships_by_type,
135            average_degree,
136        }
137    }
138}
139
140/// Query optimizer
141pub struct QueryOptimizer {
142    stats: GraphStatistics,
143}
144
145impl QueryOptimizer {
146    /// Create a new optimizer with graph statistics
147    pub fn new(stats: GraphStatistics) -> Self {
148        Self { stats }
149    }
150
151    /// Optimize a query operation
152    pub fn optimize(&self, query: QueryOp) -> Result<QueryOp> {
153        // Apply optimization rules
154        let rewritten = self.rewrite_query(query)?;
155        let optimized = self.optimize_joins(rewritten)?;
156        Ok(optimized)
157    }
158
159    /// Rewrite query using algebraic rules
160    fn rewrite_query(&self, query: QueryOp) -> Result<QueryOp> {
161        match query {
162            // Push filters down through joins
163            QueryOp::Filter { property, value } => Ok(QueryOp::Filter { property, value }),
164
165            // Reorder joins based on selectivity
166            QueryOp::Join {
167                left,
168                right,
169                join_type,
170            } => {
171                let left_opt = self.rewrite_query(*left)?;
172                let right_opt = self.rewrite_query(*right)?;
173
174                // Estimate costs
175                let left_cost = self.estimate_cost(&left_opt)?;
176                let right_cost = self.estimate_cost(&right_opt)?;
177
178                // Put smaller (more selective) operand first for hash joins
179                if left_cost.cardinality > right_cost.cardinality {
180                    Ok(QueryOp::Join {
181                        left: Box::new(right_opt),
182                        right: Box::new(left_opt),
183                        join_type,
184                    })
185                } else {
186                    Ok(QueryOp::Join {
187                        left: Box::new(left_opt),
188                        right: Box::new(right_opt),
189                        join_type,
190                    })
191                }
192            },
193
194            // Recursively optimize subqueries
195            QueryOp::Neighbors {
196                source,
197                relation_type,
198                max_hops,
199            } => {
200                let source_opt = self.rewrite_query(*source)?;
201                Ok(QueryOp::Neighbors {
202                    source: Box::new(source_opt),
203                    relation_type,
204                    max_hops,
205                })
206            },
207
208            QueryOp::Union { left, right } => {
209                let left_opt = self.rewrite_query(*left)?;
210                let right_opt = self.rewrite_query(*right)?;
211                Ok(QueryOp::Union {
212                    left: Box::new(left_opt),
213                    right: Box::new(right_opt),
214                })
215            },
216
217            QueryOp::Limit { source, count } => {
218                let source_opt = self.rewrite_query(*source)?;
219                Ok(QueryOp::Limit {
220                    source: Box::new(source_opt),
221                    count,
222                })
223            },
224
225            // Base case: entity scans
226            QueryOp::EntityScan { entity_type } => Ok(QueryOp::EntityScan { entity_type }),
227        }
228    }
229
230    /// Optimize join ordering using dynamic programming
231    fn optimize_joins(&self, query: QueryOp) -> Result<QueryOp> {
232        match query {
233            QueryOp::Join {
234                left,
235                right,
236                join_type,
237            } => {
238                // Recursively optimize sub-queries
239                let left_opt = self.optimize_joins(*left)?;
240                let right_opt = self.optimize_joins(*right)?;
241
242                // For multi-way joins, collect all join operands
243                let mut operands = Vec::new();
244                Self::collect_join_operands(&left_opt, &mut operands);
245                Self::collect_join_operands(&right_opt, &mut operands);
246
247                if operands.len() > 2 {
248                    // Multi-way join: find optimal order using greedy algorithm
249                    self.find_optimal_join_order(operands, join_type)
250                } else {
251                    // Binary join: already optimized by rewrite_query
252                    Ok(QueryOp::Join {
253                        left: Box::new(left_opt),
254                        right: Box::new(right_opt),
255                        join_type,
256                    })
257                }
258            },
259
260            // Recursively process other operations
261            QueryOp::Neighbors {
262                source,
263                relation_type,
264                max_hops,
265            } => {
266                let source_opt = self.optimize_joins(*source)?;
267                Ok(QueryOp::Neighbors {
268                    source: Box::new(source_opt),
269                    relation_type,
270                    max_hops,
271                })
272            },
273
274            QueryOp::Union { left, right } => {
275                let left_opt = self.optimize_joins(*left)?;
276                let right_opt = self.optimize_joins(*right)?;
277                Ok(QueryOp::Union {
278                    left: Box::new(left_opt),
279                    right: Box::new(right_opt),
280                })
281            },
282
283            QueryOp::Limit { source, count } => {
284                let source_opt = self.optimize_joins(*source)?;
285                Ok(QueryOp::Limit {
286                    source: Box::new(source_opt),
287                    count,
288                })
289            },
290
291            // Leaf operations
292            _ => Ok(query),
293        }
294    }
295
296    /// Collect all join operands for multi-way join optimization
297    fn collect_join_operands(op: &QueryOp, operands: &mut Vec<QueryOp>) {
298        match op {
299            QueryOp::Join { left, right, .. } => {
300                Self::collect_join_operands(left, operands);
301                Self::collect_join_operands(right, operands);
302            },
303            _ => {
304                operands.push(op.clone());
305            },
306        }
307    }
308
309    /// Find optimal join order using greedy algorithm
310    fn find_optimal_join_order(
311        &self,
312        mut operands: Vec<QueryOp>,
313        join_type: JoinType,
314    ) -> Result<QueryOp> {
315        if operands.is_empty() {
316            return Err(crate::core::GraphRAGError::Validation {
317                message: "No operands for join".to_string(),
318            });
319        }
320
321        if operands.len() == 1 {
322            return Ok(operands.pop().unwrap());
323        }
324
325        // Greedy algorithm: repeatedly pick the two cheapest operands to join
326        while operands.len() > 1 {
327            let mut min_cost = f64::MAX;
328            let mut best_i = 0;
329            let mut best_j = 1;
330
331            // Find pair with minimum join cost
332            for i in 0..operands.len() {
333                for j in (i + 1)..operands.len() {
334                    let cost_i = self.estimate_cost(&operands[i])?;
335                    let cost_j = self.estimate_cost(&operands[j])?;
336
337                    // Estimate join cost as product of cardinalities (simplified)
338                    let join_cost = (cost_i.cardinality as f64) * (cost_j.cardinality as f64);
339
340                    if join_cost < min_cost {
341                        min_cost = join_cost;
342                        best_i = i;
343                        best_j = j;
344                    }
345                }
346            }
347
348            // Create join of best pair
349            let left = operands.remove(best_i);
350            let right = operands.remove(if best_j > best_i { best_j - 1 } else { best_j });
351
352            let joined = QueryOp::Join {
353                left: Box::new(left),
354                right: Box::new(right),
355                join_type: join_type.clone(),
356            };
357
358            operands.push(joined);
359        }
360
361        Ok(operands.pop().unwrap())
362    }
363
364    /// Estimate cost of an operation
365    pub fn estimate_cost(&self, op: &QueryOp) -> Result<OperationCost> {
366        match op {
367            QueryOp::EntityScan { entity_type } => {
368                let cardinality = self
369                    .stats
370                    .entities_by_type
371                    .get(entity_type)
372                    .copied()
373                    .unwrap_or(0);
374
375                Ok(OperationCost {
376                    cardinality,
377                    cost: cardinality as f64,
378                    selectivity: if self.stats.total_entities > 0 {
379                        cardinality as f64 / self.stats.total_entities as f64
380                    } else {
381                        0.0
382                    },
383                })
384            },
385
386            QueryOp::Filter {
387                property: _,
388                value: _,
389            } => {
390                // Assume filter has 10% selectivity (can be improved with histograms)
391                let selectivity = 0.1;
392                let cardinality = (self.stats.total_entities as f64 * selectivity) as usize;
393
394                Ok(OperationCost {
395                    cardinality,
396                    cost: self.stats.total_entities as f64, // Must scan all
397                    selectivity,
398                })
399            },
400
401            QueryOp::Join {
402                left,
403                right,
404                join_type,
405            } => {
406                let left_cost = self.estimate_cost(left)?;
407                let right_cost = self.estimate_cost(right)?;
408
409                let cardinality = match join_type {
410                    JoinType::Inner => {
411                        // Estimate as geometric mean of inputs
412                        ((left_cost.cardinality as f64) * (right_cost.cardinality as f64)).sqrt()
413                            as usize
414                    },
415                    JoinType::LeftOuter => left_cost.cardinality,
416                    JoinType::Cross => left_cost.cardinality * right_cost.cardinality,
417                };
418
419                let cost = left_cost.cost
420                    + right_cost.cost
421                    + (left_cost.cardinality as f64 * right_cost.cardinality as f64);
422
423                Ok(OperationCost {
424                    cardinality,
425                    cost,
426                    selectivity: left_cost.selectivity * right_cost.selectivity,
427                })
428            },
429
430            QueryOp::Neighbors {
431                source,
432                relation_type: _,
433                max_hops,
434            } => {
435                let source_cost = self.estimate_cost(source)?;
436
437                // Estimate neighbors as source_cardinality * avg_degree^hops
438                let expansion_factor = self.stats.average_degree.powi(*max_hops as i32);
439                let cardinality = (source_cost.cardinality as f64 * expansion_factor)
440                    .min(self.stats.total_entities as f64)
441                    as usize;
442
443                Ok(OperationCost {
444                    cardinality,
445                    cost: source_cost.cost + (cardinality as f64),
446                    selectivity: cardinality as f64 / self.stats.total_entities as f64,
447                })
448            },
449
450            QueryOp::Union { left, right } => {
451                let left_cost = self.estimate_cost(left)?;
452                let right_cost = self.estimate_cost(right)?;
453
454                // Union cardinality (with some overlap assumed)
455                let cardinality = (left_cost.cardinality + right_cost.cardinality) * 9 / 10;
456
457                Ok(OperationCost {
458                    cardinality,
459                    cost: left_cost.cost + right_cost.cost,
460                    selectivity: (left_cost.selectivity + right_cost.selectivity).min(1.0),
461                })
462            },
463
464            QueryOp::Limit { source, count } => {
465                let source_cost = self.estimate_cost(source)?;
466
467                Ok(OperationCost {
468                    cardinality: (*count).min(source_cost.cardinality),
469                    cost: source_cost.cost,
470                    selectivity: (*count as f64 / self.stats.total_entities as f64).min(1.0),
471                })
472            },
473        }
474    }
475
476    /// Generate an execution plan with cost annotations
477    pub fn explain(&self, op: &QueryOp) -> Result<String> {
478        let cost = self.estimate_cost(op)?;
479        let mut plan = String::new();
480
481        self.explain_recursive(op, 0, &mut plan)?;
482
483        plan.push_str(&format!(
484            "\nEstimated Cost: {:.2}\nEstimated Cardinality: {}\nSelectivity: {:.2}%\n",
485            cost.cost,
486            cost.cardinality,
487            cost.selectivity * 100.0
488        ));
489
490        Ok(plan)
491    }
492
493    /// Recursively build execution plan string
494    fn explain_recursive(&self, op: &QueryOp, depth: usize, plan: &mut String) -> Result<()> {
495        let indent = "  ".repeat(depth);
496        let cost = self.estimate_cost(op)?;
497
498        match op {
499            QueryOp::EntityScan { entity_type } => {
500                plan.push_str(&format!(
501                    "{}EntityScan({}) [cost={:.0}, rows={}]\n",
502                    indent, entity_type, cost.cost, cost.cardinality
503                ));
504            },
505            QueryOp::Filter { property, value } => {
506                plan.push_str(&format!(
507                    "{}Filter({}={}) [cost={:.0}, rows={}]\n",
508                    indent, property, value, cost.cost, cost.cardinality
509                ));
510            },
511            QueryOp::Join {
512                left,
513                right,
514                join_type,
515            } => {
516                plan.push_str(&format!(
517                    "{}Join({:?}) [cost={:.0}, rows={}]\n",
518                    indent, join_type, cost.cost, cost.cardinality
519                ));
520                self.explain_recursive(left, depth + 1, plan)?;
521                self.explain_recursive(right, depth + 1, plan)?;
522            },
523            QueryOp::Neighbors {
524                source,
525                relation_type,
526                max_hops,
527            } => {
528                let rel_str = relation_type.as_deref().unwrap_or("*");
529                plan.push_str(&format!(
530                    "{}Neighbors({}, hops={}) [cost={:.0}, rows={}]\n",
531                    indent, rel_str, max_hops, cost.cost, cost.cardinality
532                ));
533                self.explain_recursive(source, depth + 1, plan)?;
534            },
535            QueryOp::Union { left, right } => {
536                plan.push_str(&format!(
537                    "{}Union [cost={:.0}, rows={}]\n",
538                    indent, cost.cost, cost.cardinality
539                ));
540                self.explain_recursive(left, depth + 1, plan)?;
541                self.explain_recursive(right, depth + 1, plan)?;
542            },
543            QueryOp::Limit { source, count } => {
544                plan.push_str(&format!(
545                    "{}Limit({}) [cost={:.0}, rows={}]\n",
546                    indent, count, cost.cost, cost.cardinality
547                ));
548                self.explain_recursive(source, depth + 1, plan)?;
549            },
550        }
551
552        Ok(())
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    fn create_test_stats() -> GraphStatistics {
561        let mut entities_by_type = HashMap::new();
562        entities_by_type.insert("PERSON".to_string(), 100);
563        entities_by_type.insert("ORGANIZATION".to_string(), 50);
564        entities_by_type.insert("LOCATION".to_string(), 30);
565
566        let mut relationships_by_type = HashMap::new();
567        relationships_by_type.insert("WORKS_FOR".to_string(), 80);
568        relationships_by_type.insert("LOCATED_IN".to_string(), 60);
569
570        GraphStatistics {
571            total_entities: 180,
572            entities_by_type,
573            total_relationships: 140,
574            relationships_by_type,
575            average_degree: 1.56,
576        }
577    }
578
579    #[test]
580    fn test_cost_estimation_scan() {
581        let stats = create_test_stats();
582        let optimizer = QueryOptimizer::new(stats);
583
584        let query = QueryOp::EntityScan {
585            entity_type: "PERSON".to_string(),
586        };
587
588        let cost = optimizer.estimate_cost(&query).unwrap();
589
590        assert_eq!(cost.cardinality, 100);
591        assert_eq!(cost.cost, 100.0);
592    }
593
594    #[test]
595    fn test_cost_estimation_join() {
596        let stats = create_test_stats();
597        let optimizer = QueryOptimizer::new(stats);
598
599        let query = QueryOp::Join {
600            left: Box::new(QueryOp::EntityScan {
601                entity_type: "PERSON".to_string(),
602            }),
603            right: Box::new(QueryOp::EntityScan {
604                entity_type: "ORGANIZATION".to_string(),
605            }),
606            join_type: JoinType::Inner,
607        };
608
609        let cost = optimizer.estimate_cost(&query).unwrap();
610
611        // Geometric mean: sqrt(100 * 50) = ~71
612        assert!(cost.cardinality > 60 && cost.cardinality < 80);
613    }
614
615    #[test]
616    fn test_join_reordering() {
617        let stats = create_test_stats();
618        let optimizer = QueryOptimizer::new(stats);
619
620        // Join large table (PERSON=100) with small table (LOCATION=30)
621        let query = QueryOp::Join {
622            left: Box::new(QueryOp::EntityScan {
623                entity_type: "PERSON".to_string(),
624            }),
625            right: Box::new(QueryOp::EntityScan {
626                entity_type: "LOCATION".to_string(),
627            }),
628            join_type: JoinType::Inner,
629        };
630
631        let optimized = optimizer.optimize(query).unwrap();
632
633        // Should reorder to put smaller table first
634        if let QueryOp::Join { left, .. } = optimized {
635            if let QueryOp::EntityScan { entity_type } = &*left {
636                assert_eq!(entity_type, "LOCATION", "Smaller table should be first");
637            }
638        }
639    }
640
641    #[test]
642    fn test_neighbors_cost() {
643        let stats = create_test_stats();
644        let optimizer = QueryOptimizer::new(stats);
645
646        let query = QueryOp::Neighbors {
647            source: Box::new(QueryOp::EntityScan {
648                entity_type: "PERSON".to_string(),
649            }),
650            relation_type: Some("WORKS_FOR".to_string()),
651            max_hops: 2,
652        };
653
654        let cost = optimizer.estimate_cost(&query).unwrap();
655
656        // Should expand based on avg_degree^hops
657        assert!(cost.cardinality > 100);
658    }
659
660    #[test]
661    fn test_explain_plan() {
662        let stats = create_test_stats();
663        let optimizer = QueryOptimizer::new(stats);
664
665        let query = QueryOp::Join {
666            left: Box::new(QueryOp::EntityScan {
667                entity_type: "PERSON".to_string(),
668            }),
669            right: Box::new(QueryOp::EntityScan {
670                entity_type: "ORGANIZATION".to_string(),
671            }),
672            join_type: JoinType::Inner,
673        };
674
675        let plan = optimizer.explain(&query).unwrap();
676
677        assert!(plan.contains("Join"));
678        assert!(plan.contains("EntityScan"));
679        assert!(plan.contains("Estimated Cost"));
680    }
681}