1use crate::{
2 descriptor::{ServiceProvider, ServiceProviderExt},
3 Container, DiError, DiResult, Lifetime, ServiceDescriptor, ServiceFactory,
4};
5use std::sync::Arc;
6
7pub struct ContainerBuilder {
9 container: Container,
10}
11
12impl ContainerBuilder {
13 pub fn new() -> Self {
15 Self {
16 container: Container::new(),
17 }
18 }
19
20 pub fn add_transient<TService, TImplementation>(
22 self,
23 factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
24 ) -> Self
25 where
26 TService: 'static,
27 TImplementation: Send + Sync + 'static,
28 {
29 let factory: ServiceFactory = Box::new(move |provider| {
30 let instance = factory(provider)?;
31 Ok(Box::new(instance))
32 });
33
34 let descriptor = ServiceDescriptor::transient::<TService, TImplementation>(factory);
35 self.register_descriptor(descriptor)
36 }
37
38 pub fn add_transient_self<T>(
40 self,
41 factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
42 ) -> Self
43 where
44 T: Send + Sync + 'static,
45 {
46 self.add_transient::<T, T>(factory)
47 }
48
49 pub fn add_transient_simple<TService, TImplementation>(
51 self,
52 factory: impl Fn() -> TImplementation + Send + Sync + 'static,
53 ) -> Self
54 where
55 TService: 'static,
56 TImplementation: Send + Sync + 'static,
57 {
58 self.add_transient::<TService, TImplementation>(move |_| Ok(factory()))
59 }
60
61 pub fn add_scoped<TService, TImplementation>(
63 self,
64 factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
65 ) -> Self
66 where
67 TService: 'static,
68 TImplementation: Send + Sync + 'static,
69 {
70 let factory: ServiceFactory = Box::new(move |provider| {
71 let instance = factory(provider)?;
72 Ok(Box::new(instance))
73 });
74
75 let descriptor = ServiceDescriptor::scoped::<TService, TImplementation>(factory);
76 self.register_descriptor(descriptor)
77 }
78
79 pub fn add_scoped_self<T>(
81 self,
82 factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
83 ) -> Self
84 where
85 T: Send + Sync + 'static,
86 {
87 self.add_scoped::<T, T>(factory)
88 }
89
90 pub fn add_scoped_simple<TService, TImplementation>(
92 self,
93 factory: impl Fn() -> TImplementation + Send + Sync + 'static,
94 ) -> Self
95 where
96 TService: 'static,
97 TImplementation: Send + Sync + 'static,
98 {
99 self.add_scoped::<TService, TImplementation>(move |_| Ok(factory()))
100 }
101
102 pub fn add_singleton<TService, TImplementation>(
104 self,
105 factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
106 ) -> Self
107 where
108 TService: 'static,
109 TImplementation: Send + Sync + 'static,
110 {
111 let factory: ServiceFactory = Box::new(move |provider| {
112 let instance = factory(provider)?;
113 Ok(Box::new(instance))
114 });
115
116 let descriptor = ServiceDescriptor::singleton::<TService, TImplementation>(factory);
117 self.register_descriptor(descriptor)
118 }
119
120 pub fn add_singleton_self<T>(
122 self,
123 factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
124 ) -> Self
125 where
126 T: Send + Sync + 'static,
127 {
128 self.add_singleton::<T, T>(factory)
129 }
130
131 pub fn add_singleton_simple<TService, TImplementation>(
133 self,
134 factory: impl Fn() -> TImplementation + Send + Sync + 'static,
135 ) -> Self
136 where
137 TService: 'static,
138 TImplementation: Send + Sync + 'static,
139 {
140 self.add_singleton::<TService, TImplementation>(move |_| Ok(factory()))
141 }
142
143 pub fn add_instance<T>(self, instance: T) -> Self
145 where
146 T: Send + Sync + 'static,
147 {
148 let descriptor = ServiceDescriptor::from_instance(instance);
149 self.register_descriptor(descriptor)
150 }
151
152 pub fn add_named_transient<TService, TImplementation>(
156 self,
157 name: impl Into<String>,
158 factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
159 ) -> Self
160 where
161 TService: 'static,
162 TImplementation: Send + Sync + 'static,
163 {
164 let factory: ServiceFactory = Box::new(move |provider| {
165 let instance = factory(provider)?;
166 Ok(Box::new(instance))
167 });
168
169 let descriptor =
170 ServiceDescriptor::named_transient::<TService, TImplementation>(name, factory);
171 self.register_descriptor(descriptor)
172 }
173
174 pub fn add_named_transient_self<T>(
176 self,
177 name: impl Into<String>,
178 factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
179 ) -> Self
180 where
181 T: Send + Sync + 'static,
182 {
183 self.add_named_transient::<T, T>(name, factory)
184 }
185
186 pub fn add_named_transient_simple<TService, TImplementation>(
188 self,
189 name: impl Into<String>,
190 factory: impl Fn() -> TImplementation + Send + Sync + 'static,
191 ) -> Self
192 where
193 TService: 'static,
194 TImplementation: Send + Sync + 'static,
195 {
196 self.add_named_transient::<TService, TImplementation>(name, move |_| Ok(factory()))
197 }
198
199 pub fn add_named_scoped<TService, TImplementation>(
201 self,
202 name: impl Into<String>,
203 factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
204 ) -> Self
205 where
206 TService: 'static,
207 TImplementation: Send + Sync + 'static,
208 {
209 let factory: ServiceFactory = Box::new(move |provider| {
210 let instance = factory(provider)?;
211 Ok(Box::new(instance))
212 });
213
214 let descriptor =
215 ServiceDescriptor::named_scoped::<TService, TImplementation>(name, factory);
216 self.register_descriptor(descriptor)
217 }
218
219 pub fn add_named_scoped_self<T>(
221 self,
222 name: impl Into<String>,
223 factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
224 ) -> Self
225 where
226 T: Send + Sync + 'static,
227 {
228 self.add_named_scoped::<T, T>(name, factory)
229 }
230
231 pub fn add_named_scoped_simple<TService, TImplementation>(
233 self,
234 name: impl Into<String>,
235 factory: impl Fn() -> TImplementation + Send + Sync + 'static,
236 ) -> Self
237 where
238 TService: 'static,
239 TImplementation: Send + Sync + 'static,
240 {
241 self.add_named_scoped::<TService, TImplementation>(name, move |_| Ok(factory()))
242 }
243
244 pub fn add_named_singleton<TService, TImplementation>(
246 self,
247 name: impl Into<String>,
248 factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
249 ) -> Self
250 where
251 TService: 'static,
252 TImplementation: Send + Sync + 'static,
253 {
254 let factory: ServiceFactory = Box::new(move |provider| {
255 let instance = factory(provider)?;
256 Ok(Box::new(instance))
257 });
258
259 let descriptor =
260 ServiceDescriptor::named_singleton::<TService, TImplementation>(name, factory);
261 self.register_descriptor(descriptor)
262 }
263
264 pub fn add_named_singleton_self<T>(
266 self,
267 name: impl Into<String>,
268 factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
269 ) -> Self
270 where
271 T: Send + Sync + 'static,
272 {
273 self.add_named_singleton::<T, T>(name, factory)
274 }
275
276 pub fn add_named_singleton_simple<TService, TImplementation>(
278 self,
279 name: impl Into<String>,
280 factory: impl Fn() -> TImplementation + Send + Sync + 'static,
281 ) -> Self
282 where
283 TService: 'static,
284 TImplementation: Send + Sync + 'static,
285 {
286 self.add_named_singleton::<TService, TImplementation>(name, move |_| Ok(factory()))
287 }
288
289 pub fn add_named_instance<T>(self, name: impl Into<String>, instance: T) -> Self
291 where
292 T: Send + Sync + 'static,
293 {
294 let descriptor = ServiceDescriptor::from_named_instance(name, instance);
295 self.register_descriptor(descriptor)
296 }
297
298 pub fn add_transient_with_deps<TService, TImplementation, TDep1>(
302 self,
303 factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
304 ) -> Self
305 where
306 TService: 'static,
307 TImplementation: Send + Sync + 'static,
308 TDep1: 'static + Send + Sync,
309 {
310 self.add_transient::<TService, TImplementation>(move |provider| {
311 let dep1 = provider.get_required_service::<TDep1>()?;
312 Ok(factory(dep1))
313 })
314 }
315
316 pub fn add_transient_with_deps2<TService, TImplementation, TDep1, TDep2>(
318 self,
319 factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
320 ) -> Self
321 where
322 TService: 'static,
323 TImplementation: Send + Sync + 'static,
324 TDep1: 'static + Send + Sync,
325 TDep2: 'static + Send + Sync,
326 {
327 self.add_transient::<TService, TImplementation>(move |provider| {
328 let dep1 = provider.get_required_service::<TDep1>()?;
329 let dep2 = provider.get_required_service::<TDep2>()?;
330 Ok(factory(dep1, dep2))
331 })
332 }
333
334 pub fn add_scoped_with_deps<TService, TImplementation, TDep1>(
336 self,
337 factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
338 ) -> Self
339 where
340 TService: 'static,
341 TImplementation: Send + Sync + 'static,
342 TDep1: 'static + Send + Sync,
343 {
344 self.add_scoped::<TService, TImplementation>(move |provider| {
345 let dep1 = provider.get_required_service::<TDep1>()?;
346 Ok(factory(dep1))
347 })
348 }
349
350 pub fn add_scoped_with_deps2<TService, TImplementation, TDep1, TDep2>(
352 self,
353 factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
354 ) -> Self
355 where
356 TService: 'static,
357 TImplementation: Send + Sync + 'static,
358 TDep1: 'static + Send + Sync,
359 TDep2: 'static + Send + Sync,
360 {
361 self.add_scoped::<TService, TImplementation>(move |provider| {
362 let dep1 = provider.get_required_service::<TDep1>()?;
363 let dep2 = provider.get_required_service::<TDep2>()?;
364 Ok(factory(dep1, dep2))
365 })
366 }
367
368 pub fn add_singleton_with_deps<TService, TImplementation, TDep1>(
370 self,
371 factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
372 ) -> Self
373 where
374 TService: 'static,
375 TImplementation: Send + Sync + 'static,
376 TDep1: 'static + Send + Sync,
377 {
378 self.add_singleton::<TService, TImplementation>(move |provider| {
379 let dep1 = provider.get_required_service::<TDep1>()?;
380 Ok(factory(dep1))
381 })
382 }
383
384 pub fn add_singleton_with_deps2<TService, TImplementation, TDep1, TDep2>(
386 self,
387 factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
388 ) -> Self
389 where
390 TService: 'static,
391 TImplementation: Send + Sync + 'static,
392 TDep1: 'static + Send + Sync,
393 TDep2: 'static + Send + Sync,
394 {
395 self.add_singleton::<TService, TImplementation>(move |provider| {
396 let dep1 = provider.get_required_service::<TDep1>()?;
397 let dep2 = provider.get_required_service::<TDep2>()?;
398 Ok(factory(dep1, dep2))
399 })
400 }
401
402 pub fn decorate<TService>(
406 self,
407 _decorator: impl Fn(&dyn ServiceProvider, Arc<TService>) -> DiResult<TService>
408 + Send
409 + Sync
410 + 'static,
411 ) -> Self
412 where
413 TService: Send + Sync + 'static,
414 {
415 self.add_transient_self::<TService>(move |_resolver| {
418 Err(DiError::generic("Decorator pattern not fully implemented"))
421 })
422 }
423
424 pub fn add_conditional<TService, TImplementation>(
426 self,
427 condition: bool,
428 lifetime: Lifetime,
429 factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
430 ) -> Self
431 where
432 TService: 'static,
433 TImplementation: Send + Sync + 'static,
434 {
435 if condition {
436 match lifetime {
437 Lifetime::Transient => self.add_transient::<TService, TImplementation>(factory),
438 Lifetime::Scoped => self.add_scoped::<TService, TImplementation>(factory),
439 Lifetime::Singleton => self.add_singleton::<TService, TImplementation>(factory),
440 }
441 } else {
442 self
443 }
444 }
445
446 pub fn add_services(mut self, services: Vec<ServiceDescriptor>) -> Self {
448 for descriptor in services {
449 self = self.register_descriptor(descriptor);
450 }
451 self
452 }
453
454 fn register_descriptor(self, descriptor: ServiceDescriptor) -> Self {
456 if let Err(e) = self.container.register(descriptor) {
457 eprintln!("Warning: Failed to register service: {e}");
458 }
459 self
460 }
461
462 pub fn build(self) -> crate::ServiceProvider {
464 self.container.build()
465 }
466
467 pub fn container(&self) -> &Container {
469 &self.container
470 }
471}
472
473impl Default for ContainerBuilder {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479#[macro_export]
481macro_rules! container {
482 () => {
483 $crate::ContainerBuilder::new()
484 };
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::descriptor::ServiceProviderExt;
491 use std::sync::Arc;
492
493 #[derive(Debug, Clone, PartialEq)]
494 struct DatabaseConfig {
495 connection_string: String,
496 }
497
498 #[derive(Debug)]
499 struct Database {
500 config: Arc<DatabaseConfig>,
501 }
502
503 #[derive(Debug)]
504 struct UserService {
505 database: Arc<Database>,
506 }
507
508 trait IRepository: Send + Sync {
509 fn get_data(&self) -> String;
510 }
511
512 #[derive(Debug)]
513 struct SqlRepository {
514 connection: String,
515 }
516
517 impl IRepository for SqlRepository {
518 fn get_data(&self) -> String {
519 format!("Data from SQL: {}", self.connection)
520 }
521 }
522
523 #[derive(Debug)]
524 struct InMemoryRepository;
525
526 impl IRepository for InMemoryRepository {
527 fn get_data(&self) -> String {
528 "Data from memory".to_string()
529 }
530 }
531
532 #[test]
533 fn test_basic_service_registration() {
534 let provider = ContainerBuilder::new()
535 .add_instance(DatabaseConfig {
536 connection_string: "localhost:5432".to_string(),
537 })
538 .add_transient_with_deps::<Database, Database, DatabaseConfig>(|config| Database {
539 config,
540 })
541 .add_scoped_with_deps::<UserService, UserService, Database>(|database| UserService {
542 database,
543 })
544 .build();
545
546 let config = provider.get_required_service::<DatabaseConfig>().unwrap();
547 assert_eq!(config.connection_string, "localhost:5432");
548
549 let database = provider.get_required_service::<Database>().unwrap();
550 assert_eq!(database.config.connection_string, "localhost:5432");
551
552 let mut scope = provider.create_scope().unwrap();
554 let user_service1 = scope.get_required_service::<UserService>().unwrap();
555 let user_service2 = scope.get_required_service::<UserService>().unwrap();
556
557 assert_eq!(
559 user_service1.database.config.connection_string,
560 "localhost:5432"
561 );
562 assert_eq!(
563 user_service2.database.config.connection_string,
564 "localhost:5432"
565 );
566 scope.dispose();
567 }
568
569 #[test]
570 fn test_named_services() {
571 let provider = ContainerBuilder::new()
572 .add_named_singleton_simple::<SqlRepository, SqlRepository>("sql", || SqlRepository {
573 connection: "sql-connection".to_string(),
574 })
575 .add_named_singleton_simple::<InMemoryRepository, InMemoryRepository>("memory", || {
576 InMemoryRepository
577 })
578 .build();
579
580 let sql_repo = provider
581 .get_required_keyed_service::<SqlRepository>("sql")
582 .unwrap();
583 let memory_repo = provider
584 .get_required_keyed_service::<InMemoryRepository>("memory")
585 .unwrap();
586
587 assert_eq!(sql_repo.get_data(), "Data from SQL: sql-connection");
588 assert_eq!(memory_repo.get_data(), "Data from memory");
589 }
590
591 #[test]
592 fn test_different_lifetimes() {
593 let provider = ContainerBuilder::new()
594 .add_transient_simple::<String, String>(|| "transient".to_string())
595 .add_singleton_simple::<i32, i32>(|| 42)
596 .build();
597
598 let str1 = provider.get_required_service::<String>().unwrap();
600 let str2 = provider.get_required_service::<String>().unwrap();
601 assert_eq!(*str1, "transient");
602 assert_eq!(*str2, "transient");
603
604 let int1 = provider.get_required_service::<i32>().unwrap();
606 let int2 = provider.get_required_service::<i32>().unwrap();
607 assert_eq!(*int1, 42);
608 assert_eq!(*int2, 42);
609 }
610
611 #[test]
612 fn test_conditional_registration() {
613 let use_sql = true;
614
615 let provider = ContainerBuilder::new()
616 .add_conditional::<SqlRepository, SqlRepository>(use_sql, Lifetime::Singleton, |_| {
617 Ok(SqlRepository {
618 connection: "conditional-sql".to_string(),
619 })
620 })
621 .add_conditional::<InMemoryRepository, InMemoryRepository>(
622 !use_sql,
623 Lifetime::Singleton,
624 |_| Ok(InMemoryRepository),
625 )
626 .build();
627
628 let sql_repo = provider.get_service::<SqlRepository>().unwrap();
630 assert!(sql_repo.is_some());
631 assert_eq!(
632 sql_repo.unwrap().get_data(),
633 "Data from SQL: conditional-sql"
634 );
635
636 let memory_repo = provider.get_service::<InMemoryRepository>().unwrap();
638 assert!(memory_repo.is_none());
639 }
640
641 #[test]
642 fn test_macro_usage() {
643 let provider = container!()
644 .add_instance(42i32)
645 .add_transient_simple::<String, String>(|| "hello".to_string())
646 .build();
647
648 let number = provider.get_required_service::<i32>().unwrap();
649 let text = provider.get_required_service::<String>().unwrap();
650
651 assert_eq!(*number, 42);
652 assert_eq!(*text, "hello");
653 }
654}