#![doc(html_logo_url = "https://avatars.githubusercontent.com/u/43955412")]
use rocket::{
fairing::{self, Fairing, Info, Kind},
http::{self, uri::Origin, Method, Status},
request::{self, FromRequest},
route, Build, Data, Request, Rocket, Route,
};
use sentinel_core::EntryBuilder;
use std::sync::Mutex;
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub type Extractor = fn(&Request<'_>) -> String;
pub type Fallback<R> = fn(&Request<'_>, sentinel_core::Error) -> R;
fn default_extractor(req: &Request<'_>) -> String {
req.uri().path().to_string()
}
fn default_fallback_for_guard(
_request: &Request<'_>,
err: sentinel_core::Error,
) -> request::Outcome<SentinelGuard, BoxError> {
request::Outcome::Failure((Status::TooManyRequests, err.into()))
}
pub type SentinelConfigForGuard = SentinelConfig<request::Outcome<SentinelGuard, BoxError>>;
pub type SentinelConfigForFairing = SentinelConfig<()>;
pub struct SentinelConfig<R> {
pub extractor: Option<Extractor>,
pub fallback: Option<Fallback<R>>,
}
impl<R> SentinelConfig<R> {
pub fn with_extractor(mut self, extractor: Extractor) -> Self {
self.extractor = Some(extractor);
self
}
pub fn with_fallback(mut self, fallback: Fallback<R>) -> Self {
self.fallback = Some(fallback);
self
}
}
impl<R> Clone for SentinelConfig<R> {
fn clone(&self) -> Self {
Self {
extractor: self.extractor.clone(),
fallback: self.fallback.clone(),
}
}
}
impl<R> Default for SentinelConfig<R> {
fn default() -> Self {
Self {
extractor: None,
fallback: None,
}
}
}
#[derive(Debug)]
pub struct SentinelGuard;
#[rocket::async_trait]
impl<'r> FromRequest<'r> for SentinelGuard {
type Error = BoxError;
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
let empty_config = SentinelConfig::default();
let config = req
.rocket()
.state::<SentinelConfig<request::Outcome<SentinelGuard, BoxError>>>()
.unwrap_or(&empty_config);
let extractor = config.extractor.unwrap_or(default_extractor);
let fallback = config.fallback.unwrap_or(default_fallback_for_guard);
let resource = extractor(req);
let entry_builder = EntryBuilder::new(resource)
.with_traffic_type(sentinel_core::base::TrafficType::Inbound);
match entry_builder.build() {
Ok(entry) => {
entry.exit();
request::Outcome::Success(SentinelGuard {})
}
Err(err) => fallback(req, err),
}
}
}
#[derive(Debug)]
pub struct SentinelFairingState {
pub msg: Mutex<String>,
pub uri: String,
}
impl SentinelFairingState {
pub fn new(uri: String) -> Self {
Self {
msg: Mutex::new(String::new()),
uri,
}
}
}
type FairingHandler = for<'r> fn(&'r Request<'_>, Data<'r>) -> route::Outcome<'r>;
#[derive(Clone, Default)]
pub struct SentinelFairingHandler(Option<FairingHandler>);
impl SentinelFairingHandler {
pub fn new(h: FairingHandler) -> Self {
Self(Some(h))
}
}
#[rocket::async_trait]
impl route::Handler for SentinelFairingHandler {
async fn handle<'r>(&self, req: &'r Request<'_>, data: Data<'r>) -> route::Outcome<'r> {
fn default_handler<'r>(req: &'r Request<'_>, _data: Data<'r>) -> route::Outcome<'r> {
match req.rocket().state::<SentinelFairingState>() {
Some(_) => route::Outcome::Failure(Status::TooManyRequests),
None => route::Outcome::Failure(Status::InternalServerError),
}
}
let h = self.0.unwrap_or(default_handler);
h(req, data)
}
}
impl Into<Vec<Route>> for SentinelFairingHandler {
fn into(self) -> Vec<Route> {
vec![Route::new(Method::Get, "/", self)]
}
}
#[derive(Default)]
pub struct SentinelFairing {
uri: String,
handler: SentinelFairingHandler,
config: SentinelConfig<()>,
}
impl SentinelFairing {
pub fn new(uri: &'static str) -> Result<Self, http::uri::Error> {
Ok(SentinelFairing::default().with_uri(uri)?)
}
pub fn with_extractor(mut self, extractor: Extractor) -> Self {
self.config = self.config.with_extractor(extractor);
self
}
pub fn with_fallback(mut self, fallback: Fallback<()>) -> Self {
self.config = self.config.with_fallback(fallback);
self
}
pub fn with_handler(mut self, h: FairingHandler) -> Self {
self.handler = SentinelFairingHandler::new(h);
self
}
pub fn with_uri(mut self, uri: &'static str) -> Result<Self, http::uri::Error> {
let origin = Origin::parse(uri)?;
self.uri = origin.path().to_string();
Ok(self)
}
}
#[rocket::async_trait]
impl Fairing for SentinelFairing {
fn info(&self) -> Info {
Info {
name: "Sentinel Fairing",
kind: Kind::Ignite | Kind::Request,
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
let handler = self.handler.clone();
Ok(rocket
.manage(SentinelFairingState::new(self.uri.clone()))
.mount(self.uri.clone(), handler))
}
async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
let empty_config = SentinelConfig::default();
let config = req
.rocket()
.state::<SentinelConfig<()>>()
.unwrap_or(&empty_config);
let extractor = self
.config
.extractor
.unwrap_or(config.extractor.unwrap_or(default_extractor));
let fallback = self.config.fallback.or(config.fallback);
let resource = extractor(&req);
let entry_builder = EntryBuilder::new(resource)
.with_traffic_type(sentinel_core::base::TrafficType::Inbound);
match entry_builder.build() {
Ok(entry) => {
entry.exit();
}
Err(err) => {
match fallback {
Some(fallback) => fallback(req, err),
None => {
if let Some(state) = req.rocket().state::<SentinelFairingState>() {
if let Ok(mut msg) = state.msg.lock() {
*msg = format!(
"Request to {:?} blocked by sentinel: {:?}",
req.uri().path(),
err
);
}
req.set_uri(Origin::parse(&state.uri).unwrap());
}
}
}
}
};
}
}