at-jet 0.7.2

High-performance HTTP + Protobuf API framework for mobile services
Documentation
//! Basic HTTP Authentication middleware
//!
//! Provides a simple username/password authentication layer for protecting admin endpoints.
//!
//! # Example
//!
//! ```rust,ignore
//! use at_jet::middleware::BasicAuthLayer;
//!
//! let app = Router::new()
//!     .route("/admin", get(admin_handler))
//!     .route("/health", get(health_handler))
//!     .layer(BasicAuthLayer::new("admin", "secret").exclude("/health"));
//! ```

use {axum::{body::Body,
            http::{Request,
                   StatusCode},
            middleware::Next,
            response::{IntoResponse,
                       Response}},
     std::{collections::HashSet,
           sync::Arc},
     tracing::warn};

/// Basic Authentication configuration
#[derive(Clone)]
pub struct BasicAuthConfig {
  username:       String,
  password:       String,
  realm:          String,
  excluded_paths: HashSet<String>,
}

impl BasicAuthConfig {
  /// Create a new Basic Auth configuration
  pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
    Self {
      username:       username.into(),
      password:       password.into(),
      realm:          "Restricted".to_string(),
      excluded_paths: HashSet::new(),
    }
  }

  /// Set the authentication realm (shown in browser prompt)
  pub fn realm(mut self, realm: impl Into<String>) -> Self {
    self.realm = realm.into();
    self
  }

  /// Exclude a path from authentication (e.g., health check endpoints)
  pub fn exclude(mut self, path: impl Into<String>) -> Self {
    self.excluded_paths.insert(path.into());
    self
  }

  /// Exclude multiple paths from authentication
  pub fn exclude_paths(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
    for path in paths {
      self.excluded_paths.insert(path.into());
    }
    self
  }

  /// Check if a path is excluded from authentication
  fn is_excluded(&self, path: &str) -> bool {
    self.excluded_paths.contains(path)
  }

  /// Validate credentials
  fn validate(&self, username: &str, password: &str) -> bool {
    self.username == username && self.password == password
  }
}

/// Basic Authentication Layer
///
/// A Tower layer that adds HTTP Basic Authentication to routes.
#[derive(Clone)]
pub struct BasicAuthLayer {
  config: Arc<BasicAuthConfig>,
}

impl BasicAuthLayer {
  /// Create a new Basic Auth layer with username and password
  ///
  /// # Example
  ///
  /// ```rust,ignore
  /// let layer = BasicAuthLayer::new("admin", "password");
  /// ```
  pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
    Self {
      config: Arc::new(BasicAuthConfig::new(username, password)),
    }
  }

  /// Create from an existing configuration
  pub fn from_config(config: BasicAuthConfig) -> Self {
    Self {
      config: Arc::new(config),
    }
  }

  /// Set the authentication realm
  pub fn realm(mut self, realm: impl Into<String>) -> Self {
    let config = Arc::make_mut(&mut self.config);
    config.realm = realm.into();
    self
  }

  /// Exclude a path from authentication
  pub fn exclude(mut self, path: impl Into<String>) -> Self {
    let config = Arc::make_mut(&mut self.config);
    config.excluded_paths.insert(path.into());
    self
  }

  /// Exclude multiple paths from authentication
  pub fn exclude_paths(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
    let config = Arc::make_mut(&mut self.config);
    for path in paths {
      config.excluded_paths.insert(path.into());
    }
    self
  }

  /// Get the middleware function for use with axum
  pub fn into_middleware(
    self,
  ) -> impl Fn(Request<Body>, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
  + Clone
  + Send
  + 'static {
    let config = self.config;
    move |request: Request<Body>, next: Next| {
      let config = config.clone();
      Box::pin(async move { basic_auth_check(request, next, &config).await })
    }
  }
}

impl<S> tower::Layer<S> for BasicAuthLayer {
  type Service = BasicAuthMiddleware<S>;

  fn layer(&self, inner: S) -> Self::Service {
    BasicAuthMiddleware {
      inner,
      config: self.config.clone(),
    }
  }
}

/// Basic Authentication middleware service
#[derive(Clone)]
pub struct BasicAuthMiddleware<S> {
  inner:  S,
  config: Arc<BasicAuthConfig>,
}

impl<S> tower::Service<Request<Body>> for BasicAuthMiddleware<S>
where
  S: tower::Service<Request<Body>, Response = Response> + Clone + Send + 'static,
  S::Future: Send,
{
  type Error = S::Error;
  type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
  type Response = S::Response;

  fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
    self.inner.poll_ready(cx)
  }

  fn call(&mut self, request: Request<Body>) -> Self::Future {
    let config = self.config.clone();
    let mut inner = self.inner.clone();

    Box::pin(async move {
      // Check if path is excluded
      if config.is_excluded(request.uri().path()) {
        return inner.call(request).await;
      }

      // Validate credentials
      if let Some(auth_result) = validate_basic_auth(&request, &config) {
        if auth_result {
          return inner.call(request).await;
        }
      }

      // Return 401 Unauthorized
      Ok(unauthorized_response(&config.realm))
    })
  }
}

/// Validate Basic Auth credentials from request
fn validate_basic_auth<B>(request: &Request<B>, config: &BasicAuthConfig) -> Option<bool> {
  let auth_header = request.headers().get("Authorization").and_then(|h| h.to_str().ok())?;

  if !auth_header.starts_with("Basic ") {
    return Some(false);
  }

  let encoded = &auth_header[6 ..];
  let decoded = data_encoding::BASE64.decode(encoded.as_bytes()).ok()?;
  let credentials = String::from_utf8(decoded).ok()?;
  let (username, password) = credentials.split_once(':')?;

  Some(config.validate(username, password))
}

/// Create an unauthorized response with WWW-Authenticate header
fn unauthorized_response(realm: &str) -> Response {
  (
    StatusCode::UNAUTHORIZED,
    [("WWW-Authenticate", format!("Basic realm=\"{}\"", realm))],
    "Unauthorized",
  )
    .into_response()
}

/// Middleware function for use with `axum::middleware::from_fn_with_state`
///
/// # Example
///
/// ```rust,ignore
/// use at_jet::middleware::{BasicAuthConfig, basic_auth_middleware};
///
/// let config = BasicAuthConfig::new("admin", "password")
///     .realm("Admin Area")
///     .exclude("/health");
///
/// let app = Router::new()
///     .route("/admin", get(handler))
///     .layer(axum::middleware::from_fn_with_state(
///         Arc::new(config),
///         basic_auth_middleware,
///     ));
/// ```
pub async fn basic_auth_middleware(
  axum::extract::State(config): axum::extract::State<Arc<BasicAuthConfig>>,
  request: Request<Body>,
  next: Next,
) -> Response {
  basic_auth_check(request, next, &config).await
}

/// Internal auth check function
async fn basic_auth_check(request: Request<Body>, next: Next, config: &BasicAuthConfig) -> Response {
  // Check if path is excluded
  if config.is_excluded(request.uri().path()) {
    return next.run(request).await;
  }

  // Validate credentials
  match validate_basic_auth(&request, config) {
    | Some(true) => next.run(request).await,
    | _ => {
      warn!(path = %request.uri().path(), "Unauthorized access attempt");
      unauthorized_response(&config.realm)
    }
  }
}