firebase_verifyid/
middleware.rs1use super::{FirebaseClaims, TokenVerifier, ALG, TOKEN_SIG_TYPE};
2use axum::{
3 extract::Request,
4 http::{header, StatusCode},
5 response::{IntoResponse, Response},
6 Json,
7};
8use futures_util::future::BoxFuture;
9use jwt_simple::{
10 claims::{JWTClaims, NoCustomClaims},
11 prelude::EdDSAPublicKeyLike,
12 token::Token,
13};
14use std::task::{Context, Poll};
15use tower::{Layer, Service};
16
17#[derive(Clone)]
18pub struct FirebaseAuthLayer {
19 verifier: TokenVerifier,
20}
21
22impl FirebaseAuthLayer {
23 pub fn new(verifier: TokenVerifier) -> Self {
24 Self { verifier }
25 }
26}
27
28impl<S> Layer<S> for FirebaseAuthLayer {
29 type Service = FirebaseAuthService<S>;
30
31 fn layer(&self, inner: S) -> Self::Service {
32 FirebaseAuthService {
33 inner,
34 verifier: self.verifier.clone(),
35 }
36 }
37}
38
39#[derive(Clone)]
40pub struct FirebaseAuthService<S> {
41 inner: S,
42 verifier: TokenVerifier,
43}
44
45impl<S> FirebaseAuthService<S> {
46 fn token_auth(&self, req: &mut Request) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
47 let token = req
48 .headers()
49 .get(header::AUTHORIZATION)
50 .and_then(|auth_header| auth_header.to_str().ok())
51 .and_then(|auth_value| auth_value.strip_prefix("Bearer "))
52 .ok_or_else(|| {
53 metrics::counter!("firebase-token-auth-rejected", "reason" => "missing-token")
54 .increment(1);
55 tracing::debug!("request missing required firebase auth token");
56 error_response(StatusCode::FORBIDDEN)
57 })?;
58
59 if let Some(ref bearer_verifier) = self.verifier.bearer_verifier {
60 if let Ok(claims) = bearer_verifier.verify_token::<NoCustomClaims>(token, None) {
61 let sub = if let Some(ref bearer) = claims.subject {
62 bearer
63 } else {
64 "unknown"
65 };
66 tracing::info!(subject = %sub, "bearer request authorized");
67 req.extensions_mut().insert(claims);
68 return Ok(());
69 }
70 }
71
72 let metadata = Token::decode_metadata(token).map_err(|_| {
73 metrics::counter!("firebase-token-auth-rejected", "reason" => "missing-metadata")
74 .increment(1);
75 tracing::debug!(token, "token missing metadata");
76 error_response(StatusCode::UNAUTHORIZED)
77 })?;
78
79 if metadata.algorithm() != ALG || metadata.signature_type() != Some(TOKEN_SIG_TYPE) {
81 metrics::counter!("firebase-token-auth-rejected", "reason" => "invalid-algorithm")
82 .increment(1);
83 tracing::debug!(
84 alg = metadata.algorithm(),
85 typ = metadata.signature_type(),
86 "invalid token metadata headers",
87 );
88 return Err(error_response(StatusCode::UNAUTHORIZED));
89 }
90
91 let Some(key_id) = metadata.key_id() else {
92 metrics::counter!("firebase-token-auth-rejected", "reason" => "missing-kid")
93 .increment(1);
94 tracing::debug!("token missing kid metadata header");
95 return Err(error_response(StatusCode::UNAUTHORIZED));
96 };
97
98 let claims: JWTClaims<FirebaseClaims> =
101 self.verifier.verify_token(key_id, token).map_err(|_| {
102 metrics::counter!("firebase-token-auth-rejected", "reason" => "invalid-token")
103 .increment(1);
104 tracing::debug!(token, key_id, "invalid firebase auth id token");
105 error_response(StatusCode::UNAUTHORIZED)
106 })?;
107
108 metrics::counter!("firebase-request-authorized").increment(1);
109 req.extensions_mut().insert(claims);
110
111 Ok(())
112 }
113}
114
115impl<S> Service<Request> for FirebaseAuthService<S>
116where
117 S: Service<Request, Response = Response> + Clone + Send + 'static,
118 S::Future: Send + 'static,
119{
120 type Response = S::Response;
121 type Error = S::Error;
122 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
123
124 #[inline]
125 fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126 self.inner.poll_ready(ctx)
127 }
128
129 fn call(&mut self, mut req: Request) -> Self::Future {
130 let not_ready_inner = self.inner.clone();
131 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
132 let auth_result = self.token_auth(&mut req);
133
134 Box::pin(async move {
135 match auth_result {
136 Ok(_) => ready_inner.call(req).await,
137 Err(err) => Ok(err.into_response()),
138 }
139 })
140 }
141}
142
143fn error_response(status_code: StatusCode) -> (StatusCode, Json<serde_json::Value>) {
144 let err_resp = serde_json::json!({
145 "status": "error",
146 "message": "request not authorized",
147 });
148 (status_code, Json(err_resp))
149}