flow-di 0.1.0

A dependency injection framework for Rust inspired by C# AutoFac and Microsoft.Extensions.DependencyInjection
Documentation
use crate::{DiResult, Lifetime, ServiceKey};
use std::any::Any;
use std::sync::Arc;

/// Service factory function type
/// Receives a service resolver and returns a boxed Any object
pub type ServiceFactory =
    Box<dyn Fn(&dyn ServiceProvider) -> DiResult<Box<dyn Any + Send + Sync>> + Send + Sync>;

/// Simplified service provider interface - core trait (object-safe)
pub trait ServiceProvider: Send + Sync {
    /// Get the raw implementation of a service
    fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>>;
}

/// Service provider extension trait - contains generic methods
pub trait ServiceProviderExt: ServiceProvider {
    /// Get a service of the specified type
    fn get_service<T: 'static + Send + Sync>(&self) -> DiResult<Option<Arc<T>>> {
        let key = ServiceKey::of_type::<T>();
        if let Some(any_arc) = self.get_service_raw(&key)? {
            // Use custom conversion method
            self.downcast_arc::<T>(any_arc)
        } else {
            Ok(None)
        }
    }

    /// Get a required service of the specified type
    fn get_required_service<T: 'static + Send + Sync>(&self) -> DiResult<Arc<T>> {
        match self.get_service::<T>()? {
            Some(service) => Ok(service),
            None => Err(crate::DiError::service_not_registered::<T>()),
        }
    }

    /// Get a service of the specified name and type
    fn get_keyed_service<T: 'static + Send + Sync>(&self, key: &str) -> DiResult<Option<Arc<T>>> {
        let service_key = ServiceKey::named::<T>(key);
        if let Some(any_arc) = self.get_service_raw(&service_key)? {
            self.downcast_arc::<T>(any_arc)
        } else {
            Ok(None)
        }
    }

    /// Get a required service of the specified name and type
    fn get_required_keyed_service<T: 'static + Send + Sync>(&self, key: &str) -> DiResult<Arc<T>> {
        match self.get_keyed_service::<T>(key)? {
            Some(service) => Ok(service),
            None => Err(crate::DiError::keyed_service_not_registered::<T>(key)),
        }
    }

    /// Convert Arc<dyn Any> to Arc<T>
    fn downcast_arc<T: 'static + Send + Sync>(
        &self,
        any_arc: Arc<dyn Any + Send + Sync>,
    ) -> DiResult<Option<Arc<T>>> {
        // First try to downcast directly to T
        match any_arc.downcast::<T>() {
            Ok(typed_arc) => Ok(Some(typed_arc)),
            Err(original_arc) => {
                // If failed, try to downcast to Arc<T> (handle double wrapping case)
                match original_arc.downcast::<Arc<T>>() {
                    Ok(arc_of_arc) => Ok(Some((*arc_of_arc).clone())),
                    Err(_) => Err(crate::DiError::type_casting_failed::<T>()),
                }
            }
        }
    }
}

/// Automatically implement ServiceProviderExt for all types that implement ServiceProvider
impl<T: ServiceProvider + ?Sized> ServiceProviderExt for T {}

/// Service descriptor - describes how to create and manage service instances
#[derive(Clone)]
pub struct ServiceDescriptor {
    /// Service key for uniquely identifying services
    pub service_key: ServiceKey,

    /// Service lifetime
    pub lifetime: Lifetime,

    /// Service factory function for creating service instances
    pub factory: Arc<ServiceFactory>,

    /// Service type ID
    pub service_type: std::any::TypeId,

    /// Implementation type ID (may differ from service type)
    pub implementation_type: std::any::TypeId,
}

impl ServiceDescriptor {
    /// Create a new service descriptor
    pub fn new<TService, TImplementation>(
        service_key: ServiceKey,
        lifetime: Lifetime,
        factory: ServiceFactory,
    ) -> Self
    where
        TService: 'static,
        TImplementation: 'static,
    {
        Self {
            service_key,
            lifetime,
            factory: Arc::new(factory),
            service_type: std::any::TypeId::of::<TService>(),
            implementation_type: std::any::TypeId::of::<TImplementation>(),
        }
    }

    /// Create a type-based transient service descriptor
    pub fn transient<TService, TImplementation>(factory: ServiceFactory) -> Self
    where
        TService: 'static,
        TImplementation: 'static,
    {
        Self::new::<TService, TImplementation>(
            ServiceKey::of_type::<TService>(),
            Lifetime::Transient,
            factory,
        )
    }

    /// Create a type-based scoped service descriptor
    pub fn scoped<TService, TImplementation>(factory: ServiceFactory) -> Self
    where
        TService: 'static,
        TImplementation: 'static,
    {
        Self::new::<TService, TImplementation>(
            ServiceKey::of_type::<TService>(),
            Lifetime::Scoped,
            factory,
        )
    }

