Skip to main content

snap_control/server/
auth.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SNAP control plane API authentication middleware.
15
16use std::{
17    fmt::Display,
18    future::Future,
19    pin::Pin,
20    sync::Arc,
21    task::{Context, Poll},
22};
23
24use axum::body::Body;
25use http::{Request, Response};
26use thiserror::Error;
27use tower::{BoxError, Layer, Service};
28
29use crate::server::token_verifier::SnapTokenVerifier;
30
31#[derive(Clone)]
32pub(crate) struct AuthMiddlewareLayer {
33    verifier: Arc<SnapTokenVerifier>,
34}
35
36impl AuthMiddlewareLayer {
37    pub(crate) fn new(verifier: SnapTokenVerifier) -> Self {
38        Self {
39            verifier: Arc::new(verifier),
40        }
41    }
42}
43
44impl<S> Layer<S> for AuthMiddlewareLayer {
45    type Service = AuthMiddleware<S>;
46
47    fn layer(&self, inner: S) -> Self::Service {
48        AuthMiddleware::new(inner, self.verifier.clone())
49    }
50}
51
52#[derive(Clone)]
53pub(crate) struct AuthMiddleware<S> {
54    inner: S,
55    verifier: Arc<SnapTokenVerifier>,
56}
57
58impl<S> AuthMiddleware<S> {
59    pub(crate) fn new(inner: S, verifier: Arc<SnapTokenVerifier>) -> Self {
60        Self { inner, verifier }
61    }
62}
63
64impl<S> Service<Request<Body>> for AuthMiddleware<S>
65where
66    S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
67    S::Error: Into<BoxError>,
68    S::Future: Send + 'static,
69{
70    type Response = Response<Body>;
71    type Error = BoxError;
72    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
73
74    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75        self.inner.poll_ready(cx).map_err(Into::into)
76    }
77
78    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
79        let token = match extract_bearer_token(&request) {
80            Ok(token) => token,
81            Err(err) => {
82                tracing::debug!(%err, "Extract bearer token");
83                return Box::pin(async { Ok(build_unauthorized_response(err)) });
84            }
85        };
86
87        let verifier = self.verifier.clone();
88        let mut inner = self.inner.clone();
89        Box::pin(async move {
90            match verifier.verify(&token).await {
91                Ok(token_claims) => {
92                    request.extensions_mut().insert(token_claims);
93                    inner.call(request).await.map_err(Into::into)
94                }
95                Err(err) => {
96                    tracing::debug!(%err, "Invalid Token");
97                    Ok(build_unauthorized_response(err))
98                }
99            }
100        })
101    }
102}
103
104fn build_unauthorized_response<E: Display>(err: E) -> Response<Body> {
105    Response::builder()
106        .status(http::StatusCode::UNAUTHORIZED)
107        .body(Body::from(format!("SNAP Token validation failed: {err}")))
108        .expect("no fail")
109}
110
111/// Extracts the bearer token from the `Authorization` header of the request.
112pub fn extract_bearer_token(req: &Request<Body>) -> Result<String, ExtractBearerTokenError> {
113    let auth_header = match req.headers().get("authorization") {
114        Some(header) => header,
115        None => return Err(ExtractBearerTokenError::AuthHeaderMissing),
116    };
117
118    let auth_str = match auth_header.to_str() {
119        Ok(str) => str,
120        Err(_) => return Err(ExtractBearerTokenError::AuthHeaderInvalidUtf8),
121    };
122
123    match auth_str.strip_prefix("Bearer ") {
124        Some(token) => Ok(token.to_string()),
125        None => Err(ExtractBearerTokenError::AuthHeaderNotBearer),
126    }
127}
128
129/// Bearer token extraction error.
130#[derive(Debug, Error)]
131pub enum ExtractBearerTokenError {
132    /// Authorization header is missing.
133    #[error("authorization header is missing")]
134    AuthHeaderMissing,
135    /// Authorization header is not valid UTF-8.
136    #[error("authorization header is not valid UTF-8")]
137    AuthHeaderInvalidUtf8,
138    /// Authorization header is not a Bearer token.
139    #[error("authorization header is not a bearer token")]
140    AuthHeaderNotBearer,
141}