use crate::error::{Error, Result};
use crate::extract::{FromRequest, RequestCtx};
use std::any::{Any, TypeId, type_name};
use std::collections::HashMap;
use std::future::Future;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
pub(crate) type AnyArc = Arc<dyn Any + Send + Sync>;
pub(crate) type ProviderFut<'a> = Pin<Box<dyn Future<Output = Result<AnyArc>> + Send + 'a>>;
pub(crate) type ProviderFn =
Arc<dyn for<'a> Fn(&'a mut RequestCtx) -> ProviderFut<'a> + Send + Sync>;
#[derive(Default, Clone)]
pub struct DepEnv {
pub(crate) singletons: HashMap<TypeId, AnyArc>,
pub(crate) factories: HashMap<TypeId, ProviderFn>,
}
impl DepEnv {
pub(crate) fn insert_value<T: Send + Sync + 'static>(&mut self, value: T) {
let id = TypeId::of::<T>();
self.singletons.insert(id, Arc::new(value));
self.factories.remove(&id);
}
pub(crate) fn insert_factory<F, Args, T>(&mut self, factory: F)
where
F: DepFactory<Args, T>,
T: Send + Sync + 'static,
{
let id = TypeId::of::<T>();
self.factories.insert(id, factory.into_provider());
self.singletons.remove(&id);
}
pub(crate) fn merge_from(&mut self, inner: &DepEnv) {
for (k, v) in &inner.singletons {
self.singletons.insert(*k, v.clone());
self.factories.remove(k);
}
for (k, f) in &inner.factories {
self.factories.insert(*k, f.clone());
self.singletons.remove(k);
}
}
}
pub struct DepResolver {
pub(crate) env: Arc<DepEnv>,
pub(crate) overrides: Arc<HashMap<TypeId, AnyArc>>,
pub(crate) cache: HashMap<TypeId, AnyArc>,
pub(crate) depth: u8,
}
impl DepResolver {
pub(crate) fn new(env: Arc<DepEnv>, overrides: Arc<HashMap<TypeId, AnyArc>>) -> Self {
Self {
env,
overrides,
cache: HashMap::new(),
depth: 0,
}
}
}
const MAX_RESOLVE_DEPTH: u8 = 32;
impl RequestCtx {
pub async fn resolve<T: Send + Sync + 'static>(&mut self) -> Result<Arc<T>> {
let id = TypeId::of::<T>();
if let Some(v) = self.deps.cache.get(&id) {
return downcast::<T>(v.clone());
}
if let Some(v) = self.deps.overrides.get(&id).cloned() {
self.deps.cache.insert(id, v.clone());
return downcast::<T>(v);
}
if let Some(v) = self.deps.env.singletons.get(&id).cloned() {
self.deps.cache.insert(id, v.clone());
return downcast::<T>(v);
}
let factory = match self.deps.env.factories.get(&id) {
Some(f) => f.clone(),
None => return Err(Error::missing_dependency(type_name::<T>())),
};
self.deps.depth += 1;
if self.deps.depth > MAX_RESOLVE_DEPTH {
self.deps.depth -= 1;
return Err(Error::dependency_cycle());
}
let produced = (*factory)(self).await;
self.deps.depth -= 1;
let v = produced?;
self.deps.cache.insert(id, v.clone());
downcast::<T>(v)
}
}
fn downcast<T: Send + Sync + 'static>(v: AnyArc) -> Result<Arc<T>> {
v.downcast::<T>()
.map_err(|_| Error::internal("dependency type mismatch (provider/consumer disagree)"))
}
pub struct TaskContext(RequestCtx);
impl TaskContext {
pub(crate) fn new(deps: DepResolver) -> Self {
let req = http::Request::builder()
.uri("/")
.body(())
.expect("static request head always builds");
let (parts, ()) = req.into_parts();
let mut ctx = RequestCtx::new(parts, bytes::Bytes::new(), deps);
ctx.is_task = true;
TaskContext(ctx)
}
pub async fn resolve<T: Send + Sync + 'static>(&mut self) -> Result<Arc<T>> {
self.0.resolve::<T>().await
}
pub fn fork(&self) -> TaskContext {
TaskContext::new(DepResolver::new(
self.0.deps.env.clone(),
self.0.deps.overrides.clone(),
))
}
}
pub struct Dep<T: ?Sized>(pub(crate) Arc<T>);
impl<T: ?Sized> Deref for Dep<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T: ?Sized> Clone for Dep<T> {
fn clone(&self) -> Self {
Dep(self.0.clone())
}
}
impl<T: Send + Sync + 'static> FromRequest for Dep<T> {
async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
ctx.resolve::<T>().await.map(Dep)
}
}
pub trait DepFactory<Args, T>: Send + Sync + 'static {
fn into_provider(self) -> ProviderFn;
}
macro_rules! impl_dep_factory {
($($A:ident),*) => {
impl<F, Fut, T, $($A,)*> DepFactory<($($A,)*), T> for F
where
F: Fn($($A),*) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<T>> + Send,
T: Send + Sync + 'static,
$($A: FromRequest + 'static,)*
{
fn into_provider(self) -> ProviderFn {
#[allow(unused_variables)]
Arc::new(move |ctx: &mut RequestCtx| {
let f = self.clone();
Box::pin(async move {
#[allow(non_snake_case, unused_variables)]
{
$(let $A = <$A as FromRequest>::from_request(ctx).await?;)*
let value = f($($A),*).await?;
Ok(Arc::new(value) as AnyArc)
}
})
})
}
}
};
}
impl_dep_factory!();
impl_dep_factory!(A1);
impl_dep_factory!(A1, A2);
impl_dep_factory!(A1, A2, A3);
impl_dep_factory!(A1, A2, A3, A4);
impl_dep_factory!(A1, A2, A3, A4, A5);
impl_dep_factory!(A1, A2, A3, A4, A5, A6);
impl_dep_factory!(A1, A2, A3, A4, A5, A6, A7);
impl_dep_factory!(A1, A2, A3, A4, A5, A6, A7, A8);
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
pub(crate) fn test_ctx(env: DepEnv) -> RequestCtx {
let req = http::Request::builder().uri("/").body(()).unwrap();
let (parts, ()) = req.into_parts();
RequestCtx::new(
parts,
Bytes::new(),
DepResolver::new(Arc::new(env), Arc::new(HashMap::new())),
)
}
struct Config {
name: &'static str,
}
#[tokio::test]
async fn value_provider_resolves_and_derefs() {
let mut env = DepEnv::default();
env.insert_value(Config { name: "prod" });
let mut ctx = test_ctx(env);
let cfg: Dep<Config> = Dep::from_request(&mut ctx).await.unwrap();
assert_eq!(cfg.name, "prod"); }
#[tokio::test]
async fn missing_provider_is_jc1001() {
let mut ctx = test_ctx(DepEnv::default());
let err = Dep::<Config>::from_request(&mut ctx).await.err().unwrap();
assert_eq!(err.code(), "JC1001");
assert!(err.message().contains("Config"));
}
#[tokio::test]
async fn same_request_yields_same_arc() {
let mut env = DepEnv::default();
env.insert_value(Config { name: "x" });
let mut ctx = test_ctx(env);
let a = ctx.resolve::<Config>().await.unwrap();
let b = ctx.resolve::<Config>().await.unwrap();
assert!(Arc::ptr_eq(&a, &b));
}
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Clone)]
struct Db {
url: &'static str,
}
struct Session {
token: String,
}
struct User {
name: String,
}
async fn make_session() -> crate::Result<Session> {
Ok(Session {
token: "t-1".into(),
})
}
async fn current_user(session: Dep<Session>, db: Dep<Db>) -> crate::Result<User> {
Ok(User {
name: format!("{}@{}", session.token, db.url),
})
}
fn nested_env() -> DepEnv {
let mut env = DepEnv::default();
env.insert_value(Db { url: "pg://prod" });
env.insert_factory(make_session);
env.insert_factory(current_user);
env
}
#[tokio::test]
async fn factories_nest_and_resolve_async() {
let mut ctx = test_ctx(nested_env());
let user = ctx.resolve::<User>().await.unwrap();
assert_eq!(user.name, "t-1@pg://prod");
}
#[tokio::test]
async fn nested_deps_are_memoized_once_per_request() {
static LOCAL_RESOLVES: AtomicUsize = AtomicUsize::new(0);
async fn local_session() -> crate::Result<Session> {
LOCAL_RESOLVES.fetch_add(1, Ordering::SeqCst);
Ok(Session {
token: "t-1".into(),
})
}
fn local_env() -> DepEnv {
let mut env = DepEnv::default();
env.insert_value(Db { url: "pg://prod" });
env.insert_factory(local_session);
env.insert_factory(current_user);
env
}
LOCAL_RESOLVES.store(0, Ordering::SeqCst);
let mut ctx = test_ctx(local_env());
let _s = ctx.resolve::<Session>().await.unwrap();
let _u = ctx.resolve::<User>().await.unwrap();
assert_eq!(
LOCAL_RESOLVES.load(Ordering::SeqCst),
1,
"memoized within request"
);
let mut ctx2 = test_ctx(local_env());
let _u2 = ctx2.resolve::<User>().await.unwrap();
assert_eq!(LOCAL_RESOLVES.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn self_referential_factory_hits_cycle_guard() {
struct Loopy;
async fn loopy(_again: Dep<Loopy>) -> crate::Result<Loopy> {
Ok(Loopy)
}
let mut env = DepEnv::default();
env.insert_factory(loopy);
let mut ctx = test_ctx(env);
let err = ctx.resolve::<Loopy>().await.err().unwrap();
assert_eq!(err.code(), "JC1002");
}
#[tokio::test]
async fn overrides_shadow_both_values_and_factories() {
let mut env = nested_env();
env.insert_value(Db { url: "pg://prod" });
let mut overrides: HashMap<TypeId, AnyArc> = HashMap::new();
overrides.insert(
TypeId::of::<Db>(),
Arc::new(Db {
url: "sqlite::memory:",
}),
);
overrides.insert(
TypeId::of::<Session>(),
Arc::new(Session {
token: "fake".into(),
}),
);
let req = http::Request::builder().uri("/").body(()).unwrap();
let (parts, ()) = req.into_parts();
let mut ctx = RequestCtx::new(
parts,
bytes::Bytes::new(),
DepResolver::new(Arc::new(env), Arc::new(overrides)),
);
let user = ctx.resolve::<User>().await.unwrap();
assert_eq!(user.name, "fake@sqlite::memory:");
}
#[tokio::test]
async fn deps_resolve_without_a_request() {
#[derive(Clone)]
struct Cfg(u32);
async fn make_cfg() -> crate::Result<Cfg> {
Ok(Cfg(7))
}
let built = crate::App::new().provide_dep(make_cfg).build().unwrap();
let mut ctx = built.task_context();
let cfg = ctx.resolve::<Cfg>().await.unwrap();
assert_eq!(cfg.0, 7);
}
#[tokio::test]
async fn task_resolution_memoizes_and_honors_singletons() {
static BUILDS: AtomicUsize = AtomicUsize::new(0);
#[derive(Clone)]
struct Singleton(&'static str);
struct Counted;
async fn build_counted() -> crate::Result<Counted> {
BUILDS.fetch_add(1, Ordering::SeqCst);
Ok(Counted)
}
BUILDS.store(0, Ordering::SeqCst);
let built = crate::App::new()
.provide(Singleton("app"))
.provide_dep(build_counted)
.build()
.unwrap();
let mut ctx = built.task_context();
let s = ctx.resolve::<Singleton>().await.unwrap();
assert_eq!(s.0, "app");
let a = ctx.resolve::<Counted>().await.unwrap();
let b = ctx.resolve::<Counted>().await.unwrap();
assert!(Arc::ptr_eq(&a, &b));
assert_eq!(BUILDS.load(Ordering::SeqCst), 1, "memoized within the task");
let mut ctx2 = built.task_context();
let _ = ctx2.resolve::<Counted>().await.unwrap();
assert_eq!(BUILDS.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn http_extractors_reject_task_context_with_jc1003() {
#[derive(Clone)]
struct Whoami(#[allow(dead_code)] String);
async fn needs_headers(h: crate::Headers) -> crate::Result<Whoami> {
let _ = h;
Ok(Whoami("x".into()))
}
let built = crate::App::new()
.provide_dep(needs_headers)
.build()
.unwrap();
let mut ctx = built.task_context();
let err = ctx.resolve::<Whoami>().await.err().unwrap();
assert_eq!(err.code(), "JC1003");
assert_eq!(err.status().as_u16(), 500);
}
#[test]
fn task_context_is_send() {
fn assert_send<T: Send>() {}
assert_send::<TaskContext>();
}
#[tokio::test]
async fn fork_shares_singletons_but_resets_the_resolution_cache() {
static BUILDS: AtomicUsize = AtomicUsize::new(0);
#[derive(Clone)]
struct Singleton(&'static str);
struct Counted;
async fn build_counted() -> crate::Result<Counted> {
BUILDS.fetch_add(1, Ordering::SeqCst);
Ok(Counted)
}
BUILDS.store(0, Ordering::SeqCst);
let built = crate::App::new()
.provide(Singleton("app"))
.provide_dep(build_counted)
.build()
.unwrap();
let base = built.task_context();
let mut a = base.fork();
let mut b = base.fork();
let sa = a.resolve::<Singleton>().await.unwrap();
let sb = b.resolve::<Singleton>().await.unwrap();
assert!(Arc::ptr_eq(&sa, &sb), "forks share the singleton provider");
assert_eq!(
sa.0, "app",
"the shared singleton carries the provided value"
);
let _ = a.resolve::<Counted>().await.unwrap();
let _ = b.resolve::<Counted>().await.unwrap();
assert_eq!(
BUILDS.load(Ordering::SeqCst),
2,
"each fork resolves the factory in its own fresh cache"
);
}
#[tokio::test]
async fn test_app_task_context_honors_overrides() {
#[derive(Clone)]
struct Cfg(u32);
async fn make_cfg() -> crate::Result<Cfg> {
Ok(Cfg(1)) }
let app = crate::App::new()
.provide_dep(make_cfg)
.into_test()
.override_dep(Cfg(99)); let mut ctx = app.task_context();
let cfg = ctx.resolve::<Cfg>().await.unwrap();
assert_eq!(cfg.0, 99, "override wins in a task context too");
}
}