    /// Create a type-based singleton service descriptor
    pub fn singleton<TService, TImplementation>(factory: ServiceFactory) -> Self
    where
        TService: 'static,
        TImplementation: 'static,
    {
        Self::new::<TService, TImplementation>(
            ServiceKey::of_type::<TService>(),
            Lifetime::Singleton,
            factory,
        )
    }

    /// Create a name-based transient service descriptor
    pub fn named_transient<TService, TImplementation>(
        name: impl Into<String>,
        factory: ServiceFactory,
    ) -> Self
    where
        TService: 'static,
        TImplementation: 'static,
    {
        Self::new::<TService, TImplementation>(
            ServiceKey::named::<TService>(name),
            Lifetime::Transient,
            factory,
        )
    }

    /// Create a name-based scoped service descriptor
    pub fn named_scoped<TService, TImplementation>(
        name: impl Into<String>,
        factory: ServiceFactory,
    ) -> Self
    where
        TService: 'static,
        TImplementation: 'static,
    {
        Self::new::<TService, TImplementation>(
            ServiceKey::named::<TService>(name),
            Lifetime::Scoped,
            factory,
        )
    }

    /// Create a name-based singleton service descriptor
    pub fn named_singleton<TService, TImplementation>(
        name: impl Into<String>,
        factory: ServiceFactory,
    ) -> Self
    where
        TService: 'static,
        TImplementation: 'static,
    {
        Self::new::<TService, TImplementation>(
            ServiceKey::named::<TService>(name),
            Lifetime::Singleton,
            factory,
        )
    }

    /// Create a singleton service descriptor from an instance
    pub fn from_instance<TService>(instance: TService) -> Self
    where
        TService: Send + Sync + 'static,
    {
        let instance = Arc::new(instance);
        Self::singleton::<TService, TService>(Box::new(move |_| {
            Ok(Box::new(Arc::clone(&instance)))
        }))
    }

    /// Create a singleton service descriptor from a named instance
    pub fn from_named_instance<TService>(name: impl Into<String>, instance: TService) -> Self
    where
        TService: Send + Sync + 'static,
    {
        let instance = Arc::new(instance);
        Self::named_singleton::<TService, TService>(
            name,
            Box::new(move |_| Ok(Box::new(Arc::clone(&instance)))),
        )
    }

    /// Check if the service key matches
    pub fn matches_key(&self, key: &ServiceKey) -> bool {
        &self.service_key == key
    }

    /// Check if the service type matches
    pub fn matches_service_type<T: 'static>(&self) -> bool {
        self.service_type == std::any::TypeId::of::<T>()
    }

    /// Create a service instance
    pub fn create_instance(
        &self,
        provider: &dyn ServiceProvider,
    ) -> DiResult<Box<dyn Any + Send + Sync>> {
        (self.factory)(provider)
    }
}

impl std::fmt::Debug for ServiceDescriptor {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ServiceDescriptor")
            .field("service_key", &self.service_key)
            .field("lifetime", &self.lifetime)
            .field("service_type", &self.service_type)
            .field("implementation_type", &self.implementation_type)
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[derive(Clone)]
    #[allow(dead_code)]
    struct TestService {
        value: i32,
    }

    struct MockServiceProvider;

    impl ServiceProvider for MockServiceProvider {
        fn get_service_raw(
            &self,
            _key: &ServiceKey,
        ) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
            Ok(None)
        }
    }

    #[test]
    fn test_descriptor_creation() {
        let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
            Ok(Box::new(TestService { value: 42 }))
        }));

        assert_eq!(descriptor.lifetime, Lifetime::Transient);
        assert!(descriptor.matches_service_type::<TestService>());
        assert!(descriptor.matches_key(&ServiceKey::of_type::<TestService>()));
    }

    #[test]
    fn test_from_instance() {
        let service = TestService { value: 100 };
        let descriptor = ServiceDescriptor::from_instance(service);

        assert_eq!(descriptor.lifetime, Lifetime::Singleton);
        assert!(descriptor.matches_service_type::<TestService>());

        let provider = MockServiceProvider;
        let result = descriptor.create_instance(&provider);
        assert!(result.is_ok());
    }

    #[test]
    fn test_named_descriptor() {
        let descriptor = ServiceDescriptor::named_scoped::<TestService, TestService>(
            "test-service",
            Box::new(|_| Ok(Box::new(TestService { value: 200 }))),
        );

        assert_eq!(descriptor.lifetime, Lifetime::Scoped);
        assert!(descriptor.matches_key(&ServiceKey::named::<TestService>("test-service")));
        assert!(!descriptor.matches_key(&ServiceKey::of_type::<TestService>()));
    }
}