rocket_dependency_injection/
lib.rs1use std::{
2 any::type_name,
3 sync::{Arc, RwLock},
4};
5
6#[cfg(feature = "derive")]
7pub use rocket_dependency_injection_derive as derive;
8
9use rocket::{
10 fairing::{Fairing, Info, Kind},
11 Build, Rocket,
12};
13
14pub struct ServiceProvider {
15 inner: Rocket<Build>,
16}
17
18impl From<Rocket<Build>> for ServiceProvider {
19 fn from(value: Rocket<Build>) -> Self {
20 ServiceProvider { inner: value }
21 }
22}
23
24impl ServiceProvider {
25 pub fn unwrap<T>(&self) -> T
26 where
27 T: Clone + Send + Sync + 'static,
28 {
29 let type_name = type_name::<T>();
30 match self.inner.state::<T>() {
31 None => self
32 .inner
33 .state::<Arc<ServiceResolver<_>>>()
34 .map(|resolver| resolver.resolve(&self)),
35 other => other.map(|item| item.clone()),
36 }
37 .expect(format!("Failed to resolve service of type {}", type_name).as_str())
38 }
39}
40
41struct ServiceResolver<TInjectedItem> {
42 injection_function: Box<dyn Fn(&ServiceProvider) -> TInjectedItem + Send + Sync + 'static>,
43 item: RwLock<Option<TInjectedItem>>,
44}
45
46impl<TInjectedItem> ServiceResolver<TInjectedItem>
47where
48 TInjectedItem: Clone + Send + Sync + 'static,
49{
50 pub fn new<
51 TResolutionFunction: Fn(&ServiceProvider) -> TInjectedItem + Send + Sync + 'static,
52 >(
53 injection_function: TResolutionFunction,
54 ) -> Self {
55 Self {
56 injection_function: Box::new(injection_function),
57 item: RwLock::new(None),
58 }
59 }
60
61 pub fn resolve(&self, service_provider: &ServiceProvider) -> TInjectedItem {
62 {
63 if let Some(ref item) = *self.item.read().unwrap() {
64 return item.clone();
65 }
66 }
67
68 let mut guard = self.item.write().unwrap();
69
70 let item = (self.injection_function)(service_provider);
71 *guard = Some(item.clone());
72
73 item
74 }
75}
76
77#[async_trait::async_trait]
78impl<TResolvedItem> Fairing for ServiceResolver<TResolvedItem>
79where
80 TResolvedItem: Clone + Sync + Send + 'static,
81{
82 fn info(&self) -> Info {
83 Info {
84 name: Box::leak(format!("{}_resolver", type_name::<TResolvedItem>()).into_boxed_str()),
85 kind: Kind::Ignite,
86 }
87 }
88
89 async fn on_ignite(&self, rocket: Rocket<Build>) -> Result<Rocket<Build>, Rocket<Build>> {
90 let service_provider: ServiceProvider = rocket.into();
91 let item = (self.injection_function)(&service_provider);
92 Ok(service_provider.inner.manage(item))
93 }
94}
95
96pub trait Resolve {
97 fn resolve(service_provider: &ServiceProvider) -> Self;
98}
99
100pub trait RocketExtension {
101 fn add_with<
102 TInjectedItem: Clone + Sync + Send + 'static,
103 TInjectionFunction: Fn(&ServiceProvider) -> TInjectedItem + Send + Sync + 'static,
104 >(
105 self,
106 injection_function: TInjectionFunction,
107 ) -> Self;
108
109 fn add<TResolve: Resolve + Send + Sync + Clone + 'static>(self) -> Self;
110}
111
112impl RocketExtension for Rocket<Build> {
113 fn add_with<
114 TInjectedItem: Clone + Sync + Send + 'static,
115 TInjectionFunction: Fn(&ServiceProvider) -> TInjectedItem + Send + Sync + 'static,
116 >(
117 self,
118 injection_function: TInjectionFunction,
119 ) -> Self {
120 let service_resolver = Arc::new(ServiceResolver::new(injection_function));
121
122 self.attach(service_resolver.clone())
123 .manage(service_resolver)
124 }
125
126 fn add<TResolve: Resolve + Send + Sync + Clone + 'static>(self) -> Self {
127 self.add_with(TResolve::resolve)
128 }
129}