snap_control/server/
auth.rs1use std::{
17 fmt::Display,
18 future::Future,
19 pin::Pin,
20 task::{Context, Poll},
21 time::SystemTime,
22};
23
24use axum::body::Body;
25use http::{Request, Response};
26use jsonwebtoken::DecodingKey;
27use scion_sdk_token_validator::validator::{TokenValidator, Validator};
28use snap_tokens::AnyClaims;
29use thiserror::Error;
30use tower::{BoxError, Layer, Service};
31
32#[derive(Clone)]
33pub(crate) struct AuthMiddlewareLayer {
34 validator: Validator<AnyClaims>,
35}
36
37impl AuthMiddlewareLayer {
38 pub(crate) fn new(dec: DecodingKey) -> Self {
39 Self {
40 validator: Validator::new(dec, Some(&["snap"])),
41 }
42 }
43}
44
45impl<S> Layer<S> for AuthMiddlewareLayer {
46 type Service = AuthMiddleware<S>;
47
48 fn layer(&self, inner: S) -> Self::Service {
49 AuthMiddleware::new(inner, self.validator.clone())
50 }
51}
52
53#[derive(Clone)]
54pub(crate) struct AuthMiddleware<S> {
55 inner: S,
56 validator: Validator<AnyClaims>,
57}
58
59impl<S> AuthMiddleware<S> {
60 pub(crate) fn new(inner: S, validator: Validator<AnyClaims>) -> Self {
61 Self { inner, validator }
62 }
63}
64
65impl<S> Service<Request<Body>> for AuthMiddleware<S>
66where
67 S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
68 S::Error: Into<BoxError>,
69 S::Future: Send + 'static,
70{
71 type Response = Response<Body>;
72 type Error = BoxError;
73 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
74
75 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76 self.inner.poll_ready(cx).map_err(Into::into)
77 }
78
79 fn call(&mut self, mut request: Request<Body>) -> Self::Future {
80 let token = match extract_bearer_token(&request) {
81 Ok(token) => token,
82 Err(err) => {
83 tracing::debug!(%err, "Extract bearer token");
84 return Box::pin(async { Ok(build_unauthorized_response(err)) });
85 }
86 };
87
88 match self.validator.validate(SystemTime::now(), token.as_str()) {
89 Ok(token_claims) => {
90 match token_claims {
91 AnyClaims::V0(claims) => {
92 request.extensions_mut().insert(claims);
93 }
94 AnyClaims::V1(claims) => {
95 request.extensions_mut().insert(claims);
96 }
97 }
98 let mut inner = self.inner.clone();
99 Box::pin(async move { inner.call(request).await.map_err(Into::into) })
100 }
101 Err(err) => {
102 tracing::debug!(%err, "Invalid Token");
103 Box::pin(async { Ok(build_unauthorized_response(err)) })
104 }
105 }
106 }
107}
108
109fn build_unauthorized_response<E: Display>(err: E) -> Response<Body> {
110 Response::builder()
111 .status(http::StatusCode::UNAUTHORIZED)
112 .body(Body::from(format!("SNAP Token validation failed: {err}")))
113 .expect("no fail")
114}
115
116pub fn extract_bearer_token(req: &Request<Body>) -> Result<String, ExtractBearerTokenError> {
118 let auth_header = match req.headers().get("authorization") {
119 Some(header) => header,
120 None => return Err(ExtractBearerTokenError::AuthHeaderMissing),
121 };
122
123 let auth_str = match auth_header.to_str() {
124 Ok(str) => str,
125 Err(_) => return Err(ExtractBearerTokenError::AuthHeaderInvalidUtf8),
126 };
127
128 match auth_str.strip_prefix("Bearer ") {
129 Some(token) => Ok(token.to_string()),
130 None => Err(ExtractBearerTokenError::AuthHeaderNotBearer),
131 }
132}
133
134#[derive(Debug, Error)]
136pub enum ExtractBearerTokenError {
137 #[error("authorization header is missing")]
139 AuthHeaderMissing,
140 #[error("authorization header is not valid UTF-8")]
142 AuthHeaderInvalidUtf8,
143 #[error("authorization header is not a bearer token")]
145 AuthHeaderNotBearer,
146}