flow_di/
lib.rs

1//! # Flow-DI
2//!
3//! A Rust dependency injection framework inspired by C# AutoFac and Microsoft.Extensions.DependencyInjection
4//!
5//! ## Features
6//!
7//! - **Three Lifetime Management Types**: Singleton, Scoped, Transient
8//! - **Keyed Service Support**: Named services and key-based service resolution
9//! - **Automatic Dependency Injection**: Automatic dependency injection when creating service instances
10//! - **Fluent API**: Easy-to-use service registration API using builder pattern
11//! - **Circular Dependency Detection**: Automatic detection and prevention of circular dependencies
12//! - **Thread Safety**: Full support for multi-threaded environments
13//! - **Scope Management**: Nested scopes and automatic resource cleanup
14//!
15//! ## Quick Start
16//!
17//! ```rust
18//! use flow_di::{ContainerBuilder, ServiceProviderExt};
19//! use std::sync::Arc;
20//!
21//! // Define service interface
22//! trait IRepository: Send + Sync {
23//!     fn get_data(&self) -> String;
24//! }
25//!
26//! // Implement service
27//! #[derive(Debug)]
28//! struct DatabaseRepository {
29//!     connection_string: String,
30//! }
31//!
32//! impl IRepository for DatabaseRepository {
33//!     fn get_data(&self) -> String {
34//!         format!("Data from database: {}", self.connection_string)
35//!     }
36//! }
37//!
38//! // Configure and build container
39//! let provider = ContainerBuilder::new()
40//!     .add_instance("localhost:5432".to_string()) // Register configuration
41//!     .add_singleton_with_deps::<DatabaseRepository, DatabaseRepository, String>(
42//!         |connection_string| DatabaseRepository {
43//!             connection_string: (*connection_string).clone(),
44//!         }
45//!     )
46//!     .build();
47//!
48//! // Resolve service
49//! let repository = provider.get_required_service::<DatabaseRepository>().unwrap();
50//! println!("{}", repository.get_data());
51//! ```
52
53pub mod builder;
54pub mod container;
55pub mod descriptor;
56pub mod errors;
57pub mod lifetime;
58pub mod service_key;
59
60// 重新导出核心类型
61pub use builder::ContainerBuilder;
62pub use container::{Container, ServiceProvider, ServiceScope};
63pub use descriptor::{
64    ServiceDescriptor, ServiceFactory, ServiceProvider as ServiceProviderTrait, ServiceProviderExt,
65};
66pub use errors::{DiError, DiResult};
67pub use lifetime::Lifetime;
68pub use service_key::ServiceKey;
69
70/// Convenient function to create a container builder
71///
72/// This is a shortcut for `ContainerBuilder::new()`.
73///
74/// # Example
75///
76/// ```rust
77/// use flow_di::{container_builder, ServiceProviderExt};
78///
79/// let provider = container_builder()
80///     .add_instance(42i32)
81///     .add_transient_simple::<String, String>(|| "hello".to_string())
82///     .build();
83///
84/// let number = provider.get_required_service::<i32>().unwrap();
85/// assert_eq!(*number, 42);
86/// ```
87pub fn container_builder() -> ContainerBuilder {
88    ContainerBuilder::new()
89}
90
91/// Convenient function to create a container
92///
93/// This is a shortcut for `Container::new()`.
94///
95/// # Example
96///
97/// ```rust
98/// use flow_di::{container, ServiceDescriptor};
99///
100/// let container = container();
101/// let descriptor = ServiceDescriptor::from_instance(42i32);
102/// container.register(descriptor).unwrap();
103/// ```
104pub fn container() -> Container {
105    Container::new()
106}
107
108#[cfg(test)]
109mod integration_tests {
110    use super::*;
111    use crate::descriptor::ServiceProviderExt;
112    use std::sync::Arc;
113
114    trait ILogger: Send + Sync {
115        fn log(&self, message: &str);
116        fn get_logs(&self) -> Vec<String>;
117    }
118
119    struct ConsoleLogger {
120        logs: std::sync::Mutex<Vec<String>>,
121    }
122
123    impl ConsoleLogger {
124        fn new() -> Self {
125            Self {
126                logs: std::sync::Mutex::new(Vec::new()),
127            }
128        }
129    }
130
131    impl ILogger for ConsoleLogger {
132        fn log(&self, message: &str) {
133            if let Ok(mut logs) = self.logs.lock() {
134                logs.push(message.to_string());
135            }
136        }
137
138        fn get_logs(&self) -> Vec<String> {
139            self.logs.lock().unwrap().clone()
140        }
141    }
142
143    struct DatabaseService {
144        logger: Arc<ConsoleLogger>,
145        connection_string: Arc<String>,
146    }
147
148    impl DatabaseService {
149        fn connect(&self) -> String {
150            self.logger.log("Connecting to database");
151            format!("Connected to {}", self.connection_string)
152        }
153    }
154
155    struct UserService {
156        database: Arc<DatabaseService>,
157        logger: Arc<ConsoleLogger>,
158    }
159
160    impl UserService {
161        fn get_user(&self, id: i32) -> String {
162            self.logger.log(&format!("Getting user {id}"));
163            let _ = self.database.connect();
164            format!("User {id}")
165        }
166    }
167
168    #[test]
169    fn test_complete_dependency_injection_scenario() {
170        let provider = ContainerBuilder::new()
171            // 注册配置
172            .add_instance("postgresql://localhost:5432/mydb".to_string())
173            // 注册日志服务(单例)
174            .add_singleton_simple::<ConsoleLogger, ConsoleLogger>(ConsoleLogger::new)
175            // 注册数据库服务(作用域,依赖日志和配置)
176            .add_scoped_with_deps2::<DatabaseService, DatabaseService, ConsoleLogger, String>(
177                |logger, connection_string| DatabaseService {
178                    logger,
179                    connection_string,
180                },
181            )
182            // 注册用户服务(瞬时,依赖数据库和日志)
183            .add_transient_with_deps2::<UserService, UserService, DatabaseService, ConsoleLogger>(
184                |database, logger| UserService { database, logger },
185            )
186            .build();
187
188        // 创建作用域并测试服务解析
189        let mut scope = provider.create_scope().unwrap();
190
191        // 获取用户服务
192        let user_service = scope.get_required_service::<UserService>().unwrap();
193        let result = user_service.get_user(123);
194
195        assert_eq!(result, "User 123");
196
197        // 验证日志记录
198        let logger = scope.get_required_service::<ConsoleLogger>().unwrap();
199        let logs = logger.get_logs();
200        assert!(logs.contains(&"Getting user 123".to_string()));
201        assert!(logs.contains(&"Connecting to database".to_string()));
202
203        // 测试作用域内的服务实例共享
204        let db1 = scope.get_required_service::<DatabaseService>().unwrap();
205        let db2 = scope.get_required_service::<DatabaseService>().unwrap();
206        // 验证服务被正确解析
207        assert_eq!(db1.connection_string, db2.connection_string);
208
209        // 测试瞬时服务
210        let user1 = scope.get_required_service::<UserService>().unwrap();
211        let user2 = scope.get_required_service::<UserService>().unwrap();
212        // 验证服务被正确创建
213        assert_eq!(
214            user1.database.connection_string,
215            user2.database.connection_string
216        );
217
218        scope.dispose();
219    }
220
221    #[test]
222    fn test_keyed_services_integration() {
223        let provider = ContainerBuilder::new()
224            .add_named_singleton_simple::<ConsoleLogger, ConsoleLogger>(
225                "console",
226                ConsoleLogger::new,
227            )
228            .add_named_singleton_simple::<ConsoleLogger, ConsoleLogger>("file", ConsoleLogger::new)
229            .add_named_instance("database_url", "postgresql://localhost/db".to_string())
230            .add_named_instance("cache_url", "redis://localhost/cache".to_string())
231            .build();
232
233        // 测试命名服务解析
234        let console_logger = provider
235            .get_required_keyed_service::<ConsoleLogger>("console")
236            .unwrap();
237        let file_logger = provider
238            .get_required_keyed_service::<ConsoleLogger>("file")
239            .unwrap();
240
241        console_logger.log("Console message");
242        file_logger.log("File message");
243
244        assert_eq!(console_logger.get_logs(), vec!["Console message"]);
245        assert_eq!(file_logger.get_logs(), vec!["File message"]);
246
247        // 测试命名配置
248        let db_url = provider
249            .get_required_keyed_service::<String>("database_url")
250            .unwrap();
251        let cache_url = provider
252            .get_required_keyed_service::<String>("cache_url")
253            .unwrap();
254
255        assert_eq!(*db_url, "postgresql://localhost/db");
256        assert_eq!(*cache_url, "redis://localhost/cache");
257    }
258
259    #[test]
260    fn test_singleton_across_scopes() {
261        let provider = ContainerBuilder::new()
262            .add_singleton_simple::<ConsoleLogger, ConsoleLogger>(ConsoleLogger::new)
263            .build();
264
265        let mut scope1 = provider.create_scope().unwrap();
266        let mut scope2 = provider.create_scope().unwrap();
267
268        let logger1 = scope1.get_required_service::<ConsoleLogger>().unwrap();
269        let logger2 = scope2.get_required_service::<ConsoleLogger>().unwrap();
270
271        // 单例服务在不同作用域间应该有相同的日志记录
272        logger1.log("test message");
273        assert_eq!(logger1.get_logs().len(), logger2.get_logs().len());
274
275        scope1.dispose();
276        scope2.dispose();
277    }
278
279    #[test]
280    fn test_error_handling() {
281        let provider = ContainerBuilder::new().build();
282
283        // 测试获取未注册的服务
284        let result = provider.get_service::<String>();
285        assert!(result.is_ok());
286        assert!(result.unwrap().is_none());
287
288        // 测试获取未注册的必需服务
289        let result = provider.get_required_service::<String>();
290        assert!(result.is_err());
291        assert!(matches!(
292            result.unwrap_err(),
293            DiError::ServiceNotRegistered { .. }
294        ));
295
296        // 测试获取未注册的命名服务
297        let result = provider.get_keyed_service::<String>("nonexistent");
298        assert!(result.is_ok());
299        assert!(result.unwrap().is_none());
300
301        let result = provider.get_required_keyed_service::<String>("nonexistent");
302        assert!(result.is_err());
303        assert!(matches!(
304            result.unwrap_err(),
305            DiError::KeyedServiceNotRegistered { .. }
306        ));
307    }
308
309    #[test]
310    fn test_dyn_compatibility() {
311        let provider = ContainerBuilder::new()
312            .add_singleton_simple::<ConsoleLogger, ConsoleLogger>(ConsoleLogger::new)
313            .build();
314
315        // 测试 dyn trait object 的创建和使用
316        let dyn_provider: &dyn ServiceProviderTrait = &provider;
317
318        // 测试原始服务访问
319        let raw_result = dyn_provider.get_service_raw(&ServiceKey::of_type::<ConsoleLogger>());
320        assert!(raw_result.is_ok());
321        assert!(raw_result.unwrap().is_some());
322
323        // 测试扩展方法(需要显式导入 trait)
324        let typed_result = provider.get_service::<ConsoleLogger>();
325        assert!(typed_result.is_ok());
326        assert!(typed_result.unwrap().is_some());
327    }
328
329    #[test]
330    fn test_trait_object_storage() {
331        let provider = ContainerBuilder::new()
332            .add_singleton_simple::<ConsoleLogger, ConsoleLogger>(ConsoleLogger::new)
333            .build();
334
335        // 测试将 trait object 存储在集合中
336        let providers: Vec<&dyn ServiceProviderTrait> = vec![&provider, &provider];
337
338        for p in providers {
339            let result = p.get_service_raw(&ServiceKey::of_type::<ConsoleLogger>());
340            assert!(result.is_ok());
341        }
342    }
343}