snap_control/server/
auth.rs1use std::{
17 fmt::Display,
18 future::Future,
19 pin::Pin,
20 task::{Context, Poll},
21 time::SystemTime,
22};
23
24use axum::body::Body;
25use http::{Request, Response};
26use jsonwebtoken::DecodingKey;
27use scion_sdk_token_validator::validator::{TokenValidator, Validator};
28use snap_tokens::snap_token::SnapTokenClaims;
29use thiserror::Error;
30use tower::{BoxError, Layer, Service};
31
32#[derive(Clone)]
33pub(crate) struct AuthMiddlewareLayer {
34 validator: Validator<SnapTokenClaims>,
35}
36
37impl AuthMiddlewareLayer {
38 pub(crate) fn new(dec: DecodingKey) -> Self {
39 Self {
40 validator: Validator::new(dec, Some(&["snap"])),
41 }
42 }
43}
44
45impl<S> Layer<S> for AuthMiddlewareLayer {
46 type Service = AuthMiddleware<S>;
47
48 fn layer(&self, inner: S) -> Self::Service {
49 AuthMiddleware::new(inner, self.validator.clone())
50 }
51}
52
53#[derive(Clone)]
54pub(crate) struct AuthMiddleware<S> {
55 inner: S,
56 validator: Validator<SnapTokenClaims>,
57}
58
59impl<S> AuthMiddleware<S> {
60 pub(crate) fn new(inner: S, validator: Validator<SnapTokenClaims>) -> Self {
61 Self { inner, validator }
62 }
63}
64
65impl<S> Service<Request<Body>> for AuthMiddleware<S>
66where
67 S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
68 S::Error: Into<BoxError>,
69 S::Future: Send + 'static,
70{
71 type Response = Response<Body>;
72 type Error = BoxError;
73 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
74
75 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76 self.inner.poll_ready(cx).map_err(Into::into)
77 }
78
79 fn call(&mut self, mut request: Request<Body>) -> Self::Future {
80 let token = match extract_bearer_token(&request) {
81 Ok(token) => token,
82 Err(err) => {
83 tracing::debug!(%err, "Extract bearer token");
84 return Box::pin(async { Ok(build_unauthorized_response(err)) });
85 }
86 };
87
88 match self.validator.validate(SystemTime::now(), token.as_str()) {
89 Ok(token_claims) => {
90 request.extensions_mut().insert(token_claims);
91 let mut inner = self.inner.clone();
92 Box::pin(async move { inner.call(request).await.map_err(Into::into) })
93 }
94 Err(err) => {
95 tracing::debug!(%err, "Invalid Token");
96 Box::pin(async { Ok(build_unauthorized_response(err)) })
97 }
98 }
99 }
100}
101
102fn build_unauthorized_response<E: Display>(err: E) -> Response<Body> {
103 Response::builder()
104 .status(http::StatusCode::UNAUTHORIZED)
105 .body(Body::from(format!("SNAP Token validation failed: {err}")))
106 .expect("no fail")
107}
108
109pub fn extract_bearer_token(req: &Request<Body>) -> Result<String, ExtractBearerTokenError> {
111 let auth_header = match req.headers().get("authorization") {
112 Some(header) => header,
113 None => return Err(ExtractBearerTokenError::AuthHeaderMissing),
114 };
115
116 let auth_str = match auth_header.to_str() {
117 Ok(str) => str,
118 Err(_) => return Err(ExtractBearerTokenError::AuthHeaderInvalidUtf8),
119 };
120
121 match auth_str.strip_prefix("Bearer ") {
122 Some(token) => Ok(token.to_string()),
123 None => Err(ExtractBearerTokenError::AuthHeaderNotBearer),
124 }
125}
126
127#[derive(Debug, Error)]
129pub enum ExtractBearerTokenError {
130 #[error("authorization header is missing")]
132 AuthHeaderMissing,
133 #[error("authorization header is not valid UTF-8")]
135 AuthHeaderInvalidUtf8,
136 #[error("authorization header is not a bearer token")]
138 AuthHeaderNotBearer,
139}