flow_di/
builder.rs

1use crate::{
2    descriptor::{ServiceProvider, ServiceProviderExt},
3    Container, DiError, DiResult, Lifetime, ServiceDescriptor, ServiceFactory,
4};
5use std::sync::Arc;
6
7/// Container builder - provides a fluent API for service registration
8pub struct ContainerBuilder {
9    container: Container,
10}
11
12impl ContainerBuilder {
13    /// Create a new container builder
14    pub fn new() -> Self {
15        Self {
16            container: Container::new(),
17        }
18    }
19
20    /// Register a transient service
21    pub fn add_transient<TService, TImplementation>(
22        self,
23        factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
24    ) -> Self
25    where
26        TService: 'static,
27        TImplementation: Send + Sync + 'static,
28    {
29        let factory: ServiceFactory = Box::new(move |provider| {
30            let instance = factory(provider)?;
31            Ok(Box::new(instance))
32        });
33
34        let descriptor = ServiceDescriptor::transient::<TService, TImplementation>(factory);
35        self.register_descriptor(descriptor)
36    }
37
38    /// Register a transient service (self-registration)
39    pub fn add_transient_self<T>(
40        self,
41        factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
42    ) -> Self
43    where
44        T: Send + Sync + 'static,
45    {
46        self.add_transient::<T, T>(factory)
47    }
48
49    /// Register a transient service (no dependencies)
50    pub fn add_transient_simple<TService, TImplementation>(
51        self,
52        factory: impl Fn() -> TImplementation + Send + Sync + 'static,
53    ) -> Self
54    where
55        TService: 'static,
56        TImplementation: Send + Sync + 'static,
57    {
58        self.add_transient::<TService, TImplementation>(move |_| Ok(factory()))
59    }
60
61    /// Register a scoped service
62    pub fn add_scoped<TService, TImplementation>(
63        self,
64        factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
65    ) -> Self
66    where
67        TService: 'static,
68        TImplementation: Send + Sync + 'static,
69    {
70        let factory: ServiceFactory = Box::new(move |provider| {
71            let instance = factory(provider)?;
72            Ok(Box::new(instance))
73        });
74
75        let descriptor = ServiceDescriptor::scoped::<TService, TImplementation>(factory);
76        self.register_descriptor(descriptor)
77    }
78
79    /// Register a scoped service (self-registration)
80    pub fn add_scoped_self<T>(
81        self,
82        factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
83    ) -> Self
84    where
85        T: Send + Sync + 'static,
86    {
87        self.add_scoped::<T, T>(factory)
88    }
89
90    /// Register a scoped service (no dependencies)
91    pub fn add_scoped_simple<TService, TImplementation>(
92        self,
93        factory: impl Fn() -> TImplementation + Send + Sync + 'static,
94    ) -> Self
95    where
96        TService: 'static,
97        TImplementation: Send + Sync + 'static,
98    {
99        self.add_scoped::<TService, TImplementation>(move |_| Ok(factory()))
100    }
101
102    /// Register a singleton service
103    pub fn add_singleton<TService, TImplementation>(
104        self,
105        factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
106    ) -> Self
107    where
108        TService: 'static,
109        TImplementation: Send + Sync + 'static,
110    {
111        let factory: ServiceFactory = Box::new(move |provider| {
112            let instance = factory(provider)?;
113            Ok(Box::new(instance))
114        });
115
116        let descriptor = ServiceDescriptor::singleton::<TService, TImplementation>(factory);
117        self.register_descriptor(descriptor)
118    }
119
120    /// Register a singleton service (self-registration)
121    pub fn add_singleton_self<T>(
122        self,
123        factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
124    ) -> Self
125    where
126        T: Send + Sync + 'static,
127    {
128        self.add_singleton::<T, T>(factory)
129    }
130
131    /// Register a singleton service (no dependencies)
132    pub fn add_singleton_simple<TService, TImplementation>(
133        self,
134        factory: impl Fn() -> TImplementation + Send + Sync + 'static,
135    ) -> Self
136    where
137        TService: 'static,
138        TImplementation: Send + Sync + 'static,
139    {
140        self.add_singleton::<TService, TImplementation>(move |_| Ok(factory()))
141    }
142
143    /// Register a singleton instance
144    pub fn add_instance<T>(self, instance: T) -> Self
145    where
146        T: Send + Sync + 'static,
147    {
148        let descriptor = ServiceDescriptor::from_instance(instance);
149        self.register_descriptor(descriptor)
150    }
151
152    // ==================== Named service registration ====================
153
154    /// Register a named transient service
155    pub fn add_named_transient<TService, TImplementation>(
156        self,
157        name: impl Into<String>,
158        factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
159    ) -> Self
160    where
161        TService: 'static,
162        TImplementation: Send + Sync + 'static,
163    {
164        let factory: ServiceFactory = Box::new(move |provider| {
165            let instance = factory(provider)?;
166            Ok(Box::new(instance))
167        });
168
169        let descriptor =
170            ServiceDescriptor::named_transient::<TService, TImplementation>(name, factory);
171        self.register_descriptor(descriptor)
172    }
173
174    /// Register a named transient service (self-registration)
175    pub fn add_named_transient_self<T>(
176        self,
177        name: impl Into<String>,
178        factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
179    ) -> Self
180    where
181        T: Send + Sync + 'static,
182    {
183        self.add_named_transient::<T, T>(name, factory)
184    }
185
186    /// Register a named transient service (no dependencies)
187    pub fn add_named_transient_simple<TService, TImplementation>(
188        self,
189        name: impl Into<String>,
190        factory: impl Fn() -> TImplementation + Send + Sync + 'static,
191    ) -> Self
192    where
193        TService: 'static,
194        TImplementation: Send + Sync + 'static,
195    {
196        self.add_named_transient::<TService, TImplementation>(name, move |_| Ok(factory()))
197    }
198
199    /// Register a named scoped service
200    pub fn add_named_scoped<TService, TImplementation>(
201        self,
202        name: impl Into<String>,
203        factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
204    ) -> Self
205    where
206        TService: 'static,
207        TImplementation: Send + Sync + 'static,
208    {
209        let factory: ServiceFactory = Box::new(move |provider| {
210            let instance = factory(provider)?;
211            Ok(Box::new(instance))
212        });
213
214        let descriptor =
215            ServiceDescriptor::named_scoped::<TService, TImplementation>(name, factory);
216        self.register_descriptor(descriptor)
217    }
218
219    /// Register a named scoped service (self-registration)
220    pub fn add_named_scoped_self<T>(
221        self,
222        name: impl Into<String>,
223        factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
224    ) -> Self
225    where
226        T: Send + Sync + 'static,
227    {
228        self.add_named_scoped::<T, T>(name, factory)
229    }
230
231    /// Register a named scoped service (no dependencies)
232    pub fn add_named_scoped_simple<TService, TImplementation>(
233        self,
234        name: impl Into<String>,
235        factory: impl Fn() -> TImplementation + Send + Sync + 'static,
236    ) -> Self
237    where
238        TService: 'static,
239        TImplementation: Send + Sync + 'static,
240    {
241        self.add_named_scoped::<TService, TImplementation>(name, move |_| Ok(factory()))
242    }
243
244    /// Register a named singleton service
245    pub fn add_named_singleton<TService, TImplementation>(
246        self,
247        name: impl Into<String>,
248        factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
249    ) -> Self
250    where
251        TService: 'static,
252        TImplementation: Send + Sync + 'static,
253    {
254        let factory: ServiceFactory = Box::new(move |provider| {
255            let instance = factory(provider)?;
256            Ok(Box::new(instance))
257        });
258
259        let descriptor =
260            ServiceDescriptor::named_singleton::<TService, TImplementation>(name, factory);
261        self.register_descriptor(descriptor)
262    }
263
264    /// Register a named singleton service (self-registration)
265    pub fn add_named_singleton_self<T>(
266        self,
267        name: impl Into<String>,
268        factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
269    ) -> Self
270    where
271        T: Send + Sync + 'static,
272    {
273        self.add_named_singleton::<T, T>(name, factory)
274    }
275
276    /// Register a named singleton service (no dependencies)
277    pub fn add_named_singleton_simple<TService, TImplementation>(
278        self,
279        name: impl Into<String>,
280        factory: impl Fn() -> TImplementation + Send + Sync + 'static,
281    ) -> Self
282    where
283        TService: 'static,
284        TImplementation: Send + Sync + 'static,
285    {
286        self.add_named_singleton::<TService, TImplementation>(name, move |_| Ok(factory()))
287    }
288
289    /// Register a named singleton instance
290    pub fn add_named_instance<T>(self, name: impl Into<String>, instance: T) -> Self
291    where
292        T: Send + Sync + 'static,
293    {
294        let descriptor = ServiceDescriptor::from_named_instance(name, instance);
295        self.register_descriptor(descriptor)
296    }
297
298    // ==================== Dependency injection helper methods ====================
299
300    /// Register a transient service with dependency injection
301    pub fn add_transient_with_deps<TService, TImplementation, TDep1>(
302        self,
303        factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
304    ) -> Self
305    where
306        TService: 'static,
307        TImplementation: Send + Sync + 'static,
308        TDep1: 'static + Send + Sync,
309    {
310        self.add_transient::<TService, TImplementation>(move |provider| {
311            let dep1 = provider.get_required_service::<TDep1>()?;
312            Ok(factory(dep1))
313        })
314    }
315
316    /// Register a transient service with two dependencies
317    pub fn add_transient_with_deps2<TService, TImplementation, TDep1, TDep2>(
318        self,
319        factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
320    ) -> Self
321    where
322        TService: 'static,
323        TImplementation: Send + Sync + 'static,
324        TDep1: 'static + Send + Sync,
325        TDep2: 'static + Send + Sync,
326    {
327        self.add_transient::<TService, TImplementation>(move |provider| {
328            let dep1 = provider.get_required_service::<TDep1>()?;
329            let dep2 = provider.get_required_service::<TDep2>()?;
330            Ok(factory(dep1, dep2))
331        })
332    }
333
334    /// Register a scoped service with dependency injection
335    pub fn add_scoped_with_deps<TService, TImplementation, TDep1>(
336        self,
337        factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
338    ) -> Self
339    where
340        TService: 'static,
341        TImplementation: Send + Sync + 'static,
342        TDep1: 'static + Send + Sync,
343    {
344        self.add_scoped::<TService, TImplementation>(move |provider| {
345            let dep1 = provider.get_required_service::<TDep1>()?;
346            Ok(factory(dep1))
347        })
348    }
349
350    /// Register a scoped service with two dependencies
351    pub fn add_scoped_with_deps2<TService, TImplementation, TDep1, TDep2>(
352        self,
353        factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
354    ) -> Self
355    where
356        TService: 'static,
357        TImplementation: Send + Sync + 'static,
358        TDep1: 'static + Send + Sync,
359        TDep2: 'static + Send + Sync,
360    {
361        self.add_scoped::<TService, TImplementation>(move |provider| {
362            let dep1 = provider.get_required_service::<TDep1>()?;
363            let dep2 = provider.get_required_service::<TDep2>()?;
364            Ok(factory(dep1, dep2))
365        })
366    }
367
368    /// Register a singleton service with dependency injection
369    pub fn add_singleton_with_deps<TService, TImplementation, TDep1>(
370        self,
371        factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
372    ) -> Self
373    where
374        TService: 'static,
375        TImplementation: Send + Sync + 'static,
376        TDep1: 'static + Send + Sync,
377    {
378        self.add_singleton::<TService, TImplementation>(move |provider| {
379            let dep1 = provider.get_required_service::<TDep1>()?;
380            Ok(factory(dep1))
381        })
382    }
383
384    /// Register a singleton service with two dependencies
385    pub fn add_singleton_with_deps2<TService, TImplementation, TDep1, TDep2>(
386        self,
387        factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
388    ) -> Self
389    where
390        TService: 'static,
391        TImplementation: Send + Sync + 'static,
392        TDep1: 'static + Send + Sync,
393        TDep2: 'static + Send + Sync,
394    {
395        self.add_singleton::<TService, TImplementation>(move |provider| {
396            let dep1 = provider.get_required_service::<TDep1>()?;
397            let dep2 = provider.get_required_service::<TDep2>()?;
398            Ok(factory(dep1, dep2))
399        })
400    }
401
402    // ==================== Advanced features ====================
403
404    /// Decorator pattern - decorate existing services
405    pub fn decorate<TService>(
406        self,
407        _decorator: impl Fn(&dyn ServiceProvider, Arc<TService>) -> DiResult<TService>
408            + Send
409            + Sync
410            + 'static,
411    ) -> Self
412    where
413        TService: Send + Sync + 'static,
414    {
415        // This requires a complex implementation to support decorator pattern
416        // Providing basic framework for now
417        self.add_transient_self::<TService>(move |_resolver| {
418            // This should retrieve the original service and apply the decorator
419            // But requires more complex implementation to avoid circular dependencies
420            Err(DiError::generic("Decorator pattern not fully implemented"))
421        })
422    }
423
424    /// Conditional registration - only register service when condition is met
425    pub fn add_conditional<TService, TImplementation>(
426        self,
427        condition: bool,
428        lifetime: Lifetime,
429        factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
430    ) -> Self
431    where
432        TService: 'static,
433        TImplementation: Send + Sync + 'static,
434    {
435        if condition {
436            match lifetime {
437                Lifetime::Transient => self.add_transient::<TService, TImplementation>(factory),
438                Lifetime::Scoped => self.add_scoped::<TService, TImplementation>(factory),
439                Lifetime::Singleton => self.add_singleton::<TService, TImplementation>(factory),
440            }
441        } else {
442            self
443        }
444    }
445
446    /// Register multiple services at once
447    pub fn add_services(mut self, services: Vec<ServiceDescriptor>) -> Self {
448        for descriptor in services {
449            self = self.register_descriptor(descriptor);
450        }
451        self
452    }
453
454    /// Register a service descriptor
455    fn register_descriptor(self, descriptor: ServiceDescriptor) -> Self {
456        if let Err(e) = self.container.register(descriptor) {
457            eprintln!("Warning: Failed to register service: {e}");
458        }
459        self
460    }
461
462    /// Build the container
463    pub fn build(self) -> crate::ServiceProvider {
464        self.container.build()
465    }
466
467    /// Get reference to the internal container (for advanced operations)
468    pub fn container(&self) -> &Container {
469        &self.container
470    }
471}
472
473impl Default for ContainerBuilder {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479/// Convenient macro for creating container builder
480#[macro_export]
481macro_rules! container {
482    () => {
483        $crate::ContainerBuilder::new()
484    };
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use crate::descriptor::ServiceProviderExt;
491    use std::sync::Arc;
492
493    #[derive(Debug, Clone, PartialEq)]
494    struct DatabaseConfig {
495        connection_string: String,
496    }
497
498    #[derive(Debug)]
499    struct Database {
500        config: Arc<DatabaseConfig>,
501    }
502
503    #[derive(Debug)]
504    struct UserService {
505        database: Arc<Database>,
506    }
507
508    trait IRepository: Send + Sync {
509        fn get_data(&self) -> String;
510    }
511
512    #[derive(Debug)]
513    struct SqlRepository {
514        connection: String,
515    }
516
517    impl IRepository for SqlRepository {
518        fn get_data(&self) -> String {
519            format!("Data from SQL: {}", self.connection)
520        }
521    }
522
523    #[derive(Debug)]
524    struct InMemoryRepository;
525
526    impl IRepository for InMemoryRepository {
527        fn get_data(&self) -> String {
528            "Data from memory".to_string()
529        }
530    }
531
532    #[test]
533    fn test_basic_service_registration() {
534        let provider = ContainerBuilder::new()
535            .add_instance(DatabaseConfig {
536                connection_string: "localhost:5432".to_string(),
537            })
538            .add_transient_with_deps::<Database, Database, DatabaseConfig>(|config| Database {
539                config,
540            })
541            .add_scoped_with_deps::<UserService, UserService, Database>(|database| UserService {
542                database,
543            })
544            .build();
545
546        let config = provider.get_required_service::<DatabaseConfig>().unwrap();
547        assert_eq!(config.connection_string, "localhost:5432");
548
549        let database = provider.get_required_service::<Database>().unwrap();
550        assert_eq!(database.config.connection_string, "localhost:5432");
551
552        // 测试作用域服务
553        let mut scope = provider.create_scope().unwrap();
554        let user_service1 = scope.get_required_service::<UserService>().unwrap();
555        let user_service2 = scope.get_required_service::<UserService>().unwrap();
556
557        // 验证服务被正确解析
558        assert_eq!(
559            user_service1.database.config.connection_string,
560            "localhost:5432"
561        );
562        assert_eq!(
563            user_service2.database.config.connection_string,
564            "localhost:5432"
565        );
566        scope.dispose();
567    }
568
569    #[test]
570    fn test_named_services() {
571        let provider = ContainerBuilder::new()
572            .add_named_singleton_simple::<SqlRepository, SqlRepository>("sql", || SqlRepository {
573                connection: "sql-connection".to_string(),
574            })
575            .add_named_singleton_simple::<InMemoryRepository, InMemoryRepository>("memory", || {
576                InMemoryRepository
577            })
578            .build();
579
580        let sql_repo = provider
581            .get_required_keyed_service::<SqlRepository>("sql")
582            .unwrap();
583        let memory_repo = provider
584            .get_required_keyed_service::<InMemoryRepository>("memory")
585            .unwrap();
586
587        assert_eq!(sql_repo.get_data(), "Data from SQL: sql-connection");
588        assert_eq!(memory_repo.get_data(), "Data from memory");
589    }
590
591    #[test]
592    fn test_different_lifetimes() {
593        let provider = ContainerBuilder::new()
594            .add_transient_simple::<String, String>(|| "transient".to_string())
595            .add_singleton_simple::<i32, i32>(|| 42)
596            .build();
597
598        // 测试瞬时服务
599        let str1 = provider.get_required_service::<String>().unwrap();
600        let str2 = provider.get_required_service::<String>().unwrap();
601        assert_eq!(*str1, "transient");
602        assert_eq!(*str2, "transient");
603
604        // 测试单例服务
605        let int1 = provider.get_required_service::<i32>().unwrap();
606        let int2 = provider.get_required_service::<i32>().unwrap();
607        assert_eq!(*int1, 42);
608        assert_eq!(*int2, 42);
609    }
610
611    #[test]
612    fn test_conditional_registration() {
613        let use_sql = true;
614
615        let provider = ContainerBuilder::new()
616            .add_conditional::<SqlRepository, SqlRepository>(use_sql, Lifetime::Singleton, |_| {
617                Ok(SqlRepository {
618                    connection: "conditional-sql".to_string(),
619                })
620            })
621            .add_conditional::<InMemoryRepository, InMemoryRepository>(
622                !use_sql,
623                Lifetime::Singleton,
624                |_| Ok(InMemoryRepository),
625            )
626            .build();
627
628        // 测试条件注册 - 应该只注册 SQL repository
629        let sql_repo = provider.get_service::<SqlRepository>().unwrap();
630        assert!(sql_repo.is_some());
631        assert_eq!(
632            sql_repo.unwrap().get_data(),
633            "Data from SQL: conditional-sql"
634        );
635
636        // 内存 repository 不应该被注册
637        let memory_repo = provider.get_service::<InMemoryRepository>().unwrap();
638        assert!(memory_repo.is_none());
639    }
640
641    #[test]
642    fn test_macro_usage() {
643        let provider = container!()
644            .add_instance(42i32)
645            .add_transient_simple::<String, String>(|| "hello".to_string())
646            .build();
647
648        let number = provider.get_required_service::<i32>().unwrap();
649        let text = provider.get_required_service::<String>().unwrap();
650
651        assert_eq!(*number, 42);
652        assert_eq!(*text, "hello");
653    }
654}