use std::{
any::{Any, TypeId},
collections::{HashMap, HashSet, VecDeque},
future::Future,
pin::Pin,
sync::{
Arc, Mutex, OnceLock,
atomic::{AtomicBool, Ordering},
},
};
pub type DynArc = Arc<dyn Any + Send + Sync>;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
type Factory =
for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Scope {
Singleton,
Prototype,
Request,
}
#[doc(hidden)]
pub struct ProviderDescriptor {
type_id: fn() -> TypeId,
type_name: fn() -> &'static str,
factory: Factory,
pub name: Option<&'static str>,
pub primary: bool,
pub scope: Scope,
pub eager: bool,
pub profile: Option<&'static str>,
pub condition_key: Option<&'static str>,
pub condition_value: Option<&'static str>,
pub destroy: Option<Destroy>,
}
impl ProviderDescriptor {
#[doc(hidden)]
pub const fn new(
type_id: fn() -> TypeId,
type_name: fn() -> &'static str,
factory: Factory,
) -> Self {
Self::configured(
type_id,
type_name,
factory,
None,
false,
Scope::Singleton,
false,
None,
None,
None,
None,
)
}
#[allow(clippy::too_many_arguments)]
#[doc(hidden)]
pub const fn configured(
type_id: fn() -> TypeId,
type_name: fn() -> &'static str,
factory: Factory,
name: Option<&'static str>,
primary: bool,
scope: Scope,
eager: bool,
profile: Option<&'static str>,
condition_key: Option<&'static str>,
condition_value: Option<&'static str>,
destroy: Option<Destroy>,
) -> Self {
Self {
type_id,
type_name,
factory,
name,
primary,
scope,
eager,
profile,
condition_key,
condition_value,
destroy,
}
}
fn active(&self, profiles: &[String]) -> bool {
let profile_matches = self
.profile
.is_none_or(|required| profiles.iter().any(|p| p == required));
let condition_matches = self.condition_key.is_none_or(|key| {
let actual = std::env::var(key).ok();
self.condition_value.map_or(actual.is_some(), |expected| {
actual.as_deref() == Some(expected)
})
});
profile_matches && condition_matches
}
}
inventory::collect!(ProviderDescriptor);
type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
struct RuntimeProvider {
descriptor: &'static ProviderDescriptor,
singleton: tokio::sync::OnceCell<DynArc>,
}
impl RuntimeProvider {
fn new(descriptor: &'static ProviderDescriptor) -> Self {
Self {
descriptor,
singleton: tokio::sync::OnceCell::new(),
}
}
}
#[derive(Clone, Default)]
pub struct ResolutionContext {
pub(crate) chain: Vec<&'static str>,
provider_chain: Vec<usize>,
scope_chain: Vec<Scope>,
pub(crate) request_instances: Option<Arc<Mutex<InstanceMap>>>,
}
#[derive(Debug, thiserror::Error)]
pub enum DiError {
#[error("no active provider is registered for {0}")]
MissingProvider(&'static str),
#[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
AmbiguousProvider(&'static str),
#[error("multiple primary providers are registered for {0}")]
MultiplePrimary(&'static str),
#[error("duplicate provider name '{name}' is registered for {type_name}")]
DuplicateProviderName {
type_name: &'static str,
name: &'static str,
},
#[error("circular dependency detected: {0}")]
CircularDependency(String),
#[error("provider for {0} returned an incompatible type")]
TypeMismatch(&'static str),
#[error("request-scoped dependency {0} was resolved outside RequestContext")]
RequestScopeUnavailable(&'static str),
#[error("singleton dependency {singleton} cannot capture request-scoped {request}")]
InvalidScope {
singleton: &'static str,
request: &'static str,
},
#[error("the dependency container has already been shut down")]
ContainerShutdown,
#[error("configuration property {key} is missing or invalid: {message}")]
Configuration { key: String, message: String },
#[error("provider for {provider} failed: {message}")]
Factory {
provider: &'static str,
message: String,
},
#[error("lifecycle hook failed for {provider}: {message}")]
Lifecycle {
provider: &'static str,
message: String,
},
}
struct ContainerInner {
providers: HashMap<TypeId, Vec<RuntimeProvider>>,
dependency_graph: Mutex<HashMap<usize, HashSet<usize>>>,
shut_down: AtomicBool,
}
#[derive(Clone)]
pub struct Container {
inner: Arc<ContainerInner>,
}
static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
pub fn global_container() -> Result<&'static Container, DiError> {
if let Some(container) = GLOBAL_CONTAINER.get() {
return Ok(container);
}
let container = Container::new()?;
let _ = GLOBAL_CONTAINER.set(container);
Ok(GLOBAL_CONTAINER
.get()
.expect("global DI container initialized"))
}
pub async fn resolve<T>() -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
global_container()?.resolve::<T>().await
}
impl Container {
pub fn new() -> Result<Self, DiError> {
let profiles = std::env::var("APP_PROFILES")
.unwrap_or_default()
.split(',')
.map(str::trim)
.filter(|p| !p.is_empty())
.map(str::to_owned)
.collect::<Vec<_>>();
Self::with_profiles(profiles)
}
pub fn with_profiles(
profiles: impl IntoIterator<Item = impl Into<String>>,
) -> Result<Self, DiError> {
let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
let mut providers: HashMap<TypeId, Vec<RuntimeProvider>> = HashMap::new();
for provider in inventory::iter::<ProviderDescriptor> {
if provider.active(&profiles) {
providers
.entry((provider.type_id)())
.or_default()
.push(RuntimeProvider::new(provider));
}
}
for group in providers.values() {
if group.iter().filter(|p| p.descriptor.primary).count() > 1 {
return Err(DiError::MultiplePrimary((group[0].descriptor.type_name)()));
}
let mut names = HashSet::new();
for provider in group {
if let Some(name) = provider.descriptor.name
&& !names.insert(name)
{
return Err(DiError::DuplicateProviderName {
type_name: (provider.descriptor.type_name)(),
name,
});
}
}
}
Ok(Self {
inner: Arc::new(ContainerInner {
providers,
dependency_graph: Mutex::new(HashMap::new()),
shut_down: AtomicBool::new(false),
}),
})
}
pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_dependency::<T>(&ResolutionContext::default())
.await
}
pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
.await
}
pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
where
T: Any + Send + Sync,
{
match self.resolve::<T>().await {
Ok(value) => Ok(Some(value)),
Err(DiError::MissingProvider(_)) => Ok(None),
Err(error) => Err(error),
}
}
pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_all_dependency::<T>(&ResolutionContext::default())
.await
}
pub fn request_context(&self) -> RequestContext<'_> {
RequestContext {
container: self,
context: ResolutionContext {
chain: vec![],
provider_chain: vec![],
scope_chain: vec![],
request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
},
}
}
pub async fn initialize_eager(&self) -> Result<(), DiError> {
for providers in self.inner.providers.values() {
for provider in providers.iter().filter(|p| p.descriptor.eager) {
self.resolve_provider(provider, ResolutionContext::default())
.await?;
}
}
Ok(())
}
pub async fn validate(&self) -> Result<(), DiError> {
for providers in self.inner.providers.values() {
for provider in providers
.iter()
.filter(|provider| provider.descriptor.scope == Scope::Singleton)
{
self.resolve_provider(provider, ResolutionContext::default())
.await?;
}
}
Ok(())
}
pub async fn shutdown(&self) -> Result<(), DiError> {
if self.inner.shut_down.swap(true, Ordering::AcqRel) {
return Ok(());
}
let graph = self
.inner
.dependency_graph
.lock()
.expect("DI dependency graph lock poisoned")
.clone();
let providers = self
.inner
.providers
.values()
.flatten()
.map(|provider| (runtime_provider_key(provider), provider))
.collect::<HashMap<_, _>>();
let order = shutdown_order(providers.keys().copied(), &graph);
for key in order {
let provider = providers[&key];
if let (Some(destroy), Some(value)) =
(provider.descriptor.destroy, provider.singleton.get())
{
destroy(value.clone()).await?;
}
}
Ok(())
}
#[doc(hidden)]
pub async fn resolve_dependency<T>(
&self,
context: &ResolutionContext,
) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_selected::<T>(None, context).await
}
#[doc(hidden)]
pub async fn resolve_named_dependency<T>(
&self,
name: &str,
context: &ResolutionContext,
) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.resolve_selected::<T>(Some(name), context).await
}
#[doc(hidden)]
pub async fn resolve_optional_dependency<T>(
&self,
context: &ResolutionContext,
) -> Result<Option<Arc<T>>, DiError>
where
T: Any + Send + Sync,
{
match self.resolve_dependency::<T>(context).await {
Ok(value) => Ok(Some(value)),
Err(DiError::MissingProvider(_)) => Ok(None),
Err(error) => Err(error),
}
}
#[doc(hidden)]
pub async fn resolve_all_dependency<T>(
&self,
context: &ResolutionContext,
) -> Result<Vec<Arc<T>>, DiError>
where
T: Any + Send + Sync,
{
let Some(providers) = self.inner.providers.get(&TypeId::of::<T>()) else {
return Ok(vec![]);
};
let mut values = Vec::with_capacity(providers.len());
for provider in providers {
let value = self.resolve_provider(provider, context.clone()).await?;
values.push(
value
.downcast::<T>()
.map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
);
}
Ok(values)
}
async fn resolve_selected<T>(
&self,
name: Option<&str>,
context: &ResolutionContext,
) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
let type_name = std::any::type_name::<T>();
let providers = self
.inner
.providers
.get(&TypeId::of::<T>())
.ok_or(DiError::MissingProvider(type_name))?;
let selected = if let Some(name) = name {
providers
.iter()
.find(|p| p.descriptor.name == Some(name))
.ok_or(DiError::MissingProvider(type_name))?
} else if providers.len() == 1 {
&providers[0]
} else {
providers
.iter()
.find(|p| p.descriptor.primary)
.ok_or(DiError::AmbiguousProvider(type_name))?
};
let value = self.resolve_provider(selected, context.clone()).await?;
value
.downcast::<T>()
.map_err(|_| DiError::TypeMismatch(type_name))
}
fn resolve_provider<'a>(
&'a self,
provider: &'a RuntimeProvider,
mut context: ResolutionContext,
) -> BoxFuture<'a, Result<DynArc, DiError>> {
Box::pin(async move {
if self.inner.shut_down.load(Ordering::Acquire) {
return Err(DiError::ContainerShutdown);
}
let descriptor = provider.descriptor;
let type_name = (descriptor.type_name)();
if context.chain.contains(&type_name) {
context.chain.push(type_name);
return Err(DiError::CircularDependency(context.chain.join(" -> ")));
}
let runtime_key = runtime_provider_key(provider);
if let Some(parent) = context.provider_chain.last().copied() {
self.add_dependency_edge(parent, runtime_key, &context, type_name)?;
}
if descriptor.scope == Scope::Request
&& let Some(position) = context
.scope_chain
.iter()
.position(|scope| *scope == Scope::Singleton)
{
return Err(DiError::InvalidScope {
singleton: context.chain[position],
request: type_name,
});
}
context.chain.push(type_name);
context.provider_chain.push(runtime_key);
context.scope_chain.push(descriptor.scope);
if descriptor.scope == Scope::Prototype {
return (descriptor.factory)(self, context).await;
}
match descriptor.scope {
Scope::Singleton => {
let value = provider
.singleton
.get_or_try_init(
|| async move { (descriptor.factory)(self, context).await },
)
.await?;
Ok(value.clone())
}
Scope::Request => {
let map = context
.request_instances
.as_deref()
.ok_or(DiError::RequestScopeUnavailable(type_name))?;
let cell = {
let mut instances = map.lock().expect("DI instance lock poisoned");
instances
.entry(provider_key(descriptor))
.or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
.clone()
};
let value = cell
.get_or_try_init(
|| async move { (descriptor.factory)(self, context).await },
)
.await?;
Ok(value.clone())
}
Scope::Prototype => unreachable!(),
}
})
}
fn add_dependency_edge(
&self,
parent: usize,
dependency: usize,
context: &ResolutionContext,
dependency_name: &'static str,
) -> Result<(), DiError> {
let mut graph = self
.inner
.dependency_graph
.lock()
.expect("DI dependency graph lock poisoned");
graph.entry(parent).or_default().insert(dependency);
if graph_path_exists(&graph, dependency, parent, &mut HashSet::new()) {
let mut chain = context.chain.clone();
chain.push(dependency_name);
return Err(DiError::CircularDependency(chain.join(" -> ")));
}
Ok(())
}
}
fn provider_key(provider: &'static ProviderDescriptor) -> usize {
provider as *const ProviderDescriptor as usize
}
fn runtime_provider_key(provider: &RuntimeProvider) -> usize {
provider as *const RuntimeProvider as usize
}
fn graph_path_exists(
graph: &HashMap<usize, HashSet<usize>>,
current: usize,
target: usize,
visited: &mut HashSet<usize>,
) -> bool {
if current == target {
return true;
}
visited.insert(current)
&& graph.get(¤t).is_some_and(|dependencies| {
dependencies
.iter()
.any(|next| graph_path_exists(graph, *next, target, visited))
})
}
fn shutdown_order(
keys: impl IntoIterator<Item = usize>,
graph: &HashMap<usize, HashSet<usize>>,
) -> Vec<usize> {
let keys = keys.into_iter().collect::<HashSet<_>>();
let mut incoming = keys
.iter()
.map(|key| (*key, 0usize))
.collect::<HashMap<_, _>>();
for (parent, dependencies) in graph {
if !keys.contains(parent) {
continue;
}
for dependency in dependencies {
if let Some(count) = incoming.get_mut(dependency) {
*count += 1;
}
}
}
let mut ready = incoming
.iter()
.filter_map(|(key, count)| (*count == 0).then_some(*key))
.collect::<VecDeque<_>>();
let mut order = Vec::with_capacity(keys.len());
while let Some(parent) = ready.pop_front() {
order.push(parent);
if let Some(dependencies) = graph.get(&parent) {
for dependency in dependencies {
if let Some(count) = incoming.get_mut(dependency) {
*count -= 1;
if *count == 0 {
ready.push_back(*dependency);
}
}
}
}
}
let ordered = order.iter().copied().collect::<HashSet<_>>();
order.extend(keys.into_iter().filter(|key| !ordered.contains(key)));
order
}
pub struct RequestContext<'a> {
container: &'a Container,
context: ResolutionContext,
}
impl RequestContext<'_> {
pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
where
T: Any + Send + Sync,
{
self.container.resolve_dependency::<T>(&self.context).await
}
}
#[doc(hidden)]
pub mod __private {
pub use inventory;
pub use tokio;
use std::fmt::Display;
use crate::DiError;
pub fn factory_result<T, E: Display>(
result: Result<T, E>,
provider: &'static str,
) -> Result<T, DiError> {
result.map_err(|error| DiError::Factory {
provider,
message: error.to_string(),
})
}
pub trait IntoLifecycleResult {
fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError>;
}
impl IntoLifecycleResult for () {
fn into_lifecycle_result(self, _provider: &'static str) -> Result<(), DiError> {
Ok(())
}
}
impl<E: Display> IntoLifecycleResult for Result<(), E> {
fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError> {
self.map_err(|error| DiError::Lifecycle {
provider,
message: error.to_string(),
})
}
}
pub fn lifecycle_result<R: IntoLifecycleResult>(
result: R,
provider: &'static str,
) -> Result<(), DiError> {
result.into_lifecycle_result(provider)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
Lazy, Provider, configuration_properties, injectable, injected, provider, singleton,
};
use std::sync::atomic::{AtomicUsize, Ordering};
static CREATIONS: AtomicUsize = AtomicUsize::new(0);
struct SyncDependency;
#[singleton]
fn sync_dependency() -> SyncDependency {
CREATIONS.fetch_add(1, Ordering::SeqCst);
SyncDependency
}
struct AsyncDependency {
_sync: Arc<SyncDependency>,
}
#[singleton]
async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
AsyncDependency { _sync: sync }
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn singleton_is_concurrent_safe() {
let container = Arc::new(Container::new().unwrap());
let mut tasks = tokio::task::JoinSet::new();
for _ in 0..32 {
let c = container.clone();
tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
}
let mut values = vec![];
while let Some(v) = tasks.join_next().await {
values.push(v.unwrap());
}
assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
}
static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
struct PrototypeBean(usize);
#[singleton(scope = "prototype")]
fn prototype_bean() -> PrototypeBean {
PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
}
struct RequestBean;
#[singleton(scope = "request")]
fn request_bean() -> RequestBean {
RequestBean
}
trait Greeting: Send + Sync {
fn text(&self) -> &'static str;
}
struct English;
impl Greeting for English {
fn text(&self) -> &'static str {
"hello"
}
}
struct Hindi;
impl Greeting for Hindi {
fn text(&self) -> &'static str {
"namaste"
}
}
#[singleton(name = "english", primary)]
fn english_greeting() -> Arc<dyn Greeting> {
Arc::new(English)
}
#[singleton(name = "hindi")]
fn hindi_greeting() -> Arc<dyn Greeting> {
Arc::new(Hindi)
}
struct MissingOptional;
struct Greeter {
greeting: Arc<dyn Greeting>,
optional: Option<Arc<MissingOptional>>,
}
struct GreetingLabel(&'static str);
#[singleton]
impl Greeter {
fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
Self { greeting, optional }
}
#[provider]
fn label(&self) -> GreetingLabel {
GreetingLabel(self.greeting.text())
}
}
struct QualifiedGreeter(Arc<dyn Greeting>);
#[singleton]
impl QualifiedGreeter {
fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
Self(greeting)
}
}
struct StaticConfig(&'static str);
struct ServiceWithStaticBean(Arc<StaticConfig>);
#[singleton]
impl ServiceWithStaticBean {
fn new(config: Arc<StaticConfig>) -> Self {
Self(config)
}
#[provider]
fn config() -> StaticConfig {
StaticConfig("static-bean")
}
}
static STARTED: AtomicUsize = AtomicUsize::new(0);
static STOPPED: AtomicUsize = AtomicUsize::new(0);
struct Managed;
#[singleton(eager, post_construct = "start", pre_destroy = "stop")]
impl Managed {
fn new() -> Self {
Self
}
async fn start(&self) {
STARTED.fetch_add(1, Ordering::SeqCst);
}
async fn stop(&self) {
STOPPED.fetch_add(1, Ordering::SeqCst);
}
}
struct ProfileBean;
#[singleton(profile = "test")]
fn profile_bean() -> ProfileBean {
ProfileBean
}
#[derive(Debug)]
struct Handler(&'static str);
#[singleton(name = "first")]
fn first_handler() -> Handler {
Handler("first")
}
#[singleton(name = "second")]
fn second_handler() -> Handler {
Handler("second")
}
struct Pipeline(Vec<Arc<Handler>>);
#[singleton]
impl Pipeline {
fn new(handlers: Vec<Arc<Handler>>) -> Self {
Self(handlers)
}
}
#[configuration_properties("testing_dep")]
struct TestProperties {
port: u16,
}
struct DeferredTarget;
#[singleton]
fn deferred_target() -> DeferredTarget {
DeferredTarget
}
struct DeferredConsumer {
provider: Provider<DeferredTarget>,
lazy: Lazy<DeferredTarget>,
}
#[singleton]
impl DeferredConsumer {
fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
Self { provider, lazy }
}
}
struct ConditionalBean;
#[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
fn conditional_bean() -> ConditionalBean {
ConditionalBean
}
struct FallibleDependency;
#[provider]
fn fallible_dependency() -> Result<FallibleDependency, &'static str> {
Err("expected factory failure")
}
#[derive(Clone)]
struct FieldDependency(u32);
#[singleton]
fn field_dependency() -> FieldDependency {
FieldDependency(10)
}
#[injectable]
struct FieldInjectedFacade {
dependency: Arc<FieldDependency>,
#[inject(7)]
literal: u32,
#[inject(|dependency: Arc<FieldDependency>| dependency)]
transformed: Arc<FieldDependency>,
}
#[injectable]
struct TupleInjected(#[inject(123)] u32);
#[injected]
fn injected_total(
dependency: Arc<FieldDependency>,
#[inject(2)] offset: u32,
multiplier: u32,
) -> u32 {
(dependency.0 + offset) * multiplier
}
#[injected]
fn caller_owned_arc(
#[argument] dependency: Arc<FieldDependency>,
#[inject(1)] offset: u32,
) -> u32 {
dependency.0 + offset
}
struct InjectedMethods;
impl InjectedMethods {
#[injected]
fn calculate(&self, dependency: Arc<FieldDependency>, value: u32) -> u32 {
dependency.0 + value
}
}
struct RequestLeaf;
#[singleton(scope = "request")]
fn request_leaf() -> RequestLeaf {
RequestLeaf
}
struct CaptiveSingleton {
_request: Arc<RequestLeaf>,
}
#[singleton]
impl CaptiveSingleton {
fn new(request: Arc<RequestLeaf>) -> Self {
Self { _request: request }
}
}
struct CycleA {
_b: Arc<CycleB>,
}
struct CycleB {
_a: Arc<CycleA>,
}
#[singleton]
impl CycleA {
fn new(b: Arc<CycleB>) -> Self {
Self { _b: b }
}
}
#[singleton]
impl CycleB {
fn new(a: Arc<CycleA>) -> Self {
Self { _a: a }
}
}
struct AllGreetings(Vec<Arc<dyn Greeting>>);
#[singleton]
impl AllGreetings {
fn new(greetings: Vec<Arc<dyn Greeting>>) -> Self {
Self(greetings)
}
}
struct ProfileDeferred;
#[singleton(profile = "test")]
fn profile_deferred() -> ProfileDeferred {
ProfileDeferred
}
struct ProfileDeferredConsumer(Provider<ProfileDeferred>);
#[singleton]
impl ProfileDeferredConsumer {
fn new(provider: Provider<ProfileDeferred>) -> Self {
Self(provider)
}
}
struct StandaloneBean;
#[provider]
fn standalone_bean() -> StandaloneBean {
StandaloneBean
}
#[tokio::test]
async fn scopes_traits_primary_profiles_and_lifecycle_work() {
unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
let container = Container::with_profiles(["test"]).unwrap();
let first = container.resolve::<PrototypeBean>().await.unwrap();
let second = container.resolve::<PrototypeBean>().await.unwrap();
assert_ne!(first.0, second.0);
assert!(matches!(
container.resolve::<RequestBean>().await,
Err(DiError::RequestScopeUnavailable(_))
));
let request = container.request_context();
let request_first = request.resolve::<RequestBean>().await.unwrap();
let request_second = request.resolve::<RequestBean>().await.unwrap();
assert!(Arc::ptr_eq(&request_first, &request_second));
let greeter = container.resolve::<Greeter>().await.unwrap();
assert_eq!(greeter.greeting.text(), "hello");
assert_eq!(
container.resolve::<GreetingLabel>().await.unwrap().0,
"hello"
);
assert!(greeter.optional.is_none());
assert_eq!(
container
.resolve::<QualifiedGreeter>()
.await
.unwrap()
.0
.text(),
"namaste"
);
let hindi = container
.resolve_named::<Arc<dyn Greeting>>("hindi")
.await
.unwrap();
assert_eq!(hindi.text(), "namaste");
let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
assert_eq!(static_bean_service.0.0, "static-bean");
container.resolve::<ProfileBean>().await.unwrap();
let pipeline = container.resolve::<Pipeline>().await.unwrap();
let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
handler_names.sort_unstable();
assert_eq!(handler_names, ["first", "second"]);
unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
let properties = container.resolve::<TestProperties>().await.unwrap();
assert_eq!(properties.port, 8080);
container.resolve::<ConditionalBean>().await.unwrap();
container.resolve::<StandaloneBean>().await.unwrap();
let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
deferred.provider.get().await.unwrap();
deferred.lazy.get().await.unwrap();
let profile_deferred = container
.resolve::<ProfileDeferredConsumer>()
.await
.unwrap();
profile_deferred.0.get().await.unwrap();
let greetings = container.resolve::<AllGreetings>().await.unwrap();
assert_eq!(greetings.0.len(), 2);
container.initialize_eager().await.unwrap();
assert_eq!(STARTED.load(Ordering::SeqCst), 1);
container.shutdown().await.unwrap();
assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
assert!(matches!(
container.resolve::<SyncDependency>().await,
Err(DiError::ContainerShutdown)
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn failures_are_reported_instead_of_hanging_or_capturing_scope() {
let container = Arc::new(Container::new().unwrap());
assert!(matches!(
container.resolve::<FallibleDependency>().await,
Err(DiError::Factory { .. })
));
let request = container.request_context();
let captive = request.resolve::<CaptiveSingleton>().await;
assert!(matches!(captive, Err(DiError::InvalidScope { .. })));
let first = {
let container = container.clone();
tokio::spawn(async move { container.resolve::<CycleA>().await })
};
let second = {
let container = container.clone();
tokio::spawn(async move { container.resolve::<CycleB>().await })
};
let results = tokio::time::timeout(std::time::Duration::from_secs(1), async {
(first.await.unwrap(), second.await.unwrap())
})
.await
.expect("cross-task cycle resolution must not deadlock");
assert!(
matches!(results.0, Err(DiError::CircularDependency(_)))
|| matches!(results.1, Err(DiError::CircularDependency(_)))
);
}
#[tokio::test]
async fn field_and_function_injection_are_automatic() {
let facade = resolve::<FieldInjectedFacade>().await.unwrap();
assert_eq!(facade.dependency.0, 10);
assert_eq!(facade.literal, 7);
assert!(Arc::ptr_eq(&facade.dependency, &facade.transformed));
assert_eq!(resolve::<TupleInjected>().await.unwrap().0, 123);
assert_eq!(injected_total(3).await.unwrap(), 36);
assert_eq!(injected_total_with(Arc::new(FieldDependency(20)), 4, 2), 48);
assert_eq!(
caller_owned_arc(Arc::new(FieldDependency(20)))
.await
.unwrap(),
21
);
assert_eq!(InjectedMethods.calculate(5).await.unwrap(), 15);
}
}