1use crate::core::traits::*;
7use crate::core::{GraphRAGError, Result};
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12type ServiceBox = Box<dyn Any + Send + Sync>;
14
15pub struct ServiceRegistry {
17 services: HashMap<TypeId, ServiceBox>,
18}
19
20impl ServiceRegistry {
21 pub fn new() -> Self {
23 Self {
24 services: HashMap::new(),
25 }
26 }
27
28 pub fn register<T: Any + Send + Sync>(&mut self, service: T) {
30 let type_id = TypeId::of::<T>();
31 self.services.insert(type_id, Box::new(service));
32 }
33
34 pub fn get<T: Any + Send + Sync>(&self) -> Result<&T> {
36 let type_id = TypeId::of::<T>();
37
38 self.services
39 .get(&type_id)
40 .and_then(|service| service.downcast_ref::<T>())
41 .ok_or_else(|| GraphRAGError::Config {
42 message: format!("Service not registered: {}", std::any::type_name::<T>()),
43 })
44 }
45
46 pub fn get_mut<T: Any + Send + Sync>(&mut self) -> Result<&mut T> {
48 let type_id = TypeId::of::<T>();
49
50 self.services
51 .get_mut(&type_id)
52 .and_then(|service| service.downcast_mut::<T>())
53 .ok_or_else(|| GraphRAGError::Config {
54 message: format!("Service not registered: {}", std::any::type_name::<T>()),
55 })
56 }
57
58 pub fn has<T: Any + Send + Sync>(&self) -> bool {
60 let type_id = TypeId::of::<T>();
61 self.services.contains_key(&type_id)
62 }
63
64 pub fn remove<T: Any + Send + Sync>(&mut self) -> Option<T> {
66 let type_id = TypeId::of::<T>();
67
68 self.services
69 .remove(&type_id)
70 .and_then(|service| service.downcast::<T>().ok())
71 .map(|boxed| *boxed)
72 }
73
74 pub fn len(&self) -> usize {
76 self.services.len()
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.services.is_empty()
82 }
83
84 pub fn clear(&mut self) {
86 self.services.clear();
87 }
88}
89
90impl Default for ServiceRegistry {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96pub struct RegistryBuilder {
98 registry: ServiceRegistry,
99}
100
101impl RegistryBuilder {
102 pub fn new() -> Self {
104 Self {
105 registry: ServiceRegistry::new(),
106 }
107 }
108
109 pub fn with_service<T: Any + Send + Sync>(mut self, service: T) -> Self {
111 self.registry.register(service);
112 self
113 }
114
115 pub fn with_storage<S>(mut self, storage: S) -> Self
117 where
118 S: Storage + Any + Send + Sync,
119 {
120 self.registry.register(storage);
121 self
122 }
123
124 pub fn with_embedder<E>(mut self, embedder: E) -> Self
126 where
127 E: Embedder + Any + Send + Sync,
128 {
129 self.registry.register(embedder);
130 self
131 }
132
133 pub fn with_vector_store<V>(mut self, vector_store: V) -> Self
135 where
136 V: VectorStore + Any + Send + Sync,
137 {
138 self.registry.register(vector_store);
139 self
140 }
141
142 pub fn with_entity_extractor<E>(mut self, extractor: E) -> Self
144 where
145 E: EntityExtractor + Any + Send + Sync,
146 {
147 self.registry.register(extractor);
148 self
149 }
150
151 pub fn with_retriever<R>(mut self, retriever: R) -> Self
153 where
154 R: Retriever + Any + Send + Sync,
155 {
156 self.registry.register(retriever);
157 self
158 }
159
160 pub fn with_language_model<L>(mut self, language_model: L) -> Self
162 where
163 L: LanguageModel + Any + Send + Sync,
164 {
165 self.registry.register(language_model);
166 self
167 }
168
169 pub fn with_graph_store<G>(mut self, graph_store: G) -> Self
171 where
172 G: GraphStore + Any + Send + Sync,
173 {
174 self.registry.register(graph_store);
175 self
176 }
177
178 pub fn with_function_registry<F>(mut self, function_registry: F) -> Self
180 where
181 F: FunctionRegistry + Any + Send + Sync,
182 {
183 self.registry.register(function_registry);
184 self
185 }
186
187 pub fn with_metrics_collector<M>(mut self, metrics: M) -> Self
189 where
190 M: MetricsCollector + Any + Send + Sync,
191 {
192 self.registry.register(metrics);
193 self
194 }
195
196 pub fn with_serializer<S>(mut self, serializer: S) -> Self
198 where
199 S: Serializer + Any + Send + Sync,
200 {
201 self.registry.register(serializer);
202 self
203 }
204
205 pub fn build(self) -> ServiceRegistry {
207 self.registry
208 }
209
210 #[cfg(feature = "ollama")]
212 pub fn with_ollama_defaults() -> Self {
213 #[cfg(feature = "memory-storage")]
214 use crate::storage::MemoryStorage;
215
216 let mut builder = Self::new();
217
218 #[cfg(feature = "memory-storage")]
219 {
220 builder = builder.with_storage(MemoryStorage::new());
221 }
222
223 #[cfg(feature = "parallel-processing")]
225 {
226 use crate::parallel::ParallelProcessor;
227
228 let num_threads = num_cpus::get();
230 let parallel_processor = ParallelProcessor::new(num_threads);
231 builder = builder.with_service(parallel_processor);
232 }
233
234 #[cfg(feature = "vector-hnsw")]
235 {
236 use crate::vector::VectorIndex;
237 builder = builder.with_service(VectorIndex::new());
238 }
239
240 #[cfg(feature = "caching")]
241 {
242 }
245 builder
246 }
247
248 #[cfg(feature = "memory-storage")]
257 pub fn with_test_defaults() -> Self {
258 use crate::core::test_utils::{
259 MockEmbedder, MockLanguageModel, MockRetriever, MockVectorStore,
260 };
261 use crate::storage::MemoryStorage;
262
263 Self::new()
264 .with_storage(MemoryStorage::new())
265 .with_service(MockEmbedder::new(128))
266 .with_service(MockLanguageModel::new())
267 .with_service(MockVectorStore::new(128))
268 .with_service(MockRetriever::new())
269 }
270}
271
272impl Default for RegistryBuilder {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278#[derive(Clone)]
280pub struct ServiceContext {
281 registry: Arc<ServiceRegistry>,
282}
283
284impl ServiceContext {
285 pub fn new(registry: ServiceRegistry) -> Self {
287 Self {
288 registry: Arc::new(registry),
289 }
290 }
291
292 pub fn get<T: Any + Send + Sync>(&self) -> Result<&T> {
294 unsafe {
297 let ptr = self.registry.as_ref() as *const ServiceRegistry;
298 (*ptr).get::<T>()
299 }
300 }
301}
302
303#[derive(Debug, Clone)]
305pub struct ServiceConfig {
306 pub ollama_base_url: Option<String>,
308 pub embedding_model: Option<String>,
310 pub language_model: Option<String>,
312 pub vector_dimension: Option<usize>,
314 pub entity_confidence_threshold: Option<f32>,
316 pub enable_parallel_processing: bool,
318 pub enable_function_calling: bool,
320 pub enable_monitoring: bool,
322}
323
324impl Default for ServiceConfig {
325 fn default() -> Self {
326 Self {
327 ollama_base_url: Some("http://localhost:11434".to_string()),
328 embedding_model: Some("nomic-embed-text:latest".to_string()),
329 language_model: Some("llama3.2:latest".to_string()),
330 vector_dimension: Some(384),
331 entity_confidence_threshold: Some(0.7),
332 enable_parallel_processing: true,
333 enable_function_calling: false,
334 enable_monitoring: false,
335 }
336 }
337}
338
339impl ServiceConfig {
340 pub fn build_registry(&self) -> RegistryBuilder {
362 let mut builder = RegistryBuilder::new();
363
364 #[cfg(feature = "memory-storage")]
366 {
367 use crate::storage::MemoryStorage;
368 builder = builder.with_storage(MemoryStorage::new());
369 }
370
371 #[cfg(feature = "vector-memory")]
397 {
398 if let Some(_dimension) = self.vector_dimension {
399 use crate::vector::memory_store::MemoryVectorStore;
400 let vector_store = MemoryVectorStore::new();
401 builder = builder.with_service(vector_store);
402
403 #[cfg(feature = "tracing")]
404 tracing::info!("Registered MemoryVectorStore (dimension: {})", _dimension);
405 }
406 }
407
408 #[cfg(feature = "ollama")]
411 {
412 if let Some(model) = &self.embedding_model {
413 if let Some(dimension) = self.vector_dimension {
414 use crate::core::ollama_adapters::OllamaEmbedderAdapter;
415
416 let embedder = OllamaEmbedderAdapter::new(model.clone(), dimension);
417 builder = builder.with_service(embedder);
418
419 #[cfg(feature = "tracing")]
420 tracing::info!(
421 "Registered Ollama embedder with model: {}, dimension: {}",
422 model,
423 dimension
424 );
425 }
426 }
427 }
428
429 #[cfg(all(feature = "async", feature = "lightrag"))]
432 {
433 if let Some(threshold) = self.entity_confidence_threshold {
434 use crate::core::entity_adapters::GraphIndexerAdapter;
435
436 let entity_types = vec![
438 "person".to_string(),
439 "organization".to_string(),
440 "location".to_string(),
441 ];
442 let extractor = GraphIndexerAdapter::new(entity_types, 3)
443 .map(|adapter| adapter.with_confidence_threshold(threshold));
444
445 if let Ok(extractor) = extractor {
446 builder = builder.with_service(extractor);
447
448 #[cfg(feature = "tracing")]
449 tracing::info!(
450 "Registered GraphIndexer entity extractor with threshold: {}",
451 threshold
452 );
453 }
454 }
455 }
456
457 #[cfg(all(feature = "async", feature = "basic-retrieval"))]
460 {
461 use crate::config::Config;
462 use crate::core::retrieval_adapters::RetrievalSystemAdapter;
463 use crate::retrieval::RetrievalSystem;
464
465 let config = Config::default();
467 if let Ok(system) = RetrievalSystem::new(&config) {
468 let retriever = RetrievalSystemAdapter::new(system);
469 builder = builder.with_service(retriever);
470
471 #[cfg(feature = "tracing")]
472 tracing::info!("Registered RetrievalSystem");
473 }
474 }
475
476 #[cfg(feature = "ollama")]
479 {
480 if let (Some(base_url), Some(model)) = (&self.ollama_base_url, &self.language_model) {
481 use crate::core::ollama_adapters::OllamaLanguageModelAdapter;
482 use crate::ollama::OllamaConfig;
483
484 let mut ollama_config = OllamaConfig::default();
486 if let Some(url_parts) = base_url.split("://").nth(1) {
488 let parts: Vec<&str> = url_parts.split(':').collect();
489 if parts.len() >= 2 {
490 ollama_config.host = format!("http://{}", parts[0]);
491 if let Ok(port) = parts[1].parse::<u16>() {
492 ollama_config.port = port;
493 }
494 }
495 }
496 ollama_config.chat_model = model.clone();
497 ollama_config.enabled = true;
498
499 let language_model = OllamaLanguageModelAdapter::new(ollama_config);
500 builder = builder.with_service(language_model);
501
502 #[cfg(feature = "tracing")]
503 tracing::info!(
504 "Registered Ollama language model: {} at {}",
505 model,
506 base_url
507 );
508 }
509 }
510
511 #[cfg(all(feature = "monitoring", feature = "dashmap"))]
514 {
515 if self.enable_monitoring {
516 use crate::monitoring::MetricsCollector;
517
518 let metrics = MetricsCollector::new();
519 builder = builder.with_service(metrics);
520
521 #[cfg(feature = "tracing")]
522 tracing::info!("Registered MetricsCollector");
523 }
524 }
525
526 #[cfg(feature = "function-calling")]
552 {
553 if self.enable_function_calling {
554 #[cfg(feature = "tracing")]
555 tracing::info!(
556 "Function calling enabled - use function_calling::FunctionCaller directly"
557 );
558 }
559 }
560
561 builder
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[derive(Debug)]
570 struct TestService {
571 value: String,
572 }
573
574 impl TestService {
575 fn new(value: String) -> Self {
576 Self { value }
577 }
578 }
579
580 #[test]
581 fn test_registry_basic_operations() {
582 let mut registry = ServiceRegistry::new();
583
584 registry.register(TestService::new("test".to_string()));
586 assert!(registry.has::<TestService>());
587 assert_eq!(registry.len(), 1);
588
589 let service = registry.get::<TestService>().unwrap();
591 assert_eq!(service.value, "test");
592
593 let removed = registry.remove::<TestService>().unwrap();
595 assert_eq!(removed.value, "test");
596 assert!(!registry.has::<TestService>());
597 assert!(registry.is_empty());
598 }
599
600 #[test]
601 fn test_registry_builder() {
602 let registry = RegistryBuilder::new()
603 .with_service(TestService::new("builder".to_string()))
604 .build();
605
606 assert!(registry.has::<TestService>());
607 let service = registry.get::<TestService>().unwrap();
608 assert_eq!(service.value, "builder");
609 }
610
611 #[test]
612 fn test_service_context() {
613 let mut registry = ServiceRegistry::new();
614 registry.register(TestService::new("context".to_string()));
615
616 let context = ServiceContext::new(registry);
617 let service = context.get::<TestService>().unwrap();
618 assert_eq!(service.value, "context");
619
620 let cloned_context = context.clone();
622 let service2 = cloned_context.get::<TestService>().unwrap();
623 assert_eq!(service2.value, "context");
624 }
625
626 #[test]
627 fn test_service_config_default() {
628 let config = ServiceConfig::default();
629 assert!(config.ollama_base_url.is_some());
630 assert!(config.embedding_model.is_some());
631 assert!(config.language_model.is_some());
632 assert!(config.vector_dimension.is_some());
633 assert!(config.entity_confidence_threshold.is_some());
634 assert!(config.enable_parallel_processing);
635 }
636
637 #[test]
638 #[cfg(feature = "ollama")]
639 fn test_service_config_build_with_ollama() {
640 let config = ServiceConfig {
641 ollama_base_url: Some("http://localhost:11434".to_string()),
642 embedding_model: Some("nomic-embed-text".to_string()),
643 language_model: Some("llama3.2".to_string()),
644 vector_dimension: Some(768),
645 entity_confidence_threshold: Some(0.7),
646 enable_parallel_processing: true,
647 enable_function_calling: false,
648 enable_monitoring: false,
649 };
650
651 let registry = config.build_registry().build();
652
653 #[cfg(feature = "memory-storage")]
655 {
656 use crate::storage::MemoryStorage;
657 assert!(registry.has::<MemoryStorage>());
658 }
659
660 assert!(!registry.is_empty());
664 }
665
666 #[test]
667 #[cfg(feature = "vector-memory")]
668 fn test_registry_with_vector_memory() {
669 use crate::vector::memory_store::MemoryVectorStore;
670
671 let config = ServiceConfig {
672 ollama_base_url: None,
673 embedding_model: None,
674 language_model: None,
675 vector_dimension: Some(384), entity_confidence_threshold: None,
677 enable_parallel_processing: false,
678 enable_function_calling: false,
679 enable_monitoring: false,
680 };
681
682 let registry = config.build_registry().build();
683
684 assert!(
687 registry.has::<MemoryVectorStore>(),
688 "MemoryVectorStore should be registered when vector-memory feature is enabled"
689 );
690
691 let vector_store = registry.get::<MemoryVectorStore>();
693 assert!(
694 vector_store.is_ok(),
695 "Should be able to retrieve registered MemoryVectorStore"
696 );
697 }
698
699 #[test]
700 #[cfg(not(feature = "vector-memory"))]
701 fn test_registry_without_vector_memory() {
702 let config = ServiceConfig {
703 ollama_base_url: None,
704 embedding_model: None,
705 language_model: None,
706 vector_dimension: Some(384), entity_confidence_threshold: None,
708 enable_parallel_processing: false,
709 enable_function_calling: false,
710 enable_monitoring: false,
711 };
712
713 let registry = config.build_registry().build();
714
715 }
720}