use crate::client::http_middleware::{
HttpMiddleware, HttpMiddlewareContext, HttpRequest, HttpResponse,
};
use crate::error::{Error, Result};
use async_trait::async_trait;
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone)]
pub struct BearerToken {
pub token: String,
pub token_type: String,
pub expires_at: Option<SystemTime>,
}
impl BearerToken {
pub fn new(token: String) -> Self {
Self {
token,
token_type: "Bearer".to_string(),
expires_at: None,
}
}
pub fn with_expiry(token: String, expires_in: Duration) -> Self {
Self {
token,
token_type: "Bearer".to_string(),
expires_at: Some(SystemTime::now() + expires_in),
}
}
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
SystemTime::now() >= expires_at
} else {
false
}
}
pub fn expires_soon(&self, threshold: Duration) -> bool {
if let Some(expires_at) = self.expires_at {
if let Ok(remaining) = expires_at.duration_since(SystemTime::now()) {
remaining < threshold
} else {
true }
} else {
false
}
}
pub fn to_header_value(&self) -> String {
format!("{} {}", self.token_type, self.token)
}
}
pub struct OAuthClientMiddleware {
token: Arc<RwLock<BearerToken>>,
check_expiry: bool,
refresh_threshold: Duration,
}
impl OAuthClientMiddleware {
pub fn new(token: BearerToken) -> Self {
Self {
token: Arc::new(RwLock::new(token)),
check_expiry: true,
refresh_threshold: Duration::from_secs(60), }
}
pub fn without_expiry_check(token: BearerToken) -> Self {
Self {
token: Arc::new(RwLock::new(token)),
check_expiry: false,
refresh_threshold: Duration::from_secs(60),
}
}
pub fn with_refresh_threshold(mut self, threshold: Duration) -> Self {
self.refresh_threshold = threshold;
self
}
pub fn update_token(&self, token: BearerToken) {
*self.token.write() = token;
}
pub fn get_token(&self) -> BearerToken {
self.token.read().clone()
}
fn needs_refresh(&self) -> bool {
if !self.check_expiry {
return false;
}
let token = self.token.read();
token.is_expired() || token.expires_soon(self.refresh_threshold)
}
}
impl std::fmt::Debug for OAuthClientMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthClientMiddleware")
.field("check_expiry", &self.check_expiry)
.field("refresh_threshold", &self.refresh_threshold)
.field("token_expired", &self.token.read().is_expired())
.finish()
}
}
#[async_trait]
impl HttpMiddleware for OAuthClientMiddleware {
async fn on_request(
&self,
request: &mut HttpRequest,
context: &HttpMiddlewareContext,
) -> Result<()> {
if context.get_metadata("auth_already_set").is_some() {
tracing::debug!(
"Skipping OAuth middleware - auth already set by transport auth_provider"
);
return Ok(());
}
if request.has_header("Authorization") {
tracing::warn!(
"Authorization header already present - skipping OAuth middleware injection. \
Check for duplicate auth configuration."
);
return Ok(());
}
if self.needs_refresh() {
return Err(Error::authentication(
"OAuth token expired or expiring soon - refresh required",
));
}
let token = self.token.read();
request.add_header("Authorization", &token.to_header_value());
tracing::trace!("OAuth token injected into Authorization header");
Ok(())
}
async fn on_response(
&self,
response: &mut HttpResponse,
context: &HttpMiddlewareContext,
) -> Result<()> {
if response.status == 401 || response.status == 403 {
context.set_metadata("auth_failure".to_string(), "true".to_string());
context.set_metadata("status_code".to_string(), response.status.to_string());
if context.get_metadata("oauth.retry_used").is_some() {
tracing::warn!("Authentication failed after OAuth retry - token may be invalid");
}
return Err(Error::authentication(format!(
"Authentication failed with status {}",
response.status
)));
}
Ok(())
}
async fn on_error(&self, error: &Error, context: &HttpMiddlewareContext) -> Result<()> {
if matches!(error, Error::Authentication(_)) {
tracing::error!(
"OAuth authentication error for {} {}: {}",
context.method,
context.url,
error
);
if self.token.read().is_expired() {
tracing::error!("OAuth token was expired at time of error");
}
}
Ok(())
}
fn priority(&self) -> i32 {
10 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bearer_token_creation() {
let token = BearerToken::new("test-token-123".to_string());
assert_eq!(token.token, "test-token-123");
assert_eq!(token.token_type, "Bearer");
assert!(token.expires_at.is_none());
assert!(!token.is_expired());
}
#[test]
fn test_bearer_token_with_expiry() {
let token = BearerToken::with_expiry(
"test-token".to_string(),
Duration::from_secs(3600), );
assert!(!token.is_expired());
assert!(!token.expires_soon(Duration::from_secs(120))); }
#[test]
fn test_bearer_token_header_value() {
let token = BearerToken::new("abc123".to_string());
assert_eq!(token.to_header_value(), "Bearer abc123");
}
#[test]
fn test_oauth_middleware_creation() {
let token = BearerToken::new("test-token".to_string());
let middleware = OAuthClientMiddleware::new(token);
assert!(middleware.check_expiry);
}
#[test]
fn test_oauth_middleware_token_update() {
let token1 = BearerToken::new("token1".to_string());
let middleware = OAuthClientMiddleware::new(token1);
let token2 = BearerToken::new("token2".to_string());
middleware.update_token(token2);
let current = middleware.get_token();
assert_eq!(current.token, "token2");
}
#[tokio::test]
async fn test_oauth_middleware_injects_header() {
let token = BearerToken::new("my-secret-token".to_string());
let middleware = OAuthClientMiddleware::new(token);
let mut request =
HttpRequest::new("POST".to_string(), "http://example.com".to_string(), vec![]);
let context =
HttpMiddlewareContext::new("http://example.com".to_string(), "POST".to_string());
middleware.on_request(&mut request, &context).await.unwrap();
assert_eq!(
request.get_header("Authorization"),
Some("Bearer my-secret-token")
);
}
#[tokio::test]
async fn test_oauth_middleware_detects_401() {
let token = BearerToken::new("token".to_string());
let middleware = OAuthClientMiddleware::new(token);
let mut response = HttpResponse::new(401, vec![]);
let context =
HttpMiddlewareContext::new("http://example.com".to_string(), "GET".to_string());
let result = middleware.on_response(&mut response, &context).await;
assert!(result.is_err());
assert_eq!(
context.get_metadata("auth_failure"),
Some("true".to_string())
);
}
}