Skip to main content

alun_web/middleware/
auth.rs

1//! 认证中间件:JWT Bearer Token 验证 + 黑名单检查
2
3use axum::{
4    extract::Request,
5    response::Response,
6    body::Body,
7};
8use axum::http::StatusCode;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::collections::HashSet;
13use tower::{Layer, Service};
14use crate::response::{Res, codes};
15use super::{UserId, AuthClaims, TokenClaims};
16
17/// JWT 认证中间件
18///
19/// 从 `Authorization: Bearer <token>` 中提取并验证 JWT,
20/// 将解析出的 `UserContext` 挂载到 `request.extensions` 中。
21/// 失败返回 401 Unauthorized。
22#[derive(Clone)]
23pub struct AuthLayer {
24    /// JWT 密钥(HS256)
25    pub jwt_secret: String,
26    /// 跳过认证的路径列表(如 `/public/*`、`/api/login`)
27    pub ignore_paths: Vec<String>,
28    /// 缓存层引用(用于缓存用户信息)
29    #[cfg(feature = "cache")]
30    pub cache: Option<alun_cache::SharedCache>,
31}
32
33impl<S> Layer<S> for AuthLayer {
34    type Service = AuthService<S>;
35
36    fn layer(&self, inner: S) -> Self::Service {
37        AuthService {
38            inner,
39            jwt_secret: self.jwt_secret.clone(),
40            ignore_paths: self.ignore_paths.iter().cloned().collect(),
41            #[cfg(feature = "cache")]
42            cache: self.cache.clone(),
43        }
44    }
45}
46
47#[derive(Clone)]
48/// JWT 认证服务——验证 Bearer Token 并注入用户上下文到请求 extensions
49pub struct AuthService<S> {
50    /// 下游服务
51    inner: S,
52    /// JWT 密钥(HS256)
53    jwt_secret: String,
54    /// 跳过认证的路径集合
55    ignore_paths: HashSet<String>,
56    /// 缓存层引用(用于缓存用户信息)
57    #[cfg(feature = "cache")]
58    cache: Option<alun_cache::SharedCache>,
59}
60
61impl<S> Service<Request<Body>> for AuthService<S>
62where
63    S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
64    S::Future: Send + 'static,
65{
66    type Response = S::Response;
67    type Error = S::Error;
68    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
69
70    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71        self.inner.poll_ready(cx)
72    }
73
74    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
75        let path = req.uri().path().to_string();
76        let is_ignore_path = self.ignore_paths.contains(&path);
77
78        let token_opt: Option<&str> = req.headers()
79            .get(axum::http::header::AUTHORIZATION)
80            .and_then(|v| v.to_str().ok())
81            .and_then(|v| v.strip_prefix("Bearer "));
82
83        match token_opt {
84            Some(token) => match validate_and_extract_claims(&self.jwt_secret, token) {
85                Ok(claims) => {
86                    #[cfg(feature = "cache")]
87                    let cache = self.cache.clone();
88                    let mut inner = self.inner.clone();
89                    #[allow(unused_variables)]
90                    let is_ignore = is_ignore_path;
91                    Box::pin(async move {
92                        #[cfg(feature = "cache")]
93                        {
94                            // 黑名单检查:仅在非 ignore_path 时拒绝,ignore_path 应放行
95                            if !is_ignore {
96                                if let (Some(ref c), Some(ref jti)) = (&cache, &claims.jti) {
97                                    let key = format!("token:blacklist:{}", jti);
98                                    if let Ok(Some(_)) = alun_cache::Cache::get::<serde_json::Value>(c, &key).await {
99                                        let body = serde_json::to_string(&Res::<()>::fail(
100                                            codes::UNAUTHORIZED, "Token 已登出,请重新登录"
101                                        )).unwrap_or_else(|_| r#"{"code":401,"msg":"Token 已登出,请重新登录"}"#.to_string());
102                                        return Ok(Response::builder()
103                                            .status(StatusCode::UNAUTHORIZED)
104                                            .header("Content-Type", "application/json; charset=utf-8")
105                                            .body(Body::from(body))
106                                            .expect("response body build failed"));
107                                    }
108                                }
109                            }
110                        }
111                        // 有有效 Token,注入用户信息(ignore_path 也注入,以便业务可以获取用户信息)
112                        req.extensions_mut().insert(UserId(claims.sub.clone()));
113                        req.extensions_mut().insert(AuthClaims(claims.clone()));
114                        let mut response = inner.call(req).await?;
115                        response.extensions_mut().insert(AuthClaims(claims));
116                        Ok(response)
117                    })
118                }
119                Err(_) => {
120                    if is_ignore_path {
121                        // ignore_path 上的无效 Token,忽略错误继续处理,不注入用户信息
122                        let mut inner = self.inner.clone();
123                        Box::pin(async move { inner.call(req).await })
124                    } else {
125                        let body = serde_json::to_string(&Res::<()>::fail(
126                            codes::UNAUTHORIZED, "Token 无效或已过期"
127                        )).unwrap_or_else(|_| r#"{"code":401,"msg":"Token 无效或已过期"}"#.to_string());
128                        Box::pin(async move {
129                            Ok(Response::builder()
130                                .status(StatusCode::UNAUTHORIZED)
131                                .header("Content-Type", "application/json; charset=utf-8")
132                                .body(Body::from(body))
133                                .expect("response body build failed"))
134                        })
135                    }
136                }
137            },
138            None => {
139                if is_ignore_path {
140                    // ignore_path 上没有 Token,直接放行
141                    let mut inner = self.inner.clone();
142                    Box::pin(async move { inner.call(req).await })
143                } else {
144                    let body = serde_json::to_string(&Res::<()>::fail(
145                        codes::UNAUTHORIZED, "未授权访问,请先登录"
146                    )).unwrap_or_else(|_| r#"{"code":401,"msg":"未授权访问,请先登录"}"#.to_string());
147                    Box::pin(async move {
148                        Ok(Response::builder()
149                            .status(StatusCode::UNAUTHORIZED)
150                            .header("Content-Type", "application/json; charset=utf-8")
151                            .body(Body::from(body))
152                            .expect("response body build failed"))
153                    })
154                }
155            }
156        }
157    }
158}
159
160fn validate_and_extract_claims(secret: &str, token: &str) -> Result<TokenClaims, String> {
161    use jsonwebtoken::{decode, DecodingKey, Validation};
162
163    let token_data = decode::<TokenClaims>(
164        token,
165        &DecodingKey::from_secret(secret.as_bytes()),
166        &Validation::default(),
167    )
168    .map_err(|e| format!("Token 验证失败: {}", e))?;
169
170    Ok(token_data.claims)
171}