rocket_dependency_injection/
lib.rs

1use std::{
2    any::type_name,
3    sync::{Arc, RwLock},
4};
5
6#[cfg(feature = "derive")]
7pub use rocket_dependency_injection_derive as derive;
8
9use rocket::{
10    fairing::{Fairing, Info, Kind},
11    Build, Rocket,
12};
13
14pub struct ServiceProvider {
15    inner: Rocket<Build>,
16}
17
18impl From<Rocket<Build>> for ServiceProvider {
19    fn from(value: Rocket<Build>) -> Self {
20        ServiceProvider { inner: value }
21    }
22}
23
24impl ServiceProvider {
25    pub fn unwrap<T>(&self) -> T
26    where
27        T: Clone + Send + Sync + 'static,
28    {
29        let type_name = type_name::<T>();
30        match self.inner.state::<T>() {
31            None => self
32                .inner
33                .state::<Arc<ServiceResolver<_>>>()
34                .map(|resolver| resolver.resolve(&self)),
35            other => other.map(|item| item.clone()),
36        }
37        .expect(format!("Failed to resolve service of type {}", type_name).as_str())
38    }
39}
40
41struct ServiceResolver<TInjectedItem> {
42    injection_function: Box<dyn Fn(&ServiceProvider) -> TInjectedItem + Send + Sync + 'static>,
43    item: RwLock<Option<TInjectedItem>>,
44}
45
46impl<TInjectedItem> ServiceResolver<TInjectedItem>
47where
48    TInjectedItem: Clone + Send + Sync + 'static,
49{
50    pub fn new<
51        TResolutionFunction: Fn(&ServiceProvider) -> TInjectedItem + Send + Sync + 'static,
52    >(
53        injection_function: TResolutionFunction,
54    ) -> Self {
55        Self {
56            injection_function: Box::new(injection_function),
57            item: RwLock::new(None),
58        }
59    }
60
61    pub fn resolve(&self, service_provider: &ServiceProvider) -> TInjectedItem {
62        {
63            if let Some(ref item) = *self.item.read().unwrap() {
64                return item.clone();
65            }
66        }
67
68        let mut guard = self.item.write().unwrap();
69
70        let item = (self.injection_function)(service_provider);
71        *guard = Some(item.clone());
72
73        item
74    }
75}
76
77#[async_trait::async_trait]
78impl<TResolvedItem> Fairing for ServiceResolver<TResolvedItem>
79where
80    TResolvedItem: Clone + Sync + Send + 'static,
81{
82    fn info(&self) -> Info {
83        Info {
84            name: Box::leak(format!("{}_resolver", type_name::<TResolvedItem>()).into_boxed_str()),
85            kind: Kind::Ignite,
86        }
87    }
88
89    async fn on_ignite(&self, rocket: Rocket<Build>) -> Result<Rocket<Build>, Rocket<Build>> {
90        let service_provider: ServiceProvider = rocket.into();
91        let item = (self.injection_function)(&service_provider);
92        Ok(service_provider.inner.manage(item))
93    }
94}
95
96pub trait Resolve {
97    fn resolve(service_provider: &ServiceProvider) -> Self;
98}
99
100pub trait RocketExtension {
101    fn add_with<
102        TInjectedItem: Clone + Sync + Send + 'static,
103        TInjectionFunction: Fn(&ServiceProvider) -> TInjectedItem + Send + Sync + 'static,
104    >(
105        self,
106        injection_function: TInjectionFunction,
107    ) -> Self;
108
109    fn add<TResolve: Resolve + Send + Sync + Clone + 'static>(self) -> Self;
110}
111
112impl RocketExtension for Rocket<Build> {
113    fn add_with<
114        TInjectedItem: Clone + Sync + Send + 'static,
115        TInjectionFunction: Fn(&ServiceProvider) -> TInjectedItem + Send + Sync + 'static,
116    >(
117        self,
118        injection_function: TInjectionFunction,
119    ) -> Self {
120        let service_resolver = Arc::new(ServiceResolver::new(injection_function));
121
122        self.attach(service_resolver.clone())
123            .manage(service_resolver)
124    }
125
126    fn add<TResolve: Resolve + Send + Sync + Clone + 'static>(self) -> Self {
127        self.add_with(TResolve::resolve)
128    }
129}