use axum::{
extract::{FromRequestParts, Request, State},
http::{StatusCode, request::Parts},
middleware::Next,
response::{Html, Response},
};
use axum_extra::extract::CookieJar;
use cookie::{Cookie, SameSite};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::sync::{Arc, LazyLock};
const JWT_EXPIRATION_HOURS: i64 = 24 * 7;
const AXUM_BOOTSTRAP_TOKEN: &str = "axum-boostrap-token";
pub const LOGOUT_COOKIE: LazyLock<Cookie<'_>> = LazyLock::new(|| {
Cookie::build((AXUM_BOOTSTRAP_TOKEN, ""))
.path("/")
.max_age(time::Duration::seconds(-1))
.same_site(SameSite::Lax)
.http_only(true)
.build()
});
#[derive(Clone)]
pub struct JwtConfig {
pub encoding_key: EncodingKey,
pub decoding_key: DecodingKey,
}
impl JwtConfig {
pub fn new(secret: &str) -> Self {
Self {
encoding_key: EncodingKey::from_secret(secret.as_bytes()),
decoding_key: DecodingKey::from_secret(secret.as_bytes()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims<T = ClaimsPayload> {
pub payload: T, pub exp: usize, pub iat: usize, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaimsPayload {
pub username: String,
}
impl<T> Claims<T> {
pub fn new(payload: T) -> Self {
let now = chrono::Utc::now();
let exp = (now + chrono::Duration::hours(JWT_EXPIRATION_HOURS)).timestamp() as usize;
let iat = now.timestamp() as usize;
Claims { payload, exp, iat }
}
pub(crate) fn encode(&self, config: &JwtConfig) -> Result<String, jsonwebtoken::errors::Error>
where
T: Serialize,
{
encode(&Header::default(), self, &config.encoding_key)
}
pub fn to_cookie<'a>(&self, jwt_config: &JwtConfig) -> Result<Cookie<'a>, jsonwebtoken::errors::Error>
where
T: Serialize,
{
let token = self.encode(jwt_config)?;
Ok(Cookie::build((AXUM_BOOTSTRAP_TOKEN, token))
.path("/")
.max_age(time::Duration::days(7))
.same_site(SameSite::Lax)
.http_only(true)
.build())
}
pub fn decode(token: &str, config: &JwtConfig) -> Result<Self, jsonwebtoken::errors::Error>
where
T: for<'de> Deserialize<'de>,
{
let validation = Validation::default();
let token_data = decode::<Claims<T>>(token, &config.decoding_key, &validation)?;
Ok(token_data.claims)
}
}
pub async fn jwt_auth_middleware<T>(
State(config): State<Arc<JwtConfig>>, cookie_jar: CookieJar, mut request: Request, next: Next,
) -> Result<Response, (StatusCode, Html<String>)>
where
T: for<'de> Deserialize<'de> + Send + Sync + 'static,
T: Clone,
{
let token = cookie_jar
.get(AXUM_BOOTSTRAP_TOKEN)
.map(|cookie| cookie.value().to_string())
.ok_or((StatusCode::UNAUTHORIZED, Html("Missing token".to_string())))?;
let claims = Claims::<T>::decode(&token, &config).map_err(|e| {
log::error!("JWT验证失败: {:?}", e);
(StatusCode::UNAUTHORIZED, Html("Invalid token".to_string()))
})?;
request.extensions_mut().insert(claims);
Ok(next.run(request).await)
}
impl<S, T> FromRequestParts<S> for Claims<T>
where
S: Send + Sync,
T: Send + Sync + 'static,
T: Clone,
{
type Rejection = (StatusCode, Html<String>);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let claims = parts
.extensions
.get::<Claims<T>>()
.ok_or((StatusCode::UNAUTHORIZED, Html("Missing or invalid token".to_string())))?;
Ok(claims.clone())
}
}