1use async_trait::async_trait;
2use std::time::{Duration, Instant};
3use crate::{
4 engine::{ExecutionPlan, PlanNode},
5 connectors::ConnectorRegistry,
6 utils::{
7 types::{QueryResult, Row, Value, ColumnMetadata, DataType, InternalQuery, QueryOperation, ConnectorQuery},
8 error::{NirvResult, NirvError},
9 },
10};
11
12#[async_trait]
14pub trait QueryExecutor: Send + Sync {
15 async fn execute_plan(&self, plan: &ExecutionPlan) -> NirvResult<QueryResult>;
17
18 async fn execute_node(&self, node: &PlanNode) -> NirvResult<QueryResult>;
20
21 fn set_connector_registry(&mut self, registry: ConnectorRegistry);
23}
24
25pub struct DefaultQueryExecutor {
27 connector_registry: Option<ConnectorRegistry>,
29}
30
31impl DefaultQueryExecutor {
32 pub fn new() -> Self {
34 Self {
35 connector_registry: None,
36 }
37 }
38
39 pub fn with_connector_registry(registry: ConnectorRegistry) -> Self {
41 Self {
42 connector_registry: Some(registry),
43 }
44 }
45
46 fn get_connector_registry(&self) -> NirvResult<&ConnectorRegistry> {
48 self.connector_registry.as_ref().ok_or_else(|| {
49 NirvError::Internal("No connector registry configured".to_string())
50 })
51 }
52
53 async fn execute_table_scan(
55 &self,
56 source: &crate::utils::types::DataSource,
57 projections: &[crate::utils::types::Column],
58 predicates: &[crate::utils::types::Predicate],
59 ) -> NirvResult<QueryResult> {
60 let registry = self.get_connector_registry()?;
61
62 let possible_names = vec![
64 source.object_type.clone(),
65 format!("{}_{}", source.object_type, 0),
66 format!("{}_connector", source.object_type),
67 ];
68
69 let mut connector = None;
70 for name in &possible_names {
71 if let Some(c) = registry.get(name) {
72 connector = Some(c);
73 break;
74 }
75 }
76
77 let connector = connector.ok_or_else(|| {
78 NirvError::Internal(format!("No connector found for type: {}", source.object_type))
79 })?;
80
81 let mut internal_query = InternalQuery::new(QueryOperation::Select);
83 internal_query.sources.push(source.clone());
84 internal_query.projections = projections.to_vec();
85 internal_query.predicates = predicates.to_vec();
86
87 let connector_query = ConnectorQuery {
88 connector_type: connector.get_connector_type(),
89 query: internal_query,
90 connection_params: std::collections::HashMap::new(),
91 };
92
93 connector.execute_query(connector_query).await
95 }
96
97 fn apply_limit(&self, mut result: QueryResult, count: u64) -> QueryResult {
99 let limit = count as usize;
100 if result.rows.len() > limit {
101 result.rows.truncate(limit);
102 }
103 result
104 }
105
106 fn apply_sort(&self, mut result: QueryResult, order_by: &crate::utils::types::OrderBy) -> NirvResult<QueryResult> {
108 if order_by.columns.is_empty() {
109 return Ok(result);
110 }
111
112 let sort_column = &order_by.columns[0];
114
115 let column_index = result.columns.iter()
117 .position(|col| col.name == sort_column.column)
118 .ok_or_else(|| {
119 NirvError::Internal(format!("Sort column '{}' not found in result", sort_column.column))
120 })?;
121
122 result.rows.sort_by(|a, b| {
124 let val_a = a.get(column_index).unwrap_or(&Value::Null);
125 let val_b = b.get(column_index).unwrap_or(&Value::Null);
126
127 let comparison = self.compare_values(val_a, val_b);
128
129 match sort_column.direction {
130 crate::utils::types::OrderDirection::Ascending => comparison,
131 crate::utils::types::OrderDirection::Descending => comparison.reverse(),
132 }
133 });
134
135 Ok(result)
136 }
137
138 fn compare_values(&self, a: &Value, b: &Value) -> std::cmp::Ordering {
140 use std::cmp::Ordering;
141
142 match (a, b) {
143 (Value::Null, Value::Null) => Ordering::Equal,
144 (Value::Null, _) => Ordering::Less,
145 (_, Value::Null) => Ordering::Greater,
146 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
147 (Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
148 (Value::Text(a), Value::Text(b)) => a.cmp(b),
149 (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
150 (Value::Date(a), Value::Date(b)) => a.cmp(b),
151 (Value::DateTime(a), Value::DateTime(b)) => a.cmp(b),
152 _ => format!("{:?}", a).cmp(&format!("{:?}", b)),
154 }
155 }
156
157 fn apply_projection(&self, result: QueryResult, columns: &[crate::utils::types::Column]) -> NirvResult<QueryResult> {
159 if columns.is_empty() {
160 return Ok(result);
161 }
162
163 Ok(result)
166 }
167
168 fn aggregate_results(&self, results: Vec<QueryResult>) -> NirvResult<QueryResult> {
170 if results.is_empty() {
171 return Ok(QueryResult::new());
172 }
173
174 if results.len() == 1 {
175 return Ok(results.into_iter().next().unwrap());
176 }
177
178 Ok(results.into_iter().next().unwrap())
181 }
182
183 fn format_result(&self, mut result: QueryResult, execution_time: Duration) -> QueryResult {
185 result.execution_time = execution_time;
186
187 if result.columns.is_empty() && !result.rows.is_empty() {
189 let first_row = &result.rows[0];
190 for (i, value) in first_row.values.iter().enumerate() {
191 let data_type = match value {
192 Value::Integer(_) => DataType::Integer,
193 Value::Float(_) => DataType::Float,
194 Value::Text(_) => DataType::Text,
195 Value::Boolean(_) => DataType::Boolean,
196 Value::Date(_) => DataType::Date,
197 Value::DateTime(_) => DataType::DateTime,
198 Value::Json(_) => DataType::Json,
199 Value::Binary(_) => DataType::Binary,
200 Value::Null => DataType::Text, };
202
203 result.columns.push(ColumnMetadata {
204 name: format!("column_{}", i),
205 data_type,
206 nullable: true,
207 });
208 }
209 }
210
211 result
212 }
213}
214
215impl Default for DefaultQueryExecutor {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221#[async_trait]
222impl QueryExecutor for DefaultQueryExecutor {
223 async fn execute_plan(&self, plan: &ExecutionPlan) -> NirvResult<QueryResult> {
224 let start_time = Instant::now();
225
226 if plan.is_empty() {
227 let execution_time = start_time.elapsed();
228 return Ok(self.format_result(QueryResult::new(), execution_time));
229 }
230
231 let root_node = plan.root_node().ok_or_else(|| {
234 NirvError::Internal("No root node found in execution plan".to_string())
235 })?;
236
237 let final_result = self.execute_node(root_node).await?;
238
239 let execution_time = start_time.elapsed();
240 Ok(self.format_result(final_result, execution_time))
241 }
242
243 async fn execute_node(&self, node: &PlanNode) -> NirvResult<QueryResult> {
244 match node {
245 PlanNode::TableScan { source, projections, predicates } => {
246 self.execute_table_scan(source, projections, predicates).await
247 }
248 PlanNode::Limit { count, input } => {
249 let input_result = self.execute_node(input).await?;
250 Ok(self.apply_limit(input_result, *count))
251 }
252 PlanNode::Sort { order_by, input } => {
253 let input_result = self.execute_node(input).await?;
254 self.apply_sort(input_result, order_by)
255 }
256 PlanNode::Projection { columns, input } => {
257 let input_result = self.execute_node(input).await?;
258 self.apply_projection(input_result, columns)
259 }
260 }
261 }
262
263 fn set_connector_registry(&mut self, registry: ConnectorRegistry) {
264 self.connector_registry = Some(registry);
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::{
272 engine::{ExecutionPlan, PlanNode},
273 connectors::{MockConnector, ConnectorRegistry},
274 utils::types::{DataSource, Column, Predicate, PredicateOperator, PredicateValue, OrderBy, OrderColumn, OrderDirection},
275 };
276
277 #[test]
278 fn test_default_query_executor_creation() {
279 let executor = DefaultQueryExecutor::new();
280
281 assert!(executor.get_connector_registry().is_err());
283 }
284
285 #[test]
286 fn test_query_executor_with_connector_registry() {
287 let registry = ConnectorRegistry::new();
288 let executor = DefaultQueryExecutor::with_connector_registry(registry);
289
290 assert!(executor.get_connector_registry().is_ok());
292 }
293
294 #[test]
295 fn test_query_executor_set_connector_registry() {
296 let mut executor = DefaultQueryExecutor::new();
297 let registry = ConnectorRegistry::new();
298
299 executor.set_connector_registry(registry);
300
301 assert!(executor.get_connector_registry().is_ok());
303 }
304
305 #[tokio::test]
306 async fn test_query_executor_empty_plan() {
307 let executor = DefaultQueryExecutor::new();
308 let plan = ExecutionPlan::new();
309
310 let result = executor.execute_plan(&plan).await;
311 assert!(result.is_ok());
312
313 let query_result = result.unwrap();
314 assert!(query_result.is_empty());
315 assert!(query_result.execution_time > Duration::from_millis(0));
316 }
317
318 #[tokio::test]
319 async fn test_query_executor_no_connector_registry() {
320 let executor = DefaultQueryExecutor::new();
321
322 let plan = ExecutionPlan {
323 nodes: vec![
324 PlanNode::TableScan {
325 source: DataSource {
326 object_type: "mock".to_string(),
327 identifier: "test".to_string(),
328 alias: None,
329 },
330 projections: vec![],
331 predicates: vec![],
332 }
333 ],
334 estimated_cost: 1.0,
335 };
336
337 let result = executor.execute_plan(&plan).await;
338 assert!(result.is_err());
339
340 match result.unwrap_err() {
341 NirvError::Internal(msg) => {
342 assert!(msg.contains("No connector registry"));
343 }
344 _ => panic!("Expected Internal error"),
345 }
346 }
347
348 #[test]
349 fn test_apply_limit() {
350 let executor = DefaultQueryExecutor::new();
351
352 let mut result = QueryResult::new();
353 result.rows = vec![
354 Row::new(vec![Value::Integer(1)]),
355 Row::new(vec![Value::Integer(2)]),
356 Row::new(vec![Value::Integer(3)]),
357 Row::new(vec![Value::Integer(4)]),
358 Row::new(vec![Value::Integer(5)]),
359 ];
360
361 let limited_result = executor.apply_limit(result, 3);
362 assert_eq!(limited_result.row_count(), 3);
363
364 assert_eq!(limited_result.rows[0].get(0), Some(&Value::Integer(1)));
366 assert_eq!(limited_result.rows[1].get(0), Some(&Value::Integer(2)));
367 assert_eq!(limited_result.rows[2].get(0), Some(&Value::Integer(3)));
368 }
369
370 #[test]
371 fn test_apply_limit_no_truncation() {
372 let executor = DefaultQueryExecutor::new();
373
374 let mut result = QueryResult::new();
375 result.rows = vec![
376 Row::new(vec![Value::Integer(1)]),
377 Row::new(vec![Value::Integer(2)]),
378 ];
379
380 let limited_result = executor.apply_limit(result, 5);
381 assert_eq!(limited_result.row_count(), 2); }
383
384 #[test]
385 fn test_compare_values() {
386 let executor = DefaultQueryExecutor::new();
387
388 assert_eq!(
390 executor.compare_values(&Value::Integer(1), &Value::Integer(2)),
391 std::cmp::Ordering::Less
392 );
393
394 assert_eq!(
396 executor.compare_values(&Value::Text("apple".to_string()), &Value::Text("banana".to_string())),
397 std::cmp::Ordering::Less
398 );
399
400 assert_eq!(
402 executor.compare_values(&Value::Null, &Value::Integer(1)),
403 std::cmp::Ordering::Less
404 );
405
406 assert_eq!(
408 executor.compare_values(&Value::Integer(5), &Value::Integer(5)),
409 std::cmp::Ordering::Equal
410 );
411 }
412
413 #[test]
414 fn test_apply_sort_ascending() {
415 let executor = DefaultQueryExecutor::new();
416
417 let mut result = QueryResult::new();
418 result.columns = vec![
419 ColumnMetadata {
420 name: "value".to_string(),
421 data_type: DataType::Integer,
422 nullable: false,
423 }
424 ];
425 result.rows = vec![
426 Row::new(vec![Value::Integer(3)]),
427 Row::new(vec![Value::Integer(1)]),
428 Row::new(vec![Value::Integer(2)]),
429 ];
430
431 let order_by = OrderBy {
432 columns: vec![OrderColumn {
433 column: "value".to_string(),
434 direction: OrderDirection::Ascending,
435 }],
436 };
437
438 let sorted_result = executor.apply_sort(result, &order_by).unwrap();
439
440 assert_eq!(sorted_result.rows[0].get(0), Some(&Value::Integer(1)));
441 assert_eq!(sorted_result.rows[1].get(0), Some(&Value::Integer(2)));
442 assert_eq!(sorted_result.rows[2].get(0), Some(&Value::Integer(3)));
443 }
444
445 #[test]
446 fn test_apply_sort_descending() {
447 let executor = DefaultQueryExecutor::new();
448
449 let mut result = QueryResult::new();
450 result.columns = vec![
451 ColumnMetadata {
452 name: "name".to_string(),
453 data_type: DataType::Text,
454 nullable: false,
455 }
456 ];
457 result.rows = vec![
458 Row::new(vec![Value::Text("Alice".to_string())]),
459 Row::new(vec![Value::Text("Charlie".to_string())]),
460 Row::new(vec![Value::Text("Bob".to_string())]),
461 ];
462
463 let order_by = OrderBy {
464 columns: vec![OrderColumn {
465 column: "name".to_string(),
466 direction: OrderDirection::Descending,
467 }],
468 };
469
470 let sorted_result = executor.apply_sort(result, &order_by).unwrap();
471
472 assert_eq!(sorted_result.rows[0].get(0), Some(&Value::Text("Charlie".to_string())));
473 assert_eq!(sorted_result.rows[1].get(0), Some(&Value::Text("Bob".to_string())));
474 assert_eq!(sorted_result.rows[2].get(0), Some(&Value::Text("Alice".to_string())));
475 }
476
477 #[test]
478 fn test_apply_sort_nonexistent_column() {
479 let executor = DefaultQueryExecutor::new();
480
481 let mut result = QueryResult::new();
482 result.columns = vec![
483 ColumnMetadata {
484 name: "value".to_string(),
485 data_type: DataType::Integer,
486 nullable: false,
487 }
488 ];
489 result.rows = vec![Row::new(vec![Value::Integer(1)])];
490
491 let order_by = OrderBy {
492 columns: vec![OrderColumn {
493 column: "nonexistent".to_string(),
494 direction: OrderDirection::Ascending,
495 }],
496 };
497
498 let result = executor.apply_sort(result, &order_by);
499 assert!(result.is_err());
500
501 match result.unwrap_err() {
502 NirvError::Internal(msg) => {
503 assert!(msg.contains("Sort column 'nonexistent' not found"));
504 }
505 _ => panic!("Expected Internal error"),
506 }
507 }
508
509 #[test]
510 fn test_format_result() {
511 let executor = DefaultQueryExecutor::new();
512
513 let mut result = QueryResult::new();
514 result.rows = vec![
515 Row::new(vec![Value::Integer(1), Value::Text("Alice".to_string())]),
516 Row::new(vec![Value::Integer(2), Value::Text("Bob".to_string())]),
517 ];
518
519 let execution_time = Duration::from_millis(100);
520 let formatted_result = executor.format_result(result, execution_time);
521
522 assert_eq!(formatted_result.execution_time, execution_time);
523 assert_eq!(formatted_result.columns.len(), 2);
524 assert_eq!(formatted_result.columns[0].name, "column_0");
525 assert_eq!(formatted_result.columns[0].data_type, DataType::Integer);
526 assert_eq!(formatted_result.columns[1].name, "column_1");
527 assert_eq!(formatted_result.columns[1].data_type, DataType::Text);
528 }
529
530 #[test]
531 fn test_format_result_with_existing_columns() {
532 let executor = DefaultQueryExecutor::new();
533
534 let mut result = QueryResult::new();
535 result.columns = vec![
536 ColumnMetadata {
537 name: "id".to_string(),
538 data_type: DataType::Integer,
539 nullable: false,
540 }
541 ];
542 result.rows = vec![Row::new(vec![Value::Integer(1)])];
543
544 let execution_time = Duration::from_millis(50);
545 let formatted_result = executor.format_result(result, execution_time);
546
547 assert_eq!(formatted_result.execution_time, execution_time);
548 assert_eq!(formatted_result.columns.len(), 1);
549 assert_eq!(formatted_result.columns[0].name, "id");
550 }
551
552 #[test]
553 fn test_aggregate_results_empty() {
554 let executor = DefaultQueryExecutor::new();
555
556 let result = executor.aggregate_results(vec![]).unwrap();
557 assert!(result.is_empty());
558 }
559
560 #[test]
561 fn test_aggregate_results_single() {
562 let executor = DefaultQueryExecutor::new();
563
564 let mut query_result = QueryResult::new();
565 query_result.rows = vec![Row::new(vec![Value::Integer(1)])];
566
567 let result = executor.aggregate_results(vec![query_result]).unwrap();
568 assert_eq!(result.row_count(), 1);
569 }
570
571 #[test]
572 fn test_aggregate_results_multiple() {
573 let executor = DefaultQueryExecutor::new();
574
575 let mut result1 = QueryResult::new();
576 result1.rows = vec![Row::new(vec![Value::Integer(1)])];
577
578 let mut result2 = QueryResult::new();
579 result2.rows = vec![Row::new(vec![Value::Integer(2)])];
580
581 let result = executor.aggregate_results(vec![result1, result2]).unwrap();
583 assert_eq!(result.row_count(), 1);
584 assert_eq!(result.rows[0].get(0), Some(&Value::Integer(1)));
585 }
586}