Skip to main content

graphrag_core/core/
registry.rs

1//! Service registry for dependency injection
2//!
3//! This module provides a dependency injection system that allows
4//! components to be swapped out for testing or different implementations.
5
6use crate::core::traits::*;
7use crate::core::{GraphRAGError, Result};
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Type-erased service container
13type ServiceBox = Box<dyn Any + Send + Sync>;
14
15/// Service registry for dependency injection
16pub struct ServiceRegistry {
17    services: HashMap<TypeId, ServiceBox>,
18}
19
20impl ServiceRegistry {
21    /// Create a new empty service registry
22    pub fn new() -> Self {
23        Self {
24            services: HashMap::new(),
25        }
26    }
27
28    /// Register a service implementation
29    pub fn register<T: Any + Send + Sync>(&mut self, service: T) {
30        let type_id = TypeId::of::<T>();
31        self.services.insert(type_id, Box::new(service));
32    }
33
34    /// Get a service by type
35    pub fn get<T: Any + Send + Sync>(&self) -> Result<&T> {
36        let type_id = TypeId::of::<T>();
37
38        self.services
39            .get(&type_id)
40            .and_then(|service| service.downcast_ref::<T>())
41            .ok_or_else(|| GraphRAGError::Config {
42                message: format!("Service not registered: {}", std::any::type_name::<T>()),
43            })
44    }
45
46    /// Get a mutable service by type
47    pub fn get_mut<T: Any + Send + Sync>(&mut self) -> Result<&mut T> {
48        let type_id = TypeId::of::<T>();
49
50        self.services
51            .get_mut(&type_id)
52            .and_then(|service| service.downcast_mut::<T>())
53            .ok_or_else(|| GraphRAGError::Config {
54                message: format!("Service not registered: {}", std::any::type_name::<T>()),
55            })
56    }
57
58    /// Check if a service is registered
59    pub fn has<T: Any + Send + Sync>(&self) -> bool {
60        let type_id = TypeId::of::<T>();
61        self.services.contains_key(&type_id)
62    }
63
64    /// Remove a service
65    pub fn remove<T: Any + Send + Sync>(&mut self) -> Option<T> {
66        let type_id = TypeId::of::<T>();
67
68        self.services
69            .remove(&type_id)
70            .and_then(|service| service.downcast::<T>().ok())
71            .map(|boxed| *boxed)
72    }
73
74    /// Get the number of registered services
75    pub fn len(&self) -> usize {
76        self.services.len()
77    }
78
79    /// Check if the registry is empty
80    pub fn is_empty(&self) -> bool {
81        self.services.is_empty()
82    }
83
84    /// Clear all services
85    pub fn clear(&mut self) {
86        self.services.clear();
87    }
88}
89
90impl Default for ServiceRegistry {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96/// Builder for creating and configuring service registries
97pub struct RegistryBuilder {
98    registry: ServiceRegistry,
99}
100
101impl RegistryBuilder {
102    /// Create a new registry builder
103    pub fn new() -> Self {
104        Self {
105            registry: ServiceRegistry::new(),
106        }
107    }
108
109    /// Register a service and continue building
110    pub fn with_service<T: Any + Send + Sync>(mut self, service: T) -> Self {
111        self.registry.register(service);
112        self
113    }
114
115    /// Register a storage implementation
116    pub fn with_storage<S>(mut self, storage: S) -> Self
117    where
118        S: Storage + Any + Send + Sync,
119    {
120        self.registry.register(storage);
121        self
122    }
123
124    /// Register an embedder implementation
125    pub fn with_embedder<E>(mut self, embedder: E) -> Self
126    where
127        E: Embedder + Any + Send + Sync,
128    {
129        self.registry.register(embedder);
130        self
131    }
132
133    /// Register a vector store implementation
134    pub fn with_vector_store<V>(mut self, vector_store: V) -> Self
135    where
136        V: VectorStore + Any + Send + Sync,
137    {
138        self.registry.register(vector_store);
139        self
140    }
141
142    /// Register an entity extractor implementation
143    pub fn with_entity_extractor<E>(mut self, extractor: E) -> Self
144    where
145        E: EntityExtractor + Any + Send + Sync,
146    {
147        self.registry.register(extractor);
148        self
149    }
150
151    /// Register a retriever implementation
152    pub fn with_retriever<R>(mut self, retriever: R) -> Self
153    where
154        R: Retriever + Any + Send + Sync,
155    {
156        self.registry.register(retriever);
157        self
158    }
159
160    /// Register a language model implementation
161    pub fn with_language_model<L>(mut self, language_model: L) -> Self
162    where
163        L: LanguageModel + Any + Send + Sync,
164    {
165        self.registry.register(language_model);
166        self
167    }
168
169    /// Register a graph store implementation
170    pub fn with_graph_store<G>(mut self, graph_store: G) -> Self
171    where
172        G: GraphStore + Any + Send + Sync,
173    {
174        self.registry.register(graph_store);
175        self
176    }
177
178    /// Register a function registry implementation
179    pub fn with_function_registry<F>(mut self, function_registry: F) -> Self
180    where
181        F: FunctionRegistry + Any + Send + Sync,
182    {
183        self.registry.register(function_registry);
184        self
185    }
186
187    /// Register a metrics collector implementation
188    pub fn with_metrics_collector<M>(mut self, metrics: M) -> Self
189    where
190        M: MetricsCollector + Any + Send + Sync,
191    {
192        self.registry.register(metrics);
193        self
194    }
195
196    /// Register a serializer implementation
197    pub fn with_serializer<S>(mut self, serializer: S) -> Self
198    where
199        S: Serializer + Any + Send + Sync,
200    {
201        self.registry.register(serializer);
202        self
203    }
204
205    /// Build the final registry
206    pub fn build(self) -> ServiceRegistry {
207        self.registry
208    }
209
210    /// Create a registry with default Ollama-based services
211    #[cfg(feature = "ollama")]
212    pub fn with_ollama_defaults() -> Self {
213        #[cfg(feature = "memory-storage")]
214        use crate::storage::MemoryStorage;
215
216        let mut builder = Self::new();
217
218        #[cfg(feature = "memory-storage")]
219        {
220            builder = builder.with_storage(MemoryStorage::new());
221        }
222
223        // Add other service implementations based on available features
224        #[cfg(feature = "parallel-processing")]
225        {
226            use crate::parallel::ParallelProcessor;
227
228            // Auto-detect number of threads (0 means use default)
229            let num_threads = num_cpus::get();
230            let parallel_processor = ParallelProcessor::new(num_threads);
231            builder = builder.with_service(parallel_processor);
232        }
233
234        #[cfg(feature = "vector-hnsw")]
235        {
236            use crate::vector::VectorIndex;
237            builder = builder.with_service(VectorIndex::new());
238        }
239
240        #[cfg(feature = "caching")]
241        {
242            // Add caching services when available
243            // Note: Specific cache implementations would be added here
244        }
245        builder
246    }
247
248    /// Create a registry with memory-only services for testing
249    ///
250    /// This creates a registry with mock implementations suitable for unit testing:
251    /// - MemoryStorage for document storage
252    /// - MockEmbedder for embeddings (128-dimensional)
253    /// - MockLanguageModel for text generation
254    /// - MockVectorStore for vector similarity search
255    /// - MockRetriever for content retrieval
256    #[cfg(feature = "memory-storage")]
257    pub fn with_test_defaults() -> Self {
258        use crate::core::test_utils::{
259            MockEmbedder, MockLanguageModel, MockRetriever, MockVectorStore,
260        };
261        use crate::storage::MemoryStorage;
262
263        Self::new()
264            .with_storage(MemoryStorage::new())
265            .with_service(MockEmbedder::new(128))
266            .with_service(MockLanguageModel::new())
267            .with_service(MockVectorStore::new(128))
268            .with_service(MockRetriever::new())
269    }
270}
271
272impl Default for RegistryBuilder {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278/// Context object that provides access to services
279#[derive(Clone)]
280pub struct ServiceContext {
281    registry: Arc<ServiceRegistry>,
282}
283
284impl ServiceContext {
285    /// Create a new service context
286    pub fn new(registry: ServiceRegistry) -> Self {
287        Self {
288            registry: Arc::new(registry),
289        }
290    }
291
292    /// Get a service by type
293    pub fn get<T: Any + Send + Sync>(&self) -> Result<&T> {
294        // Safety: This is safe because we're getting an immutable reference
295        // from an Arc, which ensures the registry stays alive
296        unsafe {
297            let ptr = self.registry.as_ref() as *const ServiceRegistry;
298            (*ptr).get::<T>()
299        }
300    }
301}
302
303/// Configuration for service creation
304#[derive(Debug, Clone)]
305pub struct ServiceConfig {
306    /// Base URL for Ollama API server
307    pub ollama_base_url: Option<String>,
308    /// Model name for text embeddings
309    pub embedding_model: Option<String>,
310    /// Model name for text generation
311    pub language_model: Option<String>,
312    /// Dimensionality of embedding vectors
313    pub vector_dimension: Option<usize>,
314    /// Minimum confidence threshold for entity extraction
315    pub entity_confidence_threshold: Option<f32>,
316    /// Enable parallel processing for batch operations
317    pub enable_parallel_processing: bool,
318    /// Enable function calling capabilities
319    pub enable_function_calling: bool,
320    /// Enable monitoring and metrics collection
321    pub enable_monitoring: bool,
322}
323
324impl Default for ServiceConfig {
325    fn default() -> Self {
326        Self {
327            ollama_base_url: Some("http://localhost:11434".to_string()),
328            embedding_model: Some("nomic-embed-text:latest".to_string()),
329            language_model: Some("llama3.2:latest".to_string()),
330            vector_dimension: Some(384),
331            entity_confidence_threshold: Some(0.7),
332            enable_parallel_processing: true,
333            enable_function_calling: false,
334            enable_monitoring: false,
335        }
336    }
337}
338
339impl ServiceConfig {
340    /// Create a registry builder from this configuration
341    ///
342    /// This method creates service instances based on the configuration and available features.
343    /// Services are registered in the following order:
344    ///
345    /// 1. Storage (MemoryStorage with memory-storage feature)
346    /// 2. Vector Store (when vector storage implementations are available)
347    /// 3. Embedder (when embedding providers are available)
348    /// 4. Entity Extractor (when NER models are available)
349    /// 5. Retriever (when retrieval systems are implemented)
350    /// 6. Language Model (when LLM clients are available)
351    /// 7. Metrics Collector (when monitoring is enabled)
352    ///
353    /// # Example
354    ///
355    /// ```no_run
356    /// use graphrag_core::core::registry::ServiceConfig;
357    ///
358    /// let config = ServiceConfig::default();
359    /// let registry = config.build_registry().build();
360    /// ```
361    pub fn build_registry(&self) -> RegistryBuilder {
362        let mut builder = RegistryBuilder::new();
363
364        // 1. Storage Layer
365        #[cfg(feature = "memory-storage")]
366        {
367            use crate::storage::MemoryStorage;
368            builder = builder.with_storage(MemoryStorage::new());
369        }
370
371        // 2. Vector Store
372        //
373        // Vector storage has two parallel trait hierarchies:
374        //
375        // 1. vector::store::VectorStore (local module)
376        //    - Domain-specific trait for GraphRAG vector operations
377        //    - Implemented by: MemoryVectorStore, LanceDB, Qdrant
378        //    - Used directly by retrieval and embedding systems
379        //    - Methods: store_embedding, search_similar, batch operations
380        //
381        // 2. core::traits::AsyncVectorStore (generic trait)
382        //    - Generic async interface for service registry
383        //    - Designed for dependency injection and testing
384        //    - Methods: store, search, delete, get, count, clear
385        //    - Implemented by: MockVectorStore (test_utils)
386        //
387        // Current status: Both hierarchies work independently
388        // - MemoryVectorStore works with retrieval systems ✓
389        // - MockVectorStore works with service registry ✓
390        //
391        // Future unification (optional):
392        // 2b. Vector Store (Optional)
393        // If needed, create an adapter to bridge vector::VectorStore to AsyncVectorStore.
394        // This would enable using production vector stores (LanceDB, Qdrant) through
395        // the generic registry interface.
396        #[cfg(feature = "vector-memory")]
397        {
398            if let Some(_dimension) = self.vector_dimension {
399                use crate::vector::memory_store::MemoryVectorStore;
400                let vector_store = MemoryVectorStore::new();
401                builder = builder.with_service(vector_store);
402
403                #[cfg(feature = "tracing")]
404                tracing::info!("Registered MemoryVectorStore (dimension: {})", _dimension);
405            }
406        }
407
408        // 3. Embedding Provider
409        // Create embedder based on configuration and available features
410        #[cfg(feature = "ollama")]
411        {
412            if let Some(model) = &self.embedding_model {
413                if let Some(dimension) = self.vector_dimension {
414                    use crate::core::ollama_adapters::OllamaEmbedderAdapter;
415
416                    let embedder = OllamaEmbedderAdapter::new(model.clone(), dimension);
417                    builder = builder.with_service(embedder);
418
419                    #[cfg(feature = "tracing")]
420                    tracing::info!(
421                        "Registered Ollama embedder with model: {}, dimension: {}",
422                        model,
423                        dimension
424                    );
425                }
426            }
427        }
428
429        // 4. Entity Extractor
430        // Register entity extraction service using GraphIndexer
431        #[cfg(all(feature = "async", feature = "lightrag"))]
432        {
433            if let Some(threshold) = self.entity_confidence_threshold {
434                use crate::core::entity_adapters::GraphIndexerAdapter;
435
436                // Create GraphIndexer adapter with default entity types
437                let entity_types = vec![
438                    "person".to_string(),
439                    "organization".to_string(),
440                    "location".to_string(),
441                ];
442                let extractor = GraphIndexerAdapter::new(entity_types, 3)
443                    .map(|adapter| adapter.with_confidence_threshold(threshold));
444
445                if let Ok(extractor) = extractor {
446                    builder = builder.with_service(extractor);
447
448                    #[cfg(feature = "tracing")]
449                    tracing::info!(
450                        "Registered GraphIndexer entity extractor with threshold: {}",
451                        threshold
452                    );
453                }
454            }
455        }
456
457        // 5. Retriever
458        // Register retrieval system
459        #[cfg(all(feature = "async", feature = "basic-retrieval"))]
460        {
461            use crate::config::Config;
462            use crate::core::retrieval_adapters::RetrievalSystemAdapter;
463            use crate::retrieval::RetrievalSystem;
464
465            // Create a default config for retrieval system
466            let config = Config::default();
467            if let Ok(system) = RetrievalSystem::new(&config) {
468                let retriever = RetrievalSystemAdapter::new(system);
469                builder = builder.with_service(retriever);
470
471                #[cfg(feature = "tracing")]
472                tracing::info!("Registered RetrievalSystem");
473            }
474        }
475
476        // 6. Language Model
477        // Register LLM client for text generation
478        #[cfg(feature = "ollama")]
479        {
480            if let (Some(base_url), Some(model)) = (&self.ollama_base_url, &self.language_model) {
481                use crate::core::ollama_adapters::OllamaLanguageModelAdapter;
482                use crate::ollama::OllamaConfig;
483
484                // Build OllamaConfig from ServiceConfig
485                let mut ollama_config = OllamaConfig::default();
486                // Parse host and port from base_url (format: "http://localhost:11434")
487                if let Some(url_parts) = base_url.split("://").nth(1) {
488                    let parts: Vec<&str> = url_parts.split(':').collect();
489                    if parts.len() >= 2 {
490                        ollama_config.host = format!("http://{}", parts[0]);
491                        if let Ok(port) = parts[1].parse::<u16>() {
492                            ollama_config.port = port;
493                        }
494                    }
495                }
496                ollama_config.chat_model = model.clone();
497                ollama_config.enabled = true;
498
499                let language_model = OllamaLanguageModelAdapter::new(ollama_config);
500                builder = builder.with_service(language_model);
501
502                #[cfg(feature = "tracing")]
503                tracing::info!(
504                    "Registered Ollama language model: {} at {}",
505                    model,
506                    base_url
507                );
508            }
509        }
510
511        // 7. Metrics Collector
512        // Register metrics collector when monitoring is enabled
513        #[cfg(all(feature = "monitoring", feature = "dashmap"))]
514        {
515            if self.enable_monitoring {
516                use crate::monitoring::MetricsCollector;
517
518                let metrics = MetricsCollector::new();
519                builder = builder.with_service(metrics);
520
521                #[cfg(feature = "tracing")]
522                tracing::info!("Registered MetricsCollector");
523            }
524        }
525
526        // 8. Function Registry
527        // Register function calling capabilities when enabled
528        //
529        // Note: The function_calling module provides a comprehensive FunctionCaller
530        // implementation with the following characteristics:
531        //
532        // - Requires KnowledgeGraph context for function execution
533        // - Uses json::JsonValue (json crate) instead of serde_json::Value
534        // - Provides synchronous call() methods, not async
535        // - Includes built-in function history and statistics
536        // - Supports complex function orchestration with context passing
537        //
538        // Creating an adapter for AsyncFunctionRegistry would require:
539        // 1. JSON format conversion (json::JsonValue <-> serde_json::Value)
540        // 2. Async wrapper around synchronous call methods
541        // 3. KnowledgeGraph injection mechanism (currently passed per-call)
542        // 4. Context state management for stateless async trait
543        //
544        // For applications needing function calling:
545        // - Use FunctionCaller directly from function_calling module
546        // - It provides richer functionality than the generic AsyncFunctionRegistry trait
547        // - Built-in support for GraphRAG-specific operations
548        //
549        // The AsyncFunctionRegistry trait is better suited for simpler,
550        // stateless function registries without graph context requirements.
551        #[cfg(feature = "function-calling")]
552        {
553            if self.enable_function_calling {
554                #[cfg(feature = "tracing")]
555                tracing::info!(
556                    "Function calling enabled - use function_calling::FunctionCaller directly"
557                );
558            }
559        }
560
561        builder
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    #[derive(Debug)]
570    struct TestService {
571        value: String,
572    }
573
574    impl TestService {
575        fn new(value: String) -> Self {
576            Self { value }
577        }
578    }
579
580    #[test]
581    fn test_registry_basic_operations() {
582        let mut registry = ServiceRegistry::new();
583
584        // Test registration
585        registry.register(TestService::new("test".to_string()));
586        assert!(registry.has::<TestService>());
587        assert_eq!(registry.len(), 1);
588
589        // Test retrieval
590        let service = registry.get::<TestService>().unwrap();
591        assert_eq!(service.value, "test");
592
593        // Test removal
594        let removed = registry.remove::<TestService>().unwrap();
595        assert_eq!(removed.value, "test");
596        assert!(!registry.has::<TestService>());
597        assert!(registry.is_empty());
598    }
599
600    #[test]
601    fn test_registry_builder() {
602        let registry = RegistryBuilder::new()
603            .with_service(TestService::new("builder".to_string()))
604            .build();
605
606        assert!(registry.has::<TestService>());
607        let service = registry.get::<TestService>().unwrap();
608        assert_eq!(service.value, "builder");
609    }
610
611    #[test]
612    fn test_service_context() {
613        let mut registry = ServiceRegistry::new();
614        registry.register(TestService::new("context".to_string()));
615
616        let context = ServiceContext::new(registry);
617        let service = context.get::<TestService>().unwrap();
618        assert_eq!(service.value, "context");
619
620        // Test cloning
621        let cloned_context = context.clone();
622        let service2 = cloned_context.get::<TestService>().unwrap();
623        assert_eq!(service2.value, "context");
624    }
625
626    #[test]
627    fn test_service_config_default() {
628        let config = ServiceConfig::default();
629        assert!(config.ollama_base_url.is_some());
630        assert!(config.embedding_model.is_some());
631        assert!(config.language_model.is_some());
632        assert!(config.vector_dimension.is_some());
633        assert!(config.entity_confidence_threshold.is_some());
634        assert!(config.enable_parallel_processing);
635    }
636
637    #[test]
638    #[cfg(feature = "ollama")]
639    fn test_service_config_build_with_ollama() {
640        let config = ServiceConfig {
641            ollama_base_url: Some("http://localhost:11434".to_string()),
642            embedding_model: Some("nomic-embed-text".to_string()),
643            language_model: Some("llama3.2".to_string()),
644            vector_dimension: Some(768),
645            entity_confidence_threshold: Some(0.7),
646            enable_parallel_processing: true,
647            enable_function_calling: false,
648            enable_monitoring: false,
649        };
650
651        let registry = config.build_registry().build();
652
653        // Verify services are registered
654        #[cfg(feature = "memory-storage")]
655        {
656            use crate::storage::MemoryStorage;
657            assert!(registry.has::<MemoryStorage>());
658        }
659
660        // Note: We can't easily verify OllamaEmbedderAdapter and OllamaLanguageModelAdapter
661        // are registered without making them pub, but the build succeeds which means
662        // the registration code runs without errors
663        assert!(!registry.is_empty());
664    }
665
666    #[test]
667    #[cfg(feature = "vector-memory")]
668    fn test_registry_with_vector_memory() {
669        use crate::vector::memory_store::MemoryVectorStore;
670
671        let config = ServiceConfig {
672            ollama_base_url: None,
673            embedding_model: None,
674            language_model: None,
675            vector_dimension: Some(384), // Set vector dimension to enable MemoryVectorStore
676            entity_confidence_threshold: None,
677            enable_parallel_processing: false,
678            enable_function_calling: false,
679            enable_monitoring: false,
680        };
681
682        let registry = config.build_registry().build();
683
684        // When vector-memory feature is enabled and vector_dimension is set,
685        // MemoryVectorStore should be registered
686        assert!(
687            registry.has::<MemoryVectorStore>(),
688            "MemoryVectorStore should be registered when vector-memory feature is enabled"
689        );
690
691        // Verify we can retrieve it
692        let vector_store = registry.get::<MemoryVectorStore>();
693        assert!(
694            vector_store.is_ok(),
695            "Should be able to retrieve registered MemoryVectorStore"
696        );
697    }
698
699    #[test]
700    #[cfg(not(feature = "vector-memory"))]
701    fn test_registry_without_vector_memory() {
702        let config = ServiceConfig {
703            ollama_base_url: None,
704            embedding_model: None,
705            language_model: None,
706            vector_dimension: Some(384), // Even with dimension set...
707            entity_confidence_threshold: None,
708            enable_parallel_processing: false,
709            enable_function_calling: false,
710            enable_monitoring: false,
711        };
712
713        let registry = config.build_registry().build();
714
715        // When vector-memory feature is disabled, MemoryVectorStore should NOT be registered
716        // (This test verifies the feature flag works correctly)
717        // Note: We can't import MemoryVectorStore to test for absence since it might not be available,
718        // but the build succeeds which means the #[cfg] gate works correctly
719    }
720}