1#![doc(html_logo_url = "https://avatars.githubusercontent.com/u/43955412")]
2use rocket::{
7 fairing::{self, Fairing, Info, Kind},
8 http::{self, uri::Origin, Method, Status},
9 request::{self, FromRequest},
10 route, Build, Data, Request, Rocket, Route,
11};
12use sentinel_core::EntryBuilder;
13use std::sync::Mutex;
14
15pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
16
17pub type Extractor = fn(&Request<'_>) -> String;
20
21pub type Fallback<R> = fn(&Request<'_>, sentinel_core::Error) -> R;
23
24fn default_extractor(req: &Request<'_>) -> String {
25 req.uri().path().to_string()
26}
27
28fn default_fallback_for_guard(
29 _request: &Request<'_>,
30 err: sentinel_core::Error,
31) -> request::Outcome<SentinelGuard, BoxError> {
32 request::Outcome::Failure((Status::TooManyRequests, err.into()))
33}
34
35pub type SentinelConfigForGuard = SentinelConfig<request::Outcome<SentinelGuard, BoxError>>;
36pub type SentinelConfigForFairing = SentinelConfig<()>;
37
38pub struct SentinelConfig<R> {
48 pub extractor: Option<Extractor>,
49 pub fallback: Option<Fallback<R>>,
50}
51
52impl<R> SentinelConfig<R> {
53 pub fn with_extractor(mut self, extractor: Extractor) -> Self {
54 self.extractor = Some(extractor);
55 self
56 }
57
58 pub fn with_fallback(mut self, fallback: Fallback<R>) -> Self {
59 self.fallback = Some(fallback);
60 self
61 }
62}
63
64impl<R> Clone for SentinelConfig<R> {
67 fn clone(&self) -> Self {
68 Self {
69 extractor: self.extractor.clone(),
70 fallback: self.fallback.clone(),
71 }
72 }
73}
74
75impl<R> Default for SentinelConfig<R> {
78 fn default() -> Self {
79 Self {
80 extractor: None,
81 fallback: None,
82 }
83 }
84}
85
86#[derive(Debug)]
95pub struct SentinelGuard;
96
97#[rocket::async_trait]
98impl<'r> FromRequest<'r> for SentinelGuard {
99 type Error = BoxError;
100
101 async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
102 let empty_config = SentinelConfig::default();
103 let config = req
104 .rocket()
105 .state::<SentinelConfig<request::Outcome<SentinelGuard, BoxError>>>()
107 .unwrap_or(&empty_config);
108 let extractor = config.extractor.unwrap_or(default_extractor);
109 let fallback = config.fallback.unwrap_or(default_fallback_for_guard);
110
111 let resource = extractor(req);
112 let entry_builder = EntryBuilder::new(resource)
113 .with_traffic_type(sentinel_core::base::TrafficType::Inbound);
114
115 match entry_builder.build() {
116 Ok(entry) => {
117 entry.exit();
118 request::Outcome::Success(SentinelGuard {})
119 }
120 Err(err) => fallback(req, err),
121 }
122 }
123}
124
125#[derive(Debug)]
129pub struct SentinelFairingState {
130 pub msg: Mutex<String>,
131 pub uri: String,
136}
137
138impl SentinelFairingState {
139 pub fn new(uri: String) -> Self {
140 Self {
141 msg: Mutex::new(String::new()),
142 uri,
143 }
144 }
145}
146
147type FairingHandler = for<'r> fn(&'r Request<'_>, Data<'r>) -> route::Outcome<'r>;
148
149#[derive(Clone, Default)]
150pub struct SentinelFairingHandler(Option<FairingHandler>);
151
152impl SentinelFairingHandler {
153 pub fn new(h: FairingHandler) -> Self {
154 Self(Some(h))
155 }
156}
157
158#[rocket::async_trait]
159impl route::Handler for SentinelFairingHandler {
160 async fn handle<'r>(&self, req: &'r Request<'_>, data: Data<'r>) -> route::Outcome<'r> {
161 fn default_handler<'r>(req: &'r Request<'_>, _data: Data<'r>) -> route::Outcome<'r> {
162 match req.rocket().state::<SentinelFairingState>() {
163 Some(_) => route::Outcome::Failure(Status::TooManyRequests),
164 None => route::Outcome::Failure(Status::InternalServerError),
165 }
166 }
167
168 let h = self.0.unwrap_or(default_handler);
169 h(req, data)
170 }
171}
172
173impl Into<Vec<Route>> for SentinelFairingHandler {
174 fn into(self) -> Vec<Route> {
175 vec![Route::new(Method::Get, "/", self)]
176 }
177}
178
179#[derive(Default)]
184pub struct SentinelFairing {
185 uri: String,
188 handler: SentinelFairingHandler,
190 config: SentinelConfig<()>,
194}
195
196impl SentinelFairing {
197 pub fn new(uri: &'static str) -> Result<Self, http::uri::Error> {
198 Ok(SentinelFairing::default().with_uri(uri)?)
199 }
200
201 pub fn with_extractor(mut self, extractor: Extractor) -> Self {
202 self.config = self.config.with_extractor(extractor);
203 self
204 }
205
206 pub fn with_fallback(mut self, fallback: Fallback<()>) -> Self {
207 self.config = self.config.with_fallback(fallback);
208 self
209 }
210
211 pub fn with_handler(mut self, h: FairingHandler) -> Self {
212 self.handler = SentinelFairingHandler::new(h);
213 self
214 }
215
216 pub fn with_uri(mut self, uri: &'static str) -> Result<Self, http::uri::Error> {
217 let origin = Origin::parse(uri)?;
218 self.uri = origin.path().to_string();
219 Ok(self)
220 }
221}
222
223#[rocket::async_trait]
224impl Fairing for SentinelFairing {
225 fn info(&self) -> Info {
226 Info {
227 name: "Sentinel Fairing",
228 kind: Kind::Ignite | Kind::Request,
229 }
230 }
231
232 async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
233 let handler = self.handler.clone();
234 Ok(rocket
235 .manage(SentinelFairingState::new(self.uri.clone()))
236 .mount(self.uri.clone(), handler))
237 }
238
239 async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
240 let empty_config = SentinelConfig::default();
241 let config = req
242 .rocket()
243 .state::<SentinelConfig<()>>()
244 .unwrap_or(&empty_config);
245 let extractor = self
246 .config
247 .extractor
248 .unwrap_or(config.extractor.unwrap_or(default_extractor));
249 let fallback = self.config.fallback.or(config.fallback);
250
251 let resource = extractor(&req);
252 let entry_builder = EntryBuilder::new(resource)
253 .with_traffic_type(sentinel_core::base::TrafficType::Inbound);
254
255 match entry_builder.build() {
256 Ok(entry) => {
257 entry.exit();
258 }
259 Err(err) => {
260 match fallback {
261 Some(fallback) => fallback(req, err),
262 None => {
263 if let Some(state) = req.rocket().state::<SentinelFairingState>() {
264 if let Ok(mut msg) = state.msg.lock() {
265 *msg = format!(
266 "Request to {:?} blocked by sentinel: {:?}",
267 req.uri().path(),
268 err
269 );
270 }
271 req.set_uri(Origin::parse(&state.uri).unwrap());
273 }
274 }
275 }
276 }
277 };
278 }
279}