use crate::{
descriptor::{ServiceProvider, ServiceProviderExt},
Container, DiError, DiResult, Lifetime, ServiceDescriptor, ServiceFactory,
};
use std::sync::Arc;
pub struct ContainerBuilder {
container: Container,
}
impl ContainerBuilder {
pub fn new() -> Self {
Self {
container: Container::new(),
}
}
pub fn add_transient<TService, TImplementation>(
self,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
let factory: ServiceFactory = Box::new(move |provider| {
let instance = factory(provider)?;
Ok(Box::new(instance))
});
let descriptor = ServiceDescriptor::transient::<TService, TImplementation>(factory);
self.register_descriptor(descriptor)
}
pub fn add_transient_self<T>(
self,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
) -> Self
where
T: Send + Sync + 'static,
{
self.add_transient::<T, T>(factory)
}
pub fn add_transient_simple<TService, TImplementation>(
self,
factory: impl Fn() -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
self.add_transient::<TService, TImplementation>(move |_| Ok(factory()))
}
pub fn add_scoped<TService, TImplementation>(
self,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
let factory: ServiceFactory = Box::new(move |provider| {
let instance = factory(provider)?;
Ok(Box::new(instance))
});
let descriptor = ServiceDescriptor::scoped::<TService, TImplementation>(factory);
self.register_descriptor(descriptor)
}
pub fn add_scoped_self<T>(
self,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
) -> Self
where
T: Send + Sync + 'static,
{
self.add_scoped::<T, T>(factory)
}
pub fn add_scoped_simple<TService, TImplementation>(
self,
factory: impl Fn() -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
self.add_scoped::<TService, TImplementation>(move |_| Ok(factory()))
}
pub fn add_singleton<TService, TImplementation>(
self,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
let factory: ServiceFactory = Box::new(move |provider| {
let instance = factory(provider)?;
Ok(Box::new(instance))
});
let descriptor = ServiceDescriptor::singleton::<TService, TImplementation>(factory);
self.register_descriptor(descriptor)
}
pub fn add_singleton_self<T>(
self,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
) -> Self
where
T: Send + Sync + 'static,
{
self.add_singleton::<T, T>(factory)
}
pub fn add_singleton_simple<TService, TImplementation>(
self,
factory: impl Fn() -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
self.add_singleton::<TService, TImplementation>(move |_| Ok(factory()))
}
pub fn add_instance<T>(self, instance: T) -> Self
where
T: Send + Sync + 'static,
{
let descriptor = ServiceDescriptor::from_instance(instance);
self.register_descriptor(descriptor)
}
pub fn add_named_transient<TService, TImplementation>(
self,
name: impl Into<String>,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
let factory: ServiceFactory = Box::new(move |provider| {
let instance = factory(provider)?;
Ok(Box::new(instance))
});
let descriptor =
ServiceDescriptor::named_transient::<TService, TImplementation>(name, factory);
self.register_descriptor(descriptor)
}
pub fn add_named_transient_self<T>(
self,
name: impl Into<String>,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
) -> Self
where
T: Send + Sync + 'static,
{
self.add_named_transient::<T, T>(name, factory)
}
pub fn add_named_transient_simple<TService, TImplementation>(
self,
name: impl Into<String>,
factory: impl Fn() -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
self.add_named_transient::<TService, TImplementation>(name, move |_| Ok(factory()))
}
pub fn add_named_scoped<TService, TImplementation>(
self,
name: impl Into<String>,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
let factory: ServiceFactory = Box::new(move |provider| {
let instance = factory(provider)?;
Ok(Box::new(instance))
});
let descriptor =
ServiceDescriptor::named_scoped::<TService, TImplementation>(name, factory);
self.register_descriptor(descriptor)
}
pub fn add_named_scoped_self<T>(
self,
name: impl Into<String>,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
) -> Self
where
T: Send + Sync + 'static,
{
self.add_named_scoped::<T, T>(name, factory)
}
pub fn add_named_scoped_simple<TService, TImplementation>(
self,
name: impl Into<String>,
factory: impl Fn() -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
self.add_named_scoped::<TService, TImplementation>(name, move |_| Ok(factory()))
}
pub fn add_named_singleton<TService, TImplementation>(
self,
name: impl Into<String>,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
let factory: ServiceFactory = Box::new(move |provider| {
let instance = factory(provider)?;
Ok(Box::new(instance))
});
let descriptor =
ServiceDescriptor::named_singleton::<TService, TImplementation>(name, factory);
self.register_descriptor(descriptor)
}
pub fn add_named_singleton_self<T>(
self,
name: impl Into<String>,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<T> + Send + Sync + 'static,
) -> Self
where
T: Send + Sync + 'static,
{
self.add_named_singleton::<T, T>(name, factory)
}
pub fn add_named_singleton_simple<TService, TImplementation>(
self,
name: impl Into<String>,
factory: impl Fn() -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
self.add_named_singleton::<TService, TImplementation>(name, move |_| Ok(factory()))
}
pub fn add_named_instance<T>(self, name: impl Into<String>, instance: T) -> Self
where
T: Send + Sync + 'static,
{
let descriptor = ServiceDescriptor::from_named_instance(name, instance);
self.register_descriptor(descriptor)
}
pub fn add_transient_with_deps<TService, TImplementation, TDep1>(
self,
factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
TDep1: 'static + Send + Sync,
{
self.add_transient::<TService, TImplementation>(move |provider| {
let dep1 = provider.get_required_service::<TDep1>()?;
Ok(factory(dep1))
})
}
pub fn add_transient_with_deps2<TService, TImplementation, TDep1, TDep2>(
self,
factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
TDep1: 'static + Send + Sync,
TDep2: 'static + Send + Sync,
{
self.add_transient::<TService, TImplementation>(move |provider| {
let dep1 = provider.get_required_service::<TDep1>()?;
let dep2 = provider.get_required_service::<TDep2>()?;
Ok(factory(dep1, dep2))
})
}
pub fn add_scoped_with_deps<TService, TImplementation, TDep1>(
self,
factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
TDep1: 'static + Send + Sync,
{
self.add_scoped::<TService, TImplementation>(move |provider| {
let dep1 = provider.get_required_service::<TDep1>()?;
Ok(factory(dep1))
})
}
pub fn add_scoped_with_deps2<TService, TImplementation, TDep1, TDep2>(
self,
factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
TDep1: 'static + Send + Sync,
TDep2: 'static + Send + Sync,
{
self.add_scoped::<TService, TImplementation>(move |provider| {
let dep1 = provider.get_required_service::<TDep1>()?;
let dep2 = provider.get_required_service::<TDep2>()?;
Ok(factory(dep1, dep2))
})
}
pub fn add_singleton_with_deps<TService, TImplementation, TDep1>(
self,
factory: impl Fn(Arc<TDep1>) -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
TDep1: 'static + Send + Sync,
{
self.add_singleton::<TService, TImplementation>(move |provider| {
let dep1 = provider.get_required_service::<TDep1>()?;
Ok(factory(dep1))
})
}
pub fn add_singleton_with_deps2<TService, TImplementation, TDep1, TDep2>(
self,
factory: impl Fn(Arc<TDep1>, Arc<TDep2>) -> TImplementation + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
TDep1: 'static + Send + Sync,
TDep2: 'static + Send + Sync,
{
self.add_singleton::<TService, TImplementation>(move |provider| {
let dep1 = provider.get_required_service::<TDep1>()?;
let dep2 = provider.get_required_service::<TDep2>()?;
Ok(factory(dep1, dep2))
})
}
pub fn decorate<TService>(
self,
_decorator: impl Fn(&dyn ServiceProvider, Arc<TService>) -> DiResult<TService>
+ Send
+ Sync
+ 'static,
) -> Self
where
TService: Send + Sync + 'static,
{
self.add_transient_self::<TService>(move |_resolver| {
Err(DiError::generic("Decorator pattern not fully implemented"))
})
}
pub fn add_conditional<TService, TImplementation>(
self,
condition: bool,
lifetime: Lifetime,
factory: impl Fn(&dyn ServiceProvider) -> DiResult<TImplementation> + Send + Sync + 'static,
) -> Self
where
TService: 'static,
TImplementation: Send + Sync + 'static,
{
if condition {
match lifetime {
Lifetime::Transient => self.add_transient::<TService, TImplementation>(factory),
Lifetime::Scoped => self.add_scoped::<TService, TImplementation>(factory),
Lifetime::Singleton => self.add_singleton::<TService, TImplementation>(factory),
}
} else {
self
}
}
pub fn add_services(mut self, services: Vec<ServiceDescriptor>) -> Self {
for descriptor in services {
self = self.register_descriptor(descriptor);
}
self
}
fn register_descriptor(self, descriptor: ServiceDescriptor) -> Self {
if let Err(e) = self.container.register(descriptor) {
eprintln!("Warning: Failed to register service: {e}");
}
self
}
pub fn build(self) -> crate::ServiceProvider {
self.container.build()
}
pub fn container(&self) -> &Container {
&self.container
}
}
impl Default for ContainerBuilder {
fn default() -> Self {
Self::new()
}
}
#[macro_export]
macro_rules! container {
() => {
$crate::ContainerBuilder::new()
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::descriptor::ServiceProviderExt;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
struct DatabaseConfig {
connection_string: String,
}
#[derive(Debug)]
struct Database {
config: Arc<DatabaseConfig>,
}
#[derive(Debug)]
struct UserService {
database: Arc<Database>,
}
trait IRepository: Send + Sync {
fn get_data(&self) -> String;
}
#[derive(Debug)]
struct SqlRepository {
connection: String,
}
impl IRepository for SqlRepository {
fn get_data(&self) -> String {
format!("Data from SQL: {}", self.connection)
}
}
#[derive(Debug)]
struct InMemoryRepository;
impl IRepository for InMemoryRepository {
fn get_data(&self) -> String {
"Data from memory".to_string()
}
}
#[test]
fn test_basic_service_registration() {
let provider = ContainerBuilder::new()
.add_instance(DatabaseConfig {
connection_string: "localhost:5432".to_string(),
})
.add_transient_with_deps::<Database, Database, DatabaseConfig>(|config| Database {
config,
})
.add_scoped_with_deps::<UserService, UserService, Database>(|database| UserService {
database,
})
.build();
let config = provider.get_required_service::<DatabaseConfig>().unwrap();
assert_eq!(config.connection_string, "localhost:5432");
let database = provider.get_required_service::<Database>().unwrap();
assert_eq!(database.config.connection_string, "localhost:5432");
let mut scope = provider.create_scope().unwrap();
let user_service1 = scope.get_required_service::<UserService>().unwrap();
let user_service2 = scope.get_required_service::<UserService>().unwrap();
assert_eq!(
user_service1.database.config.connection_string,
"localhost:5432"
);
assert_eq!(
user_service2.database.config.connection_string,
"localhost:5432"
);
scope.dispose();
}
#[test]
fn test_named_services() {
let provider = ContainerBuilder::new()
.add_named_singleton_simple::<SqlRepository, SqlRepository>("sql", || SqlRepository {
connection: "sql-connection".to_string(),
})
.add_named_singleton_simple::<InMemoryRepository, InMemoryRepository>("memory", || {
InMemoryRepository
})
.build();
let sql_repo = provider
.get_required_keyed_service::<SqlRepository>("sql")
.unwrap();
let memory_repo = provider
.get_required_keyed_service::<InMemoryRepository>("memory")
.unwrap();
assert_eq!(sql_repo.get_data(), "Data from SQL: sql-connection");
assert_eq!(memory_repo.get_data(), "Data from memory");
}
#[test]
fn test_different_lifetimes() {
let provider = ContainerBuilder::new()
.add_transient_simple::<String, String>(|| "transient".to_string())
.add_singleton_simple::<i32, i32>(|| 42)
.build();
let str1 = provider.get_required_service::<String>().unwrap();
let str2 = provider.get_required_service::<String>().unwrap();
assert_eq!(*str1, "transient");
assert_eq!(*str2, "transient");
let int1 = provider.get_required_service::<i32>().unwrap();
let int2 = provider.get_required_service::<i32>().unwrap();
assert_eq!(*int1, 42);
assert_eq!(*int2, 42);
}
#[test]
fn test_conditional_registration() {
let use_sql = true;
let provider = ContainerBuilder::new()
.add_conditional::<SqlRepository, SqlRepository>(use_sql, Lifetime::Singleton, |_| {
Ok(SqlRepository {
connection: "conditional-sql".to_string(),
})
})
.add_conditional::<InMemoryRepository, InMemoryRepository>(
!use_sql,
Lifetime::Singleton,
|_| Ok(InMemoryRepository),
)
.build();
let sql_repo = provider.get_service::<SqlRepository>().unwrap();
assert!(sql_repo.is_some());
assert_eq!(
sql_repo.unwrap().get_data(),
"Data from SQL: conditional-sql"
);
let memory_repo = provider.get_service::<InMemoryRepository>().unwrap();
assert!(memory_repo.is_none());
}
#[test]
fn test_macro_usage() {
let provider = container!()
.add_instance(42i32)
.add_transient_simple::<String, String>(|| "hello".to_string())
.build();
let number = provider.get_required_service::<i32>().unwrap();
let text = provider.get_required_service::<String>().unwrap();
assert_eq!(*number, 42);
assert_eq!(*text, "hello");
}
}