Skip to main content

graphrag_core/query/
optimizer.rs

1//! Query Optimizer for Knowledge Graph Queries
2//!
3//! Optimizes query execution plans using:
4//! - Join ordering optimization
5//! - Cost estimation based on graph statistics
6//! - Selectivity estimation
7//! - Query rewriting
8//!
9//! This is a rule-based optimizer without ML dependencies.
10
11use crate::core::KnowledgeGraph;
12use crate::Result;
13use std::collections::HashMap;
14
15/// Query operation types
16#[derive(Debug, Clone, PartialEq)]
17pub enum QueryOp {
18    /// Scan all entities of a type
19    EntityScan {
20        /// The entity type to scan for
21        entity_type: String,
22    },
23    /// Filter entities by property
24    Filter {
25        /// Property name to filter on
26        property: String,
27        /// Property value to match
28        value: String,
29    },
30    /// Join two results on entity relationships
31    Join {
32        /// Left operand of the join
33        left: Box<QueryOp>,
34        /// Right operand of the join
35        right: Box<QueryOp>,
36        /// Type of join operation
37        join_type: JoinType,
38    },
39    /// Get neighbors of entities
40    Neighbors {
41        /// Source entities to find neighbors for
42        source: Box<QueryOp>,
43        /// Optional relationship type filter
44        relation_type: Option<String>,
45        /// Maximum number of hops (graph traversal depth)
46        max_hops: usize,
47    },
48    /// Union of two operations
49    Union {
50        /// Left operand of the union
51        left: Box<QueryOp>,
52        /// Right operand of the union
53        right: Box<QueryOp>,
54    },
55    /// Limit results
56    Limit {
57        /// Source operation to limit
58        source: Box<QueryOp>,
59        /// Maximum number of results
60        count: usize,
61    },
62}
63
64/// Join type
65#[derive(Debug, Clone, PartialEq)]
66pub enum JoinType {
67    /// Inner join (intersection)
68    Inner,
69    /// Left outer join
70    LeftOuter,
71    /// Cross product
72    Cross,
73}
74
75/// Cost statistics for an operation
76#[derive(Debug, Clone)]
77pub struct OperationCost {
78    /// Estimated number of results
79    pub cardinality: usize,
80    /// Estimated cost (abstract units)
81    pub cost: f64,
82    /// Selectivity (0.0-1.0)
83    pub selectivity: f64,
84}
85
86/// Graph statistics for cost estimation
87#[derive(Debug, Clone)]
88pub struct GraphStatistics {
89    /// Total entities
90    pub total_entities: usize,
91    /// Entities per type
92    pub entities_by_type: HashMap<String, usize>,
93    /// Total relationships
94    pub total_relationships: usize,
95    /// Relationships per type
96    pub relationships_by_type: HashMap<String, usize>,
97    /// Average degree (edges per node)
98    pub average_degree: f64,
99}
100
101impl GraphStatistics {
102    /// Compute statistics from a knowledge graph
103    pub fn from_graph(graph: &KnowledgeGraph) -> Self {
104        let entities: Vec<_> = graph.entities().collect();
105        let total_entities = entities.len();
106
107        let mut entities_by_type: HashMap<String, usize> = HashMap::new();
108        for entity in &entities {
109            *entities_by_type
110                .entry(entity.entity_type.clone())
111                .or_insert(0) += 1;
112        }
113
114        let relationships = graph.get_all_relationships();
115        let total_relationships = relationships.len();
116
117        let mut relationships_by_type: HashMap<String, usize> = HashMap::new();
118        for rel in &relationships {
119            *relationships_by_type
120                .entry(rel.relation_type.clone())
121                .or_insert(0) += 1;
122        }
123
124        let average_degree = if total_entities > 0 {
125            (total_relationships as f64 * 2.0) / total_entities as f64
126        } else {
127            0.0
128        };
129
130        Self {
131            total_entities,
132            entities_by_type,
133            total_relationships,
134            relationships_by_type,
135            average_degree,
136        }
137    }
138}
139
140/// Query optimizer
141pub struct QueryOptimizer {
142    stats: GraphStatistics,
143}
144
145impl QueryOptimizer {
146    /// Create a new optimizer with graph statistics
147    pub fn new(stats: GraphStatistics) -> Self {
148        Self { stats }
149    }
150
151    /// Optimize a query operation
152    pub fn optimize(&self, query: QueryOp) -> Result<QueryOp> {
153        // Apply optimization rules
154        let rewritten = self.rewrite_query(query)?;
155        let optimized = self.optimize_joins(rewritten)?;
156        Ok(optimized)
157    }
158
159    /// Rewrite query using algebraic rules
160    fn rewrite_query(&self, query: QueryOp) -> Result<QueryOp> {
161        match query {
162            // Push filters down through joins
163            QueryOp::Filter { property, value } => {
164                Ok(QueryOp::Filter { property, value })
165            }
166
167            // Reorder joins based on selectivity
168            QueryOp::Join {
169                left,
170                right,
171                join_type,
172            } => {
173                let left_opt = self.rewrite_query(*left)?;
174                let right_opt = self.rewrite_query(*right)?;
175
176                // Estimate costs
177                let left_cost = self.estimate_cost(&left_opt)?;
178                let right_cost = self.estimate_cost(&right_opt)?;
179
180                // Put smaller (more selective) operand first for hash joins
181                if left_cost.cardinality > right_cost.cardinality {
182                    Ok(QueryOp::Join {
183                        left: Box::new(right_opt),
184                        right: Box::new(left_opt),
185                        join_type,
186                    })
187                } else {
188                    Ok(QueryOp::Join {
189                        left: Box::new(left_opt),
190                        right: Box::new(right_opt),
191                        join_type,
192                    })
193                }
194            }
195
196            // Recursively optimize subqueries
197            QueryOp::Neighbors {
198                source,
199                relation_type,
200                max_hops,
201            } => {
202                let source_opt = self.rewrite_query(*source)?;
203                Ok(QueryOp::Neighbors {
204                    source: Box::new(source_opt),
205                    relation_type,
206                    max_hops,
207                })
208            }
209
210            QueryOp::Union { left, right } => {
211                let left_opt = self.rewrite_query(*left)?;
212                let right_opt = self.rewrite_query(*right)?;
213                Ok(QueryOp::Union {
214                    left: Box::new(left_opt),
215                    right: Box::new(right_opt),
216                })
217            }
218
219            QueryOp::Limit { source, count } => {
220                let source_opt = self.rewrite_query(*source)?;
221                Ok(QueryOp::Limit {
222                    source: Box::new(source_opt),
223                    count,
224                })
225            }
226
227            // Base case: entity scans
228            QueryOp::EntityScan { entity_type } => Ok(QueryOp::EntityScan { entity_type }),
229        }
230    }
231
232    /// Optimize join ordering using dynamic programming
233    fn optimize_joins(&self, query: QueryOp) -> Result<QueryOp> {
234        match query {
235            QueryOp::Join {
236                left,
237                right,
238                join_type,
239            } => {
240                // Recursively optimize sub-queries
241                let left_opt = self.optimize_joins(*left)?;
242                let right_opt = self.optimize_joins(*right)?;
243
244                // For multi-way joins, collect all join operands
245                let mut operands = Vec::new();
246                self.collect_join_operands(&left_opt, &mut operands);
247                self.collect_join_operands(&right_opt, &mut operands);
248
249                if operands.len() > 2 {
250                    // Multi-way join: find optimal order using greedy algorithm
251                    self.find_optimal_join_order(operands, join_type)
252                } else {
253                    // Binary join: already optimized by rewrite_query
254                    Ok(QueryOp::Join {
255                        left: Box::new(left_opt),
256                        right: Box::new(right_opt),
257                        join_type,
258                    })
259                }
260            }
261
262            // Recursively process other operations
263            QueryOp::Neighbors {
264                source,
265                relation_type,
266                max_hops,
267            } => {
268                let source_opt = self.optimize_joins(*source)?;
269                Ok(QueryOp::Neighbors {
270                    source: Box::new(source_opt),
271                    relation_type,
272                    max_hops,
273                })
274            }
275
276            QueryOp::Union { left, right } => {
277                let left_opt = self.optimize_joins(*left)?;
278                let right_opt = self.optimize_joins(*right)?;
279                Ok(QueryOp::Union {
280                    left: Box::new(left_opt),
281                    right: Box::new(right_opt),
282                })
283            }
284
285            QueryOp::Limit { source, count } => {
286                let source_opt = self.optimize_joins(*source)?;
287                Ok(QueryOp::Limit {
288                    source: Box::new(source_opt),
289                    count,
290                })
291            }
292
293            // Leaf operations
294            _ => Ok(query),
295        }
296    }
297
298    /// Collect all join operands for multi-way join optimization
299    fn collect_join_operands(&self, op: &QueryOp, operands: &mut Vec<QueryOp>) {
300        match op {
301            QueryOp::Join { left, right, .. } => {
302                self.collect_join_operands(left, operands);
303                self.collect_join_operands(right, operands);
304            }
305            _ => {
306                operands.push(op.clone());
307            }
308        }
309    }
310
311    /// Find optimal join order using greedy algorithm
312    fn find_optimal_join_order(
313        &self,
314        mut operands: Vec<QueryOp>,
315        join_type: JoinType,
316    ) -> Result<QueryOp> {
317        if operands.is_empty() {
318            return Err(crate::core::GraphRAGError::Validation {
319                message: "No operands for join".to_string(),
320            });
321        }
322
323        if operands.len() == 1 {
324            return Ok(operands.pop().unwrap());
325        }
326
327        // Greedy algorithm: repeatedly pick the two cheapest operands to join
328        while operands.len() > 1 {
329            let mut min_cost = f64::MAX;
330            let mut best_i = 0;
331            let mut best_j = 1;
332
333            // Find pair with minimum join cost
334            for i in 0..operands.len() {
335                for j in (i + 1)..operands.len() {
336                    let cost_i = self.estimate_cost(&operands[i])?;
337                    let cost_j = self.estimate_cost(&operands[j])?;
338
339                    // Estimate join cost as product of cardinalities (simplified)
340                    let join_cost = (cost_i.cardinality as f64) * (cost_j.cardinality as f64);
341
342                    if join_cost < min_cost {
343                        min_cost = join_cost;
344                        best_i = i;
345                        best_j = j;
346                    }
347                }
348            }
349
350            // Create join of best pair
351            let left = operands.remove(best_i);
352            let right = operands.remove(if best_j > best_i {
353                best_j - 1
354            } else {
355                best_j
356            });
357
358            let joined = QueryOp::Join {
359                left: Box::new(left),
360                right: Box::new(right),
361                join_type: join_type.clone(),
362            };
363
364            operands.push(joined);
365        }
366
367        Ok(operands.pop().unwrap())
368    }
369
370    /// Estimate cost of an operation
371    pub fn estimate_cost(&self, op: &QueryOp) -> Result<OperationCost> {
372        match op {
373            QueryOp::EntityScan { entity_type } => {
374                let cardinality = self
375                    .stats
376                    .entities_by_type
377                    .get(entity_type)
378                    .copied()
379                    .unwrap_or(0);
380
381                Ok(OperationCost {
382                    cardinality,
383                    cost: cardinality as f64,
384                    selectivity: if self.stats.total_entities > 0 {
385                        cardinality as f64 / self.stats.total_entities as f64
386                    } else {
387                        0.0
388                    },
389                })
390            }
391
392            QueryOp::Filter { property: _, value: _ } => {
393                // Assume filter has 10% selectivity (can be improved with histograms)
394                let selectivity = 0.1;
395                let cardinality = (self.stats.total_entities as f64 * selectivity) as usize;
396
397                Ok(OperationCost {
398                    cardinality,
399                    cost: self.stats.total_entities as f64, // Must scan all
400                    selectivity,
401                })
402            }
403
404            QueryOp::Join { left, right, join_type } => {
405                let left_cost = self.estimate_cost(left)?;
406                let right_cost = self.estimate_cost(right)?;
407
408                let cardinality = match join_type {
409                    JoinType::Inner => {
410                        // Estimate as geometric mean of inputs
411                        ((left_cost.cardinality as f64) * (right_cost.cardinality as f64))
412                            .sqrt() as usize
413                    }
414                    JoinType::LeftOuter => left_cost.cardinality,
415                    JoinType::Cross => left_cost.cardinality * right_cost.cardinality,
416                };
417
418                let cost = left_cost.cost
419                    + right_cost.cost
420                    + (left_cost.cardinality as f64 * right_cost.cardinality as f64);
421
422                Ok(OperationCost {
423                    cardinality,
424                    cost,
425                    selectivity: left_cost.selectivity * right_cost.selectivity,
426                })
427            }
428
429            QueryOp::Neighbors {
430                source,
431                relation_type: _,
432                max_hops,
433            } => {
434                let source_cost = self.estimate_cost(source)?;
435
436                // Estimate neighbors as source_cardinality * avg_degree^hops
437                let expansion_factor = self.stats.average_degree.powi(*max_hops as i32);
438                let cardinality =
439                    (source_cost.cardinality as f64 * expansion_factor).min(self.stats.total_entities as f64) as usize;
440
441                Ok(OperationCost {
442                    cardinality,
443                    cost: source_cost.cost + (cardinality as f64),
444                    selectivity: cardinality as f64 / self.stats.total_entities as f64,
445                })
446            }
447
448            QueryOp::Union { left, right } => {
449                let left_cost = self.estimate_cost(left)?;
450                let right_cost = self.estimate_cost(right)?;
451
452                // Union cardinality (with some overlap assumed)
453                let cardinality = (left_cost.cardinality + right_cost.cardinality) * 9 / 10;
454
455                Ok(OperationCost {
456                    cardinality,
457                    cost: left_cost.cost + right_cost.cost,
458                    selectivity: (left_cost.selectivity + right_cost.selectivity).min(1.0),
459                })
460            }
461
462            QueryOp::Limit { source, count } => {
463                let source_cost = self.estimate_cost(source)?;
464
465                Ok(OperationCost {
466                    cardinality: (*count).min(source_cost.cardinality),
467                    cost: source_cost.cost,
468                    selectivity: (*count as f64 / self.stats.total_entities as f64).min(1.0),
469                })
470            }
471        }
472    }
473
474    /// Generate an execution plan with cost annotations
475    pub fn explain(&self, op: &QueryOp) -> Result<String> {
476        let cost = self.estimate_cost(op)?;
477        let mut plan = String::new();
478
479        self.explain_recursive(op, 0, &mut plan)?;
480
481        plan.push_str(&format!(
482            "\nEstimated Cost: {:.2}\nEstimated Cardinality: {}\nSelectivity: {:.2}%\n",
483            cost.cost,
484            cost.cardinality,
485            cost.selectivity * 100.0
486        ));
487
488        Ok(plan)
489    }
490
491    /// Recursively build execution plan string
492    fn explain_recursive(&self, op: &QueryOp, depth: usize, plan: &mut String) -> Result<()> {
493        let indent = "  ".repeat(depth);
494        let cost = self.estimate_cost(op)?;
495
496        match op {
497            QueryOp::EntityScan { entity_type } => {
498                plan.push_str(&format!(
499                    "{}EntityScan({}) [cost={:.0}, rows={}]\n",
500                    indent, entity_type, cost.cost, cost.cardinality
501                ));
502            }
503            QueryOp::Filter { property, value } => {
504                plan.push_str(&format!(
505                    "{}Filter({}={}) [cost={:.0}, rows={}]\n",
506                    indent, property, value, cost.cost, cost.cardinality
507                ));
508            }
509            QueryOp::Join {
510                left,
511                right,
512                join_type,
513            } => {
514                plan.push_str(&format!(
515                    "{}Join({:?}) [cost={:.0}, rows={}]\n",
516                    indent, join_type, cost.cost, cost.cardinality
517                ));
518                self.explain_recursive(left, depth + 1, plan)?;
519                self.explain_recursive(right, depth + 1, plan)?;
520            }
521            QueryOp::Neighbors {
522                source,
523                relation_type,
524                max_hops,
525            } => {
526                let rel_str = relation_type.as_deref().unwrap_or("*");
527                plan.push_str(&format!(
528                    "{}Neighbors({}, hops={}) [cost={:.0}, rows={}]\n",
529                    indent, rel_str, max_hops, cost.cost, cost.cardinality
530                ));
531                self.explain_recursive(source, depth + 1, plan)?;
532            }
533            QueryOp::Union { left, right } => {
534                plan.push_str(&format!(
535                    "{}Union [cost={:.0}, rows={}]\n",
536                    indent, cost.cost, cost.cardinality
537                ));
538                self.explain_recursive(left, depth + 1, plan)?;
539                self.explain_recursive(right, depth + 1, plan)?;
540            }
541            QueryOp::Limit { source, count } => {
542                plan.push_str(&format!(
543                    "{}Limit({}) [cost={:.0}, rows={}]\n",
544                    indent, count, cost.cost, cost.cardinality
545                ));
546                self.explain_recursive(source, depth + 1, plan)?;
547            }
548        }
549
550        Ok(())
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    fn create_test_stats() -> GraphStatistics {
559        let mut entities_by_type = HashMap::new();
560        entities_by_type.insert("PERSON".to_string(), 100);
561        entities_by_type.insert("ORGANIZATION".to_string(), 50);
562        entities_by_type.insert("LOCATION".to_string(), 30);
563
564        let mut relationships_by_type = HashMap::new();
565        relationships_by_type.insert("WORKS_FOR".to_string(), 80);
566        relationships_by_type.insert("LOCATED_IN".to_string(), 60);
567
568        GraphStatistics {
569            total_entities: 180,
570            entities_by_type,
571            total_relationships: 140,
572            relationships_by_type,
573            average_degree: 1.56,
574        }
575    }
576
577    #[test]
578    fn test_cost_estimation_scan() {
579        let stats = create_test_stats();
580        let optimizer = QueryOptimizer::new(stats);
581
582        let query = QueryOp::EntityScan {
583            entity_type: "PERSON".to_string(),
584        };
585
586        let cost = optimizer.estimate_cost(&query).unwrap();
587
588        assert_eq!(cost.cardinality, 100);
589        assert_eq!(cost.cost, 100.0);
590    }
591
592    #[test]
593    fn test_cost_estimation_join() {
594        let stats = create_test_stats();
595        let optimizer = QueryOptimizer::new(stats);
596
597        let query = QueryOp::Join {
598            left: Box::new(QueryOp::EntityScan {
599                entity_type: "PERSON".to_string(),
600            }),
601            right: Box::new(QueryOp::EntityScan {
602                entity_type: "ORGANIZATION".to_string(),
603            }),
604            join_type: JoinType::Inner,
605        };
606
607        let cost = optimizer.estimate_cost(&query).unwrap();
608
609        // Geometric mean: sqrt(100 * 50) = ~71
610        assert!(cost.cardinality > 60 && cost.cardinality < 80);
611    }
612
613    #[test]
614    fn test_join_reordering() {
615        let stats = create_test_stats();
616        let optimizer = QueryOptimizer::new(stats);
617
618        // Join large table (PERSON=100) with small table (LOCATION=30)
619        let query = QueryOp::Join {
620            left: Box::new(QueryOp::EntityScan {
621                entity_type: "PERSON".to_string(),
622            }),
623            right: Box::new(QueryOp::EntityScan {
624                entity_type: "LOCATION".to_string(),
625            }),
626            join_type: JoinType::Inner,
627        };
628
629        let optimized = optimizer.optimize(query).unwrap();
630
631        // Should reorder to put smaller table first
632        if let QueryOp::Join { left, .. } = optimized {
633            if let QueryOp::EntityScan { entity_type } = &*left {
634                assert_eq!(entity_type, "LOCATION", "Smaller table should be first");
635            }
636        }
637    }
638
639    #[test]
640    fn test_neighbors_cost() {
641        let stats = create_test_stats();
642        let optimizer = QueryOptimizer::new(stats);
643
644        let query = QueryOp::Neighbors {
645            source: Box::new(QueryOp::EntityScan {
646                entity_type: "PERSON".to_string(),
647            }),
648            relation_type: Some("WORKS_FOR".to_string()),
649            max_hops: 2,
650        };
651
652        let cost = optimizer.estimate_cost(&query).unwrap();
653
654        // Should expand based on avg_degree^hops
655        assert!(cost.cardinality > 100);
656    }
657
658    #[test]
659    fn test_explain_plan() {
660        let stats = create_test_stats();
661        let optimizer = QueryOptimizer::new(stats);
662
663        let query = QueryOp::Join {
664            left: Box::new(QueryOp::EntityScan {
665                entity_type: "PERSON".to_string(),
666            }),
667            right: Box::new(QueryOp::EntityScan {
668                entity_type: "ORGANIZATION".to_string(),
669            }),
670            join_type: JoinType::Inner,
671        };
672
673        let plan = optimizer.explain(&query).unwrap();
674
675        assert!(plan.contains("Join"));
676        assert!(plan.contains("EntityScan"));
677        assert!(plan.contains("Estimated Cost"));
678    }
679}