Skip to main content

nestforge_testing/
lib.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use anyhow::Result;
4use nestforge_core::{initialize_module_graph, Container, ContainerError, ModuleDefinition};
5
6type OverrideFn = Box<dyn Fn(&Container) -> Result<()> + Send + Sync>;
7
8pub struct TestFactory<M: ModuleDefinition> {
9    overrides: Vec<OverrideFn>,
10    _marker: PhantomData<M>,
11}
12
13impl<M: ModuleDefinition> TestFactory<M> {
14    pub fn create() -> Self {
15        Self {
16            overrides: Vec::new(),
17            _marker: PhantomData,
18        }
19    }
20
21    pub fn override_provider<T>(mut self, value: T) -> Self
22    where
23        T: Send + Sync + Clone + 'static,
24    {
25        self.overrides.push(Box::new(move |container| {
26            container.replace(value.clone())?;
27            Ok(())
28        }));
29        self
30    }
31
32    pub fn build(self) -> Result<TestingModule> {
33        let container = Container::new();
34        let _ = initialize_module_graph::<M>(&container)?;
35
36        for override_fn in self.overrides {
37            override_fn(&container)?;
38        }
39
40        Ok(TestingModule { container })
41    }
42}
43
44#[derive(Clone)]
45pub struct TestingModule {
46    container: Container,
47}
48
49impl TestingModule {
50    pub fn container(&self) -> &Container {
51        &self.container
52    }
53
54    pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
55    where
56        T: Send + Sync + 'static,
57    {
58        self.container.resolve::<T>()
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    #[derive(Clone, Debug, PartialEq, Eq)]
67    struct AppConfig {
68        app_name: &'static str,
69    }
70
71    struct AppModule;
72    impl ModuleDefinition for AppModule {
73        fn register(container: &Container) -> Result<()> {
74            container.register(AppConfig {
75                app_name: "default",
76            })?;
77            Ok(())
78        }
79    }
80
81    #[test]
82    fn builds_testing_module_and_resolves_default_provider() {
83        let module = TestFactory::<AppModule>::create()
84            .build()
85            .expect("test module should build");
86
87        let config = module
88            .resolve::<AppConfig>()
89            .expect("config should resolve");
90        assert_eq!(
91            *config,
92            AppConfig {
93                app_name: "default"
94            }
95        );
96    }
97
98    #[test]
99    fn overrides_provider_value() {
100        let module = TestFactory::<AppModule>::create()
101            .override_provider(AppConfig { app_name: "test" })
102            .build()
103            .expect("test module should build with overrides");
104
105        let config = module
106            .resolve::<AppConfig>()
107            .expect("config should resolve");
108        assert_eq!(*config, AppConfig { app_name: "test" });
109    }
110}