1use async_trait::async_trait;
2use crate::utils::{
3 types::{InternalQuery, DataSource, Column, Predicate, OrderBy},
4 error::{NirvResult, NirvError},
5};
6
7#[derive(Debug, Clone)]
9pub enum PlanNode {
10 TableScan {
12 source: DataSource,
13 projections: Vec<Column>,
14 predicates: Vec<Predicate>,
15 },
16 Limit {
18 count: u64,
19 input: Box<PlanNode>,
20 },
21 Sort {
23 order_by: OrderBy,
24 input: Box<PlanNode>,
25 },
26 Projection {
28 columns: Vec<Column>,
29 input: Box<PlanNode>,
30 },
31}
32
33#[derive(Debug, Clone)]
35pub struct ExecutionPlan {
36 pub nodes: Vec<PlanNode>,
37 pub estimated_cost: f64,
38}
39
40impl ExecutionPlan {
41 pub fn new() -> Self {
43 Self {
44 nodes: Vec::new(),
45 estimated_cost: 0.0,
46 }
47 }
48
49 pub fn add_node(&mut self, node: PlanNode) {
51 self.nodes.push(node);
52 }
53
54 pub fn set_estimated_cost(&mut self, cost: f64) {
56 self.estimated_cost = cost;
57 }
58
59 pub fn root_node(&self) -> Option<&PlanNode> {
61 self.nodes.last()
62 }
63
64 pub fn is_empty(&self) -> bool {
66 self.nodes.is_empty()
67 }
68}
69
70impl Default for ExecutionPlan {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[async_trait]
78pub trait QueryPlanner: Send + Sync {
79 async fn create_execution_plan(&self, query: &InternalQuery) -> NirvResult<ExecutionPlan>;
81
82 async fn estimate_cost(&self, query: &InternalQuery) -> NirvResult<f64>;
84
85 async fn optimize_plan(&self, plan: ExecutionPlan) -> NirvResult<ExecutionPlan>;
87}
88
89pub struct DefaultQueryPlanner {
91 base_scan_cost: f64,
93 predicate_cost_multiplier: f64,
95 sort_cost: f64,
97 limit_cost: f64,
99}
100
101impl DefaultQueryPlanner {
102 pub fn new() -> Self {
104 Self {
105 base_scan_cost: 1.0,
106 predicate_cost_multiplier: 0.1,
107 sort_cost: 0.5,
108 limit_cost: 0.1,
109 }
110 }
111
112 pub fn with_costs(
114 base_scan_cost: f64,
115 predicate_cost_multiplier: f64,
116 sort_cost: f64,
117 limit_cost: f64,
118 ) -> Self {
119 Self {
120 base_scan_cost,
121 predicate_cost_multiplier,
122 sort_cost,
123 limit_cost,
124 }
125 }
126
127 fn validate_query(&self, query: &InternalQuery) -> NirvResult<()> {
129 if query.sources.is_empty() {
130 return Err(NirvError::Internal(
131 "No data sources found in query".to_string()
132 ));
133 }
134
135 if query.sources.len() > 1 {
137 return Err(NirvError::Internal(
138 "Multi-source queries not supported in MVP".to_string()
139 ));
140 }
141
142 Ok(())
143 }
144
145 fn create_table_scan_node(&self, query: &InternalQuery) -> PlanNode {
147 let source = query.sources[0].clone();
148 let projections = if query.projections.is_empty() {
149 vec![Column {
151 name: "*".to_string(),
152 alias: None,
153 source: source.alias.clone(),
154 }]
155 } else {
156 query.projections.clone()
157 };
158
159 PlanNode::TableScan {
160 source,
161 projections,
162 predicates: query.predicates.clone(),
163 }
164 }
165
166 fn add_limit_node(&self, mut plan: ExecutionPlan, query: &InternalQuery) -> ExecutionPlan {
168 if let Some(limit) = query.limit {
169 if let Some(last_node) = plan.nodes.last() {
170 let limit_node = PlanNode::Limit {
171 count: limit,
172 input: Box::new(last_node.clone()),
173 };
174 plan.add_node(limit_node);
175 plan.estimated_cost += self.limit_cost;
176 }
177 }
178 plan
179 }
180
181 fn add_sort_node(&self, mut plan: ExecutionPlan, query: &InternalQuery) -> ExecutionPlan {
183 if let Some(order_by) = &query.ordering {
184 if let Some(last_node) = plan.nodes.last() {
185 let sort_node = PlanNode::Sort {
186 order_by: order_by.clone(),
187 input: Box::new(last_node.clone()),
188 };
189 plan.add_node(sort_node);
190 plan.estimated_cost += self.sort_cost;
191 }
192 }
193 plan
194 }
195
196 fn calculate_cost(&self, query: &InternalQuery) -> f64 {
198 let mut cost = self.base_scan_cost;
199
200 cost += query.predicates.len() as f64 * self.predicate_cost_multiplier;
202
203 if query.ordering.is_some() {
205 cost += self.sort_cost;
206 }
207
208 if query.limit.is_some() {
210 cost += self.limit_cost;
211 }
212
213 cost
214 }
215}
216
217impl Default for DefaultQueryPlanner {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223#[async_trait]
224impl QueryPlanner for DefaultQueryPlanner {
225 async fn create_execution_plan(&self, query: &InternalQuery) -> NirvResult<ExecutionPlan> {
226 self.validate_query(query)?;
228
229 let mut plan = ExecutionPlan::new();
230
231 let table_scan = self.create_table_scan_node(query);
233 plan.add_node(table_scan);
234
235 plan.estimated_cost = self.calculate_cost(query);
237
238 plan = self.add_sort_node(plan, query);
240
241 plan = self.add_limit_node(plan, query);
243
244 Ok(plan)
245 }
246
247 async fn estimate_cost(&self, query: &InternalQuery) -> NirvResult<f64> {
248 self.validate_query(query)?;
249 Ok(self.calculate_cost(query))
250 }
251
252 async fn optimize_plan(&self, plan: ExecutionPlan) -> NirvResult<ExecutionPlan> {
253 Ok(plan)
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::utils::types::{QueryOperation, PredicateOperator, PredicateValue, OrderColumn, OrderDirection};
263
264 #[test]
265 fn test_execution_plan_creation() {
266 let mut plan = ExecutionPlan::new();
267
268 assert!(plan.is_empty());
269 assert_eq!(plan.estimated_cost, 0.0);
270 assert!(plan.root_node().is_none());
271
272 let node = PlanNode::TableScan {
273 source: DataSource {
274 object_type: "mock".to_string(),
275 identifier: "test".to_string(),
276 alias: None,
277 },
278 projections: vec![],
279 predicates: vec![],
280 };
281
282 plan.add_node(node);
283 plan.set_estimated_cost(1.5);
284
285 assert!(!plan.is_empty());
286 assert_eq!(plan.estimated_cost, 1.5);
287 assert!(plan.root_node().is_some());
288 }
289
290 #[test]
291 fn test_default_query_planner_creation() {
292 let planner = DefaultQueryPlanner::new();
293
294 assert_eq!(planner.base_scan_cost, 1.0);
295 assert_eq!(planner.predicate_cost_multiplier, 0.1);
296 assert_eq!(planner.sort_cost, 0.5);
297 assert_eq!(planner.limit_cost, 0.1);
298 }
299
300 #[test]
301 fn test_query_planner_with_custom_costs() {
302 let planner = DefaultQueryPlanner::with_costs(2.0, 0.2, 1.0, 0.2);
303
304 assert_eq!(planner.base_scan_cost, 2.0);
305 assert_eq!(planner.predicate_cost_multiplier, 0.2);
306 assert_eq!(planner.sort_cost, 1.0);
307 assert_eq!(planner.limit_cost, 0.2);
308 }
309
310 #[tokio::test]
311 async fn test_query_planner_validate_empty_query() {
312 let planner = DefaultQueryPlanner::new();
313 let query = InternalQuery::new(QueryOperation::Select);
314
315 let result = planner.create_execution_plan(&query).await;
316 assert!(result.is_err());
317
318 match result.unwrap_err() {
319 NirvError::Internal(msg) => {
320 assert!(msg.contains("No data sources"));
321 }
322 _ => panic!("Expected Internal error"),
323 }
324 }
325
326 #[tokio::test]
327 async fn test_query_planner_validate_multi_source_query() {
328 let planner = DefaultQueryPlanner::new();
329
330 let mut query = InternalQuery::new(QueryOperation::Select);
331 query.sources.push(DataSource {
332 object_type: "mock".to_string(),
333 identifier: "table1".to_string(),
334 alias: None,
335 });
336 query.sources.push(DataSource {
337 object_type: "mock".to_string(),
338 identifier: "table2".to_string(),
339 alias: None,
340 });
341
342 let result = planner.create_execution_plan(&query).await;
343 assert!(result.is_err());
344
345 match result.unwrap_err() {
346 NirvError::Internal(msg) => {
347 assert!(msg.contains("Multi-source queries not supported"));
348 }
349 _ => panic!("Expected Internal error"),
350 }
351 }
352
353 #[tokio::test]
354 async fn test_query_planner_simple_select() {
355 let planner = DefaultQueryPlanner::new();
356
357 let mut query = InternalQuery::new(QueryOperation::Select);
358 query.sources.push(DataSource {
359 object_type: "mock".to_string(),
360 identifier: "users".to_string(),
361 alias: None,
362 });
363
364 let result = planner.create_execution_plan(&query).await;
365 assert!(result.is_ok());
366
367 let plan = result.unwrap();
368 assert_eq!(plan.nodes.len(), 1);
369 assert_eq!(plan.estimated_cost, 1.0); match &plan.nodes[0] {
372 PlanNode::TableScan { source, projections, predicates } => {
373 assert_eq!(source.object_type, "mock");
374 assert_eq!(source.identifier, "users");
375 assert_eq!(projections.len(), 1);
376 assert_eq!(projections[0].name, "*");
377 assert!(predicates.is_empty());
378 }
379 _ => panic!("Expected TableScan node"),
380 }
381 }
382
383 #[tokio::test]
384 async fn test_query_planner_with_projections() {
385 let planner = DefaultQueryPlanner::new();
386
387 let mut query = InternalQuery::new(QueryOperation::Select);
388 query.sources.push(DataSource {
389 object_type: "mock".to_string(),
390 identifier: "users".to_string(),
391 alias: Some("u".to_string()),
392 });
393 query.projections.push(Column {
394 name: "name".to_string(),
395 alias: Some("user_name".to_string()),
396 source: Some("u".to_string()),
397 });
398 query.projections.push(Column {
399 name: "email".to_string(),
400 alias: None,
401 source: Some("u".to_string()),
402 });
403
404 let result = planner.create_execution_plan(&query).await;
405 assert!(result.is_ok());
406
407 let plan = result.unwrap();
408 match &plan.nodes[0] {
409 PlanNode::TableScan { projections, .. } => {
410 assert_eq!(projections.len(), 2);
411 assert_eq!(projections[0].name, "name");
412 assert_eq!(projections[0].alias, Some("user_name".to_string()));
413 assert_eq!(projections[1].name, "email");
414 assert_eq!(projections[1].alias, None);
415 }
416 _ => panic!("Expected TableScan node"),
417 }
418 }
419
420 #[tokio::test]
421 async fn test_query_planner_with_predicates() {
422 let planner = DefaultQueryPlanner::new();
423
424 let mut query = InternalQuery::new(QueryOperation::Select);
425 query.sources.push(DataSource {
426 object_type: "mock".to_string(),
427 identifier: "users".to_string(),
428 alias: None,
429 });
430 query.predicates.push(Predicate {
431 column: "age".to_string(),
432 operator: PredicateOperator::GreaterThan,
433 value: PredicateValue::Integer(18),
434 });
435 query.predicates.push(Predicate {
436 column: "status".to_string(),
437 operator: PredicateOperator::Equal,
438 value: PredicateValue::String("active".to_string()),
439 });
440
441 let result = planner.create_execution_plan(&query).await;
442 assert!(result.is_ok());
443
444 let plan = result.unwrap();
445 assert_eq!(plan.estimated_cost, 1.2); match &plan.nodes[0] {
448 PlanNode::TableScan { predicates, .. } => {
449 assert_eq!(predicates.len(), 2);
450 assert_eq!(predicates[0].column, "age");
451 assert_eq!(predicates[1].column, "status");
452 }
453 _ => panic!("Expected TableScan node"),
454 }
455 }
456
457 #[tokio::test]
458 async fn test_query_planner_with_limit() {
459 let planner = DefaultQueryPlanner::new();
460
461 let mut query = InternalQuery::new(QueryOperation::Select);
462 query.sources.push(DataSource {
463 object_type: "mock".to_string(),
464 identifier: "users".to_string(),
465 alias: None,
466 });
467 query.limit = Some(10);
468
469 let result = planner.create_execution_plan(&query).await;
470 assert!(result.is_ok());
471
472 let plan = result.unwrap();
473 assert_eq!(plan.nodes.len(), 2); assert_eq!(plan.estimated_cost, 1.1); match &plan.nodes[1] {
477 PlanNode::Limit { count, .. } => {
478 assert_eq!(*count, 10);
479 }
480 _ => panic!("Expected Limit node"),
481 }
482 }
483
484 #[tokio::test]
485 async fn test_query_planner_with_ordering() {
486 let planner = DefaultQueryPlanner::new();
487
488 let mut query = InternalQuery::new(QueryOperation::Select);
489 query.sources.push(DataSource {
490 object_type: "mock".to_string(),
491 identifier: "users".to_string(),
492 alias: None,
493 });
494 query.ordering = Some(OrderBy {
495 columns: vec![OrderColumn {
496 column: "name".to_string(),
497 direction: OrderDirection::Ascending,
498 }],
499 });
500
501 let result = planner.create_execution_plan(&query).await;
502 assert!(result.is_ok());
503
504 let plan = result.unwrap();
505 assert_eq!(plan.nodes.len(), 2); assert_eq!(plan.estimated_cost, 1.5); match &plan.nodes[1] {
509 PlanNode::Sort { order_by, .. } => {
510 assert_eq!(order_by.columns.len(), 1);
511 assert_eq!(order_by.columns[0].column, "name");
512 }
513 _ => panic!("Expected Sort node"),
514 }
515 }
516
517 #[tokio::test]
518 async fn test_query_planner_with_ordering_and_limit() {
519 let planner = DefaultQueryPlanner::new();
520
521 let mut query = InternalQuery::new(QueryOperation::Select);
522 query.sources.push(DataSource {
523 object_type: "mock".to_string(),
524 identifier: "users".to_string(),
525 alias: None,
526 });
527 query.ordering = Some(OrderBy {
528 columns: vec![OrderColumn {
529 column: "created_at".to_string(),
530 direction: OrderDirection::Descending,
531 }],
532 });
533 query.limit = Some(5);
534
535 let result = planner.create_execution_plan(&query).await;
536 assert!(result.is_ok());
537
538 let plan = result.unwrap();
539 assert_eq!(plan.nodes.len(), 3); assert_eq!(plan.estimated_cost, 1.6); match &plan.nodes[1] {
544 PlanNode::Sort { .. } => {},
545 _ => panic!("Expected Sort node at position 1"),
546 }
547
548 match &plan.nodes[2] {
549 PlanNode::Limit { count, .. } => {
550 assert_eq!(*count, 5);
551 }
552 _ => panic!("Expected Limit node at position 2"),
553 }
554 }
555
556 #[tokio::test]
557 async fn test_query_planner_estimate_cost() {
558 let planner = DefaultQueryPlanner::new();
559
560 let mut query = InternalQuery::new(QueryOperation::Select);
561 query.sources.push(DataSource {
562 object_type: "mock".to_string(),
563 identifier: "users".to_string(),
564 alias: None,
565 });
566 query.predicates.push(Predicate {
567 column: "age".to_string(),
568 operator: PredicateOperator::GreaterThan,
569 value: PredicateValue::Integer(18),
570 });
571 query.ordering = Some(OrderBy {
572 columns: vec![OrderColumn {
573 column: "name".to_string(),
574 direction: OrderDirection::Ascending,
575 }],
576 });
577 query.limit = Some(10);
578
579 let result = planner.estimate_cost(&query).await;
580 assert!(result.is_ok());
581
582 let cost = result.unwrap();
583 assert_eq!(cost, 1.6); }
585
586 #[tokio::test]
587 async fn test_query_planner_optimize_plan() {
588 let planner = DefaultQueryPlanner::new();
589
590 let plan = ExecutionPlan {
591 nodes: vec![
592 PlanNode::TableScan {
593 source: DataSource {
594 object_type: "mock".to_string(),
595 identifier: "users".to_string(),
596 alias: None,
597 },
598 projections: vec![],
599 predicates: vec![],
600 }
601 ],
602 estimated_cost: 1.0,
603 };
604
605 let result = planner.optimize_plan(plan.clone()).await;
606 assert!(result.is_ok());
607
608 let optimized_plan = result.unwrap();
609 assert_eq!(optimized_plan.nodes.len(), plan.nodes.len());
610 assert_eq!(optimized_plan.estimated_cost, plan.estimated_cost);
611 }
612}