at_jet/middleware/
basic_auth.rs

1//! Basic HTTP Authentication middleware
2//!
3//! Provides a simple username/password authentication layer for protecting admin endpoints.
4//!
5//! # Example
6//!
7//! ```rust,ignore
8//! use at_jet::middleware::BasicAuthLayer;
9//!
10//! let app = Router::new()
11//!     .route("/admin", get(admin_handler))
12//!     .route("/health", get(health_handler))
13//!     .layer(BasicAuthLayer::new("admin", "secret").exclude("/health"));
14//! ```
15
16use {axum::{body::Body,
17            http::{Request,
18                   StatusCode},
19            middleware::Next,
20            response::{IntoResponse,
21                       Response}},
22     std::{collections::HashSet,
23           sync::Arc},
24     tracing::warn};
25
26/// Basic Authentication configuration
27#[derive(Clone)]
28pub struct BasicAuthConfig {
29  username:       String,
30  password:       String,
31  realm:          String,
32  excluded_paths: HashSet<String>,
33}
34
35impl BasicAuthConfig {
36  /// Create a new Basic Auth configuration
37  pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
38    Self {
39      username:       username.into(),
40      password:       password.into(),
41      realm:          "Restricted".to_string(),
42      excluded_paths: HashSet::new(),
43    }
44  }
45
46  /// Set the authentication realm (shown in browser prompt)
47  pub fn realm(mut self, realm: impl Into<String>) -> Self {
48    self.realm = realm.into();
49    self
50  }
51
52  /// Exclude a path from authentication (e.g., health check endpoints)
53  pub fn exclude(mut self, path: impl Into<String>) -> Self {
54    self.excluded_paths.insert(path.into());
55    self
56  }
57
58  /// Exclude multiple paths from authentication
59  pub fn exclude_paths(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
60    for path in paths {
61      self.excluded_paths.insert(path.into());
62    }
63    self
64  }
65
66  /// Check if a path is excluded from authentication
67  fn is_excluded(&self, path: &str) -> bool {
68    self.excluded_paths.contains(path)
69  }
70
71  /// Validate credentials
72  fn validate(&self, username: &str, password: &str) -> bool {
73    self.username == username && self.password == password
74  }
75}
76
77/// Basic Authentication Layer
78///
79/// A Tower layer that adds HTTP Basic Authentication to routes.
80#[derive(Clone)]
81pub struct BasicAuthLayer {
82  config: Arc<BasicAuthConfig>,
83}
84
85impl BasicAuthLayer {
86  /// Create a new Basic Auth layer with username and password
87  ///
88  /// # Example
89  ///
90  /// ```rust,ignore
91  /// let layer = BasicAuthLayer::new("admin", "password");
92  /// ```
93  pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
94    Self {
95      config: Arc::new(BasicAuthConfig::new(username, password)),
96    }
97  }
98
99  /// Create from an existing configuration
100  pub fn from_config(config: BasicAuthConfig) -> Self {
101    Self { config: Arc::new(config) }
102  }
103
104  /// Set the authentication realm
105  pub fn realm(mut self, realm: impl Into<String>) -> Self {
106    let config = Arc::make_mut(&mut self.config);
107    config.realm = realm.into();
108    self
109  }
110
111  /// Exclude a path from authentication
112  pub fn exclude(mut self, path: impl Into<String>) -> Self {
113    let config = Arc::make_mut(&mut self.config);
114    config.excluded_paths.insert(path.into());
115    self
116  }
117
118  /// Exclude multiple paths from authentication
119  pub fn exclude_paths(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
120    let config = Arc::make_mut(&mut self.config);
121    for path in paths {
122      config.excluded_paths.insert(path.into());
123    }
124    self
125  }
126
127  /// Get the middleware function for use with axum
128  pub fn into_middleware(self) -> impl Fn(Request<Body>, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> + Clone + Send + 'static {
129    let config = self.config;
130    move |request: Request<Body>, next: Next| {
131      let config = config.clone();
132      Box::pin(async move { basic_auth_check(request, next, &config).await })
133    }
134  }
135}
136
137impl<S> tower::Layer<S> for BasicAuthLayer {
138  type Service = BasicAuthMiddleware<S>;
139
140  fn layer(&self, inner: S) -> Self::Service {
141    BasicAuthMiddleware {
142      inner,
143      config: self.config.clone(),
144    }
145  }
146}
147
148/// Basic Authentication middleware service
149#[derive(Clone)]
150pub struct BasicAuthMiddleware<S> {
151  inner:  S,
152  config: Arc<BasicAuthConfig>,
153}
154
155impl<S> tower::Service<Request<Body>> for BasicAuthMiddleware<S>
156where
157  S: tower::Service<Request<Body>, Response = Response> + Clone + Send + 'static,
158  S::Future: Send,
159{
160  type Error = S::Error;
161  type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
162  type Response = S::Response;
163
164  fn poll_ready(
165    &mut self,
166    cx: &mut std::task::Context<'_>,
167  ) -> std::task::Poll<Result<(), Self::Error>> {
168    self.inner.poll_ready(cx)
169  }
170
171  fn call(&mut self, request: Request<Body>) -> Self::Future {
172    let config = self.config.clone();
173    let mut inner = self.inner.clone();
174
175    Box::pin(async move {
176      // Check if path is excluded
177      if config.is_excluded(request.uri().path()) {
178        return inner.call(request).await;
179      }
180
181      // Validate credentials
182      if let Some(auth_result) = validate_basic_auth(&request, &config) {
183        if auth_result {
184          return inner.call(request).await;
185        }
186      }
187
188      // Return 401 Unauthorized
189      Ok(unauthorized_response(&config.realm))
190    })
191  }
192}
193
194/// Validate Basic Auth credentials from request
195fn validate_basic_auth<B>(request: &Request<B>, config: &BasicAuthConfig) -> Option<bool> {
196  let auth_header = request
197    .headers()
198    .get("Authorization")
199    .and_then(|h| h.to_str().ok())?;
200
201  if !auth_header.starts_with("Basic ") {
202    return Some(false);
203  }
204
205  let encoded = &auth_header[6..];
206  let decoded = data_encoding::BASE64.decode(encoded.as_bytes()).ok()?;
207  let credentials = String::from_utf8(decoded).ok()?;
208  let (username, password) = credentials.split_once(':')?;
209
210  Some(config.validate(username, password))
211}
212
213/// Create an unauthorized response with WWW-Authenticate header
214fn unauthorized_response(realm: &str) -> Response {
215  (
216    StatusCode::UNAUTHORIZED,
217    [("WWW-Authenticate", format!("Basic realm=\"{}\"", realm))],
218    "Unauthorized",
219  )
220    .into_response()
221}
222
223/// Middleware function for use with `axum::middleware::from_fn_with_state`
224///
225/// # Example
226///
227/// ```rust,ignore
228/// use at_jet::middleware::{BasicAuthConfig, basic_auth_middleware};
229///
230/// let config = BasicAuthConfig::new("admin", "password")
231///     .realm("Admin Area")
232///     .exclude("/health");
233///
234/// let app = Router::new()
235///     .route("/admin", get(handler))
236///     .layer(axum::middleware::from_fn_with_state(
237///         Arc::new(config),
238///         basic_auth_middleware,
239///     ));
240/// ```
241pub async fn basic_auth_middleware(
242  axum::extract::State(config): axum::extract::State<Arc<BasicAuthConfig>>,
243  request: Request<Body>,
244  next: Next,
245) -> Response {
246  basic_auth_check(request, next, &config).await
247}
248
249/// Internal auth check function
250async fn basic_auth_check(request: Request<Body>, next: Next, config: &BasicAuthConfig) -> Response {
251  // Check if path is excluded
252  if config.is_excluded(request.uri().path()) {
253    return next.run(request).await;
254  }
255
256  // Validate credentials
257  match validate_basic_auth(&request, config) {
258    | Some(true) => next.run(request).await,
259    | _ => {
260      warn!(path = %request.uri().path(), "Unauthorized access attempt");
261      unauthorized_response(&config.realm)
262    }
263  }
264}