use crate::{DiResult, Lifetime, ServiceKey};
use std::any::Any;
use std::sync::Arc;
pub type ServiceFactory =
Box<dyn Fn(&dyn ServiceProvider) -> DiResult<Box<dyn Any + Send + Sync>> + Send + Sync>;
pub trait ServiceProvider: Send + Sync {
fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>>;
}
pub trait ServiceProviderExt: ServiceProvider {
fn get_service<T: 'static + Send + Sync>(&self) -> DiResult<Option<Arc<T>>> {
let key = ServiceKey::of_type::<T>();
if let Some(any_arc) = self.get_service_raw(&key)? {
self.downcast_arc::<T>(any_arc)
} else {
Ok(None)
}
}
fn get_required_service<T: 'static + Send + Sync>(&self) -> DiResult<Arc<T>> {
match self.get_service::<T>()? {
Some(service) => Ok(service),
None => Err(crate::DiError::service_not_registered::<T>()),
}
}
fn get_keyed_service<T: 'static + Send + Sync>(&self, key: &str) -> DiResult<Option<Arc<T>>> {
let service_key = ServiceKey::named::<T>(key);
if let Some(any_arc) = self.get_service_raw(&service_key)? {
self.downcast_arc::<T>(any_arc)
} else {
Ok(None)
}
}
fn get_required_keyed_service<T: 'static + Send + Sync>(&self, key: &str) -> DiResult<Arc<T>> {
match self.get_keyed_service::<T>(key)? {
Some(service) => Ok(service),
None => Err(crate::DiError::keyed_service_not_registered::<T>(key)),
}
}
fn downcast_arc<T: 'static + Send + Sync>(
&self,
any_arc: Arc<dyn Any + Send + Sync>,
) -> DiResult<Option<Arc<T>>> {
match any_arc.downcast::<T>() {
Ok(typed_arc) => Ok(Some(typed_arc)),
Err(original_arc) => {
match original_arc.downcast::<Arc<T>>() {
Ok(arc_of_arc) => Ok(Some((*arc_of_arc).clone())),
Err(_) => Err(crate::DiError::type_casting_failed::<T>()),
}
}
}
}
}
impl<T: ServiceProvider + ?Sized> ServiceProviderExt for T {}
#[derive(Clone)]
pub struct ServiceDescriptor {
pub service_key: ServiceKey,
pub lifetime: Lifetime,
pub factory: Arc<ServiceFactory>,
pub service_type: std::any::TypeId,
pub implementation_type: std::any::TypeId,
}
impl ServiceDescriptor {
pub fn new<TService, TImplementation>(
service_key: ServiceKey,
lifetime: Lifetime,
factory: ServiceFactory,
) -> Self
where
TService: 'static,
TImplementation: 'static,
{
Self {
service_key,
lifetime,
factory: Arc::new(factory),
service_type: std::any::TypeId::of::<TService>(),
implementation_type: std::any::TypeId::of::<TImplementation>(),
}
}
pub fn transient<TService, TImplementation>(factory: ServiceFactory) -> Self
where
TService: 'static,
TImplementation: 'static,
{
Self::new::<TService, TImplementation>(
ServiceKey::of_type::<TService>(),
Lifetime::Transient,
factory,
)
}
pub fn scoped<TService, TImplementation>(factory: ServiceFactory) -> Self
where
TService: 'static,
TImplementation: 'static,
{
Self::new::<TService, TImplementation>(
ServiceKey::of_type::<TService>(),
Lifetime::Scoped,
factory,
)
}
pub fn singleton<TService, TImplementation>(factory: ServiceFactory) -> Self
where
TService: 'static,
TImplementation: 'static,
{
Self::new::<TService, TImplementation>(
ServiceKey::of_type::<TService>(),
Lifetime::Singleton,
factory,
)
}
pub fn named_transient<TService, TImplementation>(
name: impl Into<String>,
factory: ServiceFactory,
) -> Self
where
TService: 'static,
TImplementation: 'static,
{
Self::new::<TService, TImplementation>(
ServiceKey::named::<TService>(name),
Lifetime::Transient,
factory,
)
}
pub fn named_scoped<TService, TImplementation>(
name: impl Into<String>,
factory: ServiceFactory,
) -> Self
where
TService: 'static,
TImplementation: 'static,
{
Self::new::<TService, TImplementation>(
ServiceKey::named::<TService>(name),
Lifetime::Scoped,
factory,
)
}
pub fn named_singleton<TService, TImplementation>(
name: impl Into<String>,
factory: ServiceFactory,
) -> Self
where
TService: 'static,
TImplementation: 'static,
{
Self::new::<TService, TImplementation>(
ServiceKey::named::<TService>(name),
Lifetime::Singleton,
factory,
)
}
pub fn from_instance<TService>(instance: TService) -> Self
where
TService: Send + Sync + 'static,
{
let instance = Arc::new(instance);
Self::singleton::<TService, TService>(Box::new(move |_| {
Ok(Box::new(Arc::clone(&instance)))
}))
}
pub fn from_named_instance<TService>(name: impl Into<String>, instance: TService) -> Self
where
TService: Send + Sync + 'static,
{
let instance = Arc::new(instance);
Self::named_singleton::<TService, TService>(
name,
Box::new(move |_| Ok(Box::new(Arc::clone(&instance)))),
)
}
pub fn matches_key(&self, key: &ServiceKey) -> bool {
&self.service_key == key
}
pub fn matches_service_type<T: 'static>(&self) -> bool {
self.service_type == std::any::TypeId::of::<T>()
}
pub fn create_instance(
&self,
provider: &dyn ServiceProvider,
) -> DiResult<Box<dyn Any + Send + Sync>> {
(self.factory)(provider)
}
}
impl std::fmt::Debug for ServiceDescriptor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServiceDescriptor")
.field("service_key", &self.service_key)
.field("lifetime", &self.lifetime)
.field("service_type", &self.service_type)
.field("implementation_type", &self.implementation_type)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
#[allow(dead_code)]
struct TestService {
value: i32,
}
struct MockServiceProvider;
impl ServiceProvider for MockServiceProvider {
fn get_service_raw(
&self,
_key: &ServiceKey,
) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
Ok(None)
}
}
#[test]
fn test_descriptor_creation() {
let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
Ok(Box::new(TestService { value: 42 }))
}));
assert_eq!(descriptor.lifetime, Lifetime::Transient);
assert!(descriptor.matches_service_type::<TestService>());
assert!(descriptor.matches_key(&ServiceKey::of_type::<TestService>()));
}
#[test]
fn test_from_instance() {
let service = TestService { value: 100 };
let descriptor = ServiceDescriptor::from_instance(service);
assert_eq!(descriptor.lifetime, Lifetime::Singleton);
assert!(descriptor.matches_service_type::<TestService>());
let provider = MockServiceProvider;
let result = descriptor.create_instance(&provider);
assert!(result.is_ok());
}
#[test]
fn test_named_descriptor() {
let descriptor = ServiceDescriptor::named_scoped::<TestService, TestService>(
"test-service",
Box::new(|_| Ok(Box::new(TestService { value: 200 }))),
);
assert_eq!(descriptor.lifetime, Lifetime::Scoped);
assert!(descriptor.matches_key(&ServiceKey::named::<TestService>("test-service")));
assert!(!descriptor.matches_key(&ServiceKey::of_type::<TestService>()));
}
}