injection 0.1.1

A lightweight dependency injection container for Rust applications
Documentation
use crate::config::GlobalConfig;
use crate::definition::{InjectionBean, InjectionBeanDefinition};
use crate::utils::json_merge_patch;
use crate::{Component, Configure};
use std::any::{Any, TypeId};
/// 仅在启用 "config" 特性时使用 HashMap
#[cfg(feature = "config")]
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::OnceCell;

/// 实例键类型,由类型 ID 和组件名称组成
pub type InstanceKey = (TypeId, &'static str);
/// 全局容器实例,使用 OnceCell 确保线程安全的延迟初始化
pub(crate) static CONTAINER: OnceCell<Arc<InjectionContainer>> = OnceCell::const_new();

/// 依赖注入容器
/// 负责管理所有已注册组件的生命周期和依赖关系
pub struct InjectionContainer {
    /// 组件注册表,存储所有已注册的 Bean 实例
    registry: HashMap<InstanceKey, Arc<InjectionBean>>,
    /// 全局配置对象
    config: GlobalConfig,
}

impl InjectionContainer {
    /// 创建新的容器实例
    /// 会自动扫描并注册所有通过 inventory 宏收集的 Bean 定义
    fn new() -> Self {
        let mut registry = HashMap::new();
        // 遍历所有已注册的 Bean 定义,创建对应的 Bean 实例包装器
        for definition in inventory::iter::<InjectionBeanDefinition> {
            registry.insert(
                (definition.type_id, definition.name),
                Arc::new(InjectionBean::new(definition.init_fn)),
            );
        }

        Self {
            registry,
            config: GlobalConfig::new(),
        }
    }

    /// 初始化容器(异步)
    /// 这是容器的入口点,必须在应用程序启动时调用
    /// 会创建全局单例容器并预热所有 Bean
    pub async fn init<F: FnOnce()>(init_fn: impl Into<Option<F>>) {
        let container = CONTAINER
            .get_or_init(|| async { Arc::new(InjectionContainer::new()) })
            .await;
        // 执行可选的初始化回调函数
        if let Some(f) =init_fn.into(){
            f();
        }
        // 预热所有 Bean,确保它们在首次请求前已经初始化完成
        container.prewarm_all().await;
    }

    /// 获取容器实例(同步方法)
    /// 返回全局单例容器的 Arc 引用
    /// 注意:必须先调用 init() 初始化容器才能使用此方法
    pub fn instance() -> &'static Arc<InjectionContainer> {
        CONTAINER
            .get()
            .expect("Container not initialized. Call InjectionContainer::init() first in async context.")
    }

    /// 获取配置(同步方法)
    /// 仅在同时启用 "config" 和 "serde" 特性时可用
    /// 自动从配置文件中加载对应前缀的配置数据
    #[cfg(all(feature = "config", feature = "serde"))]
    pub fn get_config<T: Configure>() -> T {
        Self::get_config_by_key(T::prefix())
    }

    /// 根据键获取配置(同步方法)
    /// 仅在同时启用 "config" 和 "serde" 特性时可用
    /// 支持 JSON Merge Patch 合并策略,将配置文件中的值与默认值合并
    #[cfg(all(feature = "config", feature = "serde"))]
    pub fn get_config_by_key<T: Configure>(key: &str) -> T {
        let default = T::default();
        // 尝试从配置中获取指定键的值
        if let Ok( config) = Self::instance().config.get(key){
            // 将默认值和配置值转换为 JSON 进行合并
            if let Ok(mut default_value) = serde_json::to_value(&default) && let Ok(config_value) = serde_json::to_value(&config){
                json_merge_patch(&mut default_value, &config_value);
                // 将合并后的 JSON 转换回目标类型
                if let Ok(merge_value) = serde_json::from_value(default_value) {
                    return merge_value;
                }
            }
            config
        }else{
            // 如果配置中不存在,返回默认值
            default
        }
    }

    /// 获取命名组件(异步版本)
    /// 从容器中获取指定类型的命名组件实例
    /// 如果组件尚未初始化,会等待其初始化完成
    async fn get_named_component_async<T: Component>(&self, name: &str) -> &'static T {
        let entry = self.get_injection_bean(TypeId::of::<T>(), name);
        let injection_instance = entry.get().await;
        injection_instance
            .downcast_ref()
            .expect("Bean can not downcast.")
    }
    
    /// 获取 Bean 实例的内部引用
    /// 根据类型 ID 和名称从注册表中查找对应的 Bean
    fn get_injection_bean<'a>(&'a self, type_id: TypeId, name: &'a str) -> &'a Arc<InjectionBean>{
        self
            .registry
            .get(&(type_id, name))
            .expect(&format!(
                "No dependency entry found for with name {}",
                name
            ))
    }
    
    /// 获取命名组件(同步版本)
    /// 从容器中获取指定类型的命名组件实例
    /// 注意:此方法要求组件已经初始化完成,否则会 panic
    fn get_named_component<T: Component>(&self, name: &str) -> &'static T {
        let entry = self.get_injection_bean(TypeId::of::<T>(), name);
        let injection_instance = entry.get_inner();
        injection_instance
            .downcast_ref()
            .expect("Bean can not downcast.")
    }

    /// 获取组件(公共静态方法)
    /// 从全局容器中获取指定类型的组件实例
    /// 这是同步版本,要求组件已经完成初始化
    pub fn get<T: Component>() -> &'static T {
        Self::instance().resolve()
    }

    /// 获取组件(异步公共静态方法)
    /// 从全局容器中获取指定类型的组件实例
    /// 这是异步版本,可以等待未初始化的组件完成初始化
    pub async fn get_async<T: Component>() -> &'static T {
        Self::instance().resolve_async().await
    }
    
    /// 解析组件(同步方法)
    /// 内部方法,根据组件名称解析并返回组件实例
    fn resolve<T: Component>(&self) -> &'static T {
        self.get_named_component(T::component_name())
    }

    /// 解析组件(异步方法)
    /// 内部方法,根据组件名称解析并返回组件实例
    /// 如果组件未初始化,会等待其初始化完成
    async fn resolve_async<T: Component>(&self) -> &'static T {
        self.get_named_component_async(T::component_name()).await
    }

    /// 预热所有 Bean(异步方法,仅在初始化时调用)
    /// 并发启动所有 Bean 的初始化过程
    /// 每个 Bean 都会在独立的 tokio 任务中初始化,并记录调试日志
    async fn prewarm_all(&self) {
        for (key, value) in self.registry.iter() {
            let value_clone = value.clone();
            let key_name = key.1;
            tokio::spawn(async move {
                log::debug!("Prewarming bean start: {}", key_name);
                let _ = value_clone.get().await;
                log::debug!("Prewarming bean end: {}", key_name);
            });
        }
    }
}

impl Default for InjectionContainer {
    fn default() -> Self {
        Self::new()
    }
}

/// ArcAnyExt trait 扩展
/// 提供将 Arc<dyn Any + Send + Sync> 转换为具体类型的 Arc 的能力
pub trait ArcAnyExt {
    /// 尝试将 Arc 下转为指定类型 T
    /// 如果类型匹配,返回 Some(Arc<T>),否则返回 None
    fn downcast_arc<T: Any + Send + Sync>(self: Arc<Self>) -> Option<Arc<T>>
    where
        Self: Send + Sync + 'static;
}

impl ArcAnyExt for dyn Any + Send + Sync {
    /// 实现 Arc 的下转操作
    /// 使用 unsafe 代码避免不必要的运行时检查
    fn downcast_arc<T: Any + Send + Sync>(self: Arc<Self>) -> Option<Arc<T>>
    where
        Self: Send + Sync + 'static,
    {
        if self.is::<T>() {
            // 类型匹配,将原始指针转换为目标类型并重新包装为 Arc
            let raw = Arc::into_raw(self) as *const T;
            Some(unsafe { Arc::from_raw(raw) })
        } else {
            None
        }
    }
}