use axum::{
extract::Request,
response::Response,
body::Body,
};
use axum::http::StatusCode;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::collections::HashSet;
use tower::{Layer, Service};
use crate::response::{Res, codes};
use super::{UserId, AuthClaims, TokenClaims};
#[derive(Clone)]
pub struct AuthLayer {
pub jwt_secret: String,
pub ignore_paths: Vec<String>,
#[cfg(feature = "cache")]
pub cache: Option<alun_cache::SharedCache>,
}
impl<S> Layer<S> for AuthLayer {
type Service = AuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
AuthService {
inner,
jwt_secret: self.jwt_secret.clone(),
ignore_paths: self.ignore_paths.iter().cloned().collect(),
#[cfg(feature = "cache")]
cache: self.cache.clone(),
}
}
}
#[derive(Clone)]
pub struct AuthService<S> {
inner: S,
jwt_secret: String,
ignore_paths: HashSet<String>,
#[cfg(feature = "cache")]
cache: Option<alun_cache::SharedCache>,
}
impl<S> Service<Request<Body>> for AuthService<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let path = req.uri().path().to_string();
let is_ignore_path = self.ignore_paths.contains(&path);
let token_opt: Option<&str> = req.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match token_opt {
Some(token) => match validate_and_extract_claims(&self.jwt_secret, token) {
Ok(claims) => {
#[cfg(feature = "cache")]
let cache = self.cache.clone();
let mut inner = self.inner.clone();
#[allow(unused_variables)]
let is_ignore = is_ignore_path;
Box::pin(async move {
#[cfg(feature = "cache")]
{
if !is_ignore {
if let (Some(ref c), Some(ref jti)) = (&cache, &claims.jti) {
let key = format!("token:blacklist:{}", jti);
if let Ok(Some(_)) = alun_cache::Cache::get::<serde_json::Value>(c, &key).await {
let body = serde_json::to_string(&Res::<()>::fail(
codes::UNAUTHORIZED, "Token 已登出,请重新登录"
)).unwrap_or_else(|_| r#"{"code":401,"msg":"Token 已登出,请重新登录"}"#.to_string());
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("Content-Type", "application/json; charset=utf-8")
.body(Body::from(body))
.expect("response body build failed"));
}
}
}
}
req.extensions_mut().insert(UserId(claims.sub.clone()));
req.extensions_mut().insert(AuthClaims(claims.clone()));
let mut response = inner.call(req).await?;
response.extensions_mut().insert(AuthClaims(claims));
Ok(response)
})
}
Err(_) => {
if is_ignore_path {
let mut inner = self.inner.clone();
Box::pin(async move { inner.call(req).await })
} else {
let body = serde_json::to_string(&Res::<()>::fail(
codes::UNAUTHORIZED, "Token 无效或已过期"
)).unwrap_or_else(|_| r#"{"code":401,"msg":"Token 无效或已过期"}"#.to_string());
Box::pin(async move {
Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("Content-Type", "application/json; charset=utf-8")
.body(Body::from(body))
.expect("response body build failed"))
})
}
}
},
None => {
if is_ignore_path {
let mut inner = self.inner.clone();
Box::pin(async move { inner.call(req).await })
} else {
let body = serde_json::to_string(&Res::<()>::fail(
codes::UNAUTHORIZED, "未授权访问,请先登录"
)).unwrap_or_else(|_| r#"{"code":401,"msg":"未授权访问,请先登录"}"#.to_string());
Box::pin(async move {
Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("Content-Type", "application/json; charset=utf-8")
.body(Body::from(body))
.expect("response body build failed"))
})
}
}
}
}
}
fn validate_and_extract_claims(secret: &str, token: &str) -> Result<TokenClaims, String> {
use jsonwebtoken::{decode, DecodingKey, Validation};
let token_data = decode::<TokenClaims>(
token,
&DecodingKey::from_secret(secret.as_bytes()),
&Validation::default(),
)
.map_err(|e| format!("Token 验证失败: {}", e))?;
Ok(token_data.claims)
}