nirv_engine/engine/
dispatcher.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use crate::utils::{
4    types::{InternalQuery, ConnectorQuery, QueryResult, DataSource},
5    error::{NirvResult, DispatcherError, NirvError},
6};
7use crate::connectors::{Connector, ConnectorRegistry};
8
9/// Central routing component that manages data object type resolution and connector selection
10#[async_trait]
11pub trait Dispatcher: Send + Sync {
12    /// Register a connector for a specific data object type
13    async fn register_connector(&mut self, object_type: &str, connector: Box<dyn Connector>) -> NirvResult<()>;
14    
15    /// Route a query to appropriate connectors based on data object types
16    async fn route_query(&self, query: &InternalQuery) -> NirvResult<Vec<ConnectorQuery>>;
17    
18    /// Execute a distributed query across multiple connectors
19    async fn execute_distributed_query(&self, queries: Vec<ConnectorQuery>) -> NirvResult<QueryResult>;
20    
21    /// List all available data object types
22    fn list_available_types(&self) -> Vec<String>;
23    
24    /// Check if a data object type is registered
25    fn is_type_registered(&self, object_type: &str) -> bool;
26    
27    /// Get connector for a specific data object type
28    fn get_connector(&self, object_type: &str) -> Option<&dyn Connector>;
29}
30
31/// Data object type registry that maps types to their corresponding connectors
32#[derive(Debug)]
33pub struct DataObjectTypeRegistry {
34    /// Maps data object type names to connector names
35    type_to_connector: HashMap<String, String>,
36    /// Maps connector names to their capabilities
37    connector_capabilities: HashMap<String, ConnectorCapabilities>,
38}
39
40/// Capabilities of a connector for routing decisions
41#[derive(Debug, Clone)]
42pub struct ConnectorCapabilities {
43    pub supports_joins: bool,
44    pub supports_aggregations: bool,
45    pub supports_subqueries: bool,
46    pub max_concurrent_queries: Option<u32>,
47}
48
49impl DataObjectTypeRegistry {
50    /// Create a new empty registry
51    pub fn new() -> Self {
52        Self {
53            type_to_connector: HashMap::new(),
54            connector_capabilities: HashMap::new(),
55        }
56    }
57    
58    /// Register a data object type with its connector
59    pub fn register_type(&mut self, object_type: &str, connector_name: &str, capabilities: ConnectorCapabilities) -> NirvResult<()> {
60        if self.type_to_connector.contains_key(object_type) {
61            return Err(NirvError::Dispatcher(DispatcherError::RegistrationFailed(
62                format!("Data object type '{}' is already registered", object_type)
63            )));
64        }
65        
66        self.type_to_connector.insert(object_type.to_string(), connector_name.to_string());
67        self.connector_capabilities.insert(connector_name.to_string(), capabilities);
68        Ok(())
69    }
70    
71    /// Get the connector name for a data object type
72    pub fn get_connector_for_type(&self, object_type: &str) -> Option<&String> {
73        self.type_to_connector.get(object_type)
74    }
75    
76    /// Get capabilities for a connector
77    pub fn get_connector_capabilities(&self, connector_name: &str) -> Option<&ConnectorCapabilities> {
78        self.connector_capabilities.get(connector_name)
79    }
80    
81    /// List all registered data object types
82    pub fn list_types(&self) -> Vec<String> {
83        self.type_to_connector.keys().cloned().collect()
84    }
85    
86    /// Check if a type is registered
87    pub fn is_type_registered(&self, object_type: &str) -> bool {
88        self.type_to_connector.contains_key(object_type)
89    }
90    
91    /// Unregister a data object type
92    pub fn unregister_type(&mut self, object_type: &str) -> Option<String> {
93        self.type_to_connector.remove(object_type)
94    }
95}
96
97impl Default for DataObjectTypeRegistry {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103/// Default implementation of the Dispatcher trait
104pub struct DefaultDispatcher {
105    /// Registry for managing connectors
106    connector_registry: ConnectorRegistry,
107    /// Registry for mapping data object types to connectors
108    type_registry: DataObjectTypeRegistry,
109}
110
111impl DefaultDispatcher {
112    /// Create a new dispatcher with empty registries
113    pub fn new() -> Self {
114        Self {
115            connector_registry: ConnectorRegistry::new(),
116            type_registry: DataObjectTypeRegistry::new(),
117        }
118    }
119    
120    /// Create a dispatcher with existing registries
121    pub fn with_registries(connector_registry: ConnectorRegistry, type_registry: DataObjectTypeRegistry) -> Self {
122        Self {
123            connector_registry,
124            type_registry,
125        }
126    }
127    
128    /// Extract data sources from a query
129    fn extract_data_sources<'a>(&self, query: &'a InternalQuery) -> Vec<&'a DataSource> {
130        query.sources.iter().collect()
131    }
132    
133    /// Validate that all data sources in a query are registered
134    fn validate_data_sources(&self, sources: &[&DataSource]) -> NirvResult<()> {
135        for source in sources {
136            if !self.type_registry.is_type_registered(&source.object_type) {
137                return Err(NirvError::Dispatcher(DispatcherError::UnregisteredObjectType(
138                    format!("Data object type '{}' is not registered. Available types: {:?}", 
139                           source.object_type, 
140                           self.type_registry.list_types())
141                )));
142            }
143        }
144        Ok(())
145    }
146    
147    /// Create connector queries for single-source routing
148    fn create_connector_queries(&self, query: &InternalQuery, sources: &[&DataSource]) -> NirvResult<Vec<ConnectorQuery>> {
149        let mut connector_queries = Vec::new();
150        
151        for source in sources {
152            let connector_name = self.type_registry
153                .get_connector_for_type(&source.object_type)
154                .ok_or_else(|| NirvError::Dispatcher(DispatcherError::UnregisteredObjectType(
155                    source.object_type.clone()
156                )))?;
157            
158            let connector = self.connector_registry
159                .get(connector_name)
160                .ok_or_else(|| NirvError::Dispatcher(DispatcherError::NoSuitableConnector))?;
161            
162            let connector_query = ConnectorQuery {
163                connector_type: connector.get_connector_type(),
164                query: query.clone(),
165                connection_params: HashMap::new(),
166            };
167            
168            connector_queries.push(connector_query);
169        }
170        
171        Ok(connector_queries)
172    }
173}
174
175impl Default for DefaultDispatcher {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181#[async_trait]
182impl Dispatcher for DefaultDispatcher {
183    async fn register_connector(&mut self, object_type: &str, connector: Box<dyn Connector>) -> NirvResult<()> {
184        let connector_name = format!("{}_{}", object_type, self.connector_registry.len());
185        let capabilities = ConnectorCapabilities {
186            supports_joins: connector.get_capabilities().supports_joins,
187            supports_aggregations: connector.get_capabilities().supports_aggregations,
188            supports_subqueries: connector.get_capabilities().supports_subqueries,
189            max_concurrent_queries: connector.get_capabilities().max_concurrent_queries,
190        };
191        
192        // Register the connector in the connector registry
193        self.connector_registry.register(connector_name.clone(), connector)?;
194        
195        // Register the data object type mapping
196        self.type_registry.register_type(object_type, &connector_name, capabilities)?;
197        
198        Ok(())
199    }
200    
201    async fn route_query(&self, query: &InternalQuery) -> NirvResult<Vec<ConnectorQuery>> {
202        // Extract data sources from the query
203        let sources = self.extract_data_sources(query);
204        
205        if sources.is_empty() {
206            return Err(NirvError::Dispatcher(DispatcherError::RoutingFailed(
207                "No data sources found in query".to_string()
208            )));
209        }
210        
211        // Validate that all data sources are registered
212        self.validate_data_sources(&sources)?;
213        
214        // For MVP, we only support single-source queries
215        if sources.len() > 1 {
216            return Err(NirvError::Dispatcher(DispatcherError::CrossConnectorJoinUnsupported));
217        }
218        
219        // Create connector queries for routing
220        self.create_connector_queries(query, &sources)
221    }
222    
223    async fn execute_distributed_query(&self, queries: Vec<ConnectorQuery>) -> NirvResult<QueryResult> {
224        if queries.is_empty() {
225            return Ok(QueryResult::new());
226        }
227        
228        // For MVP, we only handle single connector queries
229        if queries.len() > 1 {
230            return Err(NirvError::Dispatcher(DispatcherError::CrossConnectorJoinUnsupported));
231        }
232        
233        let connector_query = &queries[0];
234        let connector_name = self.type_registry
235            .get_connector_for_type(&connector_query.query.sources[0].object_type)
236            .ok_or_else(|| NirvError::Dispatcher(DispatcherError::UnregisteredObjectType(
237                connector_query.query.sources[0].object_type.clone()
238            )))?;
239        
240        let connector = self.connector_registry
241            .get(connector_name)
242            .ok_or_else(|| NirvError::Dispatcher(DispatcherError::NoSuitableConnector))?;
243        
244        connector.execute_query(connector_query.clone()).await
245    }
246    
247    fn list_available_types(&self) -> Vec<String> {
248        self.type_registry.list_types()
249    }
250    
251    fn is_type_registered(&self, object_type: &str) -> bool {
252        self.type_registry.is_type_registered(object_type)
253    }
254    
255    fn get_connector(&self, object_type: &str) -> Option<&dyn Connector> {
256        let connector_name = self.type_registry.get_connector_for_type(object_type)?;
257        self.connector_registry.get(connector_name)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::utils::types::{QueryOperation, ConnectorType, Schema, ColumnMetadata, DataType};
265    use crate::connectors::{ConnectorInitConfig, ConnectorCapabilities as ConnectorTraitCapabilities};
266    use std::time::Duration;
267
268    // Mock connector for testing
269    struct TestConnector {
270        connector_type: ConnectorType,
271        connected: bool,
272        capabilities: ConnectorTraitCapabilities,
273    }
274
275    impl TestConnector {
276        fn new(connector_type: ConnectorType) -> Self {
277            Self {
278                connector_type,
279                connected: false,
280                capabilities: ConnectorTraitCapabilities::default(),
281            }
282        }
283        
284        fn with_capabilities(mut self, capabilities: ConnectorTraitCapabilities) -> Self {
285            self.capabilities = capabilities;
286            self
287        }
288    }
289
290    #[async_trait]
291    impl Connector for TestConnector {
292        async fn connect(&mut self, _config: ConnectorInitConfig) -> NirvResult<()> {
293            self.connected = true;
294            Ok(())
295        }
296
297        async fn execute_query(&self, _query: ConnectorQuery) -> NirvResult<QueryResult> {
298            let mut result = QueryResult::new();
299            result.execution_time = Duration::from_millis(10);
300            Ok(result)
301        }
302
303        async fn get_schema(&self, object_name: &str) -> NirvResult<Schema> {
304            Ok(Schema {
305                name: object_name.to_string(),
306                columns: vec![
307                    ColumnMetadata {
308                        name: "id".to_string(),
309                        data_type: DataType::Integer,
310                        nullable: false,
311                    },
312                    ColumnMetadata {
313                        name: "name".to_string(),
314                        data_type: DataType::Text,
315                        nullable: true,
316                    },
317                ],
318                primary_key: Some(vec!["id".to_string()]),
319                indexes: vec![],
320            })
321        }
322
323        async fn disconnect(&mut self) -> NirvResult<()> {
324            self.connected = false;
325            Ok(())
326        }
327
328        fn get_connector_type(&self) -> ConnectorType {
329            self.connector_type.clone()
330        }
331
332        fn supports_transactions(&self) -> bool {
333            self.capabilities.supports_transactions
334        }
335
336        fn is_connected(&self) -> bool {
337            self.connected
338        }
339
340        fn get_capabilities(&self) -> ConnectorTraitCapabilities {
341            self.capabilities.clone()
342        }
343    }
344
345    #[test]
346    fn test_data_object_type_registry_creation() {
347        let registry = DataObjectTypeRegistry::new();
348        
349        assert!(registry.list_types().is_empty());
350        assert!(!registry.is_type_registered("test_type"));
351    }
352
353    #[test]
354    fn test_data_object_type_registry_register_type() {
355        let mut registry = DataObjectTypeRegistry::new();
356        let capabilities = ConnectorCapabilities {
357            supports_joins: true,
358            supports_aggregations: false,
359            supports_subqueries: true,
360            max_concurrent_queries: Some(5),
361        };
362        
363        let result = registry.register_type("postgres", "postgres_connector", capabilities.clone());
364        assert!(result.is_ok());
365        
366        assert!(registry.is_type_registered("postgres"));
367        assert_eq!(registry.get_connector_for_type("postgres"), Some(&"postgres_connector".to_string()));
368        
369        let retrieved_capabilities = registry.get_connector_capabilities("postgres_connector");
370        assert!(retrieved_capabilities.is_some());
371        assert!(retrieved_capabilities.unwrap().supports_joins);
372        assert!(!retrieved_capabilities.unwrap().supports_aggregations);
373    }
374
375    #[test]
376    fn test_data_object_type_registry_duplicate_registration() {
377        let mut registry = DataObjectTypeRegistry::new();
378        let capabilities = ConnectorCapabilities {
379            supports_joins: false,
380            supports_aggregations: false,
381            supports_subqueries: false,
382            max_concurrent_queries: Some(1),
383        };
384        
385        // First registration should succeed
386        let result1 = registry.register_type("postgres", "connector1", capabilities.clone());
387        assert!(result1.is_ok());
388        
389        // Second registration with same type should fail
390        let result2 = registry.register_type("postgres", "connector2", capabilities);
391        assert!(result2.is_err());
392        
393        match result2.unwrap_err() {
394            NirvError::Dispatcher(DispatcherError::RegistrationFailed(msg)) => {
395                assert!(msg.contains("already registered"));
396            }
397            _ => panic!("Expected RegistrationFailed error"),
398        }
399    }
400
401    #[test]
402    fn test_data_object_type_registry_list_types() {
403        let mut registry = DataObjectTypeRegistry::new();
404        let capabilities = ConnectorCapabilities {
405            supports_joins: false,
406            supports_aggregations: false,
407            supports_subqueries: false,
408            max_concurrent_queries: Some(1),
409        };
410        
411        registry.register_type("postgres", "pg_connector", capabilities.clone()).unwrap();
412        registry.register_type("mysql", "mysql_connector", capabilities.clone()).unwrap();
413        registry.register_type("file", "file_connector", capabilities).unwrap();
414        
415        let types = registry.list_types();
416        assert_eq!(types.len(), 3);
417        assert!(types.contains(&"postgres".to_string()));
418        assert!(types.contains(&"mysql".to_string()));
419        assert!(types.contains(&"file".to_string()));
420    }
421
422    #[test]
423    fn test_data_object_type_registry_unregister_type() {
424        let mut registry = DataObjectTypeRegistry::new();
425        let capabilities = ConnectorCapabilities {
426            supports_joins: false,
427            supports_aggregations: false,
428            supports_subqueries: false,
429            max_concurrent_queries: Some(1),
430        };
431        
432        registry.register_type("postgres", "pg_connector", capabilities).unwrap();
433        assert!(registry.is_type_registered("postgres"));
434        
435        let removed = registry.unregister_type("postgres");
436        assert_eq!(removed, Some("pg_connector".to_string()));
437        assert!(!registry.is_type_registered("postgres"));
438        
439        // Try to unregister non-existent type
440        let non_existent = registry.unregister_type("non_existent");
441        assert_eq!(non_existent, None);
442    }
443
444    #[test]
445    fn test_default_dispatcher_creation() {
446        let dispatcher = DefaultDispatcher::new();
447        
448        assert!(dispatcher.list_available_types().is_empty());
449        assert!(!dispatcher.is_type_registered("test_type"));
450    }
451
452    #[tokio::test]
453    async fn test_dispatcher_register_connector() {
454        let mut dispatcher = DefaultDispatcher::new();
455        let connector = Box::new(TestConnector::new(ConnectorType::Mock));
456        
457        let result = dispatcher.register_connector("mock", connector).await;
458        assert!(result.is_ok());
459        
460        assert!(dispatcher.is_type_registered("mock"));
461        assert_eq!(dispatcher.list_available_types(), vec!["mock".to_string()]);
462    }
463
464    #[tokio::test]
465    async fn test_dispatcher_register_multiple_connectors() {
466        let mut dispatcher = DefaultDispatcher::new();
467        
468        let mock_connector = Box::new(TestConnector::new(ConnectorType::Mock));
469        let postgres_connector = Box::new(TestConnector::new(ConnectorType::PostgreSQL));
470        
471        dispatcher.register_connector("mock", mock_connector).await.unwrap();
472        dispatcher.register_connector("postgres", postgres_connector).await.unwrap();
473        
474        let types = dispatcher.list_available_types();
475        assert_eq!(types.len(), 2);
476        assert!(types.contains(&"mock".to_string()));
477        assert!(types.contains(&"postgres".to_string()));
478    }
479
480    #[tokio::test]
481    async fn test_dispatcher_get_connector() {
482        let mut dispatcher = DefaultDispatcher::new();
483        let connector = Box::new(TestConnector::new(ConnectorType::Mock));
484        
485        dispatcher.register_connector("mock", connector).await.unwrap();
486        
487        let retrieved = dispatcher.get_connector("mock");
488        assert!(retrieved.is_some());
489        assert_eq!(retrieved.unwrap().get_connector_type(), ConnectorType::Mock);
490        
491        let non_existent = dispatcher.get_connector("non_existent");
492        assert!(non_existent.is_none());
493    }
494
495    #[tokio::test]
496    async fn test_dispatcher_route_query_single_source() {
497        let mut dispatcher = DefaultDispatcher::new();
498        let connector = Box::new(TestConnector::new(ConnectorType::Mock));
499        
500        dispatcher.register_connector("mock", connector).await.unwrap();
501        
502        let mut query = InternalQuery::new(QueryOperation::Select);
503        query.sources.push(DataSource {
504            object_type: "mock".to_string(),
505            identifier: "test_table".to_string(),
506            alias: None,
507        });
508        
509        let result = dispatcher.route_query(&query).await;
510        assert!(result.is_ok());
511        
512        let connector_queries = result.unwrap();
513        assert_eq!(connector_queries.len(), 1);
514        assert_eq!(connector_queries[0].connector_type, ConnectorType::Mock);
515    }
516
517    #[tokio::test]
518    async fn test_dispatcher_route_query_unregistered_type() {
519        let dispatcher = DefaultDispatcher::new();
520        
521        let mut query = InternalQuery::new(QueryOperation::Select);
522        query.sources.push(DataSource {
523            object_type: "unregistered".to_string(),
524            identifier: "test_table".to_string(),
525            alias: None,
526        });
527        
528        let result = dispatcher.route_query(&query).await;
529        assert!(result.is_err());
530        
531        match result.unwrap_err() {
532            NirvError::Dispatcher(DispatcherError::UnregisteredObjectType(msg)) => {
533                assert!(msg.contains("unregistered"));
534                assert!(msg.contains("not registered"));
535            }
536            _ => panic!("Expected UnregisteredObjectType error"),
537        }
538    }
539
540    #[tokio::test]
541    async fn test_dispatcher_route_query_no_sources() {
542        let dispatcher = DefaultDispatcher::new();
543        let query = InternalQuery::new(QueryOperation::Select);
544        
545        let result = dispatcher.route_query(&query).await;
546        assert!(result.is_err());
547        
548        match result.unwrap_err() {
549            NirvError::Dispatcher(DispatcherError::RoutingFailed(msg)) => {
550                assert!(msg.contains("No data sources found"));
551            }
552            _ => panic!("Expected RoutingFailed error"),
553        }
554    }
555
556    #[tokio::test]
557    async fn test_dispatcher_route_query_multiple_sources_unsupported() {
558        let mut dispatcher = DefaultDispatcher::new();
559        let connector = Box::new(TestConnector::new(ConnectorType::Mock));
560        
561        dispatcher.register_connector("mock", connector).await.unwrap();
562        
563        let mut query = InternalQuery::new(QueryOperation::Select);
564        query.sources.push(DataSource {
565            object_type: "mock".to_string(),
566            identifier: "table1".to_string(),
567            alias: None,
568        });
569        query.sources.push(DataSource {
570            object_type: "mock".to_string(),
571            identifier: "table2".to_string(),
572            alias: None,
573        });
574        
575        let result = dispatcher.route_query(&query).await;
576        assert!(result.is_err());
577        
578        match result.unwrap_err() {
579            NirvError::Dispatcher(DispatcherError::CrossConnectorJoinUnsupported) => {},
580            _ => panic!("Expected CrossConnectorJoinUnsupported error"),
581        }
582    }
583
584    #[tokio::test]
585    async fn test_dispatcher_execute_distributed_query() {
586        let mut dispatcher = DefaultDispatcher::new();
587        let connector = Box::new(TestConnector::new(ConnectorType::Mock));
588        
589        dispatcher.register_connector("mock", connector).await.unwrap();
590        
591        let mut query = InternalQuery::new(QueryOperation::Select);
592        query.sources.push(DataSource {
593            object_type: "mock".to_string(),
594            identifier: "test_table".to_string(),
595            alias: None,
596        });
597        
598        let connector_query = ConnectorQuery {
599            connector_type: ConnectorType::Mock,
600            query,
601            connection_params: HashMap::new(),
602        };
603        
604        let result = dispatcher.execute_distributed_query(vec![connector_query]).await;
605        assert!(result.is_ok());
606        
607        let query_result = result.unwrap();
608        assert!(query_result.execution_time > Duration::from_millis(0));
609    }
610
611    #[tokio::test]
612    async fn test_dispatcher_execute_distributed_query_empty() {
613        let dispatcher = DefaultDispatcher::new();
614        
615        let result = dispatcher.execute_distributed_query(vec![]).await;
616        assert!(result.is_ok());
617        
618        let query_result = result.unwrap();
619        assert_eq!(query_result.row_count(), 0);
620    }
621
622    #[tokio::test]
623    async fn test_dispatcher_execute_distributed_query_multiple_unsupported() {
624        let dispatcher = DefaultDispatcher::new();
625        
626        let query1 = ConnectorQuery {
627            connector_type: ConnectorType::Mock,
628            query: InternalQuery::new(QueryOperation::Select),
629            connection_params: HashMap::new(),
630        };
631        
632        let query2 = ConnectorQuery {
633            connector_type: ConnectorType::PostgreSQL,
634            query: InternalQuery::new(QueryOperation::Select),
635            connection_params: HashMap::new(),
636        };
637        
638        let result = dispatcher.execute_distributed_query(vec![query1, query2]).await;
639        assert!(result.is_err());
640        
641        match result.unwrap_err() {
642            NirvError::Dispatcher(DispatcherError::CrossConnectorJoinUnsupported) => {},
643            _ => panic!("Expected CrossConnectorJoinUnsupported error"),
644        }
645    }
646
647    #[test]
648    fn test_connector_capabilities_creation() {
649        let capabilities = ConnectorCapabilities {
650            supports_joins: true,
651            supports_aggregations: false,
652            supports_subqueries: true,
653            max_concurrent_queries: Some(10),
654        };
655        
656        assert!(capabilities.supports_joins);
657        assert!(!capabilities.supports_aggregations);
658        assert!(capabilities.supports_subqueries);
659        assert_eq!(capabilities.max_concurrent_queries, Some(10));
660    }
661}