tako-rs-plugins 2.0.0

Internal plugin and concrete-middleware implementations for tako-rs. Use the `tako-rs` umbrella crate instead.
Documentation
//! API Key authentication middleware for simple token-based access control.
//!
//! This module provides middleware for validating API keys from HTTP headers or query
//! parameters. It supports multiple key sources, custom header names, and dynamic
//! key verification functions for flexible authentication strategies.
//!
//! # Examples
//!
//! ```rust,ignore
//! use tako::middleware::api_key_auth::{ApiKeyAuth, ApiKeyLocation};
//! use tako::middleware::IntoMiddleware;
//!
//! // Single API key from header
//! let auth = ApiKeyAuth::new("secret-api-key");
//! let middleware = auth.into_middleware();
//!
//! // Multiple valid keys
//! let multi_auth = ApiKeyAuth::from_keys(["key1", "key2", "admin-key"]);
//!
//! // Custom header name
//! let custom_auth = ApiKeyAuth::new("secret")
//!     .header_name("X-Custom-Key");
//!
//! // From query parameter
//! let query_auth = ApiKeyAuth::new("secret")
//!     .location(ApiKeyLocation::Query("api_key"));
//!
//! // Dynamic verification
//! let dynamic_auth = ApiKeyAuth::with_verify(|key| {
//!     key.starts_with("valid_")
//! });
//! ```

use std::borrow::Cow;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use http::HeaderValue;
use http::StatusCode;
use http::header;
use subtle::Choice;
use subtle::ConstantTimeEq;
use tako_rs_core::body::TakoBody;
use tako_rs_core::middleware::IntoMiddleware;
use tako_rs_core::middleware::Next;
use tako_rs_core::responder::Responder;
use tako_rs_core::types::Request;
use tako_rs_core::types::Response;

/// Constant-time match against a list of candidate keys.
///
/// Iterates the full list every call; per-byte comparison uses `subtle::ConstantTimeEq`
/// so equal-length matches do not leak via wall-clock. Length mismatches still return
/// faster than equal-length compares — clients learn key length but not contents.
fn constant_time_contains(input: &[u8], candidates: &[Vec<u8>]) -> bool {
  let mut found = Choice::from(0u8);
  for candidate in candidates {
    found |= input.ct_eq(candidate.as_slice());
  }
  bool::from(found)
}

