use std::{
sync::Arc,
task::{Context, Poll},
time::{SystemTime, UNIX_EPOCH},
};
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::HeaderMap;
use axum::{
async_trait,
body::Body,
headers::{authorization::Bearer, Authorization, HeaderMapExt},
http::Request,
response::{IntoResponse, Response},
};
use futures_util::future::BoxFuture;
use jsonwebtoken::{DecodingKey, EncodingKey, Validation};
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use crate::res::Res;
#[must_use]
#[derive(Debug, Clone)]
pub struct Jwt<T: JwtToken>(pub T);
#[async_trait]
impl<T, S> FromRequestParts<S> for Jwt<T>
where
T: JwtToken,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let claims = auth_token::<T>(&parts.headers)?;
Ok(Jwt(claims))
}
}
fn auth_token<T: JwtToken>(header: &HeaderMap) -> Result<T, Response> {
let auth = header
.typed_get::<Authorization<Bearer>>()
.ok_or(Res::msg(401, "请求未携带有效token").into_response())?;
let claims = T::decode(auth.token()).map_err(|err| err.into_response())?;
Ok(claims)
}
#[derive(Clone)]
pub struct JwtAuth<T> {
filter: Arc<Vec<&'static str>>,
claims: Arc<Option<T>>,
}
impl<T> JwtAuth<T>
where
T: JwtToken,
{
pub fn new(filter: Vec<&'static str>) -> Self {
Self {
filter: Arc::new(filter),
claims: Arc::new(None),
}
}
}
impl<S, T> Layer<S> for JwtAuth<T>
where
T: JwtToken,
{
type Service = JwtAuthService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
JwtAuthService {
inner,
filter: self.filter.clone(),
claims: self.claims.clone(),
}
}
}
#[allow(dead_code)]
#[derive(Clone)]
pub struct JwtAuthService<S, T> {
inner: S,
filter: Arc<Vec<&'static str>>,
claims: Arc<Option<T>>,
}
impl<S, T> Service<Request<Body>> for JwtAuthService<S, T>
where
T: JwtToken + Default + Sync + Send + 'static,
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let mut response = None;
if !self.filter.contains(&req.uri().path()) {
match auth_token::<T>(req.headers()) {
Ok(claims) => {
req.extensions_mut().insert(claims);
}
Err(err_res) => response = Some(err_res),
}
}
let future = self.inner.call(req);
Box::pin(async move {
let response = match response {
Some(v) => v,
None => future.await?,
};
Ok(response)
})
}
}
pub trait JwtToken
where
Self: Serialize + for<'a> Deserialize<'a>,
{
const SECRET: &'static str = "my_key";
const DURATION: u64 = 60 * 60 * 24 * 15;
fn encode(&self) -> Result<String, Res<()>> {
let res = jsonwebtoken::encode(
&jsonwebtoken::Header::default(),
self,
&EncodingKey::from_secret(Self::SECRET.as_bytes()),
);
res.map_err(|err| Res::error(err.to_string()))
}
fn decode(token: &str) -> Result<Self, Res<()>> {
let res = jsonwebtoken::decode::<Self>(
token,
&DecodingKey::from_secret(Self::SECRET.as_bytes()),
&Validation::default(),
);
match res {
Ok(res) => Ok(res.claims),
Err(err) => Err(Res::msg(401, err.to_string())),
}
}
fn expiration() -> u64 {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
timestamp + Self::DURATION
}
}