1use std::marker::PhantomData;
2
3use anyhow::{anyhow, Result};
4
5use crate::{framework_log_event, Container};
6
7pub struct Provider;
8
9pub struct ValueProvider<T> {
10 value: T,
11}
12
13pub struct FactoryProvider<T, F> {
14 factory: F,
15 _marker: PhantomData<fn() -> T>,
16}
17
18pub struct RequestFactoryProvider<T, F> {
19 factory: F,
20 _marker: PhantomData<fn() -> T>,
21}
22
23pub struct TransientFactoryProvider<T, F> {
24 factory: F,
25 _marker: PhantomData<fn() -> T>,
26}
27
28impl Provider {
29 pub fn value<T>(value: T) -> ValueProvider<T>
30 where
31 T: Send + Sync + 'static,
32 {
33 ValueProvider { value }
34 }
35
36 pub fn factory<T, F>(factory: F) -> FactoryProvider<T, F>
37 where
38 T: Send + Sync + 'static,
39 F: FnOnce(&Container) -> Result<T> + Send + 'static,
40 {
41 FactoryProvider {
42 factory,
43 _marker: PhantomData,
44 }
45 }
46
47 pub fn request_factory<T, F>(factory: F) -> RequestFactoryProvider<T, F>
48 where
49 T: Send + Sync + 'static,
50 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
51 {
52 RequestFactoryProvider {
53 factory,
54 _marker: PhantomData,
55 }
56 }
57
58 pub fn transient_factory<T, F>(factory: F) -> TransientFactoryProvider<T, F>
59 where
60 T: Send + Sync + 'static,
61 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
62 {
63 TransientFactoryProvider {
64 factory,
65 _marker: PhantomData,
66 }
67 }
68}
69
70pub trait RegisterProvider {
71 fn register(self, container: &Container) -> Result<()>;
72}
73
74impl<T> RegisterProvider for ValueProvider<T>
75where
76 T: Send + Sync + 'static,
77{
78 fn register(self, container: &Container) -> Result<()> {
79 framework_log_event(
80 "provider_register",
81 &[("type", std::any::type_name::<T>().to_string())],
82 );
83 container.register(self.value)?;
84 Ok(())
85 }
86}
87
88impl<T, F> RegisterProvider for FactoryProvider<T, F>
89where
90 T: Send + Sync + 'static,
91 F: FnOnce(&Container) -> Result<T> + Send + 'static,
92{
93 fn register(self, container: &Container) -> Result<()> {
94 framework_log_event(
95 "provider_register_factory",
96 &[("type", std::any::type_name::<T>().to_string())],
97 );
98 let value = (self.factory)(container).map_err(|err| {
99 anyhow!(
100 "Failed to build provider `{}`: {}",
101 std::any::type_name::<T>(),
102 err
103 )
104 })?;
105 container.register(value)?;
106 Ok(())
107 }
108}
109
110impl<T, F> RegisterProvider for RequestFactoryProvider<T, F>
111where
112 T: Send + Sync + 'static,
113 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
114{
115 fn register(self, container: &Container) -> Result<()> {
116 framework_log_event(
117 "provider_register_request_factory",
118 &[("type", std::any::type_name::<T>().to_string())],
119 );
120 container
121 .register_request_factory::<T, _>(move |container| {
122 (self.factory)(container).map_err(|err| {
123 anyhow!(
124 "Failed to build request-scoped provider `{}`: {}",
125 std::any::type_name::<T>(),
126 err
127 )
128 })
129 })
130 .map_err(|err| anyhow!("Failed to register request-scoped provider: {err}"))?;
131 Ok(())
132 }
133}
134
135impl<T, F> RegisterProvider for TransientFactoryProvider<T, F>
136where
137 T: Send + Sync + 'static,
138 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
139{
140 fn register(self, container: &Container) -> Result<()> {
141 framework_log_event(
142 "provider_register_transient_factory",
143 &[("type", std::any::type_name::<T>().to_string())],
144 );
145 container
146 .register_transient_factory::<T, _>(move |container| {
147 (self.factory)(container).map_err(|err| {
148 anyhow!(
149 "Failed to build transient provider `{}`: {}",
150 std::any::type_name::<T>(),
151 err
152 )
153 })
154 })
155 .map_err(|err| anyhow!("Failed to register transient provider: {err}"))?;
156 Ok(())
157 }
158}
159
160pub fn register_provider<P>(container: &Container, provider: P) -> Result<()>
161where
162 P: RegisterProvider,
163{
164 provider.register(container)
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[derive(Clone)]
172 struct AppConfig {
173 app_name: &'static str,
174 }
175
176 struct AppService {
177 config_name: &'static str,
178 }
179
180 #[test]
181 fn registers_value_provider() {
182 let container = Container::new();
183 let result = register_provider(
184 &container,
185 Provider::value(AppConfig {
186 app_name: "nestforge",
187 }),
188 );
189
190 assert!(result.is_ok(), "value provider registration should succeed");
191 let config = container
192 .resolve::<AppConfig>()
193 .expect("config should be registered");
194 assert_eq!(config.app_name, "nestforge");
195 }
196
197 #[test]
198 fn registers_factory_provider() {
199 let container = Container::new();
200 register_provider(
201 &container,
202 Provider::value(AppConfig {
203 app_name: "nestforge",
204 }),
205 )
206 .expect("seed config");
207
208 let result = register_provider(
209 &container,
210 Provider::factory(|c| {
211 let cfg = c.resolve::<AppConfig>()?;
212 Ok(AppService {
213 config_name: cfg.app_name,
214 })
215 }),
216 );
217
218 assert!(
219 result.is_ok(),
220 "factory provider registration should succeed"
221 );
222 let service = container
223 .resolve::<AppService>()
224 .expect("service should be registered");
225 assert_eq!(service.config_name, "nestforge");
226 }
227
228 #[test]
229 fn factory_error_includes_type_name() {
230 let container = Container::new();
231 let err = register_provider(
232 &container,
233 Provider::factory::<AppService, _>(|_| Err(anyhow!("boom"))),
234 )
235 .expect_err("factory should fail");
236
237 assert!(err.to_string().contains("AppService"));
238 }
239
240 #[test]
241 fn registers_request_factory_provider() {
242 #[derive(Clone)]
243 struct RequestId(&'static str);
244
245 struct RequestService(&'static str);
246
247 let container = Container::new();
248 register_provider(
249 &container,
250 Provider::request_factory(|c| {
251 let request_id = c.resolve::<RequestId>()?;
252 Ok(RequestService(request_id.0))
253 }),
254 )
255 .expect("request factory should register");
256
257 let scoped = container.scoped();
258 scoped
259 .override_value(RequestId("req-42"))
260 .expect("request id should be set");
261
262 let service = scoped
263 .resolve::<RequestService>()
264 .expect("request service should resolve");
265 assert_eq!(service.0, "req-42");
266 }
267
268 #[test]
269 fn registers_transient_factory_provider() {
270 use std::sync::{
271 atomic::{AtomicUsize, Ordering},
272 Arc,
273 };
274
275 struct TransientService(usize);
276
277 let container = Container::new();
278 let counter = Arc::new(AtomicUsize::new(0));
279 let counter_for_factory = Arc::clone(&counter);
280
281 register_provider(
282 &container,
283 Provider::transient_factory(move |_| {
284 let value = counter_for_factory.fetch_add(1, Ordering::Relaxed) + 1;
285 Ok(TransientService(value))
286 }),
287 )
288 .expect("transient factory should register");
289
290 let first = container
291 .resolve::<TransientService>()
292 .expect("first transient should resolve");
293 let second = container
294 .resolve::<TransientService>()
295 .expect("second transient should resolve");
296
297 assert_eq!(first.0, 1);
298 assert_eq!(second.0, 2);
299 }
300}