use std::any::Any;
use std::future::Future;
use std::sync::Arc;
use reinhardt_di::{
DependencyRegistry, DependencyScope, DiResult, InjectionContext, OverrideGuard, SingletonScope,
global_registry,
};
pub struct DiOverrides {
_guards: Vec<OverrideGuard>,
}
type RequestSeedFn = Box<dyn FnOnce(&InjectionContext) + Send + 'static>;
pub struct DiOverrideBuilder<'a> {
registry: Arc<DependencyRegistry>,
scope: &'a SingletonScope,
guards: Vec<OverrideGuard>,
request_seeds: Vec<RequestSeedFn>,
}
impl<'a> DiOverrideBuilder<'a> {
fn new(registry: Arc<DependencyRegistry>, scope: &'a SingletonScope) -> Self {
Self {
registry,
scope,
guards: Vec::new(),
request_seeds: Vec::new(),
}
}
pub fn singleton<T: Any + Send + Sync + 'static>(&mut self, value: T) {
self.scope.set(value);
}
pub fn request_value<T: Any + Send + Sync + 'static>(&mut self, value: T) {
self.request_seeds.push(Box::new(move |ctx| {
ctx.set_request::<T>(value);
}));
}
pub fn factory<T, F, Fut>(&mut self, scope: DependencyScope, factory: F)
where
T: Any + Send + Sync + 'static,
F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = DiResult<T>> + Send + 'static,
{
let guard = self.registry.register_override::<T, _, _>(scope, factory);
self.guards.push(guard);
}
fn into_overrides(self) -> DiOverrides {
DiOverrides {
_guards: self.guards,
}
}
}
pub async fn injection_context_with_di_overrides<F>(setup: F) -> (InjectionContext, DiOverrides)
where
F: FnOnce(&SingletonScope, &mut DiOverrideBuilder<'_>),
{
let scope = Arc::new(SingletonScope::new());
let registry = global_registry().clone();
let mut builder = DiOverrideBuilder::new(registry, &scope);
setup(&scope, &mut builder);
let request_seeds = std::mem::take(&mut builder.request_seeds);
let overrides = builder.into_overrides();
let ctx = InjectionContext::builder(scope).build();
for seed in request_seeds {
seed(&ctx);
}
(ctx, overrides)
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
use serial_test::serial;
#[derive(Clone, Debug, PartialEq)]
struct Cfg {
key: &'static str,
}
#[rstest]
#[serial(di_registry)]
#[tokio::test]
async fn singleton_override_is_visible_via_get_singleton() {
let (ctx, _di) = injection_context_with_di_overrides(|_scope, builder| {
builder.singleton(Cfg { key: "test" });
})
.await;
let value: Option<Arc<Cfg>> = ctx.get_singleton();
assert_eq!(value.unwrap().key, "test");
}
#[derive(Clone, Debug, PartialEq)]
struct Counter(u32);
#[rstest]
#[serial(di_registry)]
#[tokio::test]
async fn factory_override_returns_mock_value() {
let (ctx, _di) = injection_context_with_di_overrides(|_scope, builder| {
builder.factory::<Counter, _, _>(DependencyScope::Transient, |_ctx| async {
Ok::<_, reinhardt_di::DiError>(Counter(7))
});
})
.await;
let value: Arc<Counter> = ctx.resolve::<Counter>().await.unwrap();
assert_eq!(*value, Counter(7));
}
#[derive(Clone, Debug, PartialEq)]
struct RequestCfg {
token: &'static str,
}
#[rstest]
#[serial(di_registry)]
#[tokio::test]
async fn request_value_lands_in_request_scope_not_singleton() {
let (ctx, _di) = injection_context_with_di_overrides(|_scope, builder| {
builder.request_value(RequestCfg { token: "req-only" });
})
.await;
let from_request: Option<Arc<RequestCfg>> = ctx.get_request();
let from_singleton: Option<Arc<RequestCfg>> = ctx.get_singleton();
assert_eq!(from_request.unwrap().token, "req-only");
assert!(from_singleton.is_none());
}
#[rstest]
#[serial(di_registry)]
#[tokio::test]
async fn factory_override_reverts_after_di_overrides_drop() {
let registry = reinhardt_di::global_registry().clone();
#[derive(Clone, Debug, PartialEq)]
struct DropProbe(u32);
if !registry.is_registered::<DropProbe>() {
registry.register_async::<DropProbe, _, _>(DependencyScope::Transient, |_ctx| async {
Ok(DropProbe(1))
});
}
{
let (ctx, di) = injection_context_with_di_overrides(|_scope, builder| {
builder.factory::<DropProbe, _, _>(DependencyScope::Transient, |_ctx| async {
Ok::<_, reinhardt_di::DiError>(DropProbe(99))
});
})
.await;
let v: Arc<DropProbe> = ctx.resolve::<DropProbe>().await.unwrap();
assert_eq!(*v, DropProbe(99));
drop(di);
}
let ctx = InjectionContext::builder(Arc::new(SingletonScope::new())).build();
let v: Arc<DropProbe> = ctx.resolve::<DropProbe>().await.unwrap();
assert_eq!(*v, DropProbe(1));
}
}