1use std::{marker::PhantomData, sync::Arc};
2
3use anyhow::Result;
4use nestforge_core::{initialize_module_graph, Container, ContainerError, ModuleDefinition};
5
6type OverrideFn = Box<dyn Fn(&Container) -> Result<()> + Send + Sync>;
7
8pub struct TestFactory<M: ModuleDefinition> {
9 overrides: Vec<OverrideFn>,
10 _marker: PhantomData<M>,
11}
12
13impl<M: ModuleDefinition> TestFactory<M> {
14 pub fn create() -> Self {
15 Self {
16 overrides: Vec::new(),
17 _marker: PhantomData,
18 }
19 }
20
21 pub fn override_provider<T>(mut self, value: T) -> Self
22 where
23 T: Send + Sync + Clone + 'static,
24 {
25 self.overrides.push(Box::new(move |container| {
26 container.replace(value.clone())?;
27 Ok(())
28 }));
29 self
30 }
31
32 pub fn build(self) -> Result<TestingModule> {
33 let container = Container::new();
34 let _ = initialize_module_graph::<M>(&container)?;
35
36 for override_fn in self.overrides {
37 override_fn(&container)?;
38 }
39
40 Ok(TestingModule { container })
41 }
42}
43
44#[derive(Clone)]
45pub struct TestingModule {
46 container: Container,
47}
48
49impl TestingModule {
50 pub fn container(&self) -> &Container {
51 &self.container
52 }
53
54 pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
55 where
56 T: Send + Sync + 'static,
57 {
58 self.container.resolve::<T>()
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 #[derive(Clone, Debug, PartialEq, Eq)]
67 struct AppConfig {
68 app_name: &'static str,
69 }
70
71 struct AppModule;
72 impl ModuleDefinition for AppModule {
73 fn register(container: &Container) -> Result<()> {
74 container.register(AppConfig {
75 app_name: "default",
76 })?;
77 Ok(())
78 }
79 }
80
81 #[test]
82 fn builds_testing_module_and_resolves_default_provider() {
83 let module = TestFactory::<AppModule>::create()
84 .build()
85 .expect("test module should build");
86
87 let config = module
88 .resolve::<AppConfig>()
89 .expect("config should resolve");
90 assert_eq!(
91 *config,
92 AppConfig {
93 app_name: "default"
94 }
95 );
96 }
97
98 #[test]
99 fn overrides_provider_value() {
100 let module = TestFactory::<AppModule>::create()
101 .override_provider(AppConfig { app_name: "test" })
102 .build()
103 .expect("test module should build with overrides");
104
105 let config = module
106 .resolve::<AppConfig>()
107 .expect("config should resolve");
108 assert_eq!(*config, AppConfig { app_name: "test" });
109 }
110}