use std::{
future::Future,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
};
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
use tower::{Layer, Service};
use crate::{Body, Error};
#[derive(Clone, Debug)]
pub struct Token {
pub value: String,
pub expires_at: Option<Instant>,
}
impl Token {
pub fn permanent(value: impl Into<String>) -> Self {
Self {
value: value.into(),
expires_at: None,
}
}
pub fn with_ttl(value: impl Into<String>, ttl: Duration) -> Self {
Self {
value: value.into(),
expires_at: Some(Instant::now() + ttl),
}
}
pub fn with_expires_at_secs(value: impl Into<String>, expires_at_secs: u64) -> Self {
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let remaining = expires_at_secs.saturating_sub(now_secs);
Self {
value: value.into(),
expires_at: Some(Instant::now() + Duration::from_secs(remaining)),
}
}
pub fn is_expired(&self) -> bool {
self.expires_at
.map(|exp| Instant::now() >= exp)
.unwrap_or(false)
}
}
#[async_trait::async_trait]
pub trait BearerTokenProvider: Send + Sync {
async fn get_token(&self) -> Result<Token, Error>;
}
struct StaticTokenProvider {
token: String,
}
#[async_trait::async_trait]
impl BearerTokenProvider for StaticTokenProvider {
async fn get_token(&self) -> Result<Token, Error> {
Ok(Token::permanent(&self.token))
}
}
pub struct ApiKeyConfig {
header_name: HeaderName,
header_value: String,
}
pub enum AuthMethod {
Bearer {
token: String,
},
BearerProvider {
provider: Arc<dyn BearerTokenProvider>,
},
ApiKey {
header_name: HeaderName,
header_value: String,
},
Custom {
#[allow(clippy::type_complexity)]
applier: Arc<dyn Fn(&mut HeaderMap) -> Result<(), Error> + Send + Sync>,
},
}
impl AuthMethod {
pub fn bearer(token: impl Into<String>) -> Self {
Self::Bearer {
token: token.into(),
}
}
#[deprecated(
since = "2.5.0",
note = "BearerProvider cannot be used synchronously. Use `cached_token_provider()` or hooks instead."
)]
pub fn bearer_provider(provider: Arc<dyn BearerTokenProvider>) -> Self {
Self::BearerProvider { provider }
}
pub fn try_api_key(name: impl TryInto<HeaderName>, value: impl Into<String>) -> Option<Self> {
Some(Self::ApiKey {
header_name: name.try_into().ok()?,
header_value: value.into(),
})
}
pub fn api_key_with_name(name: HeaderName, value: impl Into<String>) -> Self {
Self::ApiKey {
header_name: name,
header_value: value.into(),
}
}
pub fn api_key(name: impl Into<String>, value: impl Into<String>) -> Self {
Self::ApiKey {
header_name: HeaderName::try_from(name.into())
.expect("invalid header name for API key"),
header_value: value.into(),
}
}
pub fn custom<F>(f: F) -> Self
where
F: Fn(&mut HeaderMap) -> Result<(), Error> + Send + Sync + 'static,
{
Self::Custom {
applier: Arc::new(f),
}
}
}
#[derive(Clone)]
pub struct AuthLayer {
auth: Arc<AuthMethod>,
}
impl AuthLayer {
pub fn new(auth: AuthMethod) -> Self {
Self {
auth: Arc::new(auth),
}
}
}
impl<S> Layer<S> for AuthLayer {
type Service = AuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
AuthService {
inner,
auth: self.auth.clone(),
}
}
}
#[derive(Clone)]
pub struct AuthService<S> {
inner: S,
auth: Arc<AuthMethod>,
}
type BoxFut<T> = std::pin::Pin<Box<dyn Future<Output = T> + Send>>;
impl<S, ResBody> Service<Request<Body>> for AuthService<S>
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Error: Into<crate::error::BoxError> + Send,
S::Future: Send + 'static,
ResBody: Send + 'static,
{
type Response = Response<ResBody>;
type Error = crate::error::BoxError;
type Future = BoxFut<Result<Response<ResBody>, crate::error::BoxError>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if let Err(e) = apply_auth(&self.auth, &mut req) {
trace!("Auth failed: {}", e);
return Box::pin(async move { Err(e.into()) });
}
trace!("Auth applied to {} {}", req.method(), req.uri());
let mut inner = self.inner.clone();
Box::pin(async move { inner.call(req).await.map_err(Into::into) })
}
}
fn apply_auth(auth: &AuthMethod, req: &mut Request<Body>) -> Result<(), Error> {
match auth {
AuthMethod::Bearer { token } => {
req.headers_mut().insert(
http::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| {
Error::builder(Box::<dyn std::error::Error + Send + Sync>::from(e))
})?,
);
}
AuthMethod::BearerProvider { .. } => {
return Err(Error::builder(
Box::<dyn std::error::Error + Send + Sync>::from(
"BearerProvider auth requires async token refresh. \
Use hooks or implement a cached token provider with ArcSwap.",
),
));
}
AuthMethod::ApiKey {
header_name,
header_value,
} => {
req.headers_mut().insert(
header_name.clone(),
HeaderValue::from_str(header_value).map_err(|e| {
Error::builder(Box::<dyn std::error::Error + Send + Sync>::from(e))
})?,
);
}
AuthMethod::Custom { applier } => {
applier(req.headers_mut())?;
}
}
Ok(())
}
pub struct CachedTokenProvider<F, Fut>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<Token, Error>> + Send,
{
refresh_fn: F,
token: Arc<tokio::sync::RwLock<Option<Token>>>,
}
impl<F, Fut> CachedTokenProvider<F, Fut>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Token, Error>> + Send + 'static,
{
pub fn new(refresh_fn: F) -> Self {
Self {
refresh_fn,
token: Arc::new(tokio::sync::RwLock::new(None)),
}
}
}
#[async_trait::async_trait]
impl<F, Fut> BearerTokenProvider for CachedTokenProvider<F, Fut>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<Token, Error>> + Send,
{
async fn get_token(&self) -> Result<Token, Error> {
{
let cached = self.token.read().await;
if let Some(ref token) = *cached
&& !token.is_expired()
{
return Ok(token.clone());
}
}
let new_token = (self.refresh_fn)().await?;
{
let mut write = self.token.write().await;
*write = Some(new_token.clone());
}
Ok(new_token)
}
}
pub fn cached_token_provider<F, Fut>(refresh_fn: F) -> CachedTokenProvider<F, Fut>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Token, Error>> + Send + 'static,
{
CachedTokenProvider::new(refresh_fn)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_permanent() {
let token = Token::permanent("test");
assert_eq!(token.value, "test");
assert!(!token.is_expired());
assert!(token.expires_at.is_none());
}
#[test]
fn test_token_with_ttl() {
let token = Token::with_ttl("test", Duration::from_secs(3600));
assert_eq!(token.value, "test");
assert!(!token.is_expired());
assert!(token.expires_at.is_some());
}
#[test]
fn test_token_expired() {
let token = Token {
value: "test".to_string(),
expires_at: Some(Instant::now() - Duration::from_secs(1)),
};
assert!(token.is_expired());
}
#[test]
fn test_auth_method_bearer() {
let auth = AuthMethod::bearer("my-token");
assert!(matches!(auth, AuthMethod::Bearer { .. }));
}
#[test]
fn test_auth_method_api_key() {
let auth = AuthMethod::api_key("X-API-KEY", "key123");
assert!(matches!(auth, AuthMethod::ApiKey { .. }));
}
}