alun_web/middleware/
auth.rs1use axum::{
4 extract::Request,
5 response::Response,
6 body::Body,
7};
8use axum::http::StatusCode;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::collections::HashSet;
13use tower::{Layer, Service};
14use crate::response::{Res, codes};
15use super::{UserId, AuthClaims, TokenClaims};
16
17#[derive(Clone)]
23pub struct AuthLayer {
24 pub jwt_secret: String,
26 pub ignore_paths: Vec<String>,
28 #[cfg(feature = "cache")]
30 pub cache: Option<alun_cache::SharedCache>,
31}
32
33impl<S> Layer<S> for AuthLayer {
34 type Service = AuthService<S>;
35
36 fn layer(&self, inner: S) -> Self::Service {
37 AuthService {
38 inner,
39 jwt_secret: self.jwt_secret.clone(),
40 ignore_paths: self.ignore_paths.iter().cloned().collect(),
41 #[cfg(feature = "cache")]
42 cache: self.cache.clone(),
43 }
44 }
45}
46
47#[derive(Clone)]
48pub struct AuthService<S> {
50 inner: S,
52 jwt_secret: String,
54 ignore_paths: HashSet<String>,
56 #[cfg(feature = "cache")]
58 cache: Option<alun_cache::SharedCache>,
59}
60
61impl<S> Service<Request<Body>> for AuthService<S>
62where
63 S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
64 S::Future: Send + 'static,
65{
66 type Response = S::Response;
67 type Error = S::Error;
68 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
69
70 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71 self.inner.poll_ready(cx)
72 }
73
74 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
75 let path = req.uri().path().to_string();
76 let is_ignore_path = self.ignore_paths.contains(&path);
77
78 let token_opt: Option<&str> = req.headers()
79 .get(axum::http::header::AUTHORIZATION)
80 .and_then(|v| v.to_str().ok())
81 .and_then(|v| v.strip_prefix("Bearer "));
82
83 match token_opt {
84 Some(token) => match validate_and_extract_claims(&self.jwt_secret, token) {
85 Ok(claims) => {
86 #[cfg(feature = "cache")]
87 let cache = self.cache.clone();
88 let mut inner = self.inner.clone();
89 #[allow(unused_variables)]
90 let is_ignore = is_ignore_path;
91 Box::pin(async move {
92 #[cfg(feature = "cache")]
93 {
94 if !is_ignore {
96 if let (Some(ref c), Some(ref jti)) = (&cache, &claims.jti) {
97 let key = format!("token:blacklist:{}", jti);
98 if let Ok(Some(_)) = alun_cache::Cache::get::<serde_json::Value>(c, &key).await {
99 let body = serde_json::to_string(&Res::<()>::fail(
100 codes::UNAUTHORIZED, "Token 已登出,请重新登录"
101 )).unwrap_or_else(|_| r#"{"code":401,"msg":"Token 已登出,请重新登录"}"#.to_string());
102 return Ok(Response::builder()
103 .status(StatusCode::UNAUTHORIZED)
104 .header("Content-Type", "application/json; charset=utf-8")
105 .body(Body::from(body))
106 .expect("response body build failed"));
107 }
108 }
109 }
110 }
111 req.extensions_mut().insert(UserId(claims.sub.clone()));
113 req.extensions_mut().insert(AuthClaims(claims.clone()));
114 let mut response = inner.call(req).await?;
115 response.extensions_mut().insert(AuthClaims(claims));
116 Ok(response)
117 })
118 }
119 Err(_) => {
120 if is_ignore_path {
121 let mut inner = self.inner.clone();
123 Box::pin(async move { inner.call(req).await })
124 } else {
125 let body = serde_json::to_string(&Res::<()>::fail(
126 codes::UNAUTHORIZED, "Token 无效或已过期"
127 )).unwrap_or_else(|_| r#"{"code":401,"msg":"Token 无效或已过期"}"#.to_string());
128 Box::pin(async move {
129 Ok(Response::builder()
130 .status(StatusCode::UNAUTHORIZED)
131 .header("Content-Type", "application/json; charset=utf-8")
132 .body(Body::from(body))
133 .expect("response body build failed"))
134 })
135 }
136 }
137 },
138 None => {
139 if is_ignore_path {
140 let mut inner = self.inner.clone();
142 Box::pin(async move { inner.call(req).await })
143 } else {
144 let body = serde_json::to_string(&Res::<()>::fail(
145 codes::UNAUTHORIZED, "未授权访问,请先登录"
146 )).unwrap_or_else(|_| r#"{"code":401,"msg":"未授权访问,请先登录"}"#.to_string());
147 Box::pin(async move {
148 Ok(Response::builder()
149 .status(StatusCode::UNAUTHORIZED)
150 .header("Content-Type", "application/json; charset=utf-8")
151 .body(Body::from(body))
152 .expect("response body build failed"))
153 })
154 }
155 }
156 }
157 }
158}
159
160fn validate_and_extract_claims(secret: &str, token: &str) -> Result<TokenClaims, String> {
161 use jsonwebtoken::{decode, DecodingKey, Validation};
162
163 let token_data = decode::<TokenClaims>(
164 token,
165 &DecodingKey::from_secret(secret.as_bytes()),
166 &Validation::default(),
167 )
168 .map_err(|e| format!("Token 验证失败: {}", e))?;
169
170 Ok(token_data.claims)
171}