use std::{
sync::Arc,
task::{Context, Poll},
time::{SystemTime, UNIX_EPOCH},
};
use axum::{
body::Body,
headers::{authorization::Bearer, Authorization, HeaderMapExt},
http::Request,
response::{IntoResponse, Response},
};
use futures_util::future::BoxFuture;
use jsonwebtoken::{DecodingKey, EncodingKey, Validation};
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use crate::res::Res;
#[derive(Clone)]
pub struct JwtAuth<T> {
filter: Arc<Vec<&'static str>>,
claims: Arc<T>,
}
impl<T> JwtAuth<T>
where
T: Default + JwtToken,
{
#[allow(dead_code)]
pub fn new(filter: Vec<&'static str>) -> Self {
Self {
filter: Arc::new(filter),
claims: Arc::new(T::default()),
}
}
}
impl<S, T> Layer<S> for JwtAuth<T>
where
T: Default + JwtToken,
{
type Service = JwtAuthService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
JwtAuthService {
inner,
filter: self.filter.clone(),
claims: self.claims.clone(),
}
}
}
#[derive(Clone)]
pub struct JwtAuthService<S, T> {
inner: S,
filter: Arc<Vec<&'static str>>,
claims: Arc<T>,
}
impl<S, T> Service<Request<Body>> for JwtAuthService<S, T>
where
T: JwtToken + Sync + Send + 'static,
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if !self.filter.contains(&req.uri().path()) {
let auth = match req.headers().typed_get::<Authorization<Bearer>>() {
Some(v) => v,
None => {
return Box::pin(async {
Ok(Res::<()>::auth("请求未携带token").into_response())
})
}
};
let claims = match self.claims.decode(auth.token()) {
Ok(v) => v,
Err(err) => return Box::pin(async { Ok(err.into_response()) }),
};
req.extensions_mut().insert(claims);
}
let future = self.inner.call(req);
Box::pin(async move {
let response: Response = future.await?;
Ok(response)
})
}
}
pub trait JwtToken
where
Self: Serialize + for<'a> Deserialize<'a>,
{
fn encode(&self) -> Result<String, Res<()>> {
let res = jsonwebtoken::encode(
&jsonwebtoken::Header::default(),
self,
&EncodingKey::from_secret(Self::secret().as_bytes()),
);
match res {
Ok(res) => Ok(res),
Err(err) => Err(Res::error(err.to_string())),
}
}
fn decode(&self, token: &str) -> Result<Self, Res<()>> {
let res = jsonwebtoken::decode::<Self>(
token,
&DecodingKey::from_secret(Self::secret().as_bytes()),
&Validation::default(),
);
match res {
Ok(res) => Ok(res.claims),
Err(err) => Err(Res::auth(err.to_string())),
}
}
fn secret() -> &'static str {
"mykey"
}
fn duration() -> u64 {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
timestamp + 60 * 60 * 24 * 15
}
}