1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#![feature(in_band_lifetimes)]

use casbin::prelude::*;
use rocket::{
    fairing::Fairing,
    http::Status,
    request::{Outcome, FromRequest},
    Request,
};
use std::sync::{Arc, RwLock};

/// Trait implemented by rocket fairings to authorize incoming requests.
pub trait CasbinMiddleware: Fairing {
    /// get enforce rvals from request.
    /// the values returned are usually: [sub, obj, act], depend on your model.
    fn casbin_vals(&self, req: &'a Request<'r>) -> Vec<String>;

    fn cached_enforcer(&self) -> Arc<RwLock<CachedEnforcer>>;

    /// authorize request, and add result to request
    fn enforce(&self, req: &'a Request<'r>) {
        let vals = self.casbin_vals(req);
        let vals = (&vals).into_iter().map(|v| v).collect::<Vec<&String>>();

        let cloned_enforcer = self.cached_enforcer();
        let mut lock_enforcer = cloned_enforcer.write().unwrap();
        match lock_enforcer.enforce_mut(&vals) {
            Ok(true) => {
                req.local_cache(|| CasbinGuard(Some(Status::Ok)));
            }
            Ok(false) => {
                req.local_cache(|| CasbinGuard(Some(Status::Forbidden)));
            }
            Err(_) => {
                req.local_cache(|| CasbinGuard(Some(Status::BadGateway)));
            }
        }
    }
}

/// A request guard that handle authorization result.
/// CasbinGuard usually appear as arguments in a route handler.
/// 
/// Example
/// ```ignore
/// #[get("/book/1")]
/// fn book(_g: CasbinGuard) { /* ... */ }
/// ```
pub struct CasbinGuard(Option<Status>);

impl<'a, 'r> FromRequest<'a, 'r> for CasbinGuard {
    type Error = ();

    fn from_request(request: &'a Request<'r>) -> Outcome<CasbinGuard, ()> {
        match *request.local_cache(|| CasbinGuard(Status::from_code(0))) {
            CasbinGuard(Some(Status::Ok)) => {
                Outcome::Success(CasbinGuard(Some(Status::Ok)))
            }
            CasbinGuard(Some(err_status)) => Outcome::Failure((err_status, ())),
            _ => Outcome::Failure((Status::BadGateway, ())),
        }
    }
}