/// Location where the API key should be extracted from.
#[derive(Clone)]
pub enum ApiKeyLocation {
  /// Extract from HTTP header with the given name.
  Header(&'static str),
  /// Extract from query parameter with the given name.
  Query(&'static str),
  /// Try header first, then query parameter.
  HeaderOrQuery(&'static str, &'static str),
}

impl Default for ApiKeyLocation {
  fn default() -> Self {
    Self::Header("X-API-Key")
  }
}

/// API Key authentication middleware configuration.
///
/// `ApiKeyAuth` provides flexible configuration for API key authentication,
/// supporting static keys, dynamic verification, and multiple extraction locations.
///
/// # Examples
///
/// ```rust
/// use tako::middleware::api_key_auth::{ApiKeyAuth, ApiKeyLocation};
///
/// // Simple static key
/// let auth = ApiKeyAuth::new("my-secret-key");
///
/// // Multiple keys with custom location
/// let auth = ApiKeyAuth::from_keys(["key1", "key2"])
///     .location(ApiKeyLocation::Query("apikey"));
///
/// // Dynamic verification
/// let auth = ApiKeyAuth::with_verify(|key| {
///     // Lookup in database, validate format, etc.
///     key.len() == 32 && key.chars().all(|c| c.is_ascii_hexdigit())
/// });
/// ```
/// Custom verification closure for [`ApiKeyAuth`].
pub type ApiKeyVerifyFn = Arc<dyn Fn(&str) -> bool + Send + Sync + 'static>;

pub struct ApiKeyAuth {
  /// Static API keys (raw bytes, scanned in constant time).
  keys: Option<Vec<Vec<u8>>>,
  /// Custom verification function for dynamic key validation.
  verify: Option<ApiKeyVerifyFn>,
  /// Location to extract the API key from.
  location: ApiKeyLocation,
}

impl ApiKeyAuth {
  /// Creates authentication middleware with a single static API key.
  ///
  /// By default, the key is extracted from the `X-API-Key` header.
  pub fn new(key: impl Into<String>) -> Self {
    let key: String = key.into();
    Self {
      keys: Some(vec![key.into_bytes()]),
      verify: None,
      location: ApiKeyLocation::default(),
    }
  }

  /// Creates authentication middleware with multiple static API keys.
  pub fn from_keys<I>(keys: I) -> Self
  where
    I: IntoIterator,
    I::Item: Into<String>,
  {
    Self {
      keys: Some(
        keys
          .into_iter()
          .map(|k| Into::<String>::into(k).into_bytes())
          .collect(),
      ),
      verify: None,
      location: ApiKeyLocation::default(),
    }
  }

  /// Creates authentication middleware with a custom verification function.
  pub fn with_verify<F>(f: F) -> Self
  where
    F: Fn(&str) -> bool + Send + Sync + 'static,
  {
    Self {
      keys: None,
      verify: Some(Arc::new(f)),
      location: ApiKeyLocation::default(),
    }
  }

  /// Creates authentication with both static keys and custom verification.
  pub fn from_keys_with_verify<I, F>(keys: I, f: F) -> Self
  where
    I: IntoIterator,
    I::Item: Into<String>,
    F: Fn(&str) -> bool + Send + Sync + 'static,
  {
    Self {
      keys: Some(
        keys
          .into_iter()
          .map(|k| Into::<String>::into(k).into_bytes())
          .collect(),
      ),
      verify: Some(Arc::new(f)),
      location: ApiKeyLocation::default(),
    }
  }

  /// Sets the location where the API key should be extracted from.
  pub fn location(mut self, location: ApiKeyLocation) -> Self {
    self.location = location;
    self
  }

  /// Sets a custom header name for API key extraction.
  ///
  /// This is a convenience method equivalent to
  /// `.location(ApiKeyLocation::Header(name))`.
  pub fn header_name(mut self, name: &'static str) -> Self {
    self.location = ApiKeyLocation::Header(name);
    self
  }

  /// Sets a query parameter name for API key extraction.
  ///
  /// This is a convenience method equivalent to
  /// `.location(ApiKeyLocation::Query(name))`.
  pub fn query_param(mut self, name: &'static str) -> Self {
    self.location = ApiKeyLocation::Query(name);
    self
  }
}

/// Extracts API key from request based on location configuration.
fn extract_api_key<'a>(req: &'a Request, location: &ApiKeyLocation) -> Option<Cow<'a, str>> {
  match location {
    ApiKeyLocation::Header(name) => req
      .headers()
      .get(*name)
      .and_then(|v| v.to_str().ok())
      .map(|s| Cow::Borrowed(s.trim())),

    ApiKeyLocation::Query(name) => req.uri().query().and_then(|q| {
      url::form_urlencoded::parse(q.as_bytes())
        .find(|(k, _)| k == *name)
        .map(|(_, v)| v)
    }),

    ApiKeyLocation::HeaderOrQuery(header, query) => {
      // Try header first
      if let Some(key) = req
        .headers()
        .get(*header)
        .and_then(|v| v.to_str().ok())
        .map(|s| Cow::Borrowed(s.trim()))
      {
        return Some(key);
      }
      // Fall back to query parameter
      req.uri().query().and_then(|q| {
        url::form_urlencoded::parse(q.as_bytes())
          .find(|(k, _)| k == *query)
          .map(|(_, v)| v)
      })
    }
  }
}

impl IntoMiddleware for ApiKeyAuth {
  /// Converts the API key authentication configuration into middleware.
  fn into_middleware(
    self,
  ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
  + Clone
  + Send
  + Sync
  + 'static {
    let keys = self.keys.map(Arc::new);
    let verify = self.verify;
    let location = self.location;
    let api_key_authenticate = HeaderValue::from_static("ApiKey");

    move |req: Request, next: Next| {
      let keys = keys.clone();
      let verify = verify.clone();
      let location = location.clone();
      let api_key_authenticate = api_key_authenticate.clone();

      Box::pin(async move {
        // Extract API key from configured location
        let Some(api_key) = extract_api_key(&req, &location) else {
          return http::Response::builder()
            .status(StatusCode::UNAUTHORIZED)
            .header(header::WWW_AUTHENTICATE, api_key_authenticate.clone())
            .body(TakoBody::from("API key is missing"))
            .unwrap()
            .into_response();
        };

        // Validate against static keys (constant-time scan)
        if let Some(set) = &keys
          && constant_time_contains(api_key.as_bytes(), set)
        {
          return next.run(req).await.into_response();
        }

        // Validate using custom verification function
        if let Some(v) = verify.as_ref()
          && v(api_key.as_ref())
        {
          return next.run(req).await.into_response();
        }

        // Return 401 Unauthorized for invalid keys
        http::Response::builder()
          .status(StatusCode::UNAUTHORIZED)
          .header(header::WWW_AUTHENTICATE, api_key_authenticate)
          .body(TakoBody::from("Invalid API key"))
          .unwrap()
          .into_response()
      })
    }
  }
}