sentinel_rocket/
lib.rs

1#![doc(html_logo_url = "https://avatars.githubusercontent.com/u/43955412")]
2//! This crate provides the [sentinel](https://docs.rs/sentinel-core) middleware for [actix-web](https://docs.rs/actix-web).
3//! See [examples](https://github.com/sentinel-group/sentinel-rust/tree/main/middleware) for help.
4//!
5
6use rocket::{
7    fairing::{self, Fairing, Info, Kind},
8    http::{self, uri::Origin, Method, Status},
9    request::{self, FromRequest},
10    route, Build, Data, Request, Rocket, Route,
11};
12use sentinel_core::EntryBuilder;
13use std::sync::Mutex;
14
15pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
16
17/// It is used to extractor a resource name from requests for Sentinel. If you the service request is [`http::Request`](https://docs.rs/http/latest/http/request/struct.Request.html),
18/// and you are using nightly toolchain, you don't need to provide a sentinel resource name extractor. The middleware will automatically extract the request uri.
19pub type Extractor = fn(&Request<'_>) -> String;
20
21/// The fallback function when service is rejected by sentinel.
22pub type Fallback<R> = fn(&Request<'_>, sentinel_core::Error) -> R;
23
24fn default_extractor(req: &Request<'_>) -> String {
25    req.uri().path().to_string()
26}
27
28fn default_fallback_for_guard(
29    _request: &Request<'_>,
30    err: sentinel_core::Error,
31) -> request::Outcome<SentinelGuard, BoxError> {
32    request::Outcome::Failure((Status::TooManyRequests, err.into()))
33}
34
35pub type SentinelConfigForGuard = SentinelConfig<request::Outcome<SentinelGuard, BoxError>>;
36pub type SentinelConfigForFairing = SentinelConfig<()>;
37
38/// When using [`SentinelGuard`](SentinelGuard), we can only use [managed state][managed_state]
39/// to configure the [`SentinelGuard`](SentinelGuard). That is
40/// ```rust
41/// rocket::build().manage(SentinelConfig { ... });
42/// ```
43/// For [SentinelFairing](SentinelFairing), the configuration in [managed state][managed_state] is
44/// with lower priority than the `SentinelConfig` in it.
45///
46/// [managed_state]: https://rocket.rs/v0.5-rc/guide/state/#managed-state
47pub struct SentinelConfig<R> {
48    pub extractor: Option<Extractor>,
49    pub fallback: Option<Fallback<R>>,
50}
51
52impl<R> SentinelConfig<R> {
53    pub fn with_extractor(mut self, extractor: Extractor) -> Self {
54        self.extractor = Some(extractor);
55        self
56    }
57
58    pub fn with_fallback(mut self, fallback: Fallback<R>) -> Self {
59        self.fallback = Some(fallback);
60        self
61    }
62}
63
64// rustc cannot derive `Clone` trait for function pointers correctly,
65// implement it by hands
66impl<R> Clone for SentinelConfig<R> {
67    fn clone(&self) -> Self {
68        Self {
69            extractor: self.extractor.clone(),
70            fallback: self.fallback.clone(),
71        }
72    }
73}
74
75// rustc cannot derive `Default` trait for function pointers correctly,
76// implement it by hands
77impl<R> Default for SentinelConfig<R> {
78    fn default() -> Self {
79        Self {
80            extractor: None,
81            fallback: None,
82        }
83    }
84}
85
86/// The Rocket request guard, which is the recommended way in [Rocket documentation](https://rocket.rs/v0.5-rc/guide/requests/#request-guards).
87/// To use this guard, simply add it to the arguments of handler. By default, it extracts [the path in the Request::uri()](https://docs.rs/rocket/0.5.0-rc.2/rocket/struct.Request.html#method.uri) as the sentinel resource name.
88/// The blocked requests returns the [status 429](https://docs.rs/rocket/0.5.0-rc.2/rocket/http/struct.Status.html#associatedconstant.TooManyRequests),
89/// ```rust
90/// #[get("/use_sentinel")]
91/// fn use_sentinel(_sentinel: SentinelGuard) { /* .. */ }
92/// ```
93/// We can use [`SentinelConfig`](SentinelConfig) to configure the guard.
94#[derive(Debug)]
95pub struct SentinelGuard;
96
97#[rocket::async_trait]
98impl<'r> FromRequest<'r> for SentinelGuard {
99    type Error = BoxError;
100
101    async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
102        let empty_config = SentinelConfig::default();
103        let config = req
104            .rocket()
105            // The type `R` in `SentinelConfig<R>` here is the same as `default_fallback_for_guard`
106            .state::<SentinelConfig<request::Outcome<SentinelGuard, BoxError>>>()
107            .unwrap_or(&empty_config);
108        let extractor = config.extractor.unwrap_or(default_extractor);
109        let fallback = config.fallback.unwrap_or(default_fallback_for_guard);
110
111        let resource = extractor(req);
112        let entry_builder = EntryBuilder::new(resource)
113            .with_traffic_type(sentinel_core::base::TrafficType::Inbound);
114
115        match entry_builder.build() {
116            Ok(entry) => {
117                entry.exit();
118                request::Outcome::Success(SentinelGuard {})
119            }
120            Err(err) => fallback(req, err),
121        }
122    }
123}
124
125/// The [managed state][managed_state] to be processed in the handler mounted in [SentinelFairing](SentinelFairing).
126///
127/// [managed_state]: https://rocket.rs/v0.5-rc/guide/state/#managed-state
128#[derive(Debug)]
129pub struct SentinelFairingState {
130    pub msg: Mutex<String>,
131    /// the forwarded uri has to be managed by the rocket,
132    /// because currently in `Fairing::on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>)`,
133    /// the lifetime of `&'life1 self` and `Request<'life2>` are not constrained.
134    /// see the source code of [Fairing](https://docs.rs/rocket/0.5.0-rc.2/rocket/fairing/trait.Fairing.html) for details.
135    pub uri: String,
136}
137
138impl SentinelFairingState {
139    pub fn new(uri: String) -> Self {
140        Self {
141            msg: Mutex::new(String::new()),
142            uri,
143        }
144    }
145}
146
147type FairingHandler = for<'r> fn(&'r Request<'_>, Data<'r>) -> route::Outcome<'r>;
148
149#[derive(Clone, Default)]
150pub struct SentinelFairingHandler(Option<FairingHandler>);
151
152impl SentinelFairingHandler {
153    pub fn new(h: FairingHandler) -> Self {
154        Self(Some(h))
155    }
156}
157
158#[rocket::async_trait]
159impl route::Handler for SentinelFairingHandler {
160    async fn handle<'r>(&self, req: &'r Request<'_>, data: Data<'r>) -> route::Outcome<'r> {
161        fn default_handler<'r>(req: &'r Request<'_>, _data: Data<'r>) -> route::Outcome<'r> {
162            match req.rocket().state::<SentinelFairingState>() {
163                Some(_) => route::Outcome::Failure(Status::TooManyRequests),
164                None => route::Outcome::Failure(Status::InternalServerError),
165            }
166        }
167
168        let h = self.0.unwrap_or(default_handler);
169        h(req, data)
170    }
171}
172
173impl Into<Vec<Route>> for SentinelFairingHandler {
174    fn into(self) -> Vec<Route> {
175        vec![Route::new(Method::Get, "/", self)]
176    }
177}
178
179/// The Rocket Fairing. The [SentinelConfig](SentinelConfig) in
180/// SentinelFairing is with higher priority than the one in global [managed state][managed_state].
181///
182/// [managed_state]: https://rocket.rs/v0.5-rc/guide/state/#managed-state
183#[derive(Default)]
184pub struct SentinelFairing {
185    /// the forwarded page when blocked by sentinel,
186    /// which will be handled by the `handler`
187    uri: String,
188    /// a lightweight handler, which handles all the requests blocked by Sentinel.
189    handler: SentinelFairingHandler,
190    /// config for `SentinelFairing` itself, which is with higher priority than the `SentinelConfig` in global [managed state][managed_state].
191    ///
192    /// [managed_state]: https://rocket.rs/v0.5-rc/guide/state/#managed-state
193    config: SentinelConfig<()>,
194}
195
196impl SentinelFairing {
197    pub fn new(uri: &'static str) -> Result<Self, http::uri::Error> {
198        Ok(SentinelFairing::default().with_uri(uri)?)
199    }
200
201    pub fn with_extractor(mut self, extractor: Extractor) -> Self {
202        self.config = self.config.with_extractor(extractor);
203        self
204    }
205
206    pub fn with_fallback(mut self, fallback: Fallback<()>) -> Self {
207        self.config = self.config.with_fallback(fallback);
208        self
209    }
210
211    pub fn with_handler(mut self, h: FairingHandler) -> Self {
212        self.handler = SentinelFairingHandler::new(h);
213        self
214    }
215
216    pub fn with_uri(mut self, uri: &'static str) -> Result<Self, http::uri::Error> {
217        let origin = Origin::parse(uri)?;
218        self.uri = origin.path().to_string();
219        Ok(self)
220    }
221}
222
223#[rocket::async_trait]
224impl Fairing for SentinelFairing {
225    fn info(&self) -> Info {
226        Info {
227            name: "Sentinel Fairing",
228            kind: Kind::Ignite | Kind::Request,
229        }
230    }
231
232    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
233        let handler = self.handler.clone();
234        Ok(rocket
235            .manage(SentinelFairingState::new(self.uri.clone()))
236            .mount(self.uri.clone(), handler))
237    }
238
239    async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
240        let empty_config = SentinelConfig::default();
241        let config = req
242            .rocket()
243            .state::<SentinelConfig<()>>()
244            .unwrap_or(&empty_config);
245        let extractor = self
246            .config
247            .extractor
248            .unwrap_or(config.extractor.unwrap_or(default_extractor));
249        let fallback = self.config.fallback.or(config.fallback);
250
251        let resource = extractor(&req);
252        let entry_builder = EntryBuilder::new(resource)
253            .with_traffic_type(sentinel_core::base::TrafficType::Inbound);
254
255        match entry_builder.build() {
256            Ok(entry) => {
257                entry.exit();
258            }
259            Err(err) => {
260                match fallback {
261                    Some(fallback) => fallback(req, err),
262                    None => {
263                        if let Some(state) = req.rocket().state::<SentinelFairingState>() {
264                            if let Ok(mut msg) = state.msg.lock() {
265                                *msg = format!(
266                                    "Request to {:?} blocked by sentinel: {:?}",
267                                    req.uri().path(),
268                                    err
269                                );
270                            }
271                            // this `unwrap` call will never fail
272                            req.set_uri(Origin::parse(&state.uri).unwrap());
273                        }
274                    }
275                }
276            }
277        };
278    }
279}