use std::sync::Arc;
use async_trait::async_trait;
use reqwest::RequestBuilder;
use crate::error::AuthError;
#[async_trait]
pub trait OutboundAuthProvider: Send + Sync {
async fn authorize(
&self,
request: RequestBuilder,
audience: &str,
) -> Result<RequestBuilder, AuthError>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoOpOutboundAuthProvider;
#[async_trait]
impl OutboundAuthProvider for NoOpOutboundAuthProvider {
async fn authorize(
&self,
request: RequestBuilder,
_audience: &str,
) -> Result<RequestBuilder, AuthError> {
Ok(request)
}
}
#[derive(Debug, Clone)]
pub struct StaticBearerOutboundAuthProvider {
token: String,
}
impl StaticBearerOutboundAuthProvider {
pub fn new(token: impl Into<String>) -> Self {
Self {
token: token.into(),
}
}
}
#[async_trait]
impl OutboundAuthProvider for StaticBearerOutboundAuthProvider {
async fn authorize(
&self,
request: RequestBuilder,
_audience: &str,
) -> Result<RequestBuilder, AuthError> {
Ok(request.header("Authorization", format!("Bearer {}", self.token)))
}
}
pub fn provider_from_token(token: Option<&str>) -> Arc<dyn OutboundAuthProvider> {
match token.filter(|t| !t.trim().is_empty()) {
Some(t) => Arc::new(StaticBearerOutboundAuthProvider::new(t.to_string())),
None => Arc::new(NoOpOutboundAuthProvider),
}
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::Client;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn noop_provider_does_not_modify_request() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/x"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
let client = Client::new();
let request = client.post(format!("{}/x", server.uri()));
let request = NoOpOutboundAuthProvider
.authorize(request, &server.uri())
.await
.unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let received = &server.received_requests().await.unwrap()[0];
assert!(received.headers.get("authorization").is_none());
}
#[tokio::test]
async fn static_bearer_provider_adds_authorization_header() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/x"))
.and(header("Authorization", "Bearer test-token"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
let provider = StaticBearerOutboundAuthProvider::new("test-token");
let client = Client::new();
let request = client.post(format!("{}/x", server.uri()));
let request = provider.authorize(request, &server.uri()).await.unwrap();
let response = request.send().await.unwrap();
assert!(
response.status().is_success(),
"request reached the matcher with bearer token"
);
}
#[tokio::test]
async fn static_bearer_appends_alongside_existing_headers() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/x"))
.and(header("X-Custom", "value"))
.and(header("Authorization", "Bearer abc"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
let provider = StaticBearerOutboundAuthProvider::new("abc");
let client = Client::new();
let request = client
.post(format!("{}/x", server.uri()))
.header("X-Custom", "value");
let request = provider.authorize(request, &server.uri()).await.unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
}
#[test]
fn provider_from_token_returns_noop_when_none() {
let provider = provider_from_token(None);
let _ = provider; }
#[tokio::test]
async fn provider_from_token_returns_static_when_some() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/y"))
.and(header("Authorization", "Bearer xyz"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
let provider = provider_from_token(Some("xyz"));
let client = Client::new();
let request = client.post(format!("{}/y", server.uri()));
let request = provider.authorize(request, &server.uri()).await.unwrap();
assert!(request.send().await.unwrap().status().is_success());
}
#[tokio::test]
async fn provider_from_token_treats_empty_string_as_none() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/y"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
let provider = provider_from_token(Some(" "));
let client = Client::new();
let request = client.post(format!("{}/y", server.uri()));
let request = provider.authorize(request, &server.uri()).await.unwrap();
request.send().await.unwrap();
let received = &server.received_requests().await.unwrap()[0];
assert!(received.headers.get("authorization").is_none());
}
}