1use async_trait::async_trait;
2use coi::{Container, Inject};
3use rocket::{
4 http::Status,
5 outcome::IntoOutcome as _,
6 request::{FromRequest, Outcome},
7 Request, State,
8};
9use std::{marker::PhantomData, sync::Arc};
10
11pub use coi_rocket_derive::inject;
12
13#[doc(hidden)]
14pub trait ContainerKey<T>
15where
16 T: Inject + ?Sized,
17{
18 const KEY: &'static str;
19}
20
21#[doc(hidden)]
22pub struct Injected<T, K>(pub T, pub PhantomData<K>);
23
24impl<T, K> Injected<T, K> {
25 #[doc(hidden)]
26 pub fn new(injected: T) -> Self {
27 Self(injected, PhantomData)
28 }
29}
30
31struct ScopedContainer(Container);
32
33#[derive(Debug)]
34pub enum Error {
35 Coi(coi::Error),
36 MissingContainer,
37}
38
39#[async_trait]
42impl<'r> FromRequest<'r> for &'r ScopedContainer {
43 type Error = Error;
44
45 async fn from_request(req: &'r Request<'_>) -> Outcome<&'r ScopedContainer, Error> {
46 req.local_cache_async::<Option<ScopedContainer>, _>(async move {
47 let container = req.guard::<&State<Container>>().await.succeeded()?;
48 Some(ScopedContainer(container.scoped()))
49 })
50 .await
51 .as_ref()
52 .or_error((Status::InternalServerError, Error::MissingContainer))
53 }
54}
55
56#[async_trait]
58impl<'r, T, K> FromRequest<'r> for Injected<Arc<T>, K>
59where
60 T: Inject + ?Sized,
61 K: ContainerKey<T>,
62{
63 type Error = Error;
64
65 async fn from_request(req: &'r Request<'_>) -> Outcome<Injected<Arc<T>, K>, Error> {
66 let container = match req.guard::<&ScopedContainer>().await {
67 Outcome::Success(container) => container,
68 Outcome::Error(f) => return Outcome::Error(f),
69 Outcome::Forward(f) => return Outcome::Forward(f),
70 };
71 container
72 .0
73 .resolve::<T>(<K as ContainerKey<T>>::KEY)
74 .map(Injected::new)
75 .map_err(Error::Coi)
76 .or_error(Status::InternalServerError)
77 }
78}