Skip to main content

jerrycan_core/
dep.rs

1//! Dependency injection (spec §4.3) — async, nested, per-request memoized,
2//! override-able in tests. Resolution order: cache → overrides → singletons → factories.
3//! Singletons and factories are disjoint by construction (`insert_value`/`insert_factory`
4//! each remove the opposite key), so there is no singleton-vs-factory tiebreak to define.
5
6use crate::error::{Error, Result};
7use crate::extract::{FromRequest, RequestCtx};
8use std::any::{Any, TypeId, type_name};
9use std::collections::HashMap;
10use std::future::Future;
11use std::ops::Deref;
12use std::pin::Pin;
13use std::sync::Arc;
14
15pub(crate) type AnyArc = Arc<dyn Any + Send + Sync>;
16pub(crate) type ProviderFut<'a> = Pin<Box<dyn Future<Output = Result<AnyArc>> + Send + 'a>>;
17pub(crate) type ProviderFn =
18    Arc<dyn for<'a> Fn(&'a mut RequestCtx) -> ProviderFut<'a> + Send + Sync>;
19
20/// The provider set effective for a route: app providers merged with the
21/// route's module chain — inner module wins (spec §4.2 scoping).
22#[derive(Default, Clone)]
23pub struct DepEnv {
24    pub(crate) singletons: HashMap<TypeId, AnyArc>,
25    pub(crate) factories: HashMap<TypeId, ProviderFn>,
26}
27
28impl DepEnv {
29    /// Register an already-built value; shared by every request (singleton scope).
30    pub(crate) fn insert_value<T: Send + Sync + 'static>(&mut self, value: T) {
31        let id = TypeId::of::<T>();
32        self.singletons.insert(id, Arc::new(value));
33        self.factories.remove(&id);
34    }
35
36    /// Register an async factory; runs at most once per request (request scope).
37    pub(crate) fn insert_factory<F, Args, T>(&mut self, factory: F)
38    where
39        F: DepFactory<Args, T>,
40        T: Send + Sync + 'static,
41    {
42        let id = TypeId::of::<T>();
43        self.factories.insert(id, factory.into_provider());
44        self.singletons.remove(&id);
45    }
46
47    /// Later entries shadow earlier ones — used to layer module envs over the app env.
48    pub(crate) fn merge_from(&mut self, inner: &DepEnv) {
49        for (k, v) in &inner.singletons {
50            self.singletons.insert(*k, v.clone());
51            self.factories.remove(k);
52        }
53        for (k, f) in &inner.factories {
54            self.factories.insert(*k, f.clone());
55            self.singletons.remove(k);
56        }
57    }
58}
59
60/// Per-request resolution state. Cheap to create; memoizes by `TypeId`.
61pub struct DepResolver {
62    pub(crate) env: Arc<DepEnv>,
63    pub(crate) overrides: Arc<HashMap<TypeId, AnyArc>>,
64    pub(crate) cache: HashMap<TypeId, AnyArc>,
65    pub(crate) depth: u8,
66}
67
68impl DepResolver {
69    pub(crate) fn new(env: Arc<DepEnv>, overrides: Arc<HashMap<TypeId, AnyArc>>) -> Self {
70        Self {
71            env,
72            overrides,
73            cache: HashMap::new(),
74            depth: 0,
75        }
76    }
77}
78
79const MAX_RESOLVE_DEPTH: u8 = 32;
80
81impl RequestCtx {
82    /// Resolve a dependency by type, memoized for this request (spec §4.3).
83    pub async fn resolve<T: Send + Sync + 'static>(&mut self) -> Result<Arc<T>> {
84        let id = TypeId::of::<T>();
85        if let Some(v) = self.deps.cache.get(&id) {
86            return downcast::<T>(v.clone());
87        }
88        if let Some(v) = self.deps.overrides.get(&id).cloned() {
89            self.deps.cache.insert(id, v.clone());
90            return downcast::<T>(v);
91        }
92        if let Some(v) = self.deps.env.singletons.get(&id).cloned() {
93            self.deps.cache.insert(id, v.clone());
94            return downcast::<T>(v);
95        }
96        let factory = match self.deps.env.factories.get(&id) {
97            Some(f) => f.clone(),
98            None => return Err(Error::missing_dependency(type_name::<T>())),
99        };
100        self.deps.depth += 1;
101        if self.deps.depth > MAX_RESOLVE_DEPTH {
102            self.deps.depth -= 1;
103            return Err(Error::dependency_cycle());
104        }
105        let produced = (*factory)(self).await;
106        self.deps.depth -= 1;
107        let v = produced?;
108        self.deps.cache.insert(id, v.clone());
109        downcast::<T>(v)
110    }
111}
112
113fn downcast<T: Send + Sync + 'static>(v: AnyArc) -> Result<Arc<T>> {
114    v.downcast::<T>()
115        .map_err(|_| Error::internal("dependency type mismatch (provider/consumer disagree)"))
116}
117
118/// Resolve dependencies OUTSIDE an HTTP request — background jobs, startup
119/// wiring, CLI commands. Built from [`BuiltApp::task_context`](crate::BuiltApp::task_context).
120///
121/// Resolution semantics mirror a request (memoized per context, honors
122/// overrides and singletons), but only **app-level** providers are in scope,
123/// and any factory reaching for an HTTP extractor (`Json`/`Path`/`Query`/
124/// `Headers`) fails with `JC1003`.
125pub struct TaskContext(RequestCtx);
126
127impl TaskContext {
128    /// Wrap a resolver in a synthetic request marked as a task context. The
129    /// parts/body/params are placeholders: HTTP-coupled extractors reject a
130    /// task context before reading them, and `Dep<T>` never touches them.
131    pub(crate) fn new(deps: DepResolver) -> Self {
132        let req = http::Request::builder()
133            .uri("/")
134            .body(())
135            .expect("static request head always builds");
136        let (parts, ()) = req.into_parts();
137        let mut ctx = RequestCtx::new(parts, bytes::Bytes::new(), deps);
138        ctx.is_task = true;
139        TaskContext(ctx)
140    }
141
142    /// Resolve a dependency by type, memoized for this task context. Mirrors
143    /// [`RequestCtx::resolve`]; returns the same `Arc<T>` `Dep<T>` would carry.
144    pub async fn resolve<T: Send + Sync + 'static>(&mut self) -> Result<Arc<T>> {
145        self.0.resolve::<T>().await
146    }
147
148    /// A sibling task context sharing the app-level providers but with a fresh
149    /// dependency-resolution cache — the job worker forks one per job so cached
150    /// deps never leak between jobs.
151    ///
152    /// The new context reuses the same `Arc<DepEnv>` (singletons + factories)
153    /// and the same `Arc` of overrides, so it resolves identical singletons; it
154    /// gets an empty cache, so request-scope factories run afresh and per-job
155    /// state is isolated.
156    pub fn fork(&self) -> TaskContext {
157        TaskContext::new(DepResolver::new(
158            self.0.deps.env.clone(),
159            self.0.deps.overrides.clone(),
160        ))
161    }
162}
163
164/// A resolved dependency. Derefs to `T`; cloning is `Arc`-cheap.
165pub struct Dep<T: ?Sized>(pub(crate) Arc<T>);
166
167impl<T: ?Sized> Deref for Dep<T> {
168    type Target = T;
169    fn deref(&self) -> &T {
170        &self.0
171    }
172}
173
174impl<T: ?Sized> Clone for Dep<T> {
175    fn clone(&self) -> Self {
176        Dep(self.0.clone())
177    }
178}
179
180impl<T: Send + Sync + 'static> FromRequest for Dep<T> {
181    async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
182        ctx.resolve::<T>().await.map(Dep)
183    }
184}
185
186/// Async functions registrable as providers. `Args` is the tuple of extractor
187/// parameters; `T` the produced dependency. Implemented for arities 0..=8.
188pub trait DepFactory<Args, T>: Send + Sync + 'static {
189    fn into_provider(self) -> ProviderFn;
190}
191
192macro_rules! impl_dep_factory {
193    ($($A:ident),*) => {
194        impl<F, Fut, T, $($A,)*> DepFactory<($($A,)*), T> for F
195        where
196            F: Fn($($A),*) -> Fut + Clone + Send + Sync + 'static,
197            Fut: Future<Output = Result<T>> + Send,
198            T: Send + Sync + 'static,
199            $($A: FromRequest + 'static,)*
200        {
201            fn into_provider(self) -> ProviderFn {
202                #[allow(unused_variables)]
203                Arc::new(move |ctx: &mut RequestCtx| {
204                    let f = self.clone();
205                    Box::pin(async move {
206                        #[allow(non_snake_case, unused_variables)]
207                        {
208                            $(let $A = <$A as FromRequest>::from_request(ctx).await?;)*
209                            let value = f($($A),*).await?;
210                            Ok(Arc::new(value) as AnyArc)
211                        }
212                    })
213                })
214            }
215        }
216    };
217}
218
219impl_dep_factory!();
220impl_dep_factory!(A1);
221impl_dep_factory!(A1, A2);
222impl_dep_factory!(A1, A2, A3);
223impl_dep_factory!(A1, A2, A3, A4);
224impl_dep_factory!(A1, A2, A3, A4, A5);
225impl_dep_factory!(A1, A2, A3, A4, A5, A6);
226impl_dep_factory!(A1, A2, A3, A4, A5, A6, A7);
227impl_dep_factory!(A1, A2, A3, A4, A5, A6, A7, A8);
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use bytes::Bytes;
233
234    pub(crate) fn test_ctx(env: DepEnv) -> RequestCtx {
235        let req = http::Request::builder().uri("/").body(()).unwrap();
236        let (parts, ()) = req.into_parts();
237        RequestCtx::new(
238            parts,
239            Bytes::new(),
240            DepResolver::new(Arc::new(env), Arc::new(HashMap::new())),
241        )
242    }
243
244    struct Config {
245        name: &'static str,
246    }
247
248    #[tokio::test]
249    async fn value_provider_resolves_and_derefs() {
250        let mut env = DepEnv::default();
251        env.insert_value(Config { name: "prod" });
252        let mut ctx = test_ctx(env);
253        let cfg: Dep<Config> = Dep::from_request(&mut ctx).await.unwrap();
254        assert_eq!(cfg.name, "prod"); // Deref<Target = Config>
255    }
256
257    #[tokio::test]
258    async fn missing_provider_is_jc1001() {
259        let mut ctx = test_ctx(DepEnv::default());
260        let err = Dep::<Config>::from_request(&mut ctx).await.err().unwrap();
261        assert_eq!(err.code(), "JC1001");
262        assert!(err.message().contains("Config"));
263    }
264
265    #[tokio::test]
266    async fn same_request_yields_same_arc() {
267        let mut env = DepEnv::default();
268        env.insert_value(Config { name: "x" });
269        let mut ctx = test_ctx(env);
270        let a = ctx.resolve::<Config>().await.unwrap();
271        let b = ctx.resolve::<Config>().await.unwrap();
272        assert!(Arc::ptr_eq(&a, &b));
273    }
274
275    use std::sync::atomic::{AtomicUsize, Ordering};
276
277    #[derive(Clone)]
278    struct Db {
279        url: &'static str,
280    }
281    struct Session {
282        token: String,
283    }
284    struct User {
285        name: String,
286    }
287
288    async fn make_session() -> crate::Result<Session> {
289        Ok(Session {
290            token: "t-1".into(),
291        })
292    }
293
294    async fn current_user(session: Dep<Session>, db: Dep<Db>) -> crate::Result<User> {
295        Ok(User {
296            name: format!("{}@{}", session.token, db.url),
297        })
298    }
299
300    fn nested_env() -> DepEnv {
301        let mut env = DepEnv::default();
302        env.insert_value(Db { url: "pg://prod" });
303        env.insert_factory(make_session);
304        env.insert_factory(current_user);
305        env
306    }
307
308    #[tokio::test]
309    async fn factories_nest_and_resolve_async() {
310        let mut ctx = test_ctx(nested_env());
311        let user = ctx.resolve::<User>().await.unwrap();
312        assert_eq!(user.name, "t-1@pg://prod");
313    }
314
315    #[tokio::test]
316    async fn nested_deps_are_memoized_once_per_request() {
317        // A test-local counter (`LOCAL_RESOLVES`) + session factory so parallel tests
318        // can't pollute these exact-count assertions.
319        static LOCAL_RESOLVES: AtomicUsize = AtomicUsize::new(0);
320        async fn local_session() -> crate::Result<Session> {
321            LOCAL_RESOLVES.fetch_add(1, Ordering::SeqCst);
322            Ok(Session {
323                token: "t-1".into(),
324            })
325        }
326        fn local_env() -> DepEnv {
327            let mut env = DepEnv::default();
328            env.insert_value(Db { url: "pg://prod" });
329            env.insert_factory(local_session);
330            env.insert_factory(current_user);
331            env
332        }
333
334        LOCAL_RESOLVES.store(0, Ordering::SeqCst);
335        let mut ctx = test_ctx(local_env());
336        // Both of these need Session (one directly, one through current_user).
337        let _s = ctx.resolve::<Session>().await.unwrap();
338        let _u = ctx.resolve::<User>().await.unwrap();
339        assert_eq!(
340            LOCAL_RESOLVES.load(Ordering::SeqCst),
341            1,
342            "memoized within request"
343        );
344
345        // A new request resolves afresh — request scope, not singleton.
346        let mut ctx2 = test_ctx(local_env());
347        let _u2 = ctx2.resolve::<User>().await.unwrap();
348        assert_eq!(LOCAL_RESOLVES.load(Ordering::SeqCst), 2);
349    }
350
351    #[tokio::test]
352    async fn self_referential_factory_hits_cycle_guard() {
353        struct Loopy;
354        async fn loopy(_again: Dep<Loopy>) -> crate::Result<Loopy> {
355            Ok(Loopy)
356        }
357        let mut env = DepEnv::default();
358        env.insert_factory(loopy);
359        let mut ctx = test_ctx(env);
360        let err = ctx.resolve::<Loopy>().await.err().unwrap();
361        assert_eq!(err.code(), "JC1002");
362    }
363
364    #[tokio::test]
365    async fn overrides_shadow_both_values_and_factories() {
366        // Real env: Db value + Session factory.
367        let mut env = nested_env();
368        env.insert_value(Db { url: "pg://prod" });
369
370        // Overrides replace them without touching the env.
371        let mut overrides: HashMap<TypeId, AnyArc> = HashMap::new();
372        overrides.insert(
373            TypeId::of::<Db>(),
374            Arc::new(Db {
375                url: "sqlite::memory:",
376            }),
377        );
378        overrides.insert(
379            TypeId::of::<Session>(),
380            Arc::new(Session {
381                token: "fake".into(),
382            }),
383        );
384
385        let req = http::Request::builder().uri("/").body(()).unwrap();
386        let (parts, ()) = req.into_parts();
387        let mut ctx = RequestCtx::new(
388            parts,
389            bytes::Bytes::new(),
390            DepResolver::new(Arc::new(env), Arc::new(overrides)),
391        );
392
393        let user = ctx.resolve::<User>().await.unwrap();
394        assert_eq!(user.name, "fake@sqlite::memory:");
395    }
396
397    // ---- Task-scoped DI (Task 5): resolve deps outside an HTTP request ----
398
399    #[tokio::test]
400    async fn deps_resolve_without_a_request() {
401        #[derive(Clone)]
402        struct Cfg(u32);
403        async fn make_cfg() -> crate::Result<Cfg> {
404            Ok(Cfg(7))
405        }
406        let built = crate::App::new().provide_dep(make_cfg).build().unwrap();
407        let mut ctx = built.task_context();
408        let cfg = ctx.resolve::<Cfg>().await.unwrap();
409        assert_eq!(cfg.0, 7);
410    }
411
412    #[tokio::test]
413    async fn task_resolution_memoizes_and_honors_singletons() {
414        // A factory with a side-effect counter must resolve at most once per
415        // task context (request-scope memoization), and a `provide()` singleton
416        // must also be reachable from the same context.
417        static BUILDS: AtomicUsize = AtomicUsize::new(0);
418        #[derive(Clone)]
419        struct Singleton(&'static str);
420        struct Counted;
421        async fn build_counted() -> crate::Result<Counted> {
422            BUILDS.fetch_add(1, Ordering::SeqCst);
423            Ok(Counted)
424        }
425
426        BUILDS.store(0, Ordering::SeqCst);
427        let built = crate::App::new()
428            .provide(Singleton("app"))
429            .provide_dep(build_counted)
430            .build()
431            .unwrap();
432
433        let mut ctx = built.task_context();
434        // Singleton resolves in a task context.
435        let s = ctx.resolve::<Singleton>().await.unwrap();
436        assert_eq!(s.0, "app");
437        // Factory runs once and is memoized across repeat resolves.
438        let a = ctx.resolve::<Counted>().await.unwrap();
439        let b = ctx.resolve::<Counted>().await.unwrap();
440        assert!(Arc::ptr_eq(&a, &b));
441        assert_eq!(BUILDS.load(Ordering::SeqCst), 1, "memoized within the task");
442
443        // A fresh task context resolves the factory afresh (task scope, like a request).
444        let mut ctx2 = built.task_context();
445        let _ = ctx2.resolve::<Counted>().await.unwrap();
446        assert_eq!(BUILDS.load(Ordering::SeqCst), 2);
447    }
448
449    #[tokio::test]
450    async fn http_extractors_reject_task_context_with_jc1003() {
451        #[derive(Clone)]
452        struct Whoami(#[allow(dead_code)] String);
453        async fn needs_headers(h: crate::Headers) -> crate::Result<Whoami> {
454            let _ = h;
455            Ok(Whoami("x".into()))
456        }
457        let built = crate::App::new()
458            .provide_dep(needs_headers)
459            .build()
460            .unwrap();
461        let mut ctx = built.task_context();
462        let err = ctx.resolve::<Whoami>().await.err().unwrap();
463        assert_eq!(err.code(), "JC1003");
464        assert_eq!(err.status().as_u16(), 500);
465    }
466
467    #[test]
468    fn task_context_is_send() {
469        // The job worker holds an owned `TaskContext` across `.await`s and forks
470        // a fresh one per job; an `on_serve` future must be `Send`. An owned
471        // `TaskContext` is `Send` (its body lane is `Send` but `!Sync`) — this
472        // invariant is load-bearing for the jobs engine, so lock it in.
473        fn assert_send<T: Send>() {}
474        assert_send::<TaskContext>();
475    }
476
477    #[tokio::test]
478    async fn fork_shares_singletons_but_resets_the_resolution_cache() {
479        // A side-effect counter on a factory: each forked context must resolve
480        // the factory afresh (independent caches), while a `provide()` singleton
481        // is the same `Arc` across forks (shared app-level providers).
482        static BUILDS: AtomicUsize = AtomicUsize::new(0);
483        #[derive(Clone)]
484        struct Singleton(&'static str);
485        struct Counted;
486        async fn build_counted() -> crate::Result<Counted> {
487            BUILDS.fetch_add(1, Ordering::SeqCst);
488            Ok(Counted)
489        }
490
491        BUILDS.store(0, Ordering::SeqCst);
492        let built = crate::App::new()
493            .provide(Singleton("app"))
494            .provide_dep(build_counted)
495            .build()
496            .unwrap();
497        let base = built.task_context();
498
499        // Two forks resolve the SAME singleton Arc (shared providers)...
500        let mut a = base.fork();
501        let mut b = base.fork();
502        let sa = a.resolve::<Singleton>().await.unwrap();
503        let sb = b.resolve::<Singleton>().await.unwrap();
504        assert!(Arc::ptr_eq(&sa, &sb), "forks share the singleton provider");
505        assert_eq!(
506            sa.0, "app",
507            "the shared singleton carries the provided value"
508        );
509
510        // ...but each fork has an independent cache: the factory runs once per
511        // fork, not once total, so two forks => two builds.
512        let _ = a.resolve::<Counted>().await.unwrap();
513        let _ = b.resolve::<Counted>().await.unwrap();
514        assert_eq!(
515            BUILDS.load(Ordering::SeqCst),
516            2,
517            "each fork resolves the factory in its own fresh cache"
518        );
519    }
520
521    #[tokio::test]
522    async fn test_app_task_context_honors_overrides() {
523        #[derive(Clone)]
524        struct Cfg(u32);
525        async fn make_cfg() -> crate::Result<Cfg> {
526            Ok(Cfg(1)) // real provider value
527        }
528        let app = crate::App::new()
529            .provide_dep(make_cfg)
530            .into_test()
531            .override_dep(Cfg(99)); // test fake
532        let mut ctx = app.task_context();
533        let cfg = ctx.resolve::<Cfg>().await.unwrap();
534        assert_eq!(cfg.0, 99, "override wins in a task context too");
535    }
536}