use std::{sync::Arc, time::SystemTime};
use async_trait::async_trait;
use reqwest::{
header::{HeaderName, HeaderValue, AUTHORIZATION},
Request, Response,
};
use reqwest_middleware::Next;
use crate::{digest, Extensions, Middleware};
#[derive(Default)]
pub(crate) struct SignatureMiddleware;
#[async_trait]
impl Middleware for SignatureMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response, reqwest_middleware::Error> {
let mut req = req;
if let Some(signatue) = extensions.get::<Arc<dyn ApiSignature>>() {
req = signatue.sign(req).await?;
}
next.run(req, extensions).await
}
}
#[async_trait]
pub trait TokenProvider: 'static + Send + Sync {
async fn generate_token(&self) -> Result<String, reqwest_middleware::Error>;
}
#[async_trait]
impl<F, T> TokenProvider for F
where
F: 'static + Send + Sync,
F: Fn() -> Result<T, reqwest_middleware::Error>,
T: ToString,
{
async fn generate_token(&self) -> Result<String, reqwest_middleware::Error> {
self().map(|t| t.to_string())
}
}
#[async_trait]
pub trait ApiSignature: TokenProvider + std::fmt::Debug {
fn get_carrier(&self) -> &Carrier;
async fn sign(&self, req: Request) -> Result<Request, reqwest_middleware::Error> {
let token = self.generate_token().await?;
Ok(self.get_carrier().apply(req, token))
}
}
pub trait SignatureTrait {
fn with_header_name(self, name: impl ToString) -> Self;
fn with_query_param(self, name: impl ToString) -> Self;
}
#[derive(Debug)]
pub enum Carrier {
BearerAuth,
Header(String),
QueryParam(String),
}
impl Default for Carrier {
fn default() -> Self {
Self::BearerAuth
}
}
impl Carrier {
fn apply(&self, req: Request, token: impl ToString) -> Request {
let mut req = req;
let token = token.to_string();
match self {
Carrier::BearerAuth => {
req.headers_mut().remove(AUTHORIZATION);
req.headers_mut().append(
AUTHORIZATION,
HeaderValue::try_from(format!("Bearer {}", token)).unwrap(),
);
}
Carrier::Header(name) => {
req.headers_mut().append(
HeaderName::try_from(name.as_str()).unwrap(),
HeaderValue::try_from(token).unwrap(),
);
}
Carrier::QueryParam(name) => {
req.url_mut()
.query_pairs_mut()
.append_pair(name.as_str(), &token);
}
}
req
}
}
pub enum AccessToken {
Fixed(String),
Dynamic(Arc<dyn TokenProvider>),
}
impl std::fmt::Debug for AccessToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fixed(_) => f.debug_tuple("Fixed").finish(),
Self::Dynamic(_) => f.debug_tuple("Dynamic").finish(),
}
}
}
#[derive(Debug)]
pub struct AccessTokenSignature {
access_token: AccessToken,
carrier: Carrier,
}
impl AccessTokenSignature {
pub fn new(access_token: impl ToString) -> Self {
Self {
access_token: AccessToken::Fixed(access_token.to_string()),
carrier: Carrier::default(),
}
}
pub fn new_dynamic(provider: impl TokenProvider) -> Self {
Self {
access_token: AccessToken::Dynamic(Arc::new(provider)),
carrier: Carrier::default(),
}
}
}
#[async_trait]
impl ApiSignature for AccessTokenSignature {
fn get_carrier(&self) -> &Carrier {
&self.carrier
}
}
#[async_trait]
impl TokenProvider for AccessTokenSignature {
async fn generate_token(&self) -> Result<String, reqwest_middleware::Error> {
match &self.access_token {
AccessToken::Fixed(token) => Ok(token.clone()),
AccessToken::Dynamic(provider) => provider.generate_token().await,
}
}
}
impl SignatureTrait for AccessTokenSignature {
fn with_header_name(self, name: impl ToString) -> Self {
Self {
carrier: Carrier::Header(name.to_string()),
..self
}
}
fn with_query_param(self, name: impl ToString) -> Self {
Self {
carrier: Carrier::QueryParam(name.to_string()),
..self
}
}
}
#[derive(Debug)]
pub enum HashAlgorithm {
Md5,
Sha1,
Sha256,
}
impl HashAlgorithm {
pub fn apply(&self, input: impl AsRef<[u8]>) -> String {
match self {
Self::Md5 => digest::md5(input),
Self::Sha1 => digest::sha1(input),
Self::Sha256 => digest::sha256(input),
}
}
}
impl From<String> for HashAlgorithm {
fn from(s: String) -> Self {
s.as_str().into()
}
}
impl From<&str> for HashAlgorithm {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"sha1" => Self::Sha1,
"md5" => Self::Md5,
_ => Self::Sha1,
}
}
}
#[derive(Debug)]
pub struct HashedTokenSignature {
client_id: Option<String>,
app_id: String,
app_secret: String,
algorithm: HashAlgorithm,
carrier: Carrier,
}
impl HashedTokenSignature {
pub fn new<S: ToString>(app_id: S, app_secret: S) -> Self {
Self::new_with_algorithm(app_id, app_secret, HashAlgorithm::Sha1)
}
pub fn new_with_algorithm<S: ToString>(
app_id: S,
app_secret: S,
algorithm: HashAlgorithm,
) -> Self {
Self {
client_id: None,
app_id: app_id.to_string(),
app_secret: app_secret.to_string(),
algorithm,
carrier: Carrier::default(),
}
}
pub fn new_with_client_id<S: ToString>(
client_id: S,
app_id: S,
app_secret: S,
algorithm: HashAlgorithm,
) -> Self {
Self {
client_id: match client_id.to_string() {
id if id.is_empty() => None,
id => Some(id),
},
app_id: app_id.to_string(),
app_secret: app_secret.to_string(),
algorithm,
carrier: Carrier::default(),
}
}
fn generate_token_at(&self, timestamp: u64) -> String {
let plain = format!("{}{}{}", &self.app_id, &self.app_secret, timestamp);
let sign = self.algorithm.apply(plain);
let compose = match &self.client_id {
Some(client_id) => format!("{},{},{},{}", client_id, &self.app_id, timestamp, sign),
None => format!("{},{},{}", &self.app_id, timestamp, sign),
};
digest::encode_base64(compose)
}
}
#[async_trait]
impl ApiSignature for HashedTokenSignature {
fn get_carrier(&self) -> &Carrier {
&self.carrier
}
}
#[async_trait]
impl TokenProvider for HashedTokenSignature {
async fn generate_token(&self) -> Result<String, reqwest_middleware::Error> {
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
Ok(self.generate_token_at(timestamp))
}
}
impl SignatureTrait for HashedTokenSignature {
fn with_header_name(self, name: impl ToString) -> Self {
Self {
carrier: Carrier::Header(name.to_string()),
..self
}
}
fn with_query_param(self, name: impl ToString) -> Self {
Self {
carrier: Carrier::QueryParam(name.to_string()),
..self
}
}
}