salvo-casbin 0.73.3

Casbin salvo access control hoop
Documentation
use casbin::IEnforcer;
use salvo::prelude::*;
use std::{
    ops::{Deref, DerefMut},
    sync::Arc,
};
use tokio::sync::RwLock;

#[derive(Clone)]
pub struct CasbinVals {
    pub subject: String,
    pub domain: Option<String>,
}

#[derive(Clone)]
pub struct CasbinHoop<E, F> {
    enforcer: Arc<RwLock<E>>,
    use_enforcer_mut: bool,
    get_casbin_vals: F,
}

impl<E, F> Deref for CasbinHoop<E, F> {
    type Target = Arc<RwLock<E>>;

    fn deref(&self) -> &Self::Target {
        &self.enforcer
    }
}

impl<E, F> DerefMut for CasbinHoop<E, F> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.enforcer
    }
}

#[handler]
impl<E, F> CasbinHoop<E, F>
where
    E: IEnforcer + 'static,
    F: Fn(&mut Request, &mut Depot) -> Result<Option<CasbinVals>, StatusError>
        + Send
        + Sync
        + 'static,
{
    pub fn new(enforcer: E, use_enforcer_mut: bool, get_casbin_vals: F) -> Self {
        CasbinHoop {
            enforcer: Arc::new(RwLock::new(enforcer)),
            use_enforcer_mut,
            get_casbin_vals,
        }
    }

    pub fn get_enforcer(&self) -> Arc<RwLock<E>> {
        self.enforcer.clone()
    }

    async fn handle(
        &self,
        req: &mut Request,
        depot: &mut Depot,
        res: &mut Response,
        _ctrl: &mut FlowCtrl,
    ) {
        let vals = match (self.get_casbin_vals)(req, depot) {
            Ok(option_vals) => {
                let Some(vals) = option_vals else {
                    res.render(StatusError::unauthorized());
                    return;
                };
                vals
            }
            Err(err) => {
                res.render(err);
                return;
            }
        };

        let path = req.uri().path().to_string();
        let action = req.method().as_str().to_string();

        if vals.subject.is_empty() {
            res.render(StatusError::unauthorized());
            return;
        }

        let rvals = if let Some(domain) = vals.domain {
            vec![vals.subject, domain, path, action]
        } else {
            vec![vals.subject, path, action]
        };
        let r = if self.use_enforcer_mut {
            self.enforcer.write().await.enforce_mut(rvals)
        } else {
            self.enforcer.read().await.enforce(rvals)
        };
        match r {
            Ok(true) => {}
            Ok(false) => {
                res.render(StatusError::forbidden());
            }
            Err(_) => {
                res.render(StatusError::bad_gateway());
            }
        }
    }
}