nirv_engine/engine/
query_executor.rs

1use async_trait::async_trait;
2use std::time::{Duration, Instant};
3use crate::{
4    engine::{ExecutionPlan, PlanNode},
5    connectors::ConnectorRegistry,
6    utils::{
7        types::{QueryResult, Row, Value, ColumnMetadata, DataType, InternalQuery, QueryOperation, ConnectorQuery},
8        error::{NirvResult, NirvError},
9    },
10};
11
12/// Trait for query execution functionality
13#[async_trait]
14pub trait QueryExecutor: Send + Sync {
15    /// Execute an execution plan and return results
16    async fn execute_plan(&self, plan: &ExecutionPlan) -> NirvResult<QueryResult>;
17    
18    /// Execute a single plan node
19    async fn execute_node(&self, node: &PlanNode) -> NirvResult<QueryResult>;
20    
21    /// Set the connector registry for accessing data sources
22    fn set_connector_registry(&mut self, registry: ConnectorRegistry);
23}
24
25/// Default implementation of QueryExecutor
26pub struct DefaultQueryExecutor {
27    /// Registry of available connectors
28    connector_registry: Option<ConnectorRegistry>,
29}
30
31impl DefaultQueryExecutor {
32    /// Create a new query executor
33    pub fn new() -> Self {
34        Self {
35            connector_registry: None,
36        }
37    }
38    
39    /// Create a query executor with a connector registry
40    pub fn with_connector_registry(registry: ConnectorRegistry) -> Self {
41        Self {
42            connector_registry: Some(registry),
43        }
44    }
45    
46    /// Get a reference to the connector registry
47    fn get_connector_registry(&self) -> NirvResult<&ConnectorRegistry> {
48        self.connector_registry.as_ref().ok_or_else(|| {
49            NirvError::Internal("No connector registry configured".to_string())
50        })
51    }
52    
53    /// Execute a table scan operation
54    async fn execute_table_scan(
55        &self,
56        source: &crate::utils::types::DataSource,
57        projections: &[crate::utils::types::Column],
58        predicates: &[crate::utils::types::Predicate],
59    ) -> NirvResult<QueryResult> {
60        let registry = self.get_connector_registry()?;
61        
62        // Try different naming patterns to find the connector
63        let possible_names = vec![
64            source.object_type.clone(),
65            format!("{}_{}", source.object_type, 0),
66            format!("{}_connector", source.object_type),
67        ];
68        
69        let mut connector = None;
70        for name in &possible_names {
71            if let Some(c) = registry.get(name) {
72                connector = Some(c);
73                break;
74            }
75        }
76        
77        let connector = connector.ok_or_else(|| {
78            NirvError::Internal(format!("No connector found for type: {}", source.object_type))
79        })?;
80        
81        // Create a connector query
82        let mut internal_query = InternalQuery::new(QueryOperation::Select);
83        internal_query.sources.push(source.clone());
84        internal_query.projections = projections.to_vec();
85        internal_query.predicates = predicates.to_vec();
86        
87        let connector_query = ConnectorQuery {
88            connector_type: connector.get_connector_type(),
89            query: internal_query,
90            connection_params: std::collections::HashMap::new(),
91        };
92        
93        // Execute the query through the connector
94        connector.execute_query(connector_query).await
95    }
96    
97    /// Apply a limit to query results
98    fn apply_limit(&self, mut result: QueryResult, count: u64) -> QueryResult {
99        let limit = count as usize;
100        if result.rows.len() > limit {
101            result.rows.truncate(limit);
102        }
103        result
104    }
105    
106    /// Apply sorting to query results
107    fn apply_sort(&self, mut result: QueryResult, order_by: &crate::utils::types::OrderBy) -> NirvResult<QueryResult> {
108        if order_by.columns.is_empty() {
109            return Ok(result);
110        }
111        
112        // For MVP, we'll implement simple single-column sorting
113        let sort_column = &order_by.columns[0];
114        
115        // Find the column index
116        let column_index = result.columns.iter()
117            .position(|col| col.name == sort_column.column)
118            .ok_or_else(|| {
119                NirvError::Internal(format!("Sort column '{}' not found in result", sort_column.column))
120            })?;
121        
122        // Sort the rows based on the column value
123        result.rows.sort_by(|a, b| {
124            let val_a = a.get(column_index).unwrap_or(&Value::Null);
125            let val_b = b.get(column_index).unwrap_or(&Value::Null);
126            
127            let comparison = self.compare_values(val_a, val_b);
128            
129            match sort_column.direction {
130                crate::utils::types::OrderDirection::Ascending => comparison,
131                crate::utils::types::OrderDirection::Descending => comparison.reverse(),
132            }
133        });
134        
135        Ok(result)
136    }
137    
138    /// Compare two values for sorting
139    fn compare_values(&self, a: &Value, b: &Value) -> std::cmp::Ordering {
140        use std::cmp::Ordering;
141        
142        match (a, b) {
143            (Value::Null, Value::Null) => Ordering::Equal,
144            (Value::Null, _) => Ordering::Less,
145            (_, Value::Null) => Ordering::Greater,
146            (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
147            (Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
148            (Value::Text(a), Value::Text(b)) => a.cmp(b),
149            (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
150            (Value::Date(a), Value::Date(b)) => a.cmp(b),
151            (Value::DateTime(a), Value::DateTime(b)) => a.cmp(b),
152            // For mixed types, convert to string and compare
153            _ => format!("{:?}", a).cmp(&format!("{:?}", b)),
154        }
155    }
156    
157    /// Apply projection to query results
158    fn apply_projection(&self, result: QueryResult, columns: &[crate::utils::types::Column]) -> NirvResult<QueryResult> {
159        if columns.is_empty() {
160            return Ok(result);
161        }
162        
163        // For MVP, we'll assume projections are already handled in the table scan
164        // This is a placeholder for future enhancement
165        Ok(result)
166    }
167    
168    /// Aggregate results from multiple operations
169    fn aggregate_results(&self, results: Vec<QueryResult>) -> NirvResult<QueryResult> {
170        if results.is_empty() {
171            return Ok(QueryResult::new());
172        }
173        
174        if results.len() == 1 {
175            return Ok(results.into_iter().next().unwrap());
176        }
177        
178        // For MVP, we don't support complex aggregation
179        // Just return the first result
180        Ok(results.into_iter().next().unwrap())
181    }
182    
183    /// Format the final query result
184    fn format_result(&self, mut result: QueryResult, execution_time: Duration) -> QueryResult {
185        result.execution_time = execution_time;
186        
187        // Ensure we have proper column metadata if missing
188        if result.columns.is_empty() && !result.rows.is_empty() {
189            let first_row = &result.rows[0];
190            for (i, value) in first_row.values.iter().enumerate() {
191                let data_type = match value {
192                    Value::Integer(_) => DataType::Integer,
193                    Value::Float(_) => DataType::Float,
194                    Value::Text(_) => DataType::Text,
195                    Value::Boolean(_) => DataType::Boolean,
196                    Value::Date(_) => DataType::Date,
197                    Value::DateTime(_) => DataType::DateTime,
198                    Value::Json(_) => DataType::Json,
199                    Value::Binary(_) => DataType::Binary,
200                    Value::Null => DataType::Text, // Default for null values
201                };
202                
203                result.columns.push(ColumnMetadata {
204                    name: format!("column_{}", i),
205                    data_type,
206                    nullable: true,
207                });
208            }
209        }
210        
211        result
212    }
213}
214
215impl Default for DefaultQueryExecutor {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221#[async_trait]
222impl QueryExecutor for DefaultQueryExecutor {
223    async fn execute_plan(&self, plan: &ExecutionPlan) -> NirvResult<QueryResult> {
224        let start_time = Instant::now();
225        
226        if plan.is_empty() {
227            let execution_time = start_time.elapsed();
228            return Ok(self.format_result(QueryResult::new(), execution_time));
229        }
230        
231        // Execute the root node (last node in the plan)
232        // The root node will recursively execute its dependencies
233        let root_node = plan.root_node().ok_or_else(|| {
234            NirvError::Internal("No root node found in execution plan".to_string())
235        })?;
236        
237        let final_result = self.execute_node(root_node).await?;
238        
239        let execution_time = start_time.elapsed();
240        Ok(self.format_result(final_result, execution_time))
241    }
242    
243    async fn execute_node(&self, node: &PlanNode) -> NirvResult<QueryResult> {
244        match node {
245            PlanNode::TableScan { source, projections, predicates } => {
246                self.execute_table_scan(source, projections, predicates).await
247            }
248            PlanNode::Limit { count, input } => {
249                let input_result = self.execute_node(input).await?;
250                Ok(self.apply_limit(input_result, *count))
251            }
252            PlanNode::Sort { order_by, input } => {
253                let input_result = self.execute_node(input).await?;
254                self.apply_sort(input_result, order_by)
255            }
256            PlanNode::Projection { columns, input } => {
257                let input_result = self.execute_node(input).await?;
258                self.apply_projection(input_result, columns)
259            }
260        }
261    }
262    
263    fn set_connector_registry(&mut self, registry: ConnectorRegistry) {
264        self.connector_registry = Some(registry);
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::{
272        engine::{ExecutionPlan, PlanNode},
273        connectors::{MockConnector, ConnectorRegistry},
274        utils::types::{DataSource, Column, Predicate, PredicateOperator, PredicateValue, OrderBy, OrderColumn, OrderDirection},
275    };
276
277    #[test]
278    fn test_default_query_executor_creation() {
279        let executor = DefaultQueryExecutor::new();
280        
281        // Should not have a connector registry initially
282        assert!(executor.get_connector_registry().is_err());
283    }
284
285    #[test]
286    fn test_query_executor_with_connector_registry() {
287        let registry = ConnectorRegistry::new();
288        let executor = DefaultQueryExecutor::with_connector_registry(registry);
289        
290        // Should have a connector registry
291        assert!(executor.get_connector_registry().is_ok());
292    }
293
294    #[test]
295    fn test_query_executor_set_connector_registry() {
296        let mut executor = DefaultQueryExecutor::new();
297        let registry = ConnectorRegistry::new();
298        
299        executor.set_connector_registry(registry);
300        
301        // Should now have a connector registry
302        assert!(executor.get_connector_registry().is_ok());
303    }
304
305    #[tokio::test]
306    async fn test_query_executor_empty_plan() {
307        let executor = DefaultQueryExecutor::new();
308        let plan = ExecutionPlan::new();
309        
310        let result = executor.execute_plan(&plan).await;
311        assert!(result.is_ok());
312        
313        let query_result = result.unwrap();
314        assert!(query_result.is_empty());
315        assert!(query_result.execution_time > Duration::from_millis(0));
316    }
317
318    #[tokio::test]
319    async fn test_query_executor_no_connector_registry() {
320        let executor = DefaultQueryExecutor::new();
321        
322        let plan = ExecutionPlan {
323            nodes: vec![
324                PlanNode::TableScan {
325                    source: DataSource {
326                        object_type: "mock".to_string(),
327                        identifier: "test".to_string(),
328                        alias: None,
329                    },
330                    projections: vec![],
331                    predicates: vec![],
332                }
333            ],
334            estimated_cost: 1.0,
335        };
336        
337        let result = executor.execute_plan(&plan).await;
338        assert!(result.is_err());
339        
340        match result.unwrap_err() {
341            NirvError::Internal(msg) => {
342                assert!(msg.contains("No connector registry"));
343            }
344            _ => panic!("Expected Internal error"),
345        }
346    }
347
348    #[test]
349    fn test_apply_limit() {
350        let executor = DefaultQueryExecutor::new();
351        
352        let mut result = QueryResult::new();
353        result.rows = vec![
354            Row::new(vec![Value::Integer(1)]),
355            Row::new(vec![Value::Integer(2)]),
356            Row::new(vec![Value::Integer(3)]),
357            Row::new(vec![Value::Integer(4)]),
358            Row::new(vec![Value::Integer(5)]),
359        ];
360        
361        let limited_result = executor.apply_limit(result, 3);
362        assert_eq!(limited_result.row_count(), 3);
363        
364        // Check that the first 3 rows are preserved
365        assert_eq!(limited_result.rows[0].get(0), Some(&Value::Integer(1)));
366        assert_eq!(limited_result.rows[1].get(0), Some(&Value::Integer(2)));
367        assert_eq!(limited_result.rows[2].get(0), Some(&Value::Integer(3)));
368    }
369
370    #[test]
371    fn test_apply_limit_no_truncation() {
372        let executor = DefaultQueryExecutor::new();
373        
374        let mut result = QueryResult::new();
375        result.rows = vec![
376            Row::new(vec![Value::Integer(1)]),
377            Row::new(vec![Value::Integer(2)]),
378        ];
379        
380        let limited_result = executor.apply_limit(result, 5);
381        assert_eq!(limited_result.row_count(), 2); // No truncation needed
382    }
383
384    #[test]
385    fn test_compare_values() {
386        let executor = DefaultQueryExecutor::new();
387        
388        // Test integer comparison
389        assert_eq!(
390            executor.compare_values(&Value::Integer(1), &Value::Integer(2)),
391            std::cmp::Ordering::Less
392        );
393        
394        // Test string comparison
395        assert_eq!(
396            executor.compare_values(&Value::Text("apple".to_string()), &Value::Text("banana".to_string())),
397            std::cmp::Ordering::Less
398        );
399        
400        // Test null comparison
401        assert_eq!(
402            executor.compare_values(&Value::Null, &Value::Integer(1)),
403            std::cmp::Ordering::Less
404        );
405        
406        // Test equal values
407        assert_eq!(
408            executor.compare_values(&Value::Integer(5), &Value::Integer(5)),
409            std::cmp::Ordering::Equal
410        );
411    }
412
413    #[test]
414    fn test_apply_sort_ascending() {
415        let executor = DefaultQueryExecutor::new();
416        
417        let mut result = QueryResult::new();
418        result.columns = vec![
419            ColumnMetadata {
420                name: "value".to_string(),
421                data_type: DataType::Integer,
422                nullable: false,
423            }
424        ];
425        result.rows = vec![
426            Row::new(vec![Value::Integer(3)]),
427            Row::new(vec![Value::Integer(1)]),
428            Row::new(vec![Value::Integer(2)]),
429        ];
430        
431        let order_by = OrderBy {
432            columns: vec![OrderColumn {
433                column: "value".to_string(),
434                direction: OrderDirection::Ascending,
435            }],
436        };
437        
438        let sorted_result = executor.apply_sort(result, &order_by).unwrap();
439        
440        assert_eq!(sorted_result.rows[0].get(0), Some(&Value::Integer(1)));
441        assert_eq!(sorted_result.rows[1].get(0), Some(&Value::Integer(2)));
442        assert_eq!(sorted_result.rows[2].get(0), Some(&Value::Integer(3)));
443    }
444
445    #[test]
446    fn test_apply_sort_descending() {
447        let executor = DefaultQueryExecutor::new();
448        
449        let mut result = QueryResult::new();
450        result.columns = vec![
451            ColumnMetadata {
452                name: "name".to_string(),
453                data_type: DataType::Text,
454                nullable: false,
455            }
456        ];
457        result.rows = vec![
458            Row::new(vec![Value::Text("Alice".to_string())]),
459            Row::new(vec![Value::Text("Charlie".to_string())]),
460            Row::new(vec![Value::Text("Bob".to_string())]),
461        ];
462        
463        let order_by = OrderBy {
464            columns: vec![OrderColumn {
465                column: "name".to_string(),
466                direction: OrderDirection::Descending,
467            }],
468        };
469        
470        let sorted_result = executor.apply_sort(result, &order_by).unwrap();
471        
472        assert_eq!(sorted_result.rows[0].get(0), Some(&Value::Text("Charlie".to_string())));
473        assert_eq!(sorted_result.rows[1].get(0), Some(&Value::Text("Bob".to_string())));
474        assert_eq!(sorted_result.rows[2].get(0), Some(&Value::Text("Alice".to_string())));
475    }
476
477    #[test]
478    fn test_apply_sort_nonexistent_column() {
479        let executor = DefaultQueryExecutor::new();
480        
481        let mut result = QueryResult::new();
482        result.columns = vec![
483            ColumnMetadata {
484                name: "value".to_string(),
485                data_type: DataType::Integer,
486                nullable: false,
487            }
488        ];
489        result.rows = vec![Row::new(vec![Value::Integer(1)])];
490        
491        let order_by = OrderBy {
492            columns: vec![OrderColumn {
493                column: "nonexistent".to_string(),
494                direction: OrderDirection::Ascending,
495            }],
496        };
497        
498        let result = executor.apply_sort(result, &order_by);
499        assert!(result.is_err());
500        
501        match result.unwrap_err() {
502            NirvError::Internal(msg) => {
503                assert!(msg.contains("Sort column 'nonexistent' not found"));
504            }
505            _ => panic!("Expected Internal error"),
506        }
507    }
508
509    #[test]
510    fn test_format_result() {
511        let executor = DefaultQueryExecutor::new();
512        
513        let mut result = QueryResult::new();
514        result.rows = vec![
515            Row::new(vec![Value::Integer(1), Value::Text("Alice".to_string())]),
516            Row::new(vec![Value::Integer(2), Value::Text("Bob".to_string())]),
517        ];
518        
519        let execution_time = Duration::from_millis(100);
520        let formatted_result = executor.format_result(result, execution_time);
521        
522        assert_eq!(formatted_result.execution_time, execution_time);
523        assert_eq!(formatted_result.columns.len(), 2);
524        assert_eq!(formatted_result.columns[0].name, "column_0");
525        assert_eq!(formatted_result.columns[0].data_type, DataType::Integer);
526        assert_eq!(formatted_result.columns[1].name, "column_1");
527        assert_eq!(formatted_result.columns[1].data_type, DataType::Text);
528    }
529
530    #[test]
531    fn test_format_result_with_existing_columns() {
532        let executor = DefaultQueryExecutor::new();
533        
534        let mut result = QueryResult::new();
535        result.columns = vec![
536            ColumnMetadata {
537                name: "id".to_string(),
538                data_type: DataType::Integer,
539                nullable: false,
540            }
541        ];
542        result.rows = vec![Row::new(vec![Value::Integer(1)])];
543        
544        let execution_time = Duration::from_millis(50);
545        let formatted_result = executor.format_result(result, execution_time);
546        
547        assert_eq!(formatted_result.execution_time, execution_time);
548        assert_eq!(formatted_result.columns.len(), 1);
549        assert_eq!(formatted_result.columns[0].name, "id");
550    }
551
552    #[test]
553    fn test_aggregate_results_empty() {
554        let executor = DefaultQueryExecutor::new();
555        
556        let result = executor.aggregate_results(vec![]).unwrap();
557        assert!(result.is_empty());
558    }
559
560    #[test]
561    fn test_aggregate_results_single() {
562        let executor = DefaultQueryExecutor::new();
563        
564        let mut query_result = QueryResult::new();
565        query_result.rows = vec![Row::new(vec![Value::Integer(1)])];
566        
567        let result = executor.aggregate_results(vec![query_result]).unwrap();
568        assert_eq!(result.row_count(), 1);
569    }
570
571    #[test]
572    fn test_aggregate_results_multiple() {
573        let executor = DefaultQueryExecutor::new();
574        
575        let mut result1 = QueryResult::new();
576        result1.rows = vec![Row::new(vec![Value::Integer(1)])];
577        
578        let mut result2 = QueryResult::new();
579        result2.rows = vec![Row::new(vec![Value::Integer(2)])];
580        
581        // For MVP, should return the first result
582        let result = executor.aggregate_results(vec![result1, result2]).unwrap();
583        assert_eq!(result.row_count(), 1);
584        assert_eq!(result.rows[0].get(0), Some(&Value::Integer(1)));
585    }
586}