1use std::marker::PhantomData;
2
3use anyhow::{anyhow, Result};
4
5use crate::{framework_log_event, Container};
6
7pub struct Provider;
27
28pub struct ValueProvider<T> {
36 value: T,
37}
38
39pub struct FactoryProvider<T, F> {
47 factory: F,
48 _marker: PhantomData<fn() -> T>,
49}
50
51pub struct RequestFactoryProvider<T, F> {
59 factory: F,
60 _marker: PhantomData<fn() -> T>,
61}
62
63pub struct TransientFactoryProvider<T, F> {
71 factory: F,
72 _marker: PhantomData<fn() -> T>,
73}
74
75impl Provider {
76 pub fn value<T>(value: T) -> ValueProvider<T>
88 where
89 T: Send + Sync + 'static,
90 {
91 ValueProvider { value }
92 }
93
94 pub fn factory<T, F>(factory: F) -> FactoryProvider<T, F>
106 where
107 T: Send + Sync + 'static,
108 F: FnOnce(&Container) -> Result<T> + Send + 'static,
109 {
110 FactoryProvider {
111 factory,
112 _marker: PhantomData,
113 }
114 }
115
116 pub fn request_factory<T, F>(factory: F) -> RequestFactoryProvider<T, F>
127 where
128 T: Send + Sync + 'static,
129 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
130 {
131 RequestFactoryProvider {
132 factory,
133 _marker: PhantomData,
134 }
135 }
136
137 pub fn transient_factory<T, F>(factory: F) -> TransientFactoryProvider<T, F>
148 where
149 T: Send + Sync + 'static,
150 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
151 {
152 TransientFactoryProvider {
153 factory,
154 _marker: PhantomData,
155 }
156 }
157}
158
159pub trait RegisterProvider {
166 fn register(self, container: &Container) -> Result<()>;
170}
171
172impl<T> RegisterProvider for ValueProvider<T>
173where
174 T: Send + Sync + 'static,
175{
176 fn register(self, container: &Container) -> Result<()> {
177 framework_log_event(
178 "provider_register",
179 &[("type", std::any::type_name::<T>().to_string())],
180 );
181 container.register(self.value)?;
182 Ok(())
183 }
184}
185
186impl<T, F> RegisterProvider for FactoryProvider<T, F>
187where
188 T: Send + Sync + 'static,
189 F: FnOnce(&Container) -> Result<T> + Send + 'static,
190{
191 fn register(self, container: &Container) -> Result<()> {
192 framework_log_event(
193 "provider_register_factory",
194 &[("type", std::any::type_name::<T>().to_string())],
195 );
196 let value = (self.factory)(container).map_err(|err| {
197 anyhow!(
198 "Failed to build provider `{}`: {}",
199 std::any::type_name::<T>(),
200 err
201 )
202 })?;
203 container.register(value)?;
204 Ok(())
205 }
206}
207
208impl<T, F> RegisterProvider for RequestFactoryProvider<T, F>
209where
210 T: Send + Sync + 'static,
211 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
212{
213 fn register(self, container: &Container) -> Result<()> {
214 framework_log_event(
215 "provider_register_request_factory",
216 &[("type", std::any::type_name::<T>().to_string())],
217 );
218 container
219 .register_request_factory::<T, _>(move |container| {
220 (self.factory)(container).map_err(|err| {
221 anyhow!(
222 "Failed to build request-scoped provider `{}`: {}",
223 std::any::type_name::<T>(),
224 err
225 )
226 })
227 })
228 .map_err(|err| anyhow!("Failed to register request-scoped provider: {err}"))?;
229 Ok(())
230 }
231}
232
233impl<T, F> RegisterProvider for TransientFactoryProvider<T, F>
234where
235 T: Send + Sync + 'static,
236 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
237{
238 fn register(self, container: &Container) -> Result<()> {
239 framework_log_event(
240 "provider_register_transient_factory",
241 &[("type", std::any::type_name::<T>().to_string())],
242 );
243 container
244 .register_transient_factory::<T, _>(move |container| {
245 (self.factory)(container).map_err(|err| {
246 anyhow!(
247 "Failed to build transient provider `{}`: {}",
248 std::any::type_name::<T>(),
249 err
250 )
251 })
252 })
253 .map_err(|err| anyhow!("Failed to register transient provider: {err}"))?;
254 Ok(())
255 }
256}
257
258pub fn register_provider<P>(container: &Container, provider: P) -> Result<()>
260where
261 P: RegisterProvider,
262{
263 provider.register(container)
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[derive(Clone)]
271 struct AppConfig {
272 app_name: &'static str,
273 }
274
275 struct AppService {
276 config_name: &'static str,
277 }
278
279 #[test]
280 fn registers_value_provider() {
281 let container = Container::new();
282 let result = register_provider(
283 &container,
284 Provider::value(AppConfig {
285 app_name: "nestforge",
286 }),
287 );
288
289 assert!(result.is_ok(), "value provider registration should succeed");
290 let config = container
291 .resolve::<AppConfig>()
292 .expect("config should be registered");
293 assert_eq!(config.app_name, "nestforge");
294 }
295
296 #[test]
297 fn registers_factory_provider() {
298 let container = Container::new();
299 register_provider(
300 &container,
301 Provider::value(AppConfig {
302 app_name: "nestforge",
303 }),
304 )
305 .expect("seed config");
306
307 let result = register_provider(
308 &container,
309 Provider::factory(|c| {
310 let cfg = c.resolve::<AppConfig>()?;
311 Ok(AppService {
312 config_name: cfg.app_name,
313 })
314 }),
315 );
316
317 assert!(
318 result.is_ok(),
319 "factory provider registration should succeed"
320 );
321 let service = container
322 .resolve::<AppService>()
323 .expect("service should be registered");
324 assert_eq!(service.config_name, "nestforge");
325 }
326
327 #[test]
328 fn factory_error_includes_type_name() {
329 let container = Container::new();
330 let err = register_provider(
331 &container,
332 Provider::factory::<AppService, _>(|_| Err(anyhow!("boom"))),
333 )
334 .expect_err("factory should fail");
335
336 assert!(err.to_string().contains("AppService"));
337 }
338
339 #[test]
340 fn registers_request_factory_provider() {
341 #[derive(Clone)]
342 struct RequestId(&'static str);
343
344 struct RequestService(&'static str);
345
346 let container = Container::new();
347 register_provider(
348 &container,
349 Provider::request_factory(|c| {
350 let request_id = c.resolve::<RequestId>()?;
351 Ok(RequestService(request_id.0))
352 }),
353 )
354 .expect("request factory should register");
355
356 let scoped = container.scoped();
357 scoped
358 .override_value(RequestId("req-42"))
359 .expect("request id should be set");
360
361 let service = scoped
362 .resolve::<RequestService>()
363 .expect("request service should resolve");
364 assert_eq!(service.0, "req-42");
365 }
366
367 #[test]
368 fn registers_transient_factory_provider() {
369 use std::sync::{
370 atomic::{AtomicUsize, Ordering},
371 Arc,
372 };
373
374 struct TransientService(usize);
375
376 let container = Container::new();
377 let counter = Arc::new(AtomicUsize::new(0));
378 let counter_for_factory = Arc::clone(&counter);
379
380 register_provider(
381 &container,
382 Provider::transient_factory(move |_| {
383 let value = counter_for_factory.fetch_add(1, Ordering::Relaxed) + 1;
384 Ok(TransientService(value))
385 }),
386 )
387 .expect("transient factory should register");
388
389 let first = container
390 .resolve::<TransientService>()
391 .expect("first transient should resolve");
392 let second = container
393 .resolve::<TransientService>()
394 .expect("second transient should resolve");
395
396 assert_eq!(first.0, 1);
397 assert_eq!(second.0, 2);
398 }
399}