Skip to main content

nestforge_core/
provider.rs

1use std::marker::PhantomData;
2
3use anyhow::{anyhow, Result};
4
5use crate::{framework_log_event, Container};
6
7pub struct Provider;
8
9pub struct ValueProvider<T> {
10    value: T,
11}
12
13pub struct FactoryProvider<T, F> {
14    factory: F,
15    _marker: PhantomData<fn() -> T>,
16}
17
18pub struct RequestFactoryProvider<T, F> {
19    factory: F,
20    _marker: PhantomData<fn() -> T>,
21}
22
23pub struct TransientFactoryProvider<T, F> {
24    factory: F,
25    _marker: PhantomData<fn() -> T>,
26}
27
28impl Provider {
29    pub fn value<T>(value: T) -> ValueProvider<T>
30    where
31        T: Send + Sync + 'static,
32    {
33        ValueProvider { value }
34    }
35
36    pub fn factory<T, F>(factory: F) -> FactoryProvider<T, F>
37    where
38        T: Send + Sync + 'static,
39        F: FnOnce(&Container) -> Result<T> + Send + 'static,
40    {
41        FactoryProvider {
42            factory,
43            _marker: PhantomData,
44        }
45    }
46
47    pub fn request_factory<T, F>(factory: F) -> RequestFactoryProvider<T, F>
48    where
49        T: Send + Sync + 'static,
50        F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
51    {
52        RequestFactoryProvider {
53            factory,
54            _marker: PhantomData,
55        }
56    }
57
58    pub fn transient_factory<T, F>(factory: F) -> TransientFactoryProvider<T, F>
59    where
60        T: Send + Sync + 'static,
61        F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
62    {
63        TransientFactoryProvider {
64            factory,
65            _marker: PhantomData,
66        }
67    }
68}
69
70pub trait RegisterProvider {
71    fn register(self, container: &Container) -> Result<()>;
72}
73
74impl<T> RegisterProvider for ValueProvider<T>
75where
76    T: Send + Sync + 'static,
77{
78    fn register(self, container: &Container) -> Result<()> {
79        framework_log_event(
80            "provider_register",
81            &[("type", std::any::type_name::<T>().to_string())],
82        );
83        container.register(self.value)?;
84        Ok(())
85    }
86}
87
88impl<T, F> RegisterProvider for FactoryProvider<T, F>
89where
90    T: Send + Sync + 'static,
91    F: FnOnce(&Container) -> Result<T> + Send + 'static,
92{
93    fn register(self, container: &Container) -> Result<()> {
94        framework_log_event(
95            "provider_register_factory",
96            &[("type", std::any::type_name::<T>().to_string())],
97        );
98        let value = (self.factory)(container).map_err(|err| {
99            anyhow!(
100                "Failed to build provider `{}`: {}",
101                std::any::type_name::<T>(),
102                err
103            )
104        })?;
105        container.register(value)?;
106        Ok(())
107    }
108}
109
110impl<T, F> RegisterProvider for RequestFactoryProvider<T, F>
111where
112    T: Send + Sync + 'static,
113    F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
114{
115    fn register(self, container: &Container) -> Result<()> {
116        framework_log_event(
117            "provider_register_request_factory",
118            &[("type", std::any::type_name::<T>().to_string())],
119        );
120        container
121            .register_request_factory::<T, _>(move |container| {
122                (self.factory)(container).map_err(|err| {
123                    anyhow!(
124                        "Failed to build request-scoped provider `{}`: {}",
125                        std::any::type_name::<T>(),
126                        err
127                    )
128                })
129            })
130            .map_err(|err| anyhow!("Failed to register request-scoped provider: {err}"))?;
131        Ok(())
132    }
133}
134
135impl<T, F> RegisterProvider for TransientFactoryProvider<T, F>
136where
137    T: Send + Sync + 'static,
138    F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
139{
140    fn register(self, container: &Container) -> Result<()> {
141        framework_log_event(
142            "provider_register_transient_factory",
143            &[("type", std::any::type_name::<T>().to_string())],
144        );
145        container
146            .register_transient_factory::<T, _>(move |container| {
147                (self.factory)(container).map_err(|err| {
148                    anyhow!(
149                        "Failed to build transient provider `{}`: {}",
150                        std::any::type_name::<T>(),
151                        err
152                    )
153                })
154            })
155            .map_err(|err| anyhow!("Failed to register transient provider: {err}"))?;
156        Ok(())
157    }
158}
159
160pub fn register_provider<P>(container: &Container, provider: P) -> Result<()>
161where
162    P: RegisterProvider,
163{
164    provider.register(container)
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[derive(Clone)]
172    struct AppConfig {
173        app_name: &'static str,
174    }
175
176    struct AppService {
177        config_name: &'static str,
178    }
179
180    #[test]
181    fn registers_value_provider() {
182        let container = Container::new();
183        let result = register_provider(
184            &container,
185            Provider::value(AppConfig {
186                app_name: "nestforge",
187            }),
188        );
189
190        assert!(result.is_ok(), "value provider registration should succeed");
191        let config = container
192            .resolve::<AppConfig>()
193            .expect("config should be registered");
194        assert_eq!(config.app_name, "nestforge");
195    }
196
197    #[test]
198    fn registers_factory_provider() {
199        let container = Container::new();
200        register_provider(
201            &container,
202            Provider::value(AppConfig {
203                app_name: "nestforge",
204            }),
205        )
206        .expect("seed config");
207
208        let result = register_provider(
209            &container,
210            Provider::factory(|c| {
211                let cfg = c.resolve::<AppConfig>()?;
212                Ok(AppService {
213                    config_name: cfg.app_name,
214                })
215            }),
216        );
217
218        assert!(
219            result.is_ok(),
220            "factory provider registration should succeed"
221        );
222        let service = container
223            .resolve::<AppService>()
224            .expect("service should be registered");
225        assert_eq!(service.config_name, "nestforge");
226    }
227
228    #[test]
229    fn factory_error_includes_type_name() {
230        let container = Container::new();
231        let err = register_provider(
232            &container,
233            Provider::factory::<AppService, _>(|_| Err(anyhow!("boom"))),
234        )
235        .expect_err("factory should fail");
236
237        assert!(err.to_string().contains("AppService"));
238    }
239
240    #[test]
241    fn registers_request_factory_provider() {
242        #[derive(Clone)]
243        struct RequestId(&'static str);
244
245        struct RequestService(&'static str);
246
247        let container = Container::new();
248        register_provider(
249            &container,
250            Provider::request_factory(|c| {
251                let request_id = c.resolve::<RequestId>()?;
252                Ok(RequestService(request_id.0))
253            }),
254        )
255        .expect("request factory should register");
256
257        let scoped = container.scoped();
258        scoped
259            .override_value(RequestId("req-42"))
260            .expect("request id should be set");
261
262        let service = scoped
263            .resolve::<RequestService>()
264            .expect("request service should resolve");
265        assert_eq!(service.0, "req-42");
266    }
267
268    #[test]
269    fn registers_transient_factory_provider() {
270        use std::sync::{
271            atomic::{AtomicUsize, Ordering},
272            Arc,
273        };
274
275        struct TransientService(usize);
276
277        let container = Container::new();
278        let counter = Arc::new(AtomicUsize::new(0));
279        let counter_for_factory = Arc::clone(&counter);
280
281        register_provider(
282            &container,
283            Provider::transient_factory(move |_| {
284                let value = counter_for_factory.fetch_add(1, Ordering::Relaxed) + 1;
285                Ok(TransientService(value))
286            }),
287        )
288        .expect("transient factory should register");
289
290        let first = container
291            .resolve::<TransientService>()
292            .expect("first transient should resolve");
293        let second = container
294            .resolve::<TransientService>()
295            .expect("second transient should resolve");
296
297        assert_eq!(first.0, 1);
298        assert_eq!(second.0, 2);
299    }
300}