extern crate self as auto_di;
use std::{
any::{Any, TypeId},
collections::HashMap,
future::Future,
marker::PhantomData,
pin::Pin,
sync::{Arc, Mutex, OnceLock},
};
pub use auto_di_macros::{
application, bean, beans, component, configuration, configuration_properties, qualifier,
repository, service, singleton,
};
pub type DynArc = Arc<dyn Any + Send + Sync>;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
type Factory =
for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Scope {
Singleton,
Prototype,
Request,
}
#[doc(hidden)]
pub struct ProviderDescriptor {
type_id: fn() -> TypeId,
type_name: fn() -> &'static str,
factory: Factory,
pub name: Option<&'static str>,
pub primary: bool,
pub scope: Scope,
pub eager: bool,
pub profile: Option<&'static str>,
pub condition_key: Option<&'static str>,
pub condition_value: Option<&'static str>,
pub destroy: Option<Destroy>,
}
impl ProviderDescriptor {
#[doc(hidden)]
pub const fn new(
type_id: fn() -> TypeId,
type_name: fn() -> &'static str,
factory: Factory,
) -> Self {
Self::configured(
type_id,
type_name,
factory,
None,
false,
Scope::Singleton,
false,
None,
None,
None,
None,
)
}
#[allow(clippy::too_many_arguments)]
#[doc(hidden)]
pub const fn configured(
type_id: fn() -> TypeId,
type_name: fn() -> &'static str,
factory: Factory,
name: Option<&'static str>,
primary: bool,
scope: Scope,
eager: bool,
profile: Option<&'static str>,
condition_key: Option<&'static str>,
condition_value: Option<&'static str>,
destroy: Option<Destroy>,
) -> Self {
Self {
type_id,
type_name,
factory,
name,
primary,
scope,
eager,
profile,
condition_key,
condition_value,
destroy,
}
}
fn active(&self, profiles: &[String]) -> bool {
let profile_matches = self
.profile
.is_none_or(|required| profiles.iter().any(|p| p == required));
let condition_matches = self.condition_key.is_none_or(|key| {
let actual = std::env::var(key).ok();
self.condition_value.map_or(actual.is_some(), |expected| {
actual.as_deref() == Some(expected)
})
});
profile_matches && condition_matches
}
}
inventory::collect!(ProviderDescriptor);
type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
#[derive(Clone, Default)]
pub struct ResolutionContext {
chain: Vec<&'static str>,
request_instances: Option<Arc<Mutex<InstanceMap>>>,
}
#[derive(Debug, thiserror::Error)]
pub enum DiError {
#[error("no active provider is registered for {0}")]
MissingProvider(&'static str),
#[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
AmbiguousProvider(&'static str),
#[error("multiple primary providers are registered for {0}")]
MultiplePrimary(&'static str),
#[error("circular dependency detected: {0}")]
CircularDependency(String),
#[error("provider for {0} returned an incompatible type")]
TypeMismatch(&'static str),
#[error("request-scoped bean {0} was resolved outside RequestContext")]
RequestScopeUnavailable(&'static str),
#[error("configuration property {key} is missing or invalid: {message}")]
Configuration { key: String, message: String },
#[error("lifecycle hook failed for {0}")]
Lifecycle(&'static str),
}
pub struct Container {
providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>>,
instances: Mutex<InstanceMap>,
}
static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
pub fn global_container() -> Result<&'static Container, DiError> {
if let Some(container) = GLOBAL_CONTAINER.get() {
return Ok(container);
}
let container = Container::new()?;
let _ = GLOBAL_CONTAINER.set(container);
Ok(GLOBAL_CONTAINER
.get()
.expect("global DI container initialized"))
}
pub async fn resolve<T>() -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
global_container()?.resolve::<T>().await
}
impl Container {
pub fn new() -> Result<Self, DiError> {
let profiles = std::env::var("APP_PROFILES")
.unwrap_or_default()
.split(',')
.map(str::trim)
.filter(|p| !p.is_empty())
.map(str::to_owned)
.collect::<Vec<_>>();
Self::with_profiles(profiles)
}
pub fn with_profiles(
profiles: impl IntoIterator<Item = impl Into<String>>,
) -> Result<Self, DiError> {
let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
let mut providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>> = HashMap::new();
for provider in inventory::iter::<ProviderDescriptor> {
if provider.active(&profiles) {
providers
.entry((provider.type_id)())
.or_default()
.push(provider);
}
}
for group in providers.values() {
if group.iter().filter(|p| p.primary).count() > 1 {
return Err(DiError::MultiplePrimary((group[0].type_name)()));
}
}
Ok(Self {
providers,
instances: Mutex::new(HashMap::new()),
})
}
pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_dependency::<T>(&ResolutionContext::default())
.await
}
pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
.await
}
pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
where
T: Any + Send + Sync,
{
match self.resolve::<T>().await {
Ok(value) => Ok(Some(value)),
Err(DiError::MissingProvider(_)) => Ok(None),
Err(error) => Err(error),
}
}
pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_all_dependency::<T>(&ResolutionContext::default())
.await
}
pub fn request_context(&self) -> RequestContext<'_> {
RequestContext {
container: self,
context: ResolutionContext {
chain: vec![],
request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
},
}
}
pub async fn initialize_eager(&self) -> Result<(), DiError> {
for providers in self.providers.values() {
for provider in providers.iter().filter(|p| p.eager) {
self.resolve_provider(provider, ResolutionContext::default())
.await?;
}
}
Ok(())
}
pub async fn shutdown(&self) -> Result<(), DiError> {
let initialized = self
.instances
.lock()
.expect("DI instance lock poisoned")
.iter()
.filter_map(|(key, cell)| cell.get().map(|value| (*key, value.clone())))
.collect::<Vec<_>>();
for providers in self.providers.values() {
for provider in providers {
let key = provider_key(provider);
if let (Some(destroy), Some((_, value))) = (
provider.destroy,
initialized.iter().find(|(k, _)| *k == key),
) {
destroy(value.clone()).await?;
}
}
}
Ok(())
}
#[doc(hidden)]
pub async fn resolve_dependency<T>(
&self,
context: &ResolutionContext,
) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_selected::<T>(None, context).await
}
#[doc(hidden)]
pub async fn resolve_named_dependency<T>(
&self,
name: &str,
context: &ResolutionContext,
) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_selected::<T>(Some(name), context).await
}
#[doc(hidden)]
pub async fn resolve_optional_dependency<T>(
&self,
context: &ResolutionContext,
) -> Result<Option<Arc<T>>, DiError>
where
T: Any + Send + Sync,
{
match self.resolve_dependency::<T>(context).await {
Ok(value) => Ok(Some(value)),
Err(DiError::MissingProvider(_)) => Ok(None),
Err(error) => Err(error),
}
}
#[doc(hidden)]
pub async fn resolve_all_dependency<T>(
&self,
context: &ResolutionContext,
) -> Result<Vec<Arc<T>>, DiError>
where
T: Any + Send + Sync,
{
let Some(providers) = self.providers.get(&TypeId::of::<T>()) else {
return Ok(vec![]);
};
let mut values = Vec::with_capacity(providers.len());
for provider in providers {
let value = self.resolve_provider(provider, context.clone()).await?;
values.push(
value
.downcast::<T>()
.map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
);
}
Ok(values)
}
async fn resolve_selected<T>(
&self,
name: Option<&str>,
context: &ResolutionContext,
) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
let type_name = std::any::type_name::<T>();
let providers = self
.providers
.get(&TypeId::of::<T>())
.ok_or(DiError::MissingProvider(type_name))?;
let selected = if let Some(name) = name {
providers
.iter()
.copied()
.find(|p| p.name == Some(name))
.ok_or(DiError::MissingProvider(type_name))?
} else if providers.len() == 1 {
providers[0]
} else {
providers
.iter()
.copied()
.find(|p| p.primary)
.ok_or(DiError::AmbiguousProvider(type_name))?
};
let value = self.resolve_provider(selected, context.clone()).await?;
value
.downcast::<T>()
.map_err(|_| DiError::TypeMismatch(type_name))
}
fn resolve_provider<'a>(
&'a self,
provider: &'static ProviderDescriptor,
mut context: ResolutionContext,
) -> BoxFuture<'a, Result<DynArc, DiError>> {
Box::pin(async move {
let type_name = (provider.type_name)();
if context.chain.contains(&type_name) {
context.chain.push(type_name);
return Err(DiError::CircularDependency(context.chain.join(" -> ")));
}
context.chain.push(type_name);
if provider.scope == Scope::Prototype {
return (provider.factory)(self, context).await;
}
let map = match provider.scope {
Scope::Singleton => &self.instances,
Scope::Request => context
.request_instances
.as_deref()
.ok_or(DiError::RequestScopeUnavailable(type_name))?,
Scope::Prototype => unreachable!(),
};
let cell = {
let mut instances = map.lock().expect("DI instance lock poisoned");
instances
.entry(provider_key(provider))
.or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
.clone()
};
let value = cell
.get_or_try_init(|| async move { (provider.factory)(self, context).await })
.await?;
Ok(value.clone())
})
}
}
fn provider_key(provider: &'static ProviderDescriptor) -> usize {
provider as *const ProviderDescriptor as usize
}
pub struct RequestContext<'a> {
container: &'a Container,
context: ResolutionContext,
}
impl RequestContext<'_> {
pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.container.resolve_dependency::<T>(&self.context).await
}
}
pub struct Provider<T> {
_marker: PhantomData<fn() -> T>,
}
impl<T> Clone for Provider<T> {
fn clone(&self) -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T> Default for Provider<T> {
fn default() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T> Provider<T>
where
T: Any + Send + Sync,
{
pub async fn get(&self) -> Result<Arc<T>, DiError> {
resolve::<T>().await
}
}
pub struct Lazy<T> {
cell: tokio::sync::OnceCell<Arc<T>>,
}
impl<T> Default for Lazy<T> {
fn default() -> Self {
Self {
cell: tokio::sync::OnceCell::new(),
}
}
}
impl<T> Lazy<T>
where
T: Any + Send + Sync,
{
pub async fn get(&self) -> Result<&Arc<T>, DiError> {
self.cell
.get_or_try_init(|| async { resolve::<T>().await })
.await
}
}
pub trait ConfigurationProperties: Sized {
fn from_environment() -> Result<Self, DiError>;
}
#[doc(hidden)]
pub mod __private {
pub use inventory;
pub use tokio;
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
static CREATIONS: AtomicUsize = AtomicUsize::new(0);
struct SyncDependency;
#[singleton]
fn sync_dependency() -> SyncDependency {
CREATIONS.fetch_add(1, Ordering::SeqCst);
SyncDependency
}
struct AsyncDependency {
_sync: Arc<SyncDependency>,
}
#[singleton]
async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
AsyncDependency { _sync: sync }
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn singleton_is_concurrent_safe() {
let container = Arc::new(Container::new().unwrap());
let mut tasks = tokio::task::JoinSet::new();
for _ in 0..32 {
let c = container.clone();
tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
}
let mut values = vec![];
while let Some(v) = tasks.join_next().await {
values.push(v.unwrap());
}
assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
}
static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
struct PrototypeBean(usize);
#[component(scope = "prototype")]
fn prototype_bean() -> PrototypeBean {
PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
}
struct RequestBean;
#[component(scope = "request")]
fn request_bean() -> RequestBean {
RequestBean
}
trait Greeting: Send + Sync {
fn text(&self) -> &'static str;
}
struct English;
impl Greeting for English {
fn text(&self) -> &'static str {
"hello"
}
}
struct Hindi;
impl Greeting for Hindi {
fn text(&self) -> &'static str {
"namaste"
}
}
#[repository(name = "english", primary)]
fn english_greeting() -> Arc<dyn Greeting> {
Arc::new(English)
}
#[repository(name = "hindi")]
fn hindi_greeting() -> Arc<dyn Greeting> {
Arc::new(Hindi)
}
struct MissingOptional;
struct Greeter {
greeting: Arc<dyn Greeting>,
optional: Option<Arc<MissingOptional>>,
}
struct GreetingLabel(&'static str);
#[service]
impl Greeter {
fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
Self { greeting, optional }
}
#[bean]
fn label(&self) -> GreetingLabel {
GreetingLabel(self.greeting.text())
}
}
struct QualifiedGreeter(Arc<dyn Greeting>);
#[service]
impl QualifiedGreeter {
fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
Self(greeting)
}
}
static STARTED: AtomicUsize = AtomicUsize::new(0);
static STOPPED: AtomicUsize = AtomicUsize::new(0);
struct Managed;
#[service(eager, post_construct = "start", pre_destroy = "stop")]
impl Managed {
fn new() -> Self {
Self
}
async fn start(&self) {
STARTED.fetch_add(1, Ordering::SeqCst);
}
async fn stop(&self) {
STOPPED.fetch_add(1, Ordering::SeqCst);
}
}
struct ProfileBean;
#[component(profile = "test")]
fn profile_bean() -> ProfileBean {
ProfileBean
}
#[derive(Debug)]
struct Handler(&'static str);
#[component(name = "first")]
fn first_handler() -> Handler {
Handler("first")
}
#[component(name = "second")]
fn second_handler() -> Handler {
Handler("second")
}
struct Pipeline(Vec<Arc<Handler>>);
#[service]
impl Pipeline {
fn new(handlers: Vec<Arc<Handler>>) -> Self {
Self(handlers)
}
}
#[configuration_properties("testing_dep")]
struct TestProperties {
port: u16,
}
struct DeferredTarget;
#[component]
fn deferred_target() -> DeferredTarget {
DeferredTarget
}
struct DeferredConsumer {
provider: Provider<DeferredTarget>,
lazy: Lazy<DeferredTarget>,
}
#[service]
impl DeferredConsumer {
fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
Self { provider, lazy }
}
}
struct ConditionalBean;
#[component(condition = "TESTING_DEP_FEATURE=enabled")]
fn conditional_bean() -> ConditionalBean {
ConditionalBean
}
struct StandaloneBean;
#[bean]
fn standalone_bean() -> StandaloneBean {
StandaloneBean
}
#[derive(Default)]
struct OrdinaryFactory;
struct BeanFromOrdinaryImpl;
#[beans]
impl OrdinaryFactory {
#[bean]
fn dependency(&self) -> BeanFromOrdinaryImpl {
BeanFromOrdinaryImpl
}
}
#[tokio::test]
async fn scopes_traits_primary_profiles_and_lifecycle_work() {
unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
let container = Container::with_profiles(["test"]).unwrap();
let first = container.resolve::<PrototypeBean>().await.unwrap();
let second = container.resolve::<PrototypeBean>().await.unwrap();
assert_ne!(first.0, second.0);
assert!(matches!(
container.resolve::<RequestBean>().await,
Err(DiError::RequestScopeUnavailable(_))
));
let request = container.request_context();
let request_first = request.resolve::<RequestBean>().await.unwrap();
let request_second = request.resolve::<RequestBean>().await.unwrap();
assert!(Arc::ptr_eq(&request_first, &request_second));
let greeter = container.resolve::<Greeter>().await.unwrap();
assert_eq!(greeter.greeting.text(), "hello");
assert_eq!(
container.resolve::<GreetingLabel>().await.unwrap().0,
"hello"
);
assert!(greeter.optional.is_none());
assert_eq!(
container
.resolve::<QualifiedGreeter>()
.await
.unwrap()
.0
.text(),
"namaste"
);
let hindi = container
.resolve_named::<Arc<dyn Greeting>>("hindi")
.await
.unwrap();
assert_eq!(hindi.text(), "namaste");
container.resolve::<ProfileBean>().await.unwrap();
let pipeline = container.resolve::<Pipeline>().await.unwrap();
let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
handler_names.sort_unstable();
assert_eq!(handler_names, ["first", "second"]);
unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
let properties = container.resolve::<TestProperties>().await.unwrap();
assert_eq!(properties.port, 8080);
container.resolve::<ConditionalBean>().await.unwrap();
container.resolve::<StandaloneBean>().await.unwrap();
container.resolve::<BeanFromOrdinaryImpl>().await.unwrap();
let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
deferred.provider.get().await.unwrap();
deferred.lazy.get().await.unwrap();
container.initialize_eager().await.unwrap();
assert_eq!(STARTED.load(Ordering::SeqCst), 1);
container.shutdown().await.unwrap();
assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
}
}