coi_rocket/
lib.rs

1use async_trait::async_trait;
2use coi::{Container, Inject};
3use rocket::{
4    http::Status,
5    outcome::IntoOutcome as _,
6    request::{FromRequest, Outcome},
7    Request, State,
8};
9use std::{marker::PhantomData, sync::Arc};
10
11pub use coi_rocket_derive::inject;
12
13#[doc(hidden)]
14pub trait ContainerKey<T>
15where
16    T: Inject + ?Sized,
17{
18    const KEY: &'static str;
19}
20
21#[doc(hidden)]
22pub struct Injected<T, K>(pub T, pub PhantomData<K>);
23
24impl<T, K> Injected<T, K> {
25    #[doc(hidden)]
26    pub fn new(injected: T) -> Self {
27        Self(injected, PhantomData)
28    }
29}
30
31struct ScopedContainer(Container);
32
33#[derive(Debug)]
34pub enum Error {
35    Coi(coi::Error),
36    MissingContainer,
37}
38
39// For every request that needs a container, create a scoped container that lives
40// for the duration of that request.
41#[async_trait]
42impl<'r> FromRequest<'r> for &'r ScopedContainer {
43    type Error = Error;
44
45    async fn from_request(req: &'r Request<'_>) -> Outcome<&'r ScopedContainer, Error> {
46        req.local_cache_async::<Option<ScopedContainer>, _>(async move {
47            let container = req.guard::<&State<Container>>().await.succeeded()?;
48            Some(ScopedContainer(container.scoped()))
49        })
50        .await
51        .as_ref()
52        .or_error((Status::InternalServerError, Error::MissingContainer))
53    }
54}
55
56// For every injected param, just us the local cached scoped container
57#[async_trait]
58impl<'r, T, K> FromRequest<'r> for Injected<Arc<T>, K>
59where
60    T: Inject + ?Sized,
61    K: ContainerKey<T>,
62{
63    type Error = Error;
64
65    async fn from_request(req: &'r Request<'_>) -> Outcome<Injected<Arc<T>, K>, Error> {
66        let container = match req.guard::<&ScopedContainer>().await {
67            Outcome::Success(container) => container,
68            Outcome::Error(f) => return Outcome::Error(f),
69            Outcome::Forward(f) => return Outcome::Forward(f),
70        };
71        container
72            .0
73            .resolve::<T>(<K as ContainerKey<T>>::KEY)
74            .map(Injected::new)
75            .map_err(Error::Coi)
76            .or_error(Status::InternalServerError)
77    }
78}