#![warn(missing_docs)]
use anyhow::anyhow;
use http::Extensions;
use reqwest_middleware::reqwest::header::HeaderValue;
use reqwest_middleware::reqwest::header::AUTHORIZATION;
use reqwest_middleware::reqwest::Request;
use reqwest_middleware::reqwest::Response;
use reqwest_middleware::Error;
use reqwest_middleware::Middleware;
use reqwest_middleware::Next;
use std::sync::Arc;
use token_source::TokenSource;
pub struct AuthorizationHeaderMiddleware {
ts: Arc<dyn TokenSource>,
}
impl From<Arc<dyn TokenSource>> for AuthorizationHeaderMiddleware {
fn from(ts: Arc<dyn TokenSource>) -> Self {
Self { ts }
}
}
impl From<Box<dyn TokenSource>> for AuthorizationHeaderMiddleware {
fn from(ts: Box<dyn TokenSource>) -> Self {
Self { ts: ts.into() }
}
}
#[async_trait::async_trait]
impl Middleware for AuthorizationHeaderMiddleware {
async fn handle(
&self,
mut req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
let auth_token = self
.ts
.token()
.await
.map_err(|e| Error::Middleware(anyhow!(e.to_string())))?;
req.headers_mut().insert(
AUTHORIZATION,
HeaderValue::from_str(auth_token.as_str())
.map_err(|e| Error::Middleware(anyhow!(format!("Invalid auth token value: {e}"))))?,
);
next.run(req, extensions).await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use http::Extensions;
use reqwest_middleware::reqwest;
use reqwest_middleware::ClientBuilder;
use reqwest_middleware::Middleware;
use token_source::{TokenSource, TokenSourceProvider};
use super::AuthorizationHeaderMiddleware;
use reqwest_middleware::reqwest::header::HeaderValue;
use reqwest_middleware::reqwest::header::AUTHORIZATION;
use reqwest_middleware::reqwest::Request;
use reqwest_middleware::reqwest::Response;
use reqwest_middleware::Next;
#[derive(Debug)]
struct MyTokenSource {
pub token: String,
}
#[async_trait::async_trait]
impl TokenSource for MyTokenSource {
async fn token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(self.token.clone())
}
}
#[derive(Debug)]
struct MyTokenProvider {
pub ts: Arc<MyTokenSource>,
}
impl TokenSourceProvider for MyTokenProvider {
fn token_source(&self) -> Arc<dyn TokenSource> {
self.ts.clone()
}
}
struct VerificationMiddleware {
expected: &'static str,
}
#[async_trait::async_trait]
impl Middleware for VerificationMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
let token_value = req
.headers()
.get(AUTHORIZATION)
.expect("Authorization header should be set");
assert_eq!(token_value, &HeaderValue::from_static(self.expected));
next.run(req, extensions).await
}
}
#[async_std::test]
async fn test_middleware() {
let token_value = "Bearer my-token";
let ts_provider = MyTokenProvider {
ts: Arc::new(MyTokenSource {
token: token_value.to_string(),
}),
};
let auth_middleware = AuthorizationHeaderMiddleware::from(ts_provider.token_source());
let verification_middleware = VerificationMiddleware { expected: token_value };
let client = ClientBuilder::new(reqwest::Client::default())
.with(auth_middleware)
.with(verification_middleware)
.build();
let _ = client
.get("https://github.com/nicolas-vivot/reqwest-auth/CODE_OF_CONDUCT.md")
.send()
.await;
}
}