nirv_engine/engine/
query_planner.rs

1use async_trait::async_trait;
2use crate::utils::{
3    types::{InternalQuery, DataSource, Column, Predicate, OrderBy},
4    error::{NirvResult, NirvError},
5};
6
7/// Execution plan node types
8#[derive(Debug, Clone)]
9pub enum PlanNode {
10    /// Scan a table/data source
11    TableScan {
12        source: DataSource,
13        projections: Vec<Column>,
14        predicates: Vec<Predicate>,
15    },
16    /// Apply a limit to results
17    Limit {
18        count: u64,
19        input: Box<PlanNode>,
20    },
21    /// Sort results
22    Sort {
23        order_by: OrderBy,
24        input: Box<PlanNode>,
25    },
26    /// Project specific columns
27    Projection {
28        columns: Vec<Column>,
29        input: Box<PlanNode>,
30    },
31}
32
33/// Complete execution plan for a query
34#[derive(Debug, Clone)]
35pub struct ExecutionPlan {
36    pub nodes: Vec<PlanNode>,
37    pub estimated_cost: f64,
38}
39
40impl ExecutionPlan {
41    /// Create a new empty execution plan
42    pub fn new() -> Self {
43        Self {
44            nodes: Vec::new(),
45            estimated_cost: 0.0,
46        }
47    }
48    
49    /// Add a node to the execution plan
50    pub fn add_node(&mut self, node: PlanNode) {
51        self.nodes.push(node);
52    }
53    
54    /// Set the estimated cost for the plan
55    pub fn set_estimated_cost(&mut self, cost: f64) {
56        self.estimated_cost = cost;
57    }
58    
59    /// Get the root node of the plan
60    pub fn root_node(&self) -> Option<&PlanNode> {
61        self.nodes.last()
62    }
63    
64    /// Check if the plan is empty
65    pub fn is_empty(&self) -> bool {
66        self.nodes.is_empty()
67    }
68}
69
70impl Default for ExecutionPlan {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76/// Trait for query planning functionality
77#[async_trait]
78pub trait QueryPlanner: Send + Sync {
79    /// Create an execution plan for the given query
80    async fn create_execution_plan(&self, query: &InternalQuery) -> NirvResult<ExecutionPlan>;
81    
82    /// Estimate the cost of executing a query
83    async fn estimate_cost(&self, query: &InternalQuery) -> NirvResult<f64>;
84    
85    /// Optimize an execution plan
86    async fn optimize_plan(&self, plan: ExecutionPlan) -> NirvResult<ExecutionPlan>;
87}
88
89/// Default implementation of QueryPlanner
90pub struct DefaultQueryPlanner {
91    /// Base cost for table scans
92    base_scan_cost: f64,
93    /// Cost multiplier for predicates
94    predicate_cost_multiplier: f64,
95    /// Cost for sorting operations
96    sort_cost: f64,
97    /// Cost for limit operations
98    limit_cost: f64,
99}
100
101impl DefaultQueryPlanner {
102    /// Create a new query planner with default cost parameters
103    pub fn new() -> Self {
104        Self {
105            base_scan_cost: 1.0,
106            predicate_cost_multiplier: 0.1,
107            sort_cost: 0.5,
108            limit_cost: 0.1,
109        }
110    }
111    
112    /// Create a query planner with custom cost parameters
113    pub fn with_costs(
114        base_scan_cost: f64,
115        predicate_cost_multiplier: f64,
116        sort_cost: f64,
117        limit_cost: f64,
118    ) -> Self {
119        Self {
120            base_scan_cost,
121            predicate_cost_multiplier,
122            sort_cost,
123            limit_cost,
124        }
125    }
126    
127    /// Validate that a query has the required components
128    fn validate_query(&self, query: &InternalQuery) -> NirvResult<()> {
129        if query.sources.is_empty() {
130            return Err(NirvError::Internal(
131                "No data sources found in query".to_string()
132            ));
133        }
134        
135        // For MVP, we only support single-source queries
136        if query.sources.len() > 1 {
137            return Err(NirvError::Internal(
138                "Multi-source queries not supported in MVP".to_string()
139            ));
140        }
141        
142        Ok(())
143    }
144    
145    /// Create a table scan node for a data source
146    fn create_table_scan_node(&self, query: &InternalQuery) -> PlanNode {
147        let source = query.sources[0].clone();
148        let projections = if query.projections.is_empty() {
149            // Default to selecting all columns
150            vec![Column {
151                name: "*".to_string(),
152                alias: None,
153                source: source.alias.clone(),
154            }]
155        } else {
156            query.projections.clone()
157        };
158        
159        PlanNode::TableScan {
160            source,
161            projections,
162            predicates: query.predicates.clone(),
163        }
164    }
165    
166    /// Add limit node if query has a limit clause
167    fn add_limit_node(&self, mut plan: ExecutionPlan, query: &InternalQuery) -> ExecutionPlan {
168        if let Some(limit) = query.limit {
169            if let Some(last_node) = plan.nodes.last() {
170                let limit_node = PlanNode::Limit {
171                    count: limit,
172                    input: Box::new(last_node.clone()),
173                };
174                plan.add_node(limit_node);
175                plan.estimated_cost += self.limit_cost;
176            }
177        }
178        plan
179    }
180    
181    /// Add sort node if query has ordering
182    fn add_sort_node(&self, mut plan: ExecutionPlan, query: &InternalQuery) -> ExecutionPlan {
183        if let Some(order_by) = &query.ordering {
184            if let Some(last_node) = plan.nodes.last() {
185                let sort_node = PlanNode::Sort {
186                    order_by: order_by.clone(),
187                    input: Box::new(last_node.clone()),
188                };
189                plan.add_node(sort_node);
190                plan.estimated_cost += self.sort_cost;
191            }
192        }
193        plan
194    }
195    
196    /// Calculate the estimated cost for a query
197    fn calculate_cost(&self, query: &InternalQuery) -> f64 {
198        let mut cost = self.base_scan_cost;
199        
200        // Add cost for predicates
201        cost += query.predicates.len() as f64 * self.predicate_cost_multiplier;
202        
203        // Add cost for sorting
204        if query.ordering.is_some() {
205            cost += self.sort_cost;
206        }
207        
208        // Add cost for limit
209        if query.limit.is_some() {
210            cost += self.limit_cost;
211        }
212        
213        cost
214    }
215}
216
217impl Default for DefaultQueryPlanner {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223#[async_trait]
224impl QueryPlanner for DefaultQueryPlanner {
225    async fn create_execution_plan(&self, query: &InternalQuery) -> NirvResult<ExecutionPlan> {
226        // Validate the query
227        self.validate_query(query)?;
228        
229        let mut plan = ExecutionPlan::new();
230        
231        // Create the base table scan node
232        let table_scan = self.create_table_scan_node(query);
233        plan.add_node(table_scan);
234        
235        // Calculate base cost
236        plan.estimated_cost = self.calculate_cost(query);
237        
238        // Add sort node if needed (before limit)
239        plan = self.add_sort_node(plan, query);
240        
241        // Add limit node if needed (after sort)
242        plan = self.add_limit_node(plan, query);
243        
244        Ok(plan)
245    }
246    
247    async fn estimate_cost(&self, query: &InternalQuery) -> NirvResult<f64> {
248        self.validate_query(query)?;
249        Ok(self.calculate_cost(query))
250    }
251    
252    async fn optimize_plan(&self, plan: ExecutionPlan) -> NirvResult<ExecutionPlan> {
253        // For MVP, we don't implement complex optimizations
254        // Just return the plan as-is
255        Ok(plan)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::utils::types::{QueryOperation, PredicateOperator, PredicateValue, OrderColumn, OrderDirection};
263
264    #[test]
265    fn test_execution_plan_creation() {
266        let mut plan = ExecutionPlan::new();
267        
268        assert!(plan.is_empty());
269        assert_eq!(plan.estimated_cost, 0.0);
270        assert!(plan.root_node().is_none());
271        
272        let node = PlanNode::TableScan {
273            source: DataSource {
274                object_type: "mock".to_string(),
275                identifier: "test".to_string(),
276                alias: None,
277            },
278            projections: vec![],
279            predicates: vec![],
280        };
281        
282        plan.add_node(node);
283        plan.set_estimated_cost(1.5);
284        
285        assert!(!plan.is_empty());
286        assert_eq!(plan.estimated_cost, 1.5);
287        assert!(plan.root_node().is_some());
288    }
289
290    #[test]
291    fn test_default_query_planner_creation() {
292        let planner = DefaultQueryPlanner::new();
293        
294        assert_eq!(planner.base_scan_cost, 1.0);
295        assert_eq!(planner.predicate_cost_multiplier, 0.1);
296        assert_eq!(planner.sort_cost, 0.5);
297        assert_eq!(planner.limit_cost, 0.1);
298    }
299
300    #[test]
301    fn test_query_planner_with_custom_costs() {
302        let planner = DefaultQueryPlanner::with_costs(2.0, 0.2, 1.0, 0.2);
303        
304        assert_eq!(planner.base_scan_cost, 2.0);
305        assert_eq!(planner.predicate_cost_multiplier, 0.2);
306        assert_eq!(planner.sort_cost, 1.0);
307        assert_eq!(planner.limit_cost, 0.2);
308    }
309
310    #[tokio::test]
311    async fn test_query_planner_validate_empty_query() {
312        let planner = DefaultQueryPlanner::new();
313        let query = InternalQuery::new(QueryOperation::Select);
314        
315        let result = planner.create_execution_plan(&query).await;
316        assert!(result.is_err());
317        
318        match result.unwrap_err() {
319            NirvError::Internal(msg) => {
320                assert!(msg.contains("No data sources"));
321            }
322            _ => panic!("Expected Internal error"),
323        }
324    }
325
326    #[tokio::test]
327    async fn test_query_planner_validate_multi_source_query() {
328        let planner = DefaultQueryPlanner::new();
329        
330        let mut query = InternalQuery::new(QueryOperation::Select);
331        query.sources.push(DataSource {
332            object_type: "mock".to_string(),
333            identifier: "table1".to_string(),
334            alias: None,
335        });
336        query.sources.push(DataSource {
337            object_type: "mock".to_string(),
338            identifier: "table2".to_string(),
339            alias: None,
340        });
341        
342        let result = planner.create_execution_plan(&query).await;
343        assert!(result.is_err());
344        
345        match result.unwrap_err() {
346            NirvError::Internal(msg) => {
347                assert!(msg.contains("Multi-source queries not supported"));
348            }
349            _ => panic!("Expected Internal error"),
350        }
351    }
352
353    #[tokio::test]
354    async fn test_query_planner_simple_select() {
355        let planner = DefaultQueryPlanner::new();
356        
357        let mut query = InternalQuery::new(QueryOperation::Select);
358        query.sources.push(DataSource {
359            object_type: "mock".to_string(),
360            identifier: "users".to_string(),
361            alias: None,
362        });
363        
364        let result = planner.create_execution_plan(&query).await;
365        assert!(result.is_ok());
366        
367        let plan = result.unwrap();
368        assert_eq!(plan.nodes.len(), 1);
369        assert_eq!(plan.estimated_cost, 1.0); // base_scan_cost
370        
371        match &plan.nodes[0] {
372            PlanNode::TableScan { source, projections, predicates } => {
373                assert_eq!(source.object_type, "mock");
374                assert_eq!(source.identifier, "users");
375                assert_eq!(projections.len(), 1);
376                assert_eq!(projections[0].name, "*");
377                assert!(predicates.is_empty());
378            }
379            _ => panic!("Expected TableScan node"),
380        }
381    }
382
383    #[tokio::test]
384    async fn test_query_planner_with_projections() {
385        let planner = DefaultQueryPlanner::new();
386        
387        let mut query = InternalQuery::new(QueryOperation::Select);
388        query.sources.push(DataSource {
389            object_type: "mock".to_string(),
390            identifier: "users".to_string(),
391            alias: Some("u".to_string()),
392        });
393        query.projections.push(Column {
394            name: "name".to_string(),
395            alias: Some("user_name".to_string()),
396            source: Some("u".to_string()),
397        });
398        query.projections.push(Column {
399            name: "email".to_string(),
400            alias: None,
401            source: Some("u".to_string()),
402        });
403        
404        let result = planner.create_execution_plan(&query).await;
405        assert!(result.is_ok());
406        
407        let plan = result.unwrap();
408        match &plan.nodes[0] {
409            PlanNode::TableScan { projections, .. } => {
410                assert_eq!(projections.len(), 2);
411                assert_eq!(projections[0].name, "name");
412                assert_eq!(projections[0].alias, Some("user_name".to_string()));
413                assert_eq!(projections[1].name, "email");
414                assert_eq!(projections[1].alias, None);
415            }
416            _ => panic!("Expected TableScan node"),
417        }
418    }
419
420    #[tokio::test]
421    async fn test_query_planner_with_predicates() {
422        let planner = DefaultQueryPlanner::new();
423        
424        let mut query = InternalQuery::new(QueryOperation::Select);
425        query.sources.push(DataSource {
426            object_type: "mock".to_string(),
427            identifier: "users".to_string(),
428            alias: None,
429        });
430        query.predicates.push(Predicate {
431            column: "age".to_string(),
432            operator: PredicateOperator::GreaterThan,
433            value: PredicateValue::Integer(18),
434        });
435        query.predicates.push(Predicate {
436            column: "status".to_string(),
437            operator: PredicateOperator::Equal,
438            value: PredicateValue::String("active".to_string()),
439        });
440        
441        let result = planner.create_execution_plan(&query).await;
442        assert!(result.is_ok());
443        
444        let plan = result.unwrap();
445        assert_eq!(plan.estimated_cost, 1.2); // base_scan_cost + 2 * predicate_cost_multiplier
446        
447        match &plan.nodes[0] {
448            PlanNode::TableScan { predicates, .. } => {
449                assert_eq!(predicates.len(), 2);
450                assert_eq!(predicates[0].column, "age");
451                assert_eq!(predicates[1].column, "status");
452            }
453            _ => panic!("Expected TableScan node"),
454        }
455    }
456
457    #[tokio::test]
458    async fn test_query_planner_with_limit() {
459        let planner = DefaultQueryPlanner::new();
460        
461        let mut query = InternalQuery::new(QueryOperation::Select);
462        query.sources.push(DataSource {
463            object_type: "mock".to_string(),
464            identifier: "users".to_string(),
465            alias: None,
466        });
467        query.limit = Some(10);
468        
469        let result = planner.create_execution_plan(&query).await;
470        assert!(result.is_ok());
471        
472        let plan = result.unwrap();
473        assert_eq!(plan.nodes.len(), 2); // TableScan + Limit
474        assert_eq!(plan.estimated_cost, 1.1); // base_scan_cost + limit_cost
475        
476        match &plan.nodes[1] {
477            PlanNode::Limit { count, .. } => {
478                assert_eq!(*count, 10);
479            }
480            _ => panic!("Expected Limit node"),
481        }
482    }
483
484    #[tokio::test]
485    async fn test_query_planner_with_ordering() {
486        let planner = DefaultQueryPlanner::new();
487        
488        let mut query = InternalQuery::new(QueryOperation::Select);
489        query.sources.push(DataSource {
490            object_type: "mock".to_string(),
491            identifier: "users".to_string(),
492            alias: None,
493        });
494        query.ordering = Some(OrderBy {
495            columns: vec![OrderColumn {
496                column: "name".to_string(),
497                direction: OrderDirection::Ascending,
498            }],
499        });
500        
501        let result = planner.create_execution_plan(&query).await;
502        assert!(result.is_ok());
503        
504        let plan = result.unwrap();
505        assert_eq!(plan.nodes.len(), 2); // TableScan + Sort
506        assert_eq!(plan.estimated_cost, 1.5); // base_scan_cost + sort_cost
507        
508        match &plan.nodes[1] {
509            PlanNode::Sort { order_by, .. } => {
510                assert_eq!(order_by.columns.len(), 1);
511                assert_eq!(order_by.columns[0].column, "name");
512            }
513            _ => panic!("Expected Sort node"),
514        }
515    }
516
517    #[tokio::test]
518    async fn test_query_planner_with_ordering_and_limit() {
519        let planner = DefaultQueryPlanner::new();
520        
521        let mut query = InternalQuery::new(QueryOperation::Select);
522        query.sources.push(DataSource {
523            object_type: "mock".to_string(),
524            identifier: "users".to_string(),
525            alias: None,
526        });
527        query.ordering = Some(OrderBy {
528            columns: vec![OrderColumn {
529                column: "created_at".to_string(),
530                direction: OrderDirection::Descending,
531            }],
532        });
533        query.limit = Some(5);
534        
535        let result = planner.create_execution_plan(&query).await;
536        assert!(result.is_ok());
537        
538        let plan = result.unwrap();
539        assert_eq!(plan.nodes.len(), 3); // TableScan + Sort + Limit
540        assert_eq!(plan.estimated_cost, 1.6); // base_scan_cost + sort_cost + limit_cost
541        
542        // Sort should come before Limit
543        match &plan.nodes[1] {
544            PlanNode::Sort { .. } => {},
545            _ => panic!("Expected Sort node at position 1"),
546        }
547        
548        match &plan.nodes[2] {
549            PlanNode::Limit { count, .. } => {
550                assert_eq!(*count, 5);
551            }
552            _ => panic!("Expected Limit node at position 2"),
553        }
554    }
555
556    #[tokio::test]
557    async fn test_query_planner_estimate_cost() {
558        let planner = DefaultQueryPlanner::new();
559        
560        let mut query = InternalQuery::new(QueryOperation::Select);
561        query.sources.push(DataSource {
562            object_type: "mock".to_string(),
563            identifier: "users".to_string(),
564            alias: None,
565        });
566        query.predicates.push(Predicate {
567            column: "age".to_string(),
568            operator: PredicateOperator::GreaterThan,
569            value: PredicateValue::Integer(18),
570        });
571        query.ordering = Some(OrderBy {
572            columns: vec![OrderColumn {
573                column: "name".to_string(),
574                direction: OrderDirection::Ascending,
575            }],
576        });
577        query.limit = Some(10);
578        
579        let result = planner.estimate_cost(&query).await;
580        assert!(result.is_ok());
581        
582        let cost = result.unwrap();
583        assert_eq!(cost, 1.6); // base_scan_cost + predicate_cost + sort_cost + limit_cost
584    }
585
586    #[tokio::test]
587    async fn test_query_planner_optimize_plan() {
588        let planner = DefaultQueryPlanner::new();
589        
590        let plan = ExecutionPlan {
591            nodes: vec![
592                PlanNode::TableScan {
593                    source: DataSource {
594                        object_type: "mock".to_string(),
595                        identifier: "users".to_string(),
596                        alias: None,
597                    },
598                    projections: vec![],
599                    predicates: vec![],
600                }
601            ],
602            estimated_cost: 1.0,
603        };
604        
605        let result = planner.optimize_plan(plan.clone()).await;
606        assert!(result.is_ok());
607        
608        let optimized_plan = result.unwrap();
609        assert_eq!(optimized_plan.nodes.len(), plan.nodes.len());
610        assert_eq!(optimized_plan.estimated_cost, plan.estimated_cost);
611    }
612}