use reinhardt_di::resolve_context::{RESOLVE_CTX, ResolveContext};
use reinhardt_di::{InjectionContext, SingletonScope};
use rstest::*;
use std::future::Future;
use std::sync::Arc;
#[fixture]
pub fn singleton_scope() -> Arc<SingletonScope> {
Arc::new(SingletonScope::new())
}
#[fixture]
pub fn injection_context(singleton_scope: Arc<SingletonScope>) -> InjectionContext {
InjectionContext::builder(singleton_scope).build()
}
pub fn injection_context_with_overrides<F>(
singleton_scope: Arc<SingletonScope>,
overrides: F,
) -> InjectionContext
where
F: FnOnce(&SingletonScope),
{
overrides(&singleton_scope);
InjectionContext::builder(singleton_scope).build()
}
#[fixture]
pub async fn injection_context_with_sqlite() -> (tempfile::NamedTempFile, InjectionContext) {
use reinhardt_db::orm::connection::DatabaseConnection;
let temp_file = tempfile::NamedTempFile::new().expect("Failed to create temp file");
let db_path = temp_file.path().to_str().unwrap().to_string();
let database_url = format!("sqlite://{}?mode=rwc", db_path);
let db_conn = DatabaseConnection::connect_sqlite(&database_url)
.await
.expect("Failed to create DatabaseConnection");
let singleton_scope = Arc::new(SingletonScope::new());
singleton_scope.set(db_conn);
let ctx = InjectionContext::builder(singleton_scope).build();
(temp_file, ctx)
}
pub async fn injection_context_with_database(database_url: &str) -> InjectionContext {
use reinhardt_db::orm::connection::DatabaseConnection;
let db_conn = DatabaseConnection::connect(database_url)
.await
.expect("Failed to create DatabaseConnection");
let singleton_scope = Arc::new(SingletonScope::new());
singleton_scope.set(db_conn);
InjectionContext::builder(singleton_scope).build()
}
pub async fn with_test_di_context<F, Fut, T>(setup: impl FnOnce(&SingletonScope), f: F) -> T
where
F: FnOnce(Arc<InjectionContext>) -> Fut,
Fut: Future<Output = T>,
{
let scope = Arc::new(SingletonScope::new());
setup(&scope);
let ctx = Arc::new(InjectionContext::builder(scope).build());
let resolve_ctx = ResolveContext {
root: Arc::clone(&ctx),
current: Arc::clone(&ctx),
};
RESOLVE_CTX.scope(resolve_ctx, f(ctx)).await
}
#[cfg(test)]
mod tests {
use super::*;
use reinhardt_di::{Depends, DiResult, Injectable};
#[derive(Clone, Debug, PartialEq)]
struct TestConfig {
value: String,
}
#[async_trait::async_trait]
impl Injectable for TestConfig {
async fn inject(_ctx: &InjectionContext) -> DiResult<Self> {
Ok(TestConfig {
value: "test_config".to_string(),
})
}
}
#[rstest]
#[tokio::test]
async fn test_singleton_scope_fixture(singleton_scope: Arc<SingletonScope>) {
assert!(
singleton_scope.get::<String>().is_none(),
"Singleton scope should be empty initially"
);
singleton_scope.set("test".to_string());
let value: Option<Arc<String>> = singleton_scope.get();
assert_eq!(*value.unwrap(), "test");
}
#[rstest]
#[tokio::test]
async fn test_injection_context_fixture(injection_context: InjectionContext) {
let config = TestConfig::inject(&injection_context).await.unwrap();
assert_eq!(config.value, "test_config");
}
#[rstest]
#[tokio::test]
async fn test_fixture_isolation_first(injection_context: InjectionContext) {
injection_context.set_request("first".to_string());
let value: Option<Arc<String>> = injection_context.get_request();
assert_eq!(*value.unwrap(), "first");
}
#[rstest]
#[tokio::test]
async fn test_fixture_isolation_second(injection_context: InjectionContext) {
let value: Option<Arc<String>> = injection_context.get_request();
assert!(
value.is_none(),
"Request scope should be isolated between tests"
);
}
#[rstest]
#[tokio::test]
async fn test_depends_with_fixtures(injection_context: InjectionContext) {
let registry = reinhardt_di::global_registry();
registry.register_async::<TestConfig, _, _>(
reinhardt_di::DependencyScope::Request,
|_ctx| async {
Ok(TestConfig {
value: "test_config".to_string(),
})
},
);
let config = Depends::<TestConfig>::builder()
.resolve(&injection_context)
.await
.unwrap();
assert_eq!(config.value, "test_config");
}
#[rstest]
#[tokio::test]
async fn test_request_scope_caching(injection_context: InjectionContext) {
let config1 = TestConfig::inject(&injection_context).await.unwrap();
let config2 = TestConfig::inject(&injection_context).await.unwrap();
assert_eq!(config1, config2);
}
#[rstest]
#[tokio::test]
async fn test_singleton_scope_sharing(singleton_scope: Arc<SingletonScope>) {
let ctx1 = InjectionContext::builder(Arc::clone(&singleton_scope)).build();
let ctx2 = InjectionContext::builder(Arc::clone(&singleton_scope)).build();
ctx1.set_singleton("shared_value".to_string());
let value: Option<Arc<String>> = ctx2.get_singleton();
assert_eq!(*value.unwrap(), "shared_value");
}
#[derive(Clone, Debug, PartialEq)]
struct MockDatabase {
url: String,
}
#[async_trait::async_trait]
impl Injectable for MockDatabase {
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
if let Some(db) = ctx.get_singleton::<MockDatabase>() {
return Ok((*db).clone());
}
Ok(MockDatabase {
url: "prod://database".to_string(),
})
}
}
#[rstest]
#[tokio::test]
async fn test_injection_context_with_overrides_basic(singleton_scope: Arc<SingletonScope>) {
let ctx = injection_context_with_overrides(singleton_scope, |scope| {
scope.set(MockDatabase {
url: "test://database".to_string(),
});
});
let db = MockDatabase::inject(&ctx).await.unwrap();
assert_eq!(db.url, "test://database");
}
#[derive(Clone, Debug, PartialEq)]
struct MockConfig {
api_key: String,
}
#[async_trait::async_trait]
impl Injectable for MockConfig {
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
if let Some(config) = ctx.get_singleton::<MockConfig>() {
return Ok((*config).clone());
}
Ok(MockConfig {
api_key: "prod_key".to_string(),
})
}
}
#[rstest]
#[tokio::test]
async fn test_injection_context_with_overrides_multiple(singleton_scope: Arc<SingletonScope>) {
let ctx = injection_context_with_overrides(singleton_scope, |scope| {
scope.set(MockDatabase {
url: "test://database".to_string(),
});
scope.set(MockConfig {
api_key: "test_key".to_string(),
});
});
let db = MockDatabase::inject(&ctx).await.unwrap();
let config = MockConfig::inject(&ctx).await.unwrap();
assert_eq!(db.url, "test://database");
assert_eq!(config.api_key, "test_key");
}
#[rstest]
#[tokio::test]
async fn test_injection_context_without_overrides_uses_default(
singleton_scope: Arc<SingletonScope>,
) {
let ctx = InjectionContext::builder(singleton_scope).build();
let db = MockDatabase::inject(&ctx).await.unwrap();
let config = MockConfig::inject(&ctx).await.unwrap();
assert_eq!(db.url, "prod://database");
assert_eq!(config.api_key, "prod_key");
}
#[rstest]
#[tokio::test]
async fn test_injection_context_with_overrides_and_depends(
singleton_scope: Arc<SingletonScope>,
) {
let ctx = injection_context_with_overrides(singleton_scope, |scope| {
scope.set(MockDatabase {
url: "test://database".to_string(),
});
});
let db = Depends::<MockDatabase>::builder()
.resolve(&ctx)
.await
.unwrap();
assert_eq!(db.url, "test://database");
}
#[rstest]
#[tokio::test]
async fn test_with_test_di_context_unique_per_call() {
let ctx1 = with_test_di_context(|_| {}, |ctx| async move { ctx }).await;
let ctx2 = with_test_di_context(|_| {}, |ctx| async move { ctx }).await;
assert!(!Arc::ptr_eq(&ctx1, &ctx2));
}
#[rstest]
#[tokio::test]
async fn test_with_test_di_context_get_di_context_root() {
use reinhardt_di::resolve_context::{ContextLevel, get_di_context};
with_test_di_context(
|_| {},
|ctx| async move {
let root = get_di_context(ContextLevel::Root);
assert!(Arc::ptr_eq(&root, &ctx));
},
)
.await;
}
#[rstest]
#[tokio::test]
async fn test_with_test_di_context_get_di_context_current() {
use reinhardt_di::resolve_context::{ContextLevel, get_di_context};
with_test_di_context(
|_| {},
|ctx| async move {
let current = get_di_context(ContextLevel::Current);
assert!(Arc::ptr_eq(¤t, &ctx));
},
)
.await;
}
#[rstest]
#[tokio::test]
async fn test_with_test_di_context_setup_registers_singletons() {
let result = with_test_di_context(
|scope| {
scope.set(TestConfig {
value: "from_setup".to_string(),
});
},
|ctx| async move {
let config: Option<Arc<TestConfig>> = ctx.get_singleton();
config.unwrap().value.clone()
},
)
.await;
assert_eq!(result, "from_setup");
}
#[rstest]
#[tokio::test]
async fn test_with_test_di_context_parallel_safety() {
let (val1, val2) = tokio::join!(
with_test_di_context(
|scope| {
scope.set("task_1".to_string());
},
|ctx| async move {
let v: Option<Arc<String>> = ctx.get_singleton();
v.unwrap().as_str().to_owned()
},
),
with_test_di_context(
|scope| {
scope.set("task_2".to_string());
},
|ctx| async move {
let v: Option<Arc<String>> = ctx.get_singleton();
v.unwrap().as_str().to_owned()
},
),
);
assert_eq!(val1, "task_1");
assert_eq!(val2, "task_2");
}
}