use crate::container::autowiring::Injectable;
use crate::container::descriptor::{ServiceDescriptor, ServiceDescriptorFactoryBuilder, ServiceId};
use crate::container::scope::ServiceScope;
use crate::errors::CoreError;
pub type ConditionFn = Box<dyn Fn() -> bool + Send + Sync>;
pub type EnvCondition = (&'static str, String);
pub struct BindingConfig {
pub name: Option<String>,
pub lifetime: ServiceScope,
pub env_conditions: Vec<EnvCondition>,
pub feature_conditions: Vec<(String, bool)>,
pub conditions: Vec<ConditionFn>,
pub is_default: bool,
pub profile_conditions: Vec<String>,
}
impl Default for BindingConfig {
fn default() -> Self {
Self::new()
}
}
impl BindingConfig {
pub fn new() -> Self {
Self {
name: None,
lifetime: ServiceScope::Transient,
env_conditions: Vec::new(),
feature_conditions: Vec::new(),
conditions: Vec::new(),
is_default: false,
profile_conditions: Vec::new(),
}
}
pub fn evaluate_conditions(&self) -> bool {
for (key, expected_value) in &self.env_conditions {
if let Ok(actual_value) = std::env::var(key) {
if actual_value != *expected_value {
return false;
}
} else {
return false;
}
}
for (feature, expected) in &self.feature_conditions {
let feature_enabled =
std::env::var(format!("FEATURE_{}", feature.to_uppercase())).is_ok();
if feature_enabled != *expected {
return false;
}
}
if !self.profile_conditions.is_empty() {
let current_profile =
std::env::var("PROFILE").unwrap_or_else(|_| "development".to_string());
if !self.profile_conditions.contains(¤t_profile) {
return false;
}
}
for condition in &self.conditions {
if !condition() {
return false;
}
}
true
}
}
pub struct AdvancedBindingBuilder<TInterface: ?Sized + 'static> {
config: BindingConfig,
_phantom: std::marker::PhantomData<*const TInterface>,
}
impl<TInterface: ?Sized + 'static> Default for AdvancedBindingBuilder<TInterface> {
fn default() -> Self {
Self::new()
}
}
impl<TInterface: ?Sized + 'static> AdvancedBindingBuilder<TInterface> {
pub fn new() -> Self {
Self {
config: BindingConfig::new(),
_phantom: std::marker::PhantomData,
}
}
pub fn named(mut self, name: impl Into<String>) -> Self {
self.config.name = Some(name.into());
self
}
pub fn with_lifetime(mut self, lifetime: ServiceScope) -> Self {
self.config.lifetime = lifetime;
self
}
pub fn when_env(mut self, key: &'static str, value: impl Into<String>) -> Self {
self.config.env_conditions.push((key, value.into()));
self
}
pub fn when_feature(mut self, feature: impl Into<String>) -> Self {
self.config.feature_conditions.push((feature.into(), true));
self
}
pub fn when_not_feature(mut self, feature: impl Into<String>) -> Self {
self.config.feature_conditions.push((feature.into(), false));
self
}
pub fn when<F>(mut self, condition: F) -> Self
where
F: Fn() -> bool + Send + Sync + 'static,
{
self.config.conditions.push(Box::new(condition));
self
}
pub fn as_default(mut self) -> Self {
self.config.is_default = true;
self
}
pub fn in_profile(mut self, profile: impl Into<String>) -> Self {
self.config.profile_conditions.push(profile.into());
self
}
pub fn config(self) -> BindingConfig {
self.config
}
}
pub trait ServiceBinder {
fn add_service_descriptor(
&mut self,
descriptor: crate::container::descriptor::ServiceDescriptor,
) -> Result<&mut Self, crate::errors::CoreError>;
fn bind<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
) -> &mut Self;
fn bind_singleton<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
) -> &mut Self;
fn bind_transient<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
) -> &mut Self;
fn bind_factory<TInterface: ?Sized + 'static, F, T>(&mut self, factory: F) -> &mut Self
where
F: Fn() -> Result<T, CoreError> + Send + Sync + 'static,
T: Send + Sync + 'static;
fn bind_instance<TInterface: ?Sized + 'static, TImpl: Send + Sync + Clone + 'static>(
&mut self,
instance: TImpl,
) -> &mut Self;
fn bind_named<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
name: &str,
) -> &mut Self;
fn bind_injectable<T: Injectable>(&mut self) -> &mut Self;
fn bind_injectable_singleton<T: Injectable>(&mut self) -> &mut Self;
fn bind_with<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
) -> AdvancedBindingBuilder<TInterface>;
fn with_implementation<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
config: BindingConfig,
) -> &mut Self;
fn bind_lazy<TInterface: ?Sized + 'static, F, T>(&mut self, factory: F) -> &mut Self
where
F: Fn() -> T + Send + Sync + 'static,
T: Send + Sync + 'static;
fn bind_parameterized_factory<TInterface: ?Sized + 'static, P, F, T>(
&mut self,
factory: F,
) -> &mut Self
where
F: Fn(P) -> Result<T, CoreError> + Send + Sync + 'static,
T: Send + Sync + 'static,
P: Send + Sync + 'static;
fn bind_collection<TInterface: ?Sized + 'static, F>(&mut self, configure: F) -> &mut Self
where
F: FnOnce(&mut CollectionBindingBuilder<TInterface>);
}
pub struct CollectionBindingBuilder<TInterface: ?Sized + 'static> {
services: Vec<ServiceDescriptor>,
_phantom: std::marker::PhantomData<*const TInterface>,
}
impl<TInterface: ?Sized + 'static> Default for CollectionBindingBuilder<TInterface> {
fn default() -> Self {
Self::new()
}
}
impl<TInterface: ?Sized + 'static> CollectionBindingBuilder<TInterface> {
pub fn new() -> Self {
Self {
services: Vec::new(),
_phantom: std::marker::PhantomData,
}
}
pub fn add<TImpl: Send + Sync + Default + 'static>(&mut self) -> &mut Self {
let descriptor = ServiceDescriptor::bind::<TInterface, TImpl>()
.with_lifetime(ServiceScope::Transient)
.build();
self.services.push(descriptor);
self
}
pub fn add_named<TImpl: Send + Sync + Default + 'static>(
&mut self,
name: impl Into<String>,
) -> &mut Self {
let descriptor = ServiceDescriptor::bind_named::<TInterface, TImpl>(name)
.with_lifetime(ServiceScope::Transient)
.build();
self.services.push(descriptor);
self
}
pub fn add_singleton<TImpl: Send + Sync + Default + 'static>(&mut self) -> &mut Self {
let descriptor = ServiceDescriptor::bind::<TInterface, TImpl>()
.with_lifetime(ServiceScope::Singleton)
.build();
self.services.push(descriptor);
self
}
pub fn add_named_singleton<TImpl: Send + Sync + Default + 'static>(
&mut self,
name: impl Into<String>,
) -> &mut Self {
let descriptor = ServiceDescriptor::bind_named::<TInterface, TImpl>(name)
.with_lifetime(ServiceScope::Singleton)
.build();
self.services.push(descriptor);
self
}
pub(crate) fn into_services(self) -> Vec<ServiceDescriptor> {
self.services
}
}
#[derive(Debug)]
pub struct ServiceBindings {
descriptors: Vec<ServiceDescriptor>,
}
impl ServiceBindings {
pub fn new() -> Self {
Self {
descriptors: Vec::new(),
}
}
pub fn add_descriptor(&mut self, descriptor: ServiceDescriptor) {
self.descriptors.push(descriptor);
}
pub fn descriptors(&self) -> &[ServiceDescriptor] {
&self.descriptors
}
pub fn get_descriptor(&self, service_id: &ServiceId) -> Option<&ServiceDescriptor> {
self.descriptors
.iter()
.find(|d| d.service_id == *service_id)
}
pub fn get_descriptor_named<T: 'static + ?Sized>(
&self,
name: &str,
) -> Option<&ServiceDescriptor> {
self.descriptors
.iter()
.find(|d| d.service_id.matches_named::<T>(name))
}
pub fn service_ids(&self) -> Vec<ServiceId> {
self.descriptors
.iter()
.map(|d| d.service_id.clone())
.collect()
}
pub fn contains(&self, service_id: &ServiceId) -> bool {
self.descriptors.iter().any(|d| d.service_id == *service_id)
}
pub fn contains_named<T: 'static + ?Sized>(&self, name: &str) -> bool {
self.descriptors
.iter()
.any(|d| d.service_id.matches_named::<T>(name))
}
pub fn count(&self) -> usize {
self.descriptors.len()
}
pub fn into_descriptors(self) -> Vec<ServiceDescriptor> {
self.descriptors
}
}
impl Default for ServiceBindings {
fn default() -> Self {
Self::new()
}
}
impl ServiceBinder for ServiceBindings {
fn add_service_descriptor(
&mut self,
descriptor: crate::container::descriptor::ServiceDescriptor,
) -> Result<&mut Self, crate::errors::CoreError> {
self.add_descriptor(descriptor);
Ok(self)
}
fn bind<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
) -> &mut Self {
let descriptor = ServiceDescriptor::bind::<TInterface, TImpl>()
.with_lifetime(ServiceScope::Transient)
.build();
self.add_descriptor(descriptor);
self
}
fn bind_singleton<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
) -> &mut Self {
let descriptor = ServiceDescriptor::bind::<TInterface, TImpl>()
.with_lifetime(ServiceScope::Singleton)
.build();
self.add_descriptor(descriptor);
self
}
fn bind_transient<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
) -> &mut Self {
let descriptor = ServiceDescriptor::bind::<TInterface, TImpl>()
.with_lifetime(ServiceScope::Transient)
.build();
self.add_descriptor(descriptor);
self
}
fn bind_factory<TInterface: ?Sized + 'static, F, T>(&mut self, factory: F) -> &mut Self
where
F: Fn() -> Result<T, CoreError> + Send + Sync + 'static,
T: Send + Sync + 'static,
{
let descriptor = ServiceDescriptorFactoryBuilder::<TInterface>::new()
.with_factory(factory)
.build()
.expect("Failed to build factory descriptor");
self.add_descriptor(descriptor);
self
}
fn bind_instance<TInterface: ?Sized + 'static, TImpl: Send + Sync + Clone + 'static>(
&mut self,
instance: TImpl,
) -> &mut Self {
let descriptor = ServiceDescriptorFactoryBuilder::<TInterface>::new()
.with_lifetime(ServiceScope::Singleton)
.with_factory({
let instance = instance.clone();
move || -> Result<TImpl, CoreError> { Ok(instance.clone()) }
})
.build()
.expect("Failed to build instance descriptor");
self.add_descriptor(descriptor);
self
}
fn bind_named<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
name: &str,
) -> &mut Self {
let descriptor = ServiceDescriptor::bind_named::<TInterface, TImpl>(name)
.with_lifetime(ServiceScope::Transient)
.build();
self.add_descriptor(descriptor);
self
}
fn bind_injectable<T: Injectable>(&mut self) -> &mut Self {
let dependencies = T::dependencies();
let descriptor = ServiceDescriptor::autowired::<T>(dependencies);
self.add_descriptor(descriptor);
self
}
fn bind_injectable_singleton<T: Injectable>(&mut self) -> &mut Self {
let dependencies = T::dependencies();
let descriptor = ServiceDescriptor::autowired_singleton::<T>(dependencies);
self.add_descriptor(descriptor);
self
}
fn bind_with<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
) -> AdvancedBindingBuilder<TInterface> {
AdvancedBindingBuilder::new()
}
fn with_implementation<TInterface: ?Sized + 'static, TImpl: Send + Sync + Default + 'static>(
&mut self,
config: BindingConfig,
) -> &mut Self {
if config.evaluate_conditions() {
let mut builder = if let Some(name) = &config.name {
ServiceDescriptor::bind_named::<TInterface, TImpl>(name.clone())
} else {
ServiceDescriptor::bind::<TInterface, TImpl>()
};
builder = builder.with_lifetime(config.lifetime);
let descriptor = builder.build();
self.add_descriptor(descriptor);
}
self
}
fn bind_lazy<TInterface: ?Sized + 'static, F, T>(&mut self, factory: F) -> &mut Self
where
F: Fn() -> T + Send + Sync + 'static,
T: Send + Sync + 'static,
{
let lazy_factory = move || -> Result<T, CoreError> { Ok(factory()) };
let descriptor = ServiceDescriptorFactoryBuilder::<TInterface>::new()
.with_factory(lazy_factory)
.build()
.expect("Failed to build lazy factory descriptor");
self.add_descriptor(descriptor);
self
}
fn bind_parameterized_factory<TInterface: ?Sized + 'static, P, F, T>(
&mut self,
_factory: F,
) -> &mut Self
where
F: Fn(P) -> Result<T, CoreError> + Send + Sync + 'static,
T: Send + Sync + 'static,
P: Send + Sync + 'static,
{
let descriptor = ServiceDescriptorFactoryBuilder::<TInterface>::new()
.with_factory(move || -> Result<T, CoreError> {
Err(CoreError::ServiceNotFound {
service_type: format!(
"Parameterized factory for {} requires runtime parameter resolution",
std::any::type_name::<TInterface>()
),
})
})
.build()
.expect("Failed to build parameterized factory descriptor");
self.add_descriptor(descriptor);
self
}
fn bind_collection<TInterface: ?Sized + 'static, F>(&mut self, configure: F) -> &mut Self
where
F: FnOnce(&mut CollectionBindingBuilder<TInterface>),
{
let mut builder = CollectionBindingBuilder::new();
configure(&mut builder);
let services = builder.into_services();
for service in services {
self.add_descriptor(service);
}
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[allow(dead_code)]
trait TestRepository: Send + Sync {
fn find(&self, id: u32) -> Option<String>;
}
#[derive(Default)]
struct PostgresRepository;
unsafe impl Send for PostgresRepository {}
unsafe impl Sync for PostgresRepository {}
impl TestRepository for PostgresRepository {
fn find(&self, _id: u32) -> Option<String> {
Some("postgres".to_string())
}
}
#[allow(dead_code)]
trait TestService: Send + Sync {
fn get_data(&self) -> String;
}
#[derive(Default)]
struct UserService;
unsafe impl Send for UserService {}
unsafe impl Sync for UserService {}
impl TestService for UserService {
fn get_data(&self) -> String {
"user_data".to_string()
}
}
#[test]
fn test_service_bindings() {
let mut bindings = ServiceBindings::new();
bindings
.bind::<PostgresRepository, PostgresRepository>()
.bind_singleton::<UserService, UserService>()
.bind_named::<PostgresRepository, PostgresRepository>("postgres");
assert_eq!(bindings.count(), 3);
let service_ids = bindings.service_ids();
assert_eq!(service_ids.len(), 3);
assert!(bindings.contains(&ServiceId::of::<PostgresRepository>()));
assert!(bindings.contains(&ServiceId::of::<UserService>()));
assert!(bindings.contains(&ServiceId::named::<PostgresRepository>("postgres")));
}
#[test]
fn test_factory_binding() {
let mut bindings = ServiceBindings::new();
bindings.bind_factory::<UserService, _, _>(|| Ok(UserService::default()));
assert_eq!(bindings.count(), 1);
assert!(bindings.contains(&ServiceId::of::<UserService>()));
}
#[test]
#[serial]
fn test_advanced_binding_with_environment_conditions() {
let mut bindings = ServiceBindings::new();
std::env::set_var("CACHE_PROVIDER", "redis");
let config = AdvancedBindingBuilder::<dyn TestRepository>::new()
.named("redis")
.when_env("CACHE_PROVIDER", "redis")
.with_lifetime(ServiceScope::Singleton)
.config();
bindings.with_implementation::<dyn TestRepository, PostgresRepository>(config);
assert_eq!(bindings.count(), 1);
assert!(bindings.contains_named::<dyn TestRepository>("redis"));
std::env::remove_var("CACHE_PROVIDER");
}
#[test]
fn test_conditional_binding_not_met() {
let mut bindings = ServiceBindings::new();
let config = AdvancedBindingBuilder::<dyn TestRepository>::new()
.named("nonexistent")
.when_env("NON_EXISTENT_VAR", "value")
.config();
bindings.with_implementation::<dyn TestRepository, PostgresRepository>(config);
assert_eq!(bindings.count(), 0);
}
#[test]
#[serial]
fn test_feature_flag_conditions() {
let mut bindings = ServiceBindings::new();
std::env::set_var("FEATURE_ADVANCED_CACHE", "1");
let config = AdvancedBindingBuilder::<dyn TestRepository>::new()
.when_feature("advanced_cache")
.config();
bindings.with_implementation::<dyn TestRepository, PostgresRepository>(config);
assert_eq!(bindings.count(), 1);
std::env::remove_var("FEATURE_ADVANCED_CACHE");
}
#[test]
#[serial]
fn test_profile_conditions() {
let mut bindings = ServiceBindings::new();
std::env::set_var("PROFILE", "development");
let config = AdvancedBindingBuilder::<dyn TestService>::new()
.in_profile("development")
.config();
bindings.with_implementation::<dyn TestService, UserService>(config);
assert_eq!(bindings.count(), 1);
std::env::set_var("PROFILE", "production");
let config2 = AdvancedBindingBuilder::<dyn TestRepository>::new()
.in_profile("development")
.config();
bindings.with_implementation::<dyn TestRepository, PostgresRepository>(config2);
assert_eq!(bindings.count(), 1);
std::env::remove_var("PROFILE");
}
#[test]
fn test_custom_conditions() {
let mut bindings = ServiceBindings::new();
let config = AdvancedBindingBuilder::<dyn TestService>::new()
.when(|| true) .config();
bindings.with_implementation::<dyn TestService, UserService>(config);
assert_eq!(bindings.count(), 1);
let config2 = AdvancedBindingBuilder::<dyn TestRepository>::new()
.when(|| false) .config();
bindings.with_implementation::<dyn TestRepository, PostgresRepository>(config2);
assert_eq!(bindings.count(), 1);
}
#[test]
fn test_lazy_binding() {
let mut bindings = ServiceBindings::new();
bindings.bind_lazy::<UserService, _, _>(|| UserService::default());
assert_eq!(bindings.count(), 1);
assert!(bindings.contains(&ServiceId::of::<UserService>()));
}
#[test]
fn test_collection_binding() {
let mut bindings = ServiceBindings::new();
bindings.bind_collection::<dyn TestService, _>(|collection| {
collection
.add::<UserService>()
.add_named::<UserService>("named_user_service");
});
assert_eq!(bindings.count(), 2);
assert!(bindings.contains(&ServiceId::of::<dyn TestService>()));
assert!(bindings.contains(&ServiceId::named::<dyn TestService>("named_user_service")));
}
#[test]
#[serial]
fn test_multiple_conditions() {
let mut bindings = ServiceBindings::new();
std::env::set_var("ENV_VAR", "test_value");
std::env::set_var("FEATURE_TEST", "1");
std::env::set_var("PROFILE", "test");
let config = AdvancedBindingBuilder::<dyn TestService>::new()
.when_env("ENV_VAR", "test_value")
.when_feature("test")
.in_profile("test")
.when(|| true)
.named("complex_service")
.with_lifetime(ServiceScope::Singleton)
.config();
bindings.with_implementation::<dyn TestService, UserService>(config);
assert_eq!(bindings.count(), 1);
assert!(bindings.contains_named::<dyn TestService>("complex_service"));
std::env::remove_var("ENV_VAR");
std::env::remove_var("FEATURE_TEST");
std::env::remove_var("PROFILE");
}
}