graphrag_core/core/
registry.rs1use crate::core::traits::*;
7use crate::core::{GraphRAGError, Result};
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12type ServiceBox = Box<dyn Any + Send + Sync>;
14
15pub struct ServiceRegistry {
17 services: HashMap<TypeId, ServiceBox>,
18}
19
20impl ServiceRegistry {
21 pub fn new() -> Self {
23 Self {
24 services: HashMap::new(),
25 }
26 }
27
28 pub fn register<T: Any + Send + Sync>(&mut self, service: T) {
30 let type_id = TypeId::of::<T>();
31 self.services.insert(type_id, Box::new(service));
32 }
33
34 pub fn get<T: Any + Send + Sync>(&self) -> Result<&T> {
36 let type_id = TypeId::of::<T>();
37
38 self.services
39 .get(&type_id)
40 .and_then(|service| service.downcast_ref::<T>())
41 .ok_or_else(|| GraphRAGError::Config {
42 message: format!("Service not registered: {}", std::any::type_name::<T>()),
43 })
44 }
45
46 pub fn get_mut<T: Any + Send + Sync>(&mut self) -> Result<&mut T> {
48 let type_id = TypeId::of::<T>();
49
50 self.services
51 .get_mut(&type_id)
52 .and_then(|service| service.downcast_mut::<T>())
53 .ok_or_else(|| GraphRAGError::Config {
54 message: format!("Service not registered: {}", std::any::type_name::<T>()),
55 })
56 }
57
58 pub fn has<T: Any + Send + Sync>(&self) -> bool {
60 let type_id = TypeId::of::<T>();
61 self.services.contains_key(&type_id)
62 }
63
64 pub fn remove<T: Any + Send + Sync>(&mut self) -> Option<T> {
66 let type_id = TypeId::of::<T>();
67
68 self.services
69 .remove(&type_id)
70 .and_then(|service| service.downcast::<T>().ok())
71 .map(|boxed| *boxed)
72 }
73
74 pub fn len(&self) -> usize {
76 self.services.len()
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.services.is_empty()
82 }
83
84 pub fn clear(&mut self) {
86 self.services.clear();
87 }
88}
89
90impl Default for ServiceRegistry {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96pub struct RegistryBuilder {
98 registry: ServiceRegistry,
99}
100
101impl RegistryBuilder {
102 pub fn new() -> Self {
104 Self {
105 registry: ServiceRegistry::new(),
106 }
107 }
108
109 pub fn with_service<T: Any + Send + Sync>(mut self, service: T) -> Self {
111 self.registry.register(service);
112 self
113 }
114
115 pub fn with_storage<S>(mut self, storage: S) -> Self
117 where
118 S: Storage + Any + Send + Sync,
119 {
120 self.registry.register(storage);
121 self
122 }
123
124 pub fn with_embedder<E>(mut self, embedder: E) -> Self
126 where
127 E: Embedder + Any + Send + Sync,
128 {
129 self.registry.register(embedder);
130 self
131 }
132
133 pub fn with_vector_store<V>(mut self, vector_store: V) -> Self
135 where
136 V: VectorStore + Any + Send + Sync,
137 {
138 self.registry.register(vector_store);
139 self
140 }
141
142 pub fn with_entity_extractor<E>(mut self, extractor: E) -> Self
144 where
145 E: EntityExtractor + Any + Send + Sync,
146 {
147 self.registry.register(extractor);
148 self
149 }
150
151 pub fn with_retriever<R>(mut self, retriever: R) -> Self
153 where
154 R: Retriever + Any + Send + Sync,
155 {
156 self.registry.register(retriever);
157 self
158 }
159
160 pub fn with_language_model<L>(mut self, language_model: L) -> Self
162 where
163 L: LanguageModel + Any + Send + Sync,
164 {
165 self.registry.register(language_model);
166 self
167 }
168
169 pub fn with_graph_store<G>(mut self, graph_store: G) -> Self
171 where
172 G: GraphStore + Any + Send + Sync,
173 {
174 self.registry.register(graph_store);
175 self
176 }
177
178 pub fn with_function_registry<F>(mut self, function_registry: F) -> Self
180 where
181 F: FunctionRegistry + Any + Send + Sync,
182 {
183 self.registry.register(function_registry);
184 self
185 }
186
187 pub fn with_metrics_collector<M>(mut self, metrics: M) -> Self
189 where
190 M: MetricsCollector + Any + Send + Sync,
191 {
192 self.registry.register(metrics);
193 self
194 }
195
196 pub fn with_serializer<S>(mut self, serializer: S) -> Self
198 where
199 S: Serializer + Any + Send + Sync,
200 {
201 self.registry.register(serializer);
202 self
203 }
204
205 pub fn build(self) -> ServiceRegistry {
207 self.registry
208 }
209
210 #[cfg(feature = "ollama")]
212 pub fn with_ollama_defaults() -> Self {
213 #[cfg(feature = "memory-storage")]
214 use crate::storage::MemoryStorage;
215
216 let mut builder = Self::new();
217
218 #[cfg(feature = "memory-storage")]
219 {
220 builder = builder.with_storage(MemoryStorage::new());
221 }
222
223 #[cfg(feature = "parallel-processing")]
225 {
226 use crate::parallel::ParallelProcessor;
227
228 let num_threads = num_cpus::get();
230 let parallel_processor = ParallelProcessor::new(num_threads);
231 builder = builder.with_service(parallel_processor);
232 }
233
234 #[cfg(feature = "vector-hnsw")]
235 {
236 use crate::vector::VectorIndex;
237 builder = builder.with_service(VectorIndex::new());
238 }
239
240 #[cfg(feature = "caching")]
241 {
242 }
245 builder
246 }
247
248 #[cfg(feature = "memory-storage")]
250 pub fn with_test_defaults() -> Self {
251 use crate::storage::MemoryStorage;
252 Self::new().with_storage(MemoryStorage::new())
255 }
256}
257
258impl Default for RegistryBuilder {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264#[derive(Clone)]
266pub struct ServiceContext {
267 registry: Arc<ServiceRegistry>,
268}
269
270impl ServiceContext {
271 pub fn new(registry: ServiceRegistry) -> Self {
273 Self {
274 registry: Arc::new(registry),
275 }
276 }
277
278 pub fn get<T: Any + Send + Sync>(&self) -> Result<&T> {
280 unsafe {
283 let ptr = self.registry.as_ref() as *const ServiceRegistry;
284 (*ptr).get::<T>()
285 }
286 }
287}
288
289#[derive(Debug, Clone)]
291pub struct ServiceConfig {
292 pub ollama_base_url: Option<String>,
294 pub embedding_model: Option<String>,
296 pub language_model: Option<String>,
298 pub vector_dimension: Option<usize>,
300 pub entity_confidence_threshold: Option<f32>,
302 pub enable_parallel_processing: bool,
304 pub enable_function_calling: bool,
306 pub enable_monitoring: bool,
308}
309
310impl Default for ServiceConfig {
311 fn default() -> Self {
312 Self {
313 ollama_base_url: Some("http://localhost:11434".to_string()),
314 embedding_model: Some("nomic-embed-text:latest".to_string()),
315 language_model: Some("llama3.2:latest".to_string()),
316 vector_dimension: Some(384),
317 entity_confidence_threshold: Some(0.7),
318 enable_parallel_processing: true,
319 enable_function_calling: false,
320 enable_monitoring: false,
321 }
322 }
323}
324
325impl ServiceConfig {
326 pub fn build_registry(&self) -> RegistryBuilder {
328 let mut builder = RegistryBuilder::new();
329
330 #[cfg(feature = "memory-storage")]
331 {
332 use crate::storage::MemoryStorage;
333 builder = builder.with_storage(MemoryStorage::new());
334 }
335
336 builder
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[derive(Debug)]
347 struct TestService {
348 value: String,
349 }
350
351 impl TestService {
352 fn new(value: String) -> Self {
353 Self { value }
354 }
355 }
356
357 #[test]
358 fn test_registry_basic_operations() {
359 let mut registry = ServiceRegistry::new();
360
361 registry.register(TestService::new("test".to_string()));
363 assert!(registry.has::<TestService>());
364 assert_eq!(registry.len(), 1);
365
366 let service = registry.get::<TestService>().unwrap();
368 assert_eq!(service.value, "test");
369
370 let removed = registry.remove::<TestService>().unwrap();
372 assert_eq!(removed.value, "test");
373 assert!(!registry.has::<TestService>());
374 assert!(registry.is_empty());
375 }
376
377 #[test]
378 fn test_registry_builder() {
379 let registry = RegistryBuilder::new()
380 .with_service(TestService::new("builder".to_string()))
381 .build();
382
383 assert!(registry.has::<TestService>());
384 let service = registry.get::<TestService>().unwrap();
385 assert_eq!(service.value, "builder");
386 }
387
388 #[test]
389 fn test_service_context() {
390 let mut registry = ServiceRegistry::new();
391 registry.register(TestService::new("context".to_string()));
392
393 let context = ServiceContext::new(registry);
394 let service = context.get::<TestService>().unwrap();
395 assert_eq!(service.value, "context");
396
397 let cloned_context = context.clone();
399 let service2 = cloned_context.get::<TestService>().unwrap();
400 assert_eq!(service2.value, "context");
401 }
402
403 #[test]
404 fn test_service_config_default() {
405 let config = ServiceConfig::default();
406 assert!(config.ollama_base_url.is_some());
407 assert!(config.embedding_model.is_some());
408 assert!(config.language_model.is_some());
409 assert!(config.vector_dimension.is_some());
410 assert!(config.entity_confidence_threshold.is_some());
411 assert!(config.enable_parallel_processing);
412 }
413}