Skip to main content

cool_core/middleware/
authority.rs

1//! 权限中间件
2//!
3//! 对应 TypeScript 版本的 `middleware/authority.ts`
4
5use super::{AdminInfo, DepotExt, JwtClaims};
6use crate::error::{CoolError, CoolResponse};
7use salvo::prelude::*;
8use std::collections::HashSet;
9use std::sync::Arc;
10
11/// 权限中间件配置
12#[derive(Debug, Clone)]
13pub struct AuthorityConfig {
14    /// JWT 密钥
15    pub jwt_secret: String,
16    /// 忽略 Token 验证的 URL
17    pub ignore_urls: HashSet<String>,
18    /// Token Header 名称
19    pub token_header: String,
20}
21
22impl Default for AuthorityConfig {
23    fn default() -> Self {
24        Self {
25            jwt_secret: "cool-admin-rust".to_string(),
26            ignore_urls: HashSet::new(),
27            token_header: "Authorization".to_string(),
28        }
29    }
30}
31
32impl AuthorityConfig {
33    pub fn new(jwt_secret: impl Into<String>) -> Self {
34        Self {
35            jwt_secret: jwt_secret.into(),
36            ..Default::default()
37        }
38    }
39
40    /// 添加忽略的 URL
41    pub fn ignore_url(mut self, url: impl Into<String>) -> Self {
42        self.ignore_urls.insert(url.into());
43        self
44    }
45
46    /// 批量添加忽略的 URL
47    pub fn ignore_urls<I, S>(mut self, urls: I) -> Self
48    where
49        I: IntoIterator<Item = S>,
50        S: Into<String>,
51    {
52        for url in urls {
53            self.ignore_urls.insert(url.into());
54        }
55        self
56    }
57}
58
59/// 权限中间件
60pub struct AuthorityMiddleware {
61    config: Arc<AuthorityConfig>,
62}
63
64impl AuthorityMiddleware {
65    pub fn new(config: AuthorityConfig) -> Self {
66        Self {
67            config: Arc::new(config),
68        }
69    }
70}
71
72#[async_trait]
73impl Handler for AuthorityMiddleware {
74    async fn handle(
75        &self,
76        req: &mut Request,
77        depot: &mut Depot,
78        res: &mut Response,
79        ctrl: &mut FlowCtrl,
80    ) {
81        let path = req.uri().path();
82
83        // 检查是否在忽略列表中
84        if self.config.ignore_urls.contains(path) {
85            ctrl.call_next(req, depot, res).await;
86            return;
87        }
88
89        // 检查路径是否以忽略的前缀开头
90        for ignore_url in &self.config.ignore_urls {
91            if ignore_url.ends_with("*") {
92                let prefix = &ignore_url[..ignore_url.len() - 1];
93                if path.starts_with(prefix) {
94                    ctrl.call_next(req, depot, res).await;
95                    return;
96                }
97            }
98        }
99
100        // 获取 Token
101        let token = req
102            .header::<String>(&self.config.token_header)
103            .or_else(|| req.query::<String>("token"));
104
105        let token = match token {
106            Some(t) => {
107                // 安全去掉 "Bearer " 前缀
108                t.strip_prefix("Bearer ").unwrap_or(&t).to_string()
109            }
110            None => {
111                res.status_code(StatusCode::UNAUTHORIZED);
112                res.render(Json(CoolResponse::<()>::from_error(
113                    &CoolError::unauthorized(),
114                )));
115                ctrl.skip_rest();
116                return;
117            }
118        };
119
120        // 验证 Token
121        match JwtClaims::verify_token(&token, &self.config.jwt_secret) {
122            Ok(claims) => {
123                // 检查是否过期
124                let now = chrono::Utc::now().timestamp();
125                if claims.exp < now {
126                    res.status_code(StatusCode::UNAUTHORIZED);
127                    res.render(Json(CoolResponse::<()>::fail("Token 已过期")));
128                    ctrl.skip_rest();
129                    return;
130                }
131
132                // 设置用户信息到 Depot
133                depot.set_admin(AdminInfo::from(claims));
134                ctrl.call_next(req, depot, res).await;
135            }
136            Err(e) => {
137                res.status_code(StatusCode::UNAUTHORIZED);
138                res.render(Json(CoolResponse::<()>::from_error(&e)));
139                ctrl.skip_rest();
140            }
141        }
142    }
143}
144
145/// 创建权限中间件
146pub fn authority(config: AuthorityConfig) -> AuthorityMiddleware {
147    AuthorityMiddleware::new(config)
148}