Skip to main content

nestforge_core/
provider.rs

1use std::marker::PhantomData;
2
3use anyhow::{anyhow, Result};
4
5use crate::{framework_log_event, Container};
6
7/**
8 * Provider Helper Struct
9 *
10 * A helper struct for creating different types of providers.
11 * Use the static methods `Provider::value()`, `Provider::factory()`,
12 * `Provider::request_factory()`, and `Provider::transient_factory()`
13 * when defining module providers or registering them manually.
14 *
15 * # Example
16 * ```rust
17 * impl ModuleDefinition for AppModule {
18 *     fn register(container: &Container) -> Result<()> {
19 *         register_provider(container, Provider::value(AppConfig::default()))?;
20 *         register_provider(container, Provider::factory(|c| Ok(MyService::new(c))))?;
21 *         Ok(())
22 *     }
23 * }
24 * ```
25 */
26pub struct Provider;
27
28/**
29 * ValueProvider
30 *
31 * A provider that registers an existing value as a singleton.
32 * The value is registered directly into the container and shared
33 * across all resolutions.
34 */
35pub struct ValueProvider<T> {
36    value: T,
37}
38
39/**
40 * FactoryProvider
41 *
42 * A provider that uses a factory function to create a singleton.
43 * The factory runs once, immediately upon registration, and the
44 * resulting instance is stored as a singleton.
45 */
46pub struct FactoryProvider<T, F> {
47    factory: F,
48    _marker: PhantomData<fn() -> T>,
49}
50
51/**
52 * RequestFactoryProvider
53 *
54 * A provider that creates a new instance for every HTTP request.
55 * The instance is cached for the duration of that request and shared
56 * within that request's scope.
57 */
58pub struct RequestFactoryProvider<T, F> {
59    factory: F,
60    _marker: PhantomData<fn() -> T>,
61}
62
63/**
64 * TransientFactoryProvider
65 *
66 * A provider that creates a new instance every time it is resolved.
67 * Unlike singletons or request-scoped providers, transient instances
68 * are never cached - a fresh instance is created on each resolution.
69 */
70pub struct TransientFactoryProvider<T, F> {
71    factory: F,
72    _marker: PhantomData<fn() -> T>,
73}
74
75impl Provider {
76    /**
77     * Creates a value provider from an existing value.
78     *
79     * The value will be registered as a singleton in the container.
80     *
81     * # Type Parameters
82     * - `T`: The type to register (must be Send + Sync + 'static)
83     *
84     * # Arguments
85     * - `value`: The value to register as a singleton
86     */
87    pub fn value<T>(value: T) -> ValueProvider<T>
88    where
89        T: Send + Sync + 'static,
90    {
91        ValueProvider { value }
92    }
93
94    /**
95     * Creates a factory provider.
96     *
97     * The factory receives the Container and returns Result<T>.
98     * It is executed immediately when the module registers its providers.
99     * The result is stored as a singleton.
100     *
101     * # Type Parameters
102     * - `T`: The type to create (must be Send + Sync + 'static)
103     * - `F`: The factory function type
104     */
105    pub fn factory<T, F>(factory: F) -> FactoryProvider<T, F>
106    where
107        T: Send + Sync + 'static,
108        F: FnOnce(&Container) -> Result<T> + Send + 'static,
109    {
110        FactoryProvider {
111            factory,
112            _marker: PhantomData,
113        }
114    }
115
116    /**
117     * Creates a request-scoped provider.
118     *
119     * The factory is executed once per request (per scoped container).
120     * The created instance is cached for the duration of that request.
121     *
122     * # Type Parameters
123     * - `T`: The type to create (must be Send + Sync + 'static)
124     * - `F`: The factory function type
125     */
126    pub fn request_factory<T, F>(factory: F) -> RequestFactoryProvider<T, F>
127    where
128        T: Send + Sync + 'static,
129        F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
130    {
131        RequestFactoryProvider {
132            factory,
133            _marker: PhantomData,
134        }
135    }
136
137    /**
138     * Creates a transient provider.
139     *
140     * The factory is executed every time the type is resolved via
141     * container.resolve(). A new instance is created each time.
142     *
143     * # Type Parameters
144     * - `T`: The type to create (must be Send + Sync + 'static)
145     * - `F`: The factory function type
146     */
147    pub fn transient_factory<T, F>(factory: F) -> TransientFactoryProvider<T, F>
148    where
149        T: Send + Sync + 'static,
150        F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
151    {
152        TransientFactoryProvider {
153            factory,
154            _marker: PhantomData,
155        }
156    }
157}
158
159/**
160 * RegisterProvider Trait
161 *
162 * A trait for types that can register themselves into a Container.
163 * This is implemented by all provider types returned by Provider:: methods.
164 */
165pub trait RegisterProvider {
166    /**
167     * Registers this provider into the given container.
168     */
169    fn register(self, container: &Container) -> Result<()>;
170}
171
172impl<T> RegisterProvider for ValueProvider<T>
173where
174    T: Send + Sync + 'static,
175{
176    fn register(self, container: &Container) -> Result<()> {
177        framework_log_event(
178            "provider_register",
179            &[("type", std::any::type_name::<T>().to_string())],
180        );
181        container.register(self.value)?;
182        Ok(())
183    }
184}
185
186impl<T, F> RegisterProvider for FactoryProvider<T, F>
187where
188    T: Send + Sync + 'static,
189    F: FnOnce(&Container) -> Result<T> + Send + 'static,
190{
191    fn register(self, container: &Container) -> Result<()> {
192        framework_log_event(
193            "provider_register_factory",
194            &[("type", std::any::type_name::<T>().to_string())],
195        );
196        let value = (self.factory)(container).map_err(|err| {
197            anyhow!(
198                "Failed to build provider `{}`: {}",
199                std::any::type_name::<T>(),
200                err
201            )
202        })?;
203        container.register(value)?;
204        Ok(())
205    }
206}
207
208impl<T, F> RegisterProvider for RequestFactoryProvider<T, F>
209where
210    T: Send + Sync + 'static,
211    F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
212{
213    fn register(self, container: &Container) -> Result<()> {
214        framework_log_event(
215            "provider_register_request_factory",
216            &[("type", std::any::type_name::<T>().to_string())],
217        );
218        container
219            .register_request_factory::<T, _>(move |container| {
220                (self.factory)(container).map_err(|err| {
221                    anyhow!(
222                        "Failed to build request-scoped provider `{}`: {}",
223                        std::any::type_name::<T>(),
224                        err
225                    )
226                })
227            })
228            .map_err(|err| anyhow!("Failed to register request-scoped provider: {err}"))?;
229        Ok(())
230    }
231}
232
233impl<T, F> RegisterProvider for TransientFactoryProvider<T, F>
234where
235    T: Send + Sync + 'static,
236    F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
237{
238    fn register(self, container: &Container) -> Result<()> {
239        framework_log_event(
240            "provider_register_transient_factory",
241            &[("type", std::any::type_name::<T>().to_string())],
242        );
243        container
244            .register_transient_factory::<T, _>(move |container| {
245                (self.factory)(container).map_err(|err| {
246                    anyhow!(
247                        "Failed to build transient provider `{}`: {}",
248                        std::any::type_name::<T>(),
249                        err
250                    )
251                })
252            })
253            .map_err(|err| anyhow!("Failed to register transient provider: {err}"))?;
254        Ok(())
255    }
256}
257
258/// Helper function to register a provider into a container.
259pub fn register_provider<P>(container: &Container, provider: P) -> Result<()>
260where
261    P: RegisterProvider,
262{
263    provider.register(container)
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[derive(Clone)]
271    struct AppConfig {
272        app_name: &'static str,
273    }
274
275    struct AppService {
276        config_name: &'static str,
277    }
278
279    #[test]
280    fn registers_value_provider() {
281        let container = Container::new();
282        let result = register_provider(
283            &container,
284            Provider::value(AppConfig {
285                app_name: "nestforge",
286            }),
287        );
288
289        assert!(result.is_ok(), "value provider registration should succeed");
290        let config = container
291            .resolve::<AppConfig>()
292            .expect("config should be registered");
293        assert_eq!(config.app_name, "nestforge");
294    }
295
296    #[test]
297    fn registers_factory_provider() {
298        let container = Container::new();
299        register_provider(
300            &container,
301            Provider::value(AppConfig {
302                app_name: "nestforge",
303            }),
304        )
305        .expect("seed config");
306
307        let result = register_provider(
308            &container,
309            Provider::factory(|c| {
310                let cfg = c.resolve::<AppConfig>()?;
311                Ok(AppService {
312                    config_name: cfg.app_name,
313                })
314            }),
315        );
316
317        assert!(
318            result.is_ok(),
319            "factory provider registration should succeed"
320        );
321        let service = container
322            .resolve::<AppService>()
323            .expect("service should be registered");
324        assert_eq!(service.config_name, "nestforge");
325    }
326
327    #[test]
328    fn factory_error_includes_type_name() {
329        let container = Container::new();
330        let err = register_provider(
331            &container,
332            Provider::factory::<AppService, _>(|_| Err(anyhow!("boom"))),
333        )
334        .expect_err("factory should fail");
335
336        assert!(err.to_string().contains("AppService"));
337    }
338
339    #[test]
340    fn registers_request_factory_provider() {
341        #[derive(Clone)]
342        struct RequestId(&'static str);
343
344        struct RequestService(&'static str);
345
346        let container = Container::new();
347        register_provider(
348            &container,
349            Provider::request_factory(|c| {
350                let request_id = c.resolve::<RequestId>()?;
351                Ok(RequestService(request_id.0))
352            }),
353        )
354        .expect("request factory should register");
355
356        let scoped = container.scoped();
357        scoped
358            .override_value(RequestId("req-42"))
359            .expect("request id should be set");
360
361        let service = scoped
362            .resolve::<RequestService>()
363            .expect("request service should resolve");
364        assert_eq!(service.0, "req-42");
365    }
366
367    #[test]
368    fn registers_transient_factory_provider() {
369        use std::sync::{
370            atomic::{AtomicUsize, Ordering},
371            Arc,
372        };
373
374        struct TransientService(usize);
375
376        let container = Container::new();
377        let counter = Arc::new(AtomicUsize::new(0));
378        let counter_for_factory = Arc::clone(&counter);
379
380        register_provider(
381            &container,
382            Provider::transient_factory(move |_| {
383                let value = counter_for_factory.fetch_add(1, Ordering::Relaxed) + 1;
384                Ok(TransientService(value))
385            }),
386        )
387        .expect("transient factory should register");
388
389        let first = container
390            .resolve::<TransientService>()
391            .expect("first transient should resolve");
392        let second = container
393            .resolve::<TransientService>()
394            .expect("second transient should resolve");
395
396        assert_eq!(first.0, 1);
397        assert_eq!(second.0, 2);
398    }
399}