Skip to main content

graphrag_core/core/
registry.rs

1//! Service registry for dependency injection
2//!
3//! This module provides a dependency injection system that allows
4//! components to be swapped out for testing or different implementations.
5
6use crate::core::traits::*;
7use crate::core::{GraphRAGError, Result};
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Type-erased service container
13type ServiceBox = Box<dyn Any + Send + Sync>;
14
15/// Service registry for dependency injection
16pub struct ServiceRegistry {
17    services: HashMap<TypeId, ServiceBox>,
18}
19
20impl ServiceRegistry {
21    /// Create a new empty service registry
22    pub fn new() -> Self {
23        Self {
24            services: HashMap::new(),
25        }
26    }
27
28    /// Register a service implementation
29    pub fn register<T: Any + Send + Sync>(&mut self, service: T) {
30        let type_id = TypeId::of::<T>();
31        self.services.insert(type_id, Box::new(service));
32    }
33
34    /// Get a service by type
35    pub fn get<T: Any + Send + Sync>(&self) -> Result<&T> {
36        let type_id = TypeId::of::<T>();
37
38        self.services
39            .get(&type_id)
40            .and_then(|service| service.downcast_ref::<T>())
41            .ok_or_else(|| GraphRAGError::Config {
42                message: format!("Service not registered: {}", std::any::type_name::<T>()),
43            })
44    }
45
46    /// Get a mutable service by type
47    pub fn get_mut<T: Any + Send + Sync>(&mut self) -> Result<&mut T> {
48        let type_id = TypeId::of::<T>();
49
50        self.services
51            .get_mut(&type_id)
52            .and_then(|service| service.downcast_mut::<T>())
53            .ok_or_else(|| GraphRAGError::Config {
54                message: format!("Service not registered: {}", std::any::type_name::<T>()),
55            })
56    }
57
58    /// Check if a service is registered
59    pub fn has<T: Any + Send + Sync>(&self) -> bool {
60        let type_id = TypeId::of::<T>();
61        self.services.contains_key(&type_id)
62    }
63
64    /// Remove a service
65    pub fn remove<T: Any + Send + Sync>(&mut self) -> Option<T> {
66        let type_id = TypeId::of::<T>();
67
68        self.services
69            .remove(&type_id)
70            .and_then(|service| service.downcast::<T>().ok())
71            .map(|boxed| *boxed)
72    }
73
74    /// Get the number of registered services
75    pub fn len(&self) -> usize {
76        self.services.len()
77    }
78
79    /// Check if the registry is empty
80    pub fn is_empty(&self) -> bool {
81        self.services.is_empty()
82    }
83
84    /// Clear all services
85    pub fn clear(&mut self) {
86        self.services.clear();
87    }
88}
89
90impl Default for ServiceRegistry {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96/// Builder for creating and configuring service registries
97pub struct RegistryBuilder {
98    registry: ServiceRegistry,
99}
100
101impl RegistryBuilder {
102    /// Create a new registry builder
103    pub fn new() -> Self {
104        Self {
105            registry: ServiceRegistry::new(),
106        }
107    }
108
109    /// Register a service and continue building
110    pub fn with_service<T: Any + Send + Sync>(mut self, service: T) -> Self {
111        self.registry.register(service);
112        self
113    }
114
115    /// Register a storage implementation
116    pub fn with_storage<S>(mut self, storage: S) -> Self
117    where
118        S: Storage + Any + Send + Sync,
119    {
120        self.registry.register(storage);
121        self
122    }
123
124    /// Register an embedder implementation
125    pub fn with_embedder<E>(mut self, embedder: E) -> Self
126    where
127        E: Embedder + Any + Send + Sync,
128    {
129        self.registry.register(embedder);
130        self
131    }
132
133    /// Register a vector store implementation
134    pub fn with_vector_store<V>(mut self, vector_store: V) -> Self
135    where
136        V: VectorStore + Any + Send + Sync,
137    {
138        self.registry.register(vector_store);
139        self
140    }
141
142    /// Register an entity extractor implementation
143    pub fn with_entity_extractor<E>(mut self, extractor: E) -> Self
144    where
145        E: EntityExtractor + Any + Send + Sync,
146    {
147        self.registry.register(extractor);
148        self
149    }
150
151    /// Register a retriever implementation
152    pub fn with_retriever<R>(mut self, retriever: R) -> Self
153    where
154        R: Retriever + Any + Send + Sync,
155    {
156        self.registry.register(retriever);
157        self
158    }
159
160    /// Register a language model implementation
161    pub fn with_language_model<L>(mut self, language_model: L) -> Self
162    where
163        L: LanguageModel + Any + Send + Sync,
164    {
165        self.registry.register(language_model);
166        self
167    }
168
169    /// Register a graph store implementation
170    pub fn with_graph_store<G>(mut self, graph_store: G) -> Self
171    where
172        G: GraphStore + Any + Send + Sync,
173    {
174        self.registry.register(graph_store);
175        self
176    }
177
178    /// Register a function registry implementation
179    pub fn with_function_registry<F>(mut self, function_registry: F) -> Self
180    where
181        F: FunctionRegistry + Any + Send + Sync,
182    {
183        self.registry.register(function_registry);
184        self
185    }
186
187    /// Register a metrics collector implementation
188    pub fn with_metrics_collector<M>(mut self, metrics: M) -> Self
189    where
190        M: MetricsCollector + Any + Send + Sync,
191    {
192        self.registry.register(metrics);
193        self
194    }
195
196    /// Register a serializer implementation
197    pub fn with_serializer<S>(mut self, serializer: S) -> Self
198    where
199        S: Serializer + Any + Send + Sync,
200    {
201        self.registry.register(serializer);
202        self
203    }
204
205    /// Build the final registry
206    pub fn build(self) -> ServiceRegistry {
207        self.registry
208    }
209
210    /// Create a registry with default Ollama-based services
211    #[cfg(feature = "ollama")]
212    pub fn with_ollama_defaults() -> Self {
213        #[cfg(feature = "memory-storage")]
214        use crate::storage::MemoryStorage;
215
216        let mut builder = Self::new();
217
218        #[cfg(feature = "memory-storage")]
219        {
220            builder = builder.with_storage(MemoryStorage::new());
221        }
222
223        // Add other service implementations based on available features
224        #[cfg(feature = "parallel-processing")]
225        {
226            use crate::parallel::ParallelProcessor;
227
228            // Auto-detect number of threads (0 means use default)
229            let num_threads = num_cpus::get();
230            let parallel_processor = ParallelProcessor::new(num_threads);
231            builder = builder.with_service(parallel_processor);
232        }
233
234        #[cfg(feature = "vector-hnsw")]
235        {
236            use crate::vector::VectorIndex;
237            builder = builder.with_service(VectorIndex::new());
238        }
239
240        #[cfg(feature = "caching")]
241        {
242            // Add caching services when available
243            // Note: Specific cache implementations would be added here
244        }
245        builder
246    }
247
248    /// Create a registry with memory-only services for testing
249    #[cfg(feature = "memory-storage")]
250    pub fn with_test_defaults() -> Self {
251        use crate::storage::MemoryStorage;
252        // TODO: Add mock implementations when test_utils module is created
253
254        Self::new().with_storage(MemoryStorage::new())
255    }
256}
257
258impl Default for RegistryBuilder {
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264/// Context object that provides access to services
265#[derive(Clone)]
266pub struct ServiceContext {
267    registry: Arc<ServiceRegistry>,
268}
269
270impl ServiceContext {
271    /// Create a new service context
272    pub fn new(registry: ServiceRegistry) -> Self {
273        Self {
274            registry: Arc::new(registry),
275        }
276    }
277
278    /// Get a service by type
279    pub fn get<T: Any + Send + Sync>(&self) -> Result<&T> {
280        // Safety: This is safe because we're getting an immutable reference
281        // from an Arc, which ensures the registry stays alive
282        unsafe {
283            let ptr = self.registry.as_ref() as *const ServiceRegistry;
284            (*ptr).get::<T>()
285        }
286    }
287}
288
289/// Configuration for service creation
290#[derive(Debug, Clone)]
291pub struct ServiceConfig {
292    /// Base URL for Ollama API server
293    pub ollama_base_url: Option<String>,
294    /// Model name for text embeddings
295    pub embedding_model: Option<String>,
296    /// Model name for text generation
297    pub language_model: Option<String>,
298    /// Dimensionality of embedding vectors
299    pub vector_dimension: Option<usize>,
300    /// Minimum confidence threshold for entity extraction
301    pub entity_confidence_threshold: Option<f32>,
302    /// Enable parallel processing for batch operations
303    pub enable_parallel_processing: bool,
304    /// Enable function calling capabilities
305    pub enable_function_calling: bool,
306    /// Enable monitoring and metrics collection
307    pub enable_monitoring: bool,
308}
309
310impl Default for ServiceConfig {
311    fn default() -> Self {
312        Self {
313            ollama_base_url: Some("http://localhost:11434".to_string()),
314            embedding_model: Some("nomic-embed-text:latest".to_string()),
315            language_model: Some("llama3.2:latest".to_string()),
316            vector_dimension: Some(384),
317            entity_confidence_threshold: Some(0.7),
318            enable_parallel_processing: true,
319            enable_function_calling: false,
320            enable_monitoring: false,
321        }
322    }
323}
324
325impl ServiceConfig {
326    /// Create a registry builder from this configuration
327    pub fn build_registry(&self) -> RegistryBuilder {
328        let mut builder = RegistryBuilder::new();
329
330        #[cfg(feature = "memory-storage")]
331        {
332            use crate::storage::MemoryStorage;
333            builder = builder.with_storage(MemoryStorage::new());
334        }
335
336        // TODO: Add other service implementations when they're available
337
338        builder
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[derive(Debug)]
347    struct TestService {
348        value: String,
349    }
350
351    impl TestService {
352        fn new(value: String) -> Self {
353            Self { value }
354        }
355    }
356
357    #[test]
358    fn test_registry_basic_operations() {
359        let mut registry = ServiceRegistry::new();
360
361        // Test registration
362        registry.register(TestService::new("test".to_string()));
363        assert!(registry.has::<TestService>());
364        assert_eq!(registry.len(), 1);
365
366        // Test retrieval
367        let service = registry.get::<TestService>().unwrap();
368        assert_eq!(service.value, "test");
369
370        // Test removal
371        let removed = registry.remove::<TestService>().unwrap();
372        assert_eq!(removed.value, "test");
373        assert!(!registry.has::<TestService>());
374        assert!(registry.is_empty());
375    }
376
377    #[test]
378    fn test_registry_builder() {
379        let registry = RegistryBuilder::new()
380            .with_service(TestService::new("builder".to_string()))
381            .build();
382
383        assert!(registry.has::<TestService>());
384        let service = registry.get::<TestService>().unwrap();
385        assert_eq!(service.value, "builder");
386    }
387
388    #[test]
389    fn test_service_context() {
390        let mut registry = ServiceRegistry::new();
391        registry.register(TestService::new("context".to_string()));
392
393        let context = ServiceContext::new(registry);
394        let service = context.get::<TestService>().unwrap();
395        assert_eq!(service.value, "context");
396
397        // Test cloning
398        let cloned_context = context.clone();
399        let service2 = cloned_context.get::<TestService>().unwrap();
400        assert_eq!(service2.value, "context");
401    }
402
403    #[test]
404    fn test_service_config_default() {
405        let config = ServiceConfig::default();
406        assert!(config.ollama_base_url.is_some());
407        assert!(config.embedding_model.is_some());
408        assert!(config.language_model.is_some());
409        assert!(config.vector_dimension.is_some());
410        assert!(config.entity_confidence_threshold.is_some());
411        assert!(config.enable_parallel_processing);
412    }
413}