1use crate::error::{Error, Result};
7use crate::extract::{FromRequest, RequestCtx};
8use std::any::{Any, TypeId, type_name};
9use std::collections::HashMap;
10use std::future::Future;
11use std::ops::Deref;
12use std::pin::Pin;
13use std::sync::Arc;
14
15pub(crate) type AnyArc = Arc<dyn Any + Send + Sync>;
16pub(crate) type ProviderFut<'a> = Pin<Box<dyn Future<Output = Result<AnyArc>> + Send + 'a>>;
17pub(crate) type ProviderFn =
18 Arc<dyn for<'a> Fn(&'a mut RequestCtx) -> ProviderFut<'a> + Send + Sync>;
19
20#[derive(Default, Clone)]
23pub struct DepEnv {
24 pub(crate) singletons: HashMap<TypeId, AnyArc>,
25 pub(crate) factories: HashMap<TypeId, ProviderFn>,
26}
27
28impl DepEnv {
29 pub(crate) fn insert_value<T: Send + Sync + 'static>(&mut self, value: T) {
31 let id = TypeId::of::<T>();
32 self.singletons.insert(id, Arc::new(value));
33 self.factories.remove(&id);
34 }
35
36 pub(crate) fn insert_factory<F, Args, T>(&mut self, factory: F)
38 where
39 F: DepFactory<Args, T>,
40 T: Send + Sync + 'static,
41 {
42 let id = TypeId::of::<T>();
43 self.factories.insert(id, factory.into_provider());
44 self.singletons.remove(&id);
45 }
46
47 pub(crate) fn merge_from(&mut self, inner: &DepEnv) {
49 for (k, v) in &inner.singletons {
50 self.singletons.insert(*k, v.clone());
51 self.factories.remove(k);
52 }
53 for (k, f) in &inner.factories {
54 self.factories.insert(*k, f.clone());
55 self.singletons.remove(k);
56 }
57 }
58}
59
60pub struct DepResolver {
62 pub(crate) env: Arc<DepEnv>,
63 pub(crate) overrides: Arc<HashMap<TypeId, AnyArc>>,
64 pub(crate) cache: HashMap<TypeId, AnyArc>,
65 pub(crate) depth: u8,
66}
67
68impl DepResolver {
69 pub(crate) fn new(env: Arc<DepEnv>, overrides: Arc<HashMap<TypeId, AnyArc>>) -> Self {
70 Self {
71 env,
72 overrides,
73 cache: HashMap::new(),
74 depth: 0,
75 }
76 }
77}
78
79const MAX_RESOLVE_DEPTH: u8 = 32;
80
81impl RequestCtx {
82 pub async fn resolve<T: Send + Sync + 'static>(&mut self) -> Result<Arc<T>> {
84 let id = TypeId::of::<T>();
85 if let Some(v) = self.deps.cache.get(&id) {
86 return downcast::<T>(v.clone());
87 }
88 if let Some(v) = self.deps.overrides.get(&id).cloned() {
89 self.deps.cache.insert(id, v.clone());
90 return downcast::<T>(v);
91 }
92 if let Some(v) = self.deps.env.singletons.get(&id).cloned() {
93 self.deps.cache.insert(id, v.clone());
94 return downcast::<T>(v);
95 }
96 let factory = match self.deps.env.factories.get(&id) {
97 Some(f) => f.clone(),
98 None => return Err(Error::missing_dependency(type_name::<T>())),
99 };
100 self.deps.depth += 1;
101 if self.deps.depth > MAX_RESOLVE_DEPTH {
102 self.deps.depth -= 1;
103 return Err(Error::dependency_cycle());
104 }
105 let produced = (*factory)(self).await;
106 self.deps.depth -= 1;
107 let v = produced?;
108 self.deps.cache.insert(id, v.clone());
109 downcast::<T>(v)
110 }
111}
112
113fn downcast<T: Send + Sync + 'static>(v: AnyArc) -> Result<Arc<T>> {
114 v.downcast::<T>()
115 .map_err(|_| Error::internal("dependency type mismatch (provider/consumer disagree)"))
116}
117
118pub struct TaskContext(RequestCtx);
126
127impl TaskContext {
128 pub(crate) fn new(deps: DepResolver) -> Self {
132 let req = http::Request::builder()
133 .uri("/")
134 .body(())
135 .expect("static request head always builds");
136 let (parts, ()) = req.into_parts();
137 let mut ctx = RequestCtx::new(parts, bytes::Bytes::new(), deps);
138 ctx.is_task = true;
139 TaskContext(ctx)
140 }
141
142 pub async fn resolve<T: Send + Sync + 'static>(&mut self) -> Result<Arc<T>> {
145 self.0.resolve::<T>().await
146 }
147
148 pub fn fork(&self) -> TaskContext {
157 TaskContext::new(DepResolver::new(
158 self.0.deps.env.clone(),
159 self.0.deps.overrides.clone(),
160 ))
161 }
162}
163
164pub struct Dep<T: ?Sized>(pub(crate) Arc<T>);
166
167impl<T: ?Sized> Deref for Dep<T> {
168 type Target = T;
169 fn deref(&self) -> &T {
170 &self.0
171 }
172}
173
174impl<T: ?Sized> Clone for Dep<T> {
175 fn clone(&self) -> Self {
176 Dep(self.0.clone())
177 }
178}
179
180impl<T: Send + Sync + 'static> FromRequest for Dep<T> {
181 async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
182 ctx.resolve::<T>().await.map(Dep)
183 }
184}
185
186pub trait DepFactory<Args, T>: Send + Sync + 'static {
189 fn into_provider(self) -> ProviderFn;
190}
191
192macro_rules! impl_dep_factory {
193 ($($A:ident),*) => {
194 impl<F, Fut, T, $($A,)*> DepFactory<($($A,)*), T> for F
195 where
196 F: Fn($($A),*) -> Fut + Clone + Send + Sync + 'static,
197 Fut: Future<Output = Result<T>> + Send,
198 T: Send + Sync + 'static,
199 $($A: FromRequest + 'static,)*
200 {
201 fn into_provider(self) -> ProviderFn {
202 #[allow(unused_variables)]
203 Arc::new(move |ctx: &mut RequestCtx| {
204 let f = self.clone();
205 Box::pin(async move {
206 #[allow(non_snake_case, unused_variables)]
207 {
208 $(let $A = <$A as FromRequest>::from_request(ctx).await?;)*
209 let value = f($($A),*).await?;
210 Ok(Arc::new(value) as AnyArc)
211 }
212 })
213 })
214 }
215 }
216 };
217}
218
219impl_dep_factory!();
220impl_dep_factory!(A1);
221impl_dep_factory!(A1, A2);
222impl_dep_factory!(A1, A2, A3);
223impl_dep_factory!(A1, A2, A3, A4);
224impl_dep_factory!(A1, A2, A3, A4, A5);
225impl_dep_factory!(A1, A2, A3, A4, A5, A6);
226impl_dep_factory!(A1, A2, A3, A4, A5, A6, A7);
227impl_dep_factory!(A1, A2, A3, A4, A5, A6, A7, A8);
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use bytes::Bytes;
233
234 pub(crate) fn test_ctx(env: DepEnv) -> RequestCtx {
235 let req = http::Request::builder().uri("/").body(()).unwrap();
236 let (parts, ()) = req.into_parts();
237 RequestCtx::new(
238 parts,
239 Bytes::new(),
240 DepResolver::new(Arc::new(env), Arc::new(HashMap::new())),
241 )
242 }
243
244 struct Config {
245 name: &'static str,
246 }
247
248 #[tokio::test]
249 async fn value_provider_resolves_and_derefs() {
250 let mut env = DepEnv::default();
251 env.insert_value(Config { name: "prod" });
252 let mut ctx = test_ctx(env);
253 let cfg: Dep<Config> = Dep::from_request(&mut ctx).await.unwrap();
254 assert_eq!(cfg.name, "prod"); }
256
257 #[tokio::test]
258 async fn missing_provider_is_jc1001() {
259 let mut ctx = test_ctx(DepEnv::default());
260 let err = Dep::<Config>::from_request(&mut ctx).await.err().unwrap();
261 assert_eq!(err.code(), "JC1001");
262 assert!(err.message().contains("Config"));
263 }
264
265 #[tokio::test]
266 async fn same_request_yields_same_arc() {
267 let mut env = DepEnv::default();
268 env.insert_value(Config { name: "x" });
269 let mut ctx = test_ctx(env);
270 let a = ctx.resolve::<Config>().await.unwrap();
271 let b = ctx.resolve::<Config>().await.unwrap();
272 assert!(Arc::ptr_eq(&a, &b));
273 }
274
275 use std::sync::atomic::{AtomicUsize, Ordering};
276
277 #[derive(Clone)]
278 struct Db {
279 url: &'static str,
280 }
281 struct Session {
282 token: String,
283 }
284 struct User {
285 name: String,
286 }
287
288 async fn make_session() -> crate::Result<Session> {
289 Ok(Session {
290 token: "t-1".into(),
291 })
292 }
293
294 async fn current_user(session: Dep<Session>, db: Dep<Db>) -> crate::Result<User> {
295 Ok(User {
296 name: format!("{}@{}", session.token, db.url),
297 })
298 }
299
300 fn nested_env() -> DepEnv {
301 let mut env = DepEnv::default();
302 env.insert_value(Db { url: "pg://prod" });
303 env.insert_factory(make_session);
304 env.insert_factory(current_user);
305 env
306 }
307
308 #[tokio::test]
309 async fn factories_nest_and_resolve_async() {
310 let mut ctx = test_ctx(nested_env());
311 let user = ctx.resolve::<User>().await.unwrap();
312 assert_eq!(user.name, "t-1@pg://prod");
313 }
314
315 #[tokio::test]
316 async fn nested_deps_are_memoized_once_per_request() {
317 static LOCAL_RESOLVES: AtomicUsize = AtomicUsize::new(0);
320 async fn local_session() -> crate::Result<Session> {
321 LOCAL_RESOLVES.fetch_add(1, Ordering::SeqCst);
322 Ok(Session {
323 token: "t-1".into(),
324 })
325 }
326 fn local_env() -> DepEnv {
327 let mut env = DepEnv::default();
328 env.insert_value(Db { url: "pg://prod" });
329 env.insert_factory(local_session);
330 env.insert_factory(current_user);
331 env
332 }
333
334 LOCAL_RESOLVES.store(0, Ordering::SeqCst);
335 let mut ctx = test_ctx(local_env());
336 let _s = ctx.resolve::<Session>().await.unwrap();
338 let _u = ctx.resolve::<User>().await.unwrap();
339 assert_eq!(
340 LOCAL_RESOLVES.load(Ordering::SeqCst),
341 1,
342 "memoized within request"
343 );
344
345 let mut ctx2 = test_ctx(local_env());
347 let _u2 = ctx2.resolve::<User>().await.unwrap();
348 assert_eq!(LOCAL_RESOLVES.load(Ordering::SeqCst), 2);
349 }
350
351 #[tokio::test]
352 async fn self_referential_factory_hits_cycle_guard() {
353 struct Loopy;
354 async fn loopy(_again: Dep<Loopy>) -> crate::Result<Loopy> {
355 Ok(Loopy)
356 }
357 let mut env = DepEnv::default();
358 env.insert_factory(loopy);
359 let mut ctx = test_ctx(env);
360 let err = ctx.resolve::<Loopy>().await.err().unwrap();
361 assert_eq!(err.code(), "JC1002");
362 }
363
364 #[tokio::test]
365 async fn overrides_shadow_both_values_and_factories() {
366 let mut env = nested_env();
368 env.insert_value(Db { url: "pg://prod" });
369
370 let mut overrides: HashMap<TypeId, AnyArc> = HashMap::new();
372 overrides.insert(
373 TypeId::of::<Db>(),
374 Arc::new(Db {
375 url: "sqlite::memory:",
376 }),
377 );
378 overrides.insert(
379 TypeId::of::<Session>(),
380 Arc::new(Session {
381 token: "fake".into(),
382 }),
383 );
384
385 let req = http::Request::builder().uri("/").body(()).unwrap();
386 let (parts, ()) = req.into_parts();
387 let mut ctx = RequestCtx::new(
388 parts,
389 bytes::Bytes::new(),
390 DepResolver::new(Arc::new(env), Arc::new(overrides)),
391 );
392
393 let user = ctx.resolve::<User>().await.unwrap();
394 assert_eq!(user.name, "fake@sqlite::memory:");
395 }
396
397 #[tokio::test]
400 async fn deps_resolve_without_a_request() {
401 #[derive(Clone)]
402 struct Cfg(u32);
403 async fn make_cfg() -> crate::Result<Cfg> {
404 Ok(Cfg(7))
405 }
406 let built = crate::App::new().provide_dep(make_cfg).build().unwrap();
407 let mut ctx = built.task_context();
408 let cfg = ctx.resolve::<Cfg>().await.unwrap();
409 assert_eq!(cfg.0, 7);
410 }
411
412 #[tokio::test]
413 async fn task_resolution_memoizes_and_honors_singletons() {
414 static BUILDS: AtomicUsize = AtomicUsize::new(0);
418 #[derive(Clone)]
419 struct Singleton(&'static str);
420 struct Counted;
421 async fn build_counted() -> crate::Result<Counted> {
422 BUILDS.fetch_add(1, Ordering::SeqCst);
423 Ok(Counted)
424 }
425
426 BUILDS.store(0, Ordering::SeqCst);
427 let built = crate::App::new()
428 .provide(Singleton("app"))
429 .provide_dep(build_counted)
430 .build()
431 .unwrap();
432
433 let mut ctx = built.task_context();
434 let s = ctx.resolve::<Singleton>().await.unwrap();
436 assert_eq!(s.0, "app");
437 let a = ctx.resolve::<Counted>().await.unwrap();
439 let b = ctx.resolve::<Counted>().await.unwrap();
440 assert!(Arc::ptr_eq(&a, &b));
441 assert_eq!(BUILDS.load(Ordering::SeqCst), 1, "memoized within the task");
442
443 let mut ctx2 = built.task_context();
445 let _ = ctx2.resolve::<Counted>().await.unwrap();
446 assert_eq!(BUILDS.load(Ordering::SeqCst), 2);
447 }
448
449 #[tokio::test]
450 async fn http_extractors_reject_task_context_with_jc1003() {
451 #[derive(Clone)]
452 struct Whoami(#[allow(dead_code)] String);
453 async fn needs_headers(h: crate::Headers) -> crate::Result<Whoami> {
454 let _ = h;
455 Ok(Whoami("x".into()))
456 }
457 let built = crate::App::new()
458 .provide_dep(needs_headers)
459 .build()
460 .unwrap();
461 let mut ctx = built.task_context();
462 let err = ctx.resolve::<Whoami>().await.err().unwrap();
463 assert_eq!(err.code(), "JC1003");
464 assert_eq!(err.status().as_u16(), 500);
465 }
466
467 #[test]
468 fn task_context_is_send() {
469 fn assert_send<T: Send>() {}
474 assert_send::<TaskContext>();
475 }
476
477 #[tokio::test]
478 async fn fork_shares_singletons_but_resets_the_resolution_cache() {
479 static BUILDS: AtomicUsize = AtomicUsize::new(0);
483 #[derive(Clone)]
484 struct Singleton(&'static str);
485 struct Counted;
486 async fn build_counted() -> crate::Result<Counted> {
487 BUILDS.fetch_add(1, Ordering::SeqCst);
488 Ok(Counted)
489 }
490
491 BUILDS.store(0, Ordering::SeqCst);
492 let built = crate::App::new()
493 .provide(Singleton("app"))
494 .provide_dep(build_counted)
495 .build()
496 .unwrap();
497 let base = built.task_context();
498
499 let mut a = base.fork();
501 let mut b = base.fork();
502 let sa = a.resolve::<Singleton>().await.unwrap();
503 let sb = b.resolve::<Singleton>().await.unwrap();
504 assert!(Arc::ptr_eq(&sa, &sb), "forks share the singleton provider");
505 assert_eq!(
506 sa.0, "app",
507 "the shared singleton carries the provided value"
508 );
509
510 let _ = a.resolve::<Counted>().await.unwrap();
513 let _ = b.resolve::<Counted>().await.unwrap();
514 assert_eq!(
515 BUILDS.load(Ordering::SeqCst),
516 2,
517 "each fork resolves the factory in its own fresh cache"
518 );
519 }
520
521 #[tokio::test]
522 async fn test_app_task_context_honors_overrides() {
523 #[derive(Clone)]
524 struct Cfg(u32);
525 async fn make_cfg() -> crate::Result<Cfg> {
526 Ok(Cfg(1)) }
528 let app = crate::App::new()
529 .provide_dep(make_cfg)
530 .into_test()
531 .override_dep(Cfg(99)); let mut ctx = app.task_context();
533 let cfg = ctx.resolve::<Cfg>().await.unwrap();
534 assert_eq!(cfg.0, 99, "override wins in a task context too");
535 }
536}