1use async_trait::async_trait;
2use std::collections::HashMap;
3use crate::utils::{
4 types::{InternalQuery, ConnectorQuery, QueryResult, DataSource},
5 error::{NirvResult, DispatcherError, NirvError},
6};
7use crate::connectors::{Connector, ConnectorRegistry};
8
9#[async_trait]
11pub trait Dispatcher: Send + Sync {
12 async fn register_connector(&mut self, object_type: &str, connector: Box<dyn Connector>) -> NirvResult<()>;
14
15 async fn route_query(&self, query: &InternalQuery) -> NirvResult<Vec<ConnectorQuery>>;
17
18 async fn execute_distributed_query(&self, queries: Vec<ConnectorQuery>) -> NirvResult<QueryResult>;
20
21 fn list_available_types(&self) -> Vec<String>;
23
24 fn is_type_registered(&self, object_type: &str) -> bool;
26
27 fn get_connector(&self, object_type: &str) -> Option<&dyn Connector>;
29}
30
31#[derive(Debug)]
33pub struct DataObjectTypeRegistry {
34 type_to_connector: HashMap<String, String>,
36 connector_capabilities: HashMap<String, ConnectorCapabilities>,
38}
39
40#[derive(Debug, Clone)]
42pub struct ConnectorCapabilities {
43 pub supports_joins: bool,
44 pub supports_aggregations: bool,
45 pub supports_subqueries: bool,
46 pub max_concurrent_queries: Option<u32>,
47}
48
49impl DataObjectTypeRegistry {
50 pub fn new() -> Self {
52 Self {
53 type_to_connector: HashMap::new(),
54 connector_capabilities: HashMap::new(),
55 }
56 }
57
58 pub fn register_type(&mut self, object_type: &str, connector_name: &str, capabilities: ConnectorCapabilities) -> NirvResult<()> {
60 if self.type_to_connector.contains_key(object_type) {
61 return Err(NirvError::Dispatcher(DispatcherError::RegistrationFailed(
62 format!("Data object type '{}' is already registered", object_type)
63 )));
64 }
65
66 self.type_to_connector.insert(object_type.to_string(), connector_name.to_string());
67 self.connector_capabilities.insert(connector_name.to_string(), capabilities);
68 Ok(())
69 }
70
71 pub fn get_connector_for_type(&self, object_type: &str) -> Option<&String> {
73 self.type_to_connector.get(object_type)
74 }
75
76 pub fn get_connector_capabilities(&self, connector_name: &str) -> Option<&ConnectorCapabilities> {
78 self.connector_capabilities.get(connector_name)
79 }
80
81 pub fn list_types(&self) -> Vec<String> {
83 self.type_to_connector.keys().cloned().collect()
84 }
85
86 pub fn is_type_registered(&self, object_type: &str) -> bool {
88 self.type_to_connector.contains_key(object_type)
89 }
90
91 pub fn unregister_type(&mut self, object_type: &str) -> Option<String> {
93 self.type_to_connector.remove(object_type)
94 }
95}
96
97impl Default for DataObjectTypeRegistry {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103pub struct DefaultDispatcher {
105 connector_registry: ConnectorRegistry,
107 type_registry: DataObjectTypeRegistry,
109}
110
111impl DefaultDispatcher {
112 pub fn new() -> Self {
114 Self {
115 connector_registry: ConnectorRegistry::new(),
116 type_registry: DataObjectTypeRegistry::new(),
117 }
118 }
119
120 pub fn with_registries(connector_registry: ConnectorRegistry, type_registry: DataObjectTypeRegistry) -> Self {
122 Self {
123 connector_registry,
124 type_registry,
125 }
126 }
127
128 fn extract_data_sources<'a>(&self, query: &'a InternalQuery) -> Vec<&'a DataSource> {
130 query.sources.iter().collect()
131 }
132
133 fn validate_data_sources(&self, sources: &[&DataSource]) -> NirvResult<()> {
135 for source in sources {
136 if !self.type_registry.is_type_registered(&source.object_type) {
137 return Err(NirvError::Dispatcher(DispatcherError::UnregisteredObjectType(
138 format!("Data object type '{}' is not registered. Available types: {:?}",
139 source.object_type,
140 self.type_registry.list_types())
141 )));
142 }
143 }
144 Ok(())
145 }
146
147 fn create_connector_queries(&self, query: &InternalQuery, sources: &[&DataSource]) -> NirvResult<Vec<ConnectorQuery>> {
149 let mut connector_queries = Vec::new();
150
151 for source in sources {
152 let connector_name = self.type_registry
153 .get_connector_for_type(&source.object_type)
154 .ok_or_else(|| NirvError::Dispatcher(DispatcherError::UnregisteredObjectType(
155 source.object_type.clone()
156 )))?;
157
158 let connector = self.connector_registry
159 .get(connector_name)
160 .ok_or_else(|| NirvError::Dispatcher(DispatcherError::NoSuitableConnector))?;
161
162 let connector_query = ConnectorQuery {
163 connector_type: connector.get_connector_type(),
164 query: query.clone(),
165 connection_params: HashMap::new(),
166 };
167
168 connector_queries.push(connector_query);
169 }
170
171 Ok(connector_queries)
172 }
173}
174
175impl Default for DefaultDispatcher {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181#[async_trait]
182impl Dispatcher for DefaultDispatcher {
183 async fn register_connector(&mut self, object_type: &str, connector: Box<dyn Connector>) -> NirvResult<()> {
184 let connector_name = format!("{}_{}", object_type, self.connector_registry.len());
185 let capabilities = ConnectorCapabilities {
186 supports_joins: connector.get_capabilities().supports_joins,
187 supports_aggregations: connector.get_capabilities().supports_aggregations,
188 supports_subqueries: connector.get_capabilities().supports_subqueries,
189 max_concurrent_queries: connector.get_capabilities().max_concurrent_queries,
190 };
191
192 self.connector_registry.register(connector_name.clone(), connector)?;
194
195 self.type_registry.register_type(object_type, &connector_name, capabilities)?;
197
198 Ok(())
199 }
200
201 async fn route_query(&self, query: &InternalQuery) -> NirvResult<Vec<ConnectorQuery>> {
202 let sources = self.extract_data_sources(query);
204
205 if sources.is_empty() {
206 return Err(NirvError::Dispatcher(DispatcherError::RoutingFailed(
207 "No data sources found in query".to_string()
208 )));
209 }
210
211 self.validate_data_sources(&sources)?;
213
214 if sources.len() > 1 {
216 return Err(NirvError::Dispatcher(DispatcherError::CrossConnectorJoinUnsupported));
217 }
218
219 self.create_connector_queries(query, &sources)
221 }
222
223 async fn execute_distributed_query(&self, queries: Vec<ConnectorQuery>) -> NirvResult<QueryResult> {
224 if queries.is_empty() {
225 return Ok(QueryResult::new());
226 }
227
228 if queries.len() > 1 {
230 return Err(NirvError::Dispatcher(DispatcherError::CrossConnectorJoinUnsupported));
231 }
232
233 let connector_query = &queries[0];
234 let connector_name = self.type_registry
235 .get_connector_for_type(&connector_query.query.sources[0].object_type)
236 .ok_or_else(|| NirvError::Dispatcher(DispatcherError::UnregisteredObjectType(
237 connector_query.query.sources[0].object_type.clone()
238 )))?;
239
240 let connector = self.connector_registry
241 .get(connector_name)
242 .ok_or_else(|| NirvError::Dispatcher(DispatcherError::NoSuitableConnector))?;
243
244 connector.execute_query(connector_query.clone()).await
245 }
246
247 fn list_available_types(&self) -> Vec<String> {
248 self.type_registry.list_types()
249 }
250
251 fn is_type_registered(&self, object_type: &str) -> bool {
252 self.type_registry.is_type_registered(object_type)
253 }
254
255 fn get_connector(&self, object_type: &str) -> Option<&dyn Connector> {
256 let connector_name = self.type_registry.get_connector_for_type(object_type)?;
257 self.connector_registry.get(connector_name)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::utils::types::{QueryOperation, ConnectorType, Schema, ColumnMetadata, DataType};
265 use crate::connectors::{ConnectorInitConfig, ConnectorCapabilities as ConnectorTraitCapabilities};
266 use std::time::Duration;
267
268 struct TestConnector {
270 connector_type: ConnectorType,
271 connected: bool,
272 capabilities: ConnectorTraitCapabilities,
273 }
274
275 impl TestConnector {
276 fn new(connector_type: ConnectorType) -> Self {
277 Self {
278 connector_type,
279 connected: false,
280 capabilities: ConnectorTraitCapabilities::default(),
281 }
282 }
283
284 fn with_capabilities(mut self, capabilities: ConnectorTraitCapabilities) -> Self {
285 self.capabilities = capabilities;
286 self
287 }
288 }
289
290 #[async_trait]
291 impl Connector for TestConnector {
292 async fn connect(&mut self, _config: ConnectorInitConfig) -> NirvResult<()> {
293 self.connected = true;
294 Ok(())
295 }
296
297 async fn execute_query(&self, _query: ConnectorQuery) -> NirvResult<QueryResult> {
298 let mut result = QueryResult::new();
299 result.execution_time = Duration::from_millis(10);
300 Ok(result)
301 }
302
303 async fn get_schema(&self, object_name: &str) -> NirvResult<Schema> {
304 Ok(Schema {
305 name: object_name.to_string(),
306 columns: vec![
307 ColumnMetadata {
308 name: "id".to_string(),
309 data_type: DataType::Integer,
310 nullable: false,
311 },
312 ColumnMetadata {
313 name: "name".to_string(),
314 data_type: DataType::Text,
315 nullable: true,
316 },
317 ],
318 primary_key: Some(vec!["id".to_string()]),
319 indexes: vec![],
320 })
321 }
322
323 async fn disconnect(&mut self) -> NirvResult<()> {
324 self.connected = false;
325 Ok(())
326 }
327
328 fn get_connector_type(&self) -> ConnectorType {
329 self.connector_type.clone()
330 }
331
332 fn supports_transactions(&self) -> bool {
333 self.capabilities.supports_transactions
334 }
335
336 fn is_connected(&self) -> bool {
337 self.connected
338 }
339
340 fn get_capabilities(&self) -> ConnectorTraitCapabilities {
341 self.capabilities.clone()
342 }
343 }
344
345 #[test]
346 fn test_data_object_type_registry_creation() {
347 let registry = DataObjectTypeRegistry::new();
348
349 assert!(registry.list_types().is_empty());
350 assert!(!registry.is_type_registered("test_type"));
351 }
352
353 #[test]
354 fn test_data_object_type_registry_register_type() {
355 let mut registry = DataObjectTypeRegistry::new();
356 let capabilities = ConnectorCapabilities {
357 supports_joins: true,
358 supports_aggregations: false,
359 supports_subqueries: true,
360 max_concurrent_queries: Some(5),
361 };
362
363 let result = registry.register_type("postgres", "postgres_connector", capabilities.clone());
364 assert!(result.is_ok());
365
366 assert!(registry.is_type_registered("postgres"));
367 assert_eq!(registry.get_connector_for_type("postgres"), Some(&"postgres_connector".to_string()));
368
369 let retrieved_capabilities = registry.get_connector_capabilities("postgres_connector");
370 assert!(retrieved_capabilities.is_some());
371 assert!(retrieved_capabilities.unwrap().supports_joins);
372 assert!(!retrieved_capabilities.unwrap().supports_aggregations);
373 }
374
375 #[test]
376 fn test_data_object_type_registry_duplicate_registration() {
377 let mut registry = DataObjectTypeRegistry::new();
378 let capabilities = ConnectorCapabilities {
379 supports_joins: false,
380 supports_aggregations: false,
381 supports_subqueries: false,
382 max_concurrent_queries: Some(1),
383 };
384
385 let result1 = registry.register_type("postgres", "connector1", capabilities.clone());
387 assert!(result1.is_ok());
388
389 let result2 = registry.register_type("postgres", "connector2", capabilities);
391 assert!(result2.is_err());
392
393 match result2.unwrap_err() {
394 NirvError::Dispatcher(DispatcherError::RegistrationFailed(msg)) => {
395 assert!(msg.contains("already registered"));
396 }
397 _ => panic!("Expected RegistrationFailed error"),
398 }
399 }
400
401 #[test]
402 fn test_data_object_type_registry_list_types() {
403 let mut registry = DataObjectTypeRegistry::new();
404 let capabilities = ConnectorCapabilities {
405 supports_joins: false,
406 supports_aggregations: false,
407 supports_subqueries: false,
408 max_concurrent_queries: Some(1),
409 };
410
411 registry.register_type("postgres", "pg_connector", capabilities.clone()).unwrap();
412 registry.register_type("mysql", "mysql_connector", capabilities.clone()).unwrap();
413 registry.register_type("file", "file_connector", capabilities).unwrap();
414
415 let types = registry.list_types();
416 assert_eq!(types.len(), 3);
417 assert!(types.contains(&"postgres".to_string()));
418 assert!(types.contains(&"mysql".to_string()));
419 assert!(types.contains(&"file".to_string()));
420 }
421
422 #[test]
423 fn test_data_object_type_registry_unregister_type() {
424 let mut registry = DataObjectTypeRegistry::new();
425 let capabilities = ConnectorCapabilities {
426 supports_joins: false,
427 supports_aggregations: false,
428 supports_subqueries: false,
429 max_concurrent_queries: Some(1),
430 };
431
432 registry.register_type("postgres", "pg_connector", capabilities).unwrap();
433 assert!(registry.is_type_registered("postgres"));
434
435 let removed = registry.unregister_type("postgres");
436 assert_eq!(removed, Some("pg_connector".to_string()));
437 assert!(!registry.is_type_registered("postgres"));
438
439 let non_existent = registry.unregister_type("non_existent");
441 assert_eq!(non_existent, None);
442 }
443
444 #[test]
445 fn test_default_dispatcher_creation() {
446 let dispatcher = DefaultDispatcher::new();
447
448 assert!(dispatcher.list_available_types().is_empty());
449 assert!(!dispatcher.is_type_registered("test_type"));
450 }
451
452 #[tokio::test]
453 async fn test_dispatcher_register_connector() {
454 let mut dispatcher = DefaultDispatcher::new();
455 let connector = Box::new(TestConnector::new(ConnectorType::Mock));
456
457 let result = dispatcher.register_connector("mock", connector).await;
458 assert!(result.is_ok());
459
460 assert!(dispatcher.is_type_registered("mock"));
461 assert_eq!(dispatcher.list_available_types(), vec!["mock".to_string()]);
462 }
463
464 #[tokio::test]
465 async fn test_dispatcher_register_multiple_connectors() {
466 let mut dispatcher = DefaultDispatcher::new();
467
468 let mock_connector = Box::new(TestConnector::new(ConnectorType::Mock));
469 let postgres_connector = Box::new(TestConnector::new(ConnectorType::PostgreSQL));
470
471 dispatcher.register_connector("mock", mock_connector).await.unwrap();
472 dispatcher.register_connector("postgres", postgres_connector).await.unwrap();
473
474 let types = dispatcher.list_available_types();
475 assert_eq!(types.len(), 2);
476 assert!(types.contains(&"mock".to_string()));
477 assert!(types.contains(&"postgres".to_string()));
478 }
479
480 #[tokio::test]
481 async fn test_dispatcher_get_connector() {
482 let mut dispatcher = DefaultDispatcher::new();
483 let connector = Box::new(TestConnector::new(ConnectorType::Mock));
484
485 dispatcher.register_connector("mock", connector).await.unwrap();
486
487 let retrieved = dispatcher.get_connector("mock");
488 assert!(retrieved.is_some());
489 assert_eq!(retrieved.unwrap().get_connector_type(), ConnectorType::Mock);
490
491 let non_existent = dispatcher.get_connector("non_existent");
492 assert!(non_existent.is_none());
493 }
494
495 #[tokio::test]
496 async fn test_dispatcher_route_query_single_source() {
497 let mut dispatcher = DefaultDispatcher::new();
498 let connector = Box::new(TestConnector::new(ConnectorType::Mock));
499
500 dispatcher.register_connector("mock", connector).await.unwrap();
501
502 let mut query = InternalQuery::new(QueryOperation::Select);
503 query.sources.push(DataSource {
504 object_type: "mock".to_string(),
505 identifier: "test_table".to_string(),
506 alias: None,
507 });
508
509 let result = dispatcher.route_query(&query).await;
510 assert!(result.is_ok());
511
512 let connector_queries = result.unwrap();
513 assert_eq!(connector_queries.len(), 1);
514 assert_eq!(connector_queries[0].connector_type, ConnectorType::Mock);
515 }
516
517 #[tokio::test]
518 async fn test_dispatcher_route_query_unregistered_type() {
519 let dispatcher = DefaultDispatcher::new();
520
521 let mut query = InternalQuery::new(QueryOperation::Select);
522 query.sources.push(DataSource {
523 object_type: "unregistered".to_string(),
524 identifier: "test_table".to_string(),
525 alias: None,
526 });
527
528 let result = dispatcher.route_query(&query).await;
529 assert!(result.is_err());
530
531 match result.unwrap_err() {
532 NirvError::Dispatcher(DispatcherError::UnregisteredObjectType(msg)) => {
533 assert!(msg.contains("unregistered"));
534 assert!(msg.contains("not registered"));
535 }
536 _ => panic!("Expected UnregisteredObjectType error"),
537 }
538 }
539
540 #[tokio::test]
541 async fn test_dispatcher_route_query_no_sources() {
542 let dispatcher = DefaultDispatcher::new();
543 let query = InternalQuery::new(QueryOperation::Select);
544
545 let result = dispatcher.route_query(&query).await;
546 assert!(result.is_err());
547
548 match result.unwrap_err() {
549 NirvError::Dispatcher(DispatcherError::RoutingFailed(msg)) => {
550 assert!(msg.contains("No data sources found"));
551 }
552 _ => panic!("Expected RoutingFailed error"),
553 }
554 }
555
556 #[tokio::test]
557 async fn test_dispatcher_route_query_multiple_sources_unsupported() {
558 let mut dispatcher = DefaultDispatcher::new();
559 let connector = Box::new(TestConnector::new(ConnectorType::Mock));
560
561 dispatcher.register_connector("mock", connector).await.unwrap();
562
563 let mut query = InternalQuery::new(QueryOperation::Select);
564 query.sources.push(DataSource {
565 object_type: "mock".to_string(),
566 identifier: "table1".to_string(),
567 alias: None,
568 });
569 query.sources.push(DataSource {
570 object_type: "mock".to_string(),
571 identifier: "table2".to_string(),
572 alias: None,
573 });
574
575 let result = dispatcher.route_query(&query).await;
576 assert!(result.is_err());
577
578 match result.unwrap_err() {
579 NirvError::Dispatcher(DispatcherError::CrossConnectorJoinUnsupported) => {},
580 _ => panic!("Expected CrossConnectorJoinUnsupported error"),
581 }
582 }
583
584 #[tokio::test]
585 async fn test_dispatcher_execute_distributed_query() {
586 let mut dispatcher = DefaultDispatcher::new();
587 let connector = Box::new(TestConnector::new(ConnectorType::Mock));
588
589 dispatcher.register_connector("mock", connector).await.unwrap();
590
591 let mut query = InternalQuery::new(QueryOperation::Select);
592 query.sources.push(DataSource {
593 object_type: "mock".to_string(),
594 identifier: "test_table".to_string(),
595 alias: None,
596 });
597
598 let connector_query = ConnectorQuery {
599 connector_type: ConnectorType::Mock,
600 query,
601 connection_params: HashMap::new(),
602 };
603
604 let result = dispatcher.execute_distributed_query(vec![connector_query]).await;
605 assert!(result.is_ok());
606
607 let query_result = result.unwrap();
608 assert!(query_result.execution_time > Duration::from_millis(0));
609 }
610
611 #[tokio::test]
612 async fn test_dispatcher_execute_distributed_query_empty() {
613 let dispatcher = DefaultDispatcher::new();
614
615 let result = dispatcher.execute_distributed_query(vec![]).await;
616 assert!(result.is_ok());
617
618 let query_result = result.unwrap();
619 assert_eq!(query_result.row_count(), 0);
620 }
621
622 #[tokio::test]
623 async fn test_dispatcher_execute_distributed_query_multiple_unsupported() {
624 let dispatcher = DefaultDispatcher::new();
625
626 let query1 = ConnectorQuery {
627 connector_type: ConnectorType::Mock,
628 query: InternalQuery::new(QueryOperation::Select),
629 connection_params: HashMap::new(),
630 };
631
632 let query2 = ConnectorQuery {
633 connector_type: ConnectorType::PostgreSQL,
634 query: InternalQuery::new(QueryOperation::Select),
635 connection_params: HashMap::new(),
636 };
637
638 let result = dispatcher.execute_distributed_query(vec![query1, query2]).await;
639 assert!(result.is_err());
640
641 match result.unwrap_err() {
642 NirvError::Dispatcher(DispatcherError::CrossConnectorJoinUnsupported) => {},
643 _ => panic!("Expected CrossConnectorJoinUnsupported error"),
644 }
645 }
646
647 #[test]
648 fn test_connector_capabilities_creation() {
649 let capabilities = ConnectorCapabilities {
650 supports_joins: true,
651 supports_aggregations: false,
652 supports_subqueries: true,
653 max_concurrent_queries: Some(10),
654 };
655
656 assert!(capabilities.supports_joins);
657 assert!(!capabilities.supports_aggregations);
658 assert!(capabilities.supports_subqueries);
659 assert_eq!(capabilities.max_concurrent_queries, Some(10));
660 }
661}