rocket-authz 0.1.0

Casbin rocket access control middleware
Documentation
use rocket::{
    fairing::{Fairing, Info, Kind},
    http::Status,
    request::{self, FromRequest, Request},
    Data,
};

use casbin::prelude::*;
use casbin::{CachedEnforcer, CoreApi, Result as CasbinResult};
use parking_lot::RwLock;
use std::sync::Arc;

pub struct CasbinVals {
    pub subject: Option<String>,
    pub domain: Option<String>,
}

impl CasbinVals {
    pub fn new(subject: Option<String>, domain: Option<String>) -> CasbinVals {
        CasbinVals { subject, domain }
    }
}

#[derive(Clone)]
pub struct CasbinGuard(Option<Status>);

impl<'a, 'r> FromRequest<'a, 'r> for CasbinGuard {
    type Error = ();

    fn from_request(request: &'a Request<'r>) -> request::Outcome<CasbinGuard, ()> {
        match *request.local_cache(|| CasbinGuard(Status::from_code(0))) {
            CasbinGuard(Some(Status::Ok)) => {
                request::Outcome::Success(CasbinGuard(Some(Status::Ok)))
            }
            CasbinGuard(Some(err_status)) => request::Outcome::Failure((err_status, ())),
            _ => request::Outcome::Failure((Status::BadGateway, ())),
        }
    }
}

#[derive(Clone)]
pub struct CasbinFairing {
    pub enforcer: Arc<RwLock<CachedEnforcer>>,
}

impl CasbinFairing {
    pub async fn new<M: TryIntoModel, A: TryIntoAdapter>(m: M, a: A) -> CasbinResult<Self> {
        let enforcer: CachedEnforcer = CachedEnforcer::new(m, a).await?;
        Ok(CasbinFairing {
            enforcer: Arc::new(RwLock::new(enforcer)),
        })
    }

    pub fn get_enforcer(&mut self) -> Arc<RwLock<CachedEnforcer>> {
        self.enforcer.clone()
    }

    pub fn set_enforcer(e: Arc<RwLock<CachedEnforcer>>) -> CasbinFairing {
        CasbinFairing { enforcer: e }
    }
}

impl Fairing for CasbinFairing {
    fn info(&self) -> Info {
        Info {
            name: "Casbin Fairing",
            kind: Kind::Request | Kind::Response,
        }
    }

    fn on_request(&self, request: &mut Request, _data: &Data) {
        let cloned_enforce = self.enforcer.clone();
        let path = request.uri().path().to_owned();
        let action = request.method().as_str().to_owned();

        let (subject, domain) = match request.local_cache(|| CasbinVals {
            subject: None,
            domain: None,
        }) {
            CasbinVals {
                subject: Some(x),
                domain: Some(y),
            } => (Some(x.to_owned()), Some(y.to_owned())),
            CasbinVals {
                subject: Some(x),
                domain: None,
            } => (Some(x.to_owned()), None),
            _ => (None, None),
        };

        if let Some(subject) = subject {
            if let Some(domain) = domain {
                let mut lock = cloned_enforce.write();
                match lock.enforce_mut(vec![subject, domain, path, action]) {
                    Ok(true) => {
                        drop(lock);
                        request.local_cache(|| CasbinGuard(Some(Status::Ok)));
                    }
                    Ok(false) => {
                        drop(lock);
                        request.local_cache(|| CasbinGuard(Some(Status::Forbidden)));
                    }
                    Err(_) => {
                        drop(lock);
                        request.local_cache(|| CasbinGuard(Some(Status::BadGateway)));
                    }
                };
            } else {
                let mut lock = cloned_enforce.write();
                match lock.enforce_mut(vec![subject, path, action]) {
                    Ok(true) => {
                        drop(lock);
                        request.local_cache(|| CasbinGuard(Some(Status::Ok)));
                    }
                    Ok(false) => {
                        drop(lock);
                        request.local_cache(|| CasbinGuard(Some(Status::Forbidden)));
                    }
                    Err(_) => {
                        drop(lock);
                        request.local_cache(|| CasbinGuard(Some(Status::BadGateway)));
                    }
                };
            }
        } else {
            request.local_cache(|| CasbinGuard(Some(Status::BadGateway)));
        }
    }
}