use std::{any::type_name, num::ParseIntError, string::FromUtf8Error, sync::Arc, time::SystemTime};
use async_trait::async_trait;
use base64::DecodeError;
use http::Extensions;
use reqwest::{
header::{HeaderName, HeaderValue, AUTHORIZATION},
Request, Response,
};
use reqwest_middleware::Next;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::{
digest::{self, decode_base64},
Middleware,
};
#[derive(Default)]
pub(crate) struct AuthenticateMiddleware;
#[async_trait]
impl Middleware for AuthenticateMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response, reqwest_middleware::Error> {
let mut req = req;
if let Some(signatue) = extensions.get::<Arc<dyn ApiAuthenticator>>() {
req = signatue.authenticate(req, extensions).await?;
}
next.run(req, extensions).await
}
}
#[async_trait]
pub trait TokenGenerator: 'static + Send + Sync {
async fn generate_token(&self, req: &Request) -> Result<String, reqwest_middleware::Error>;
}
#[async_trait]
impl<F, T> TokenGenerator for F
where
F: 'static + Send + Sync,
F: Fn() -> Result<T, reqwest_middleware::Error>,
T: ToString,
{
async fn generate_token(&self, _req: &Request) -> Result<String, reqwest_middleware::Error> {
self().map(|t| t.to_string())
}
}
#[async_trait]
pub trait ApiAuthenticator: TokenGenerator {
fn type_name(&self) -> &str {
type_name::<Self>()
}
fn get_carrier(&self) -> &Carrier {
&Carrier::BearerAuth
}
async fn authenticate(
&self,
req: Request,
_extensions: &Extensions,
) -> Result<Request, reqwest_middleware::Error> {
let token = self.generate_token(&req).await?;
Ok(self.get_carrier().apply(req, token))
}
}
#[async_trait]
impl TokenGenerator for Box<dyn ApiAuthenticator> {
async fn generate_token(&self, req: &Request) -> Result<String, reqwest_middleware::Error> {
self.as_ref().generate_token(req).await
}
}
#[async_trait]
impl ApiAuthenticator for Box<dyn ApiAuthenticator> {
fn get_carrier(&self) -> &Carrier {
self.as_ref().get_carrier()
}
async fn authenticate(
&self,
req: Request,
extensions: &Extensions,
) -> Result<Request, reqwest_middleware::Error> {
self.as_ref().authenticate(req, extensions).await
}
}
pub trait WithCarrier {
fn with_carrier(self, carrier: Carrier) -> Self;
fn with_header_name(self, name: impl ToString) -> Self;
fn with_query_param(self, name: impl ToString) -> Self;
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub enum Carrier {
#[default]
BearerAuth,
SchemalessAuth,
Header(String),
QueryParam(String),
}
impl Carrier {
pub fn apply(&self, req: Request, token: impl ToString) -> Request {
let mut req = req;
let token = token.to_string();
match self {
Carrier::BearerAuth => {
req.headers_mut().insert(
AUTHORIZATION,
HeaderValue::try_from(format!("Bearer {}", token)).unwrap(),
);
}
Carrier::SchemalessAuth => {
req.headers_mut()
.insert(AUTHORIZATION, HeaderValue::try_from(token).unwrap());
}
Carrier::Header(name) => {
req.headers_mut().append(
HeaderName::try_from(name.as_str()).unwrap(),
HeaderValue::try_from(token).unwrap(),
);
}
Carrier::QueryParam(name) => {
req.url_mut()
.query_pairs_mut()
.append_pair(name.as_str(), &token);
}
}
req
}
}
pub enum AccessToken {
Fixed(String),
Dynamic(Arc<dyn TokenGenerator>),
}
impl std::fmt::Debug for AccessToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fixed(_) => f.debug_tuple("Fixed").finish(),
Self::Dynamic(_) => f.debug_tuple("Dynamic").finish(),
}
}
}
#[derive(Debug)]
pub struct AccessTokenAuth {
access_token: AccessToken,
carrier: Carrier,
}
impl AccessTokenAuth {
pub fn new(access_token: impl ToString) -> Self {
Self {
access_token: AccessToken::Fixed(access_token.to_string()),
carrier: Carrier::default(),
}
}
pub fn new_dynamic(provider: impl TokenGenerator) -> Self {
Self {
access_token: AccessToken::Dynamic(Arc::new(provider)),
carrier: Carrier::default(),
}
}
}
#[async_trait]
impl ApiAuthenticator for AccessTokenAuth {
fn get_carrier(&self) -> &Carrier {
&self.carrier
}
}
#[async_trait]
impl TokenGenerator for AccessTokenAuth {
async fn generate_token(&self, req: &Request) -> Result<String, reqwest_middleware::Error> {
match &self.access_token {
AccessToken::Fixed(token) => Ok(token.clone()),
AccessToken::Dynamic(provider) => provider.generate_token(req).await,
}
}
}
impl WithCarrier for AccessTokenAuth {
fn with_carrier(self, carrier: Carrier) -> Self {
Self { carrier, ..self }
}
fn with_header_name(self, name: impl ToString) -> Self {
Self {
carrier: Carrier::Header(name.to_string()),
..self
}
}
fn with_query_param(self, name: impl ToString) -> Self {
Self {
carrier: Carrier::QueryParam(name.to_string()),
..self
}
}
}
#[derive(Debug)]
pub enum HashAlgorithm {
Md5,
Sha1,
Sha256,
}
impl HashAlgorithm {
pub fn apply(&self, input: impl AsRef<[u8]>) -> String {
match self {
Self::Md5 => digest::md5(input),
Self::Sha1 => digest::sha1(input),
Self::Sha256 => digest::sha256(input),
}
}
}
impl From<String> for HashAlgorithm {
fn from(s: String) -> Self {
s.as_str().into()
}
}
impl From<&str> for HashAlgorithm {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"sha1" => Self::Sha1,
"md5" => Self::Md5,
"sha256" => Self::Sha256,
_ => Self::Sha1,
}
}
}
#[derive(Debug)]
pub struct HashedTokenAuth {
client_id: Option<String>,
app_id: String,
app_secret: String,
algorithm: HashAlgorithm,
carrier: Carrier,
}
impl HashedTokenAuth {
pub fn new<S: ToString>(app_id: S, app_secret: S) -> Self {
Self::new_with_algorithm(app_id, app_secret, HashAlgorithm::Sha1)
}
pub fn new_with_algorithm<S: ToString>(
app_id: S,
app_secret: S,
algorithm: HashAlgorithm,
) -> Self {
Self {
client_id: None,
app_id: app_id.to_string(),
app_secret: app_secret.to_string(),
algorithm,
carrier: Carrier::default(),
}
}
pub fn new_with_client_id<S: ToString>(
client_id: S,
app_id: S,
app_secret: S,
algorithm: HashAlgorithm,
) -> Self {
Self {
client_id: match client_id.to_string() {
id if id.is_empty() => None,
id => Some(id),
},
app_id: app_id.to_string(),
app_secret: app_secret.to_string(),
algorithm,
carrier: Carrier::default(),
}
}
fn generate_token_at(&self, timestamp: u64) -> String {
let plain = format!("{}{}{}", &self.app_id, &self.app_secret, timestamp);
let sign = self.algorithm.apply(plain);
let compose = match &self.client_id {
Some(client_id) => format!("{},{},{},{}", client_id, &self.app_id, timestamp, sign),
None => format!("{},{},{}", &self.app_id, timestamp, sign),
};
digest::encode_base64(compose)
}
}
#[async_trait]
impl ApiAuthenticator for HashedTokenAuth {
fn get_carrier(&self) -> &Carrier {
&self.carrier
}
}
#[async_trait]
impl TokenGenerator for HashedTokenAuth {
async fn generate_token(&self, _req: &Request) -> Result<String, reqwest_middleware::Error> {
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
Ok(self.generate_token_at(timestamp))
}
}
impl WithCarrier for HashedTokenAuth {
fn with_carrier(self, carrier: Carrier) -> Self {
Self { carrier, ..self }
}
fn with_header_name(self, name: impl ToString) -> Self {
Self {
carrier: Carrier::Header(name.to_string()),
..self
}
}
fn with_query_param(self, name: impl ToString) -> Self {
Self {
carrier: Carrier::QueryParam(name.to_string()),
..self
}
}
}
#[derive(Debug, Error)]
pub enum TokenError {
#[error("{0}")]
Base64(#[from] DecodeError),
#[error("{0}")]
Utf8(#[from] FromUtf8Error),
#[error("Invalid format")]
Format,
#[error("{0}")]
Timestamp(#[from] ParseIntError),
}
#[derive(Debug)]
pub struct ParsedHashedToken {
pub client_id: Option<String>,
pub app_id: String,
pub timestamp: u64,
pub sign: String,
}
impl ParsedHashedToken {
pub fn parse(token: impl AsRef<[u8]>) -> Result<Self, TokenError> {
let token = token.as_ref();
if token.is_empty() {
return Err(TokenError::Format);
}
let composed = decode_base64(token)
.map_err(TokenError::Base64)
.and_then(|b| String::from_utf8(b).map_err(TokenError::Utf8))?;
let terms: Vec<&str> = composed.split(',').collect();
let mut iter = terms.iter();
match terms.len() {
4 => Ok(Self {
client_id: Some(iter.next().unwrap().to_string()),
app_id: iter.next().unwrap().to_string(),
timestamp: iter
.next()
.unwrap()
.parse()
.map_err(TokenError::Timestamp)?,
sign: iter.next().unwrap().to_string(),
}),
3 => Ok(Self {
client_id: None,
app_id: iter.next().unwrap().to_string(),
timestamp: iter
.next()
.unwrap()
.parse()
.map_err(TokenError::Timestamp)?,
sign: iter.next().unwrap().to_string(),
}),
_ => Err(TokenError::Format),
}
}
pub fn is_expired(&self, expires_in_secs: u64, deviation: Option<u64>) -> bool {
let deviation = deviation.unwrap_or(60) as i64;
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let diff = now as i64 - self.timestamp as i64;
diff < -deviation || diff > expires_in_secs as i64 + deviation
}
pub fn is_signed<S, A>(&self, app_secret: S, algorithm: A) -> bool
where
S: std::fmt::Display,
A: Into<HashAlgorithm>,
{
let plain = format!("{}{}{}", self.app_id, app_secret, self.timestamp);
let algorithm: HashAlgorithm = algorithm.into();
let sign = algorithm.apply(plain);
sign == self.sign
}
}
impl TryFrom<&str> for ParsedHashedToken {
type Error = TokenError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Self::parse(value)
}
}
impl TryFrom<String> for ParsedHashedToken {
type Error = TokenError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::parse(value)
}
}