firebase_verifyid/
middleware.rs

1use super::{FirebaseClaims, TokenVerifier, ALG, TOKEN_SIG_TYPE};
2use axum::{
3    extract::Request,
4    http::{header, StatusCode},
5    response::{IntoResponse, Response},
6    Json,
7};
8use futures_util::future::BoxFuture;
9use jwt_simple::{
10    claims::{JWTClaims, NoCustomClaims},
11    prelude::EdDSAPublicKeyLike,
12    token::Token,
13};
14use std::task::{Context, Poll};
15use tower::{Layer, Service};
16
17#[derive(Clone)]
18pub struct FirebaseAuthLayer {
19    verifier: TokenVerifier,
20}
21
22impl FirebaseAuthLayer {
23    pub fn new(verifier: TokenVerifier) -> Self {
24        Self { verifier }
25    }
26}
27
28impl<S> Layer<S> for FirebaseAuthLayer {
29    type Service = FirebaseAuthService<S>;
30
31    fn layer(&self, inner: S) -> Self::Service {
32        FirebaseAuthService {
33            inner,
34            verifier: self.verifier.clone(),
35        }
36    }
37}
38
39#[derive(Clone)]
40pub struct FirebaseAuthService<S> {
41    inner: S,
42    verifier: TokenVerifier,
43}
44
45impl<S> FirebaseAuthService<S> {
46    fn token_auth(&self, req: &mut Request) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
47        let token = req
48            .headers()
49            .get(header::AUTHORIZATION)
50            .and_then(|auth_header| auth_header.to_str().ok())
51            .and_then(|auth_value| auth_value.strip_prefix("Bearer "))
52            .ok_or_else(|| {
53                metrics::counter!("firebase-token-auth-rejected", "reason" => "missing-token")
54                    .increment(1);
55                tracing::debug!("request missing required firebase auth token");
56                error_response(StatusCode::FORBIDDEN)
57            })?;
58
59        if let Some(ref bearer_verifier) = self.verifier.bearer_verifier {
60            if let Ok(claims) = bearer_verifier.verify_token::<NoCustomClaims>(token, None) {
61                let sub = if let Some(ref bearer) = claims.subject {
62                    bearer
63                } else {
64                    "unknown"
65                };
66                tracing::info!(subject = %sub, "bearer request authorized");
67                req.extensions_mut().insert(claims);
68                return Ok(());
69            }
70        }
71
72        let metadata = Token::decode_metadata(token).map_err(|_| {
73            metrics::counter!("firebase-token-auth-rejected", "reason" => "missing-metadata")
74                .increment(1);
75            tracing::debug!(token, "token missing metadata");
76            error_response(StatusCode::UNAUTHORIZED)
77        })?;
78
79        // Check token header `alg` and `typ` field match the expected values
80        if metadata.algorithm() != ALG || metadata.signature_type() != Some(TOKEN_SIG_TYPE) {
81            metrics::counter!("firebase-token-auth-rejected", "reason" => "invalid-algorithm")
82                .increment(1);
83            tracing::debug!(
84                alg = metadata.algorithm(),
85                typ = metadata.signature_type(),
86                "invalid token metadata headers",
87            );
88            return Err(error_response(StatusCode::UNAUTHORIZED));
89        }
90
91        let Some(key_id) = metadata.key_id() else {
92            metrics::counter!("firebase-token-auth-rejected", "reason" => "missing-kid")
93                .increment(1);
94            tracing::debug!("token missing kid metadata header");
95            return Err(error_response(StatusCode::UNAUTHORIZED));
96        };
97
98        // Validates the token signature and that the expiry (+tolerance) is within the limit
99        // automatically. Also incorporates validation of issuer and audience (firebase project id)
100        let claims: JWTClaims<FirebaseClaims> =
101            self.verifier.verify_token(key_id, token).map_err(|_| {
102                metrics::counter!("firebase-token-auth-rejected", "reason" => "invalid-token")
103                    .increment(1);
104                tracing::debug!(token, key_id, "invalid firebase auth id token");
105                error_response(StatusCode::UNAUTHORIZED)
106            })?;
107
108        metrics::counter!("firebase-request-authorized").increment(1);
109        req.extensions_mut().insert(claims);
110
111        Ok(())
112    }
113}
114
115impl<S> Service<Request> for FirebaseAuthService<S>
116where
117    S: Service<Request, Response = Response> + Clone + Send + 'static,
118    S::Future: Send + 'static,
119{
120    type Response = S::Response;
121    type Error = S::Error;
122    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
123
124    #[inline]
125    fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126        self.inner.poll_ready(ctx)
127    }
128
129    fn call(&mut self, mut req: Request) -> Self::Future {
130        let not_ready_inner = self.inner.clone();
131        let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
132        let auth_result = self.token_auth(&mut req);
133
134        Box::pin(async move {
135            match auth_result {
136                Ok(_) => ready_inner.call(req).await,
137                Err(err) => Ok(err.into_response()),
138            }
139        })
140    }
141}
142
143fn error_response(status_code: StatusCode) -> (StatusCode, Json<serde_json::Value>) {
144    let err_resp = serde_json::json!({
145        "status": "error",
146        "message": "request not authorized",
147    });
148    (status_code, Json(err_resp))
149}