1use axum::{
49 extract::{Request, State},
50 http::{HeaderMap, StatusCode},
51 middleware::Next,
52 response::Response,
53};
54use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
55use log::{debug, error, info, warn};
56use serde::{Deserialize, Serialize};
57use std::{
58 collections::HashMap,
59 sync::Arc,
60 time::{Duration, SystemTime, UNIX_EPOCH},
61};
62use tokio::sync::RwLock;
63use uuid::Uuid;
64
65const GOOGLE_CERTS_URL: &str =
66 "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com";
67const DEFAULT_CACHE_DURATION: u64 = 3600;
68const MAX_TOKEN_LENGTH: usize = 4096;
69const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
70const MAX_RETRIES: u32 = 3;
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct FirebaseClaims {
78 pub iss: String,
80 pub aud: String,
82 pub auth_time: i64,
84 pub user_id: String,
86 pub sub: String,
88 pub iat: i64,
90 pub exp: i64,
92 pub email: Option<String>,
94 pub email_verified: Option<bool>,
96 pub firebase: FirebaseAuthProvider,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct FirebaseAuthProvider {
103 pub identities: HashMap<String, Vec<String>>,
105 pub sign_in_provider: String,
107}
108
109#[derive(Debug, thiserror::Error)]
111pub enum FirebaseAuthError {
112 #[error("Invalid token format: {0}")]
113 InvalidTokenFormat(String),
114
115 #[error("Token validation failed: {0}")]
116 ValidationFailed(String),
117
118 #[error("Public key fetch failed: {0}")]
119 KeyFetchFailed(String),
120
121 #[error("Token expired or invalid timing")]
122 InvalidTiming,
123
124 #[error("Missing required claims")]
125 MissingClaims,
126
127 #[error("Configuration error: {0}")]
128 ConfigError(String),
129
130 #[error("Rate limit exceeded")]
131 RateLimited,
132}
133
134#[derive(Clone)]
139pub struct PublicKeyCache {
140 keys: Arc<RwLock<HashMap<String, DecodingKey>>>,
141 last_updated: Arc<RwLock<SystemTime>>,
142 cache_duration: Duration,
143 http_client: reqwest::Client,
144 retry_count: Arc<RwLock<u32>>,
145}
146
147impl PublicKeyCache {
148 pub fn new(cache_duration_seconds: u64) -> Result<Self, FirebaseAuthError> {
156 let http_client = reqwest::Client::builder()
157 .timeout(HTTP_TIMEOUT)
158 .user_agent("Firebase-JWT-Validator/1.0")
159 .https_only(true)
160 .build()
161 .map_err(|e| {
162 FirebaseAuthError::ConfigError(format!("HTTP client creation failed: {}", e))
163 })?;
164
165 Ok(Self {
166 keys: Arc::new(RwLock::new(HashMap::new())),
167 last_updated: Arc::new(RwLock::new(UNIX_EPOCH)),
168 cache_duration: Duration::from_secs(cache_duration_seconds),
169 http_client,
170 retry_count: Arc::new(RwLock::new(0)),
171 })
172 }
173
174 pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, FirebaseAuthError> {
184 if kid.is_empty() || kid.len() > 128 {
185 return Err(FirebaseAuthError::InvalidTokenFormat(
186 "Invalid key ID".to_string(),
187 ));
188 }
189
190 let last_updated = *self.last_updated.read().await;
191 let now = SystemTime::now();
192
193 let needs_refresh =
194 now.duration_since(last_updated).unwrap_or(Duration::MAX) > self.cache_duration;
195
196 if needs_refresh {
197 self.refresh_keys().await?;
198 }
199
200 if let Some(key) = self.keys.read().await.get(kid).cloned() {
201 debug!("Public key cache hit for kid: {}", kid);
202 return Ok(key);
203 }
204
205 if !needs_refresh {
206 warn!("Key {} not found in fresh cache, forcing refresh", kid);
207 self.refresh_keys().await?;
208
209 if let Some(key) = self.keys.read().await.get(kid).cloned() {
210 return Ok(key);
211 }
212 }
213
214 Err(FirebaseAuthError::KeyFetchFailed(format!(
215 "Public key not found for kid: {}",
216 kid
217 )))
218 }
219
220 async fn refresh_keys(&self) -> Result<(), FirebaseAuthError> {
222 let mut retry_count = *self.retry_count.read().await;
223 let mut delay = Duration::from_millis(100);
224
225 for attempt in 0..MAX_RETRIES {
226 match self.fetch_keys().await {
227 Ok(()) => {
228 *self.retry_count.write().await = 0;
229 info!("Successfully refreshed Firebase public keys");
230 return Ok(());
231 }
232 Err(e) => {
233 retry_count += 1;
234 *self.retry_count.write().await = retry_count;
235
236 if attempt < MAX_RETRIES - 1 {
237 warn!(
238 "Failed to fetch keys (attempt {}): {}. Retrying in {:?}",
239 attempt + 1,
240 e,
241 delay
242 );
243 tokio::time::sleep(delay).await;
244 delay *= 2;
245 } else {
246 error!("Failed to fetch keys after {} attempts: {}", MAX_RETRIES, e);
247 return Err(e);
248 }
249 }
250 }
251 }
252
253 unreachable!()
254 }
255
256 async fn fetch_keys(&self) -> Result<(), FirebaseAuthError> {
258 let response = self
259 .http_client
260 .get(GOOGLE_CERTS_URL)
261 .send()
262 .await
263 .map_err(|e| {
264 FirebaseAuthError::KeyFetchFailed(format!("HTTP request failed: {}", e))
265 })?;
266
267 if !response.status().is_success() {
268 return Err(FirebaseAuthError::KeyFetchFailed(format!(
269 "HTTP {} from Google certificates endpoint",
270 response.status()
271 )));
272 }
273
274 let certs: HashMap<String, String> = response.json().await.map_err(|e| {
275 FirebaseAuthError::KeyFetchFailed(format!("Invalid JSON response: {}", e))
276 })?;
277
278 if certs.is_empty() {
279 return Err(FirebaseAuthError::KeyFetchFailed(
280 "Empty certificates response".to_string(),
281 ));
282 }
283
284 let mut keys = HashMap::new();
285 let mut parse_errors = 0;
286
287 for (kid, cert) in certs {
288 if !cert.starts_with("-----BEGIN CERTIFICATE-----") {
289 warn!("Invalid certificate format for kid: {}", kid);
290 parse_errors += 1;
291 continue;
292 }
293
294 match DecodingKey::from_rsa_pem(cert.as_bytes()) {
295 Ok(key) => {
296 keys.insert(kid.clone(), key);
297 debug!("Successfully parsed certificate for kid: {}", kid);
298 }
299 Err(e) => {
300 warn!("Failed to parse certificate for kid {}: {}", kid, e);
301 parse_errors += 1;
302 }
303 }
304 }
305
306 if keys.is_empty() {
307 return Err(FirebaseAuthError::KeyFetchFailed(
308 "No valid certificates found".to_string(),
309 ));
310 }
311
312 if parse_errors > 0 {
313 warn!(
314 "Failed to parse {} out of {} certificates",
315 parse_errors,
316 keys.len() + parse_errors
317 );
318 }
319
320 *self.keys.write().await = keys;
321 *self.last_updated.write().await = SystemTime::now();
322
323 Ok(())
324 }
325}
326
327#[derive(Clone)]
332pub struct FirebaseConfig {
333 pub project_id: String,
335 pub key_cache: PublicKeyCache,
337 pub max_token_age: Duration,
339 pub allowed_algorithms: Vec<Algorithm>,
341}
342
343impl FirebaseConfig {
344 pub fn new(project_id: String) -> Result<Self, FirebaseAuthError> {
360 if project_id.is_empty() {
361 return Err(FirebaseAuthError::ConfigError(
362 "Project ID cannot be empty".to_string(),
363 ));
364 }
365
366 if !project_id.chars().all(|c| c.is_alphanumeric() || c == '-') || project_id.len() > 30 {
367 return Err(FirebaseAuthError::ConfigError(
368 "Invalid project ID format".to_string(),
369 ));
370 }
371
372 let key_cache = PublicKeyCache::new(DEFAULT_CACHE_DURATION)?;
373
374 Ok(Self {
375 project_id,
376 key_cache,
377 max_token_age: Duration::from_secs(24 * 3600),
378 allowed_algorithms: vec![Algorithm::RS256],
379 })
380 }
381
382 pub fn with_cache_duration(mut self, seconds: u64) -> Result<Self, FirebaseAuthError> {
387 self.key_cache = PublicKeyCache::new(seconds)?;
388 Ok(self)
389 }
390
391 pub fn with_max_token_age(mut self, duration: Duration) -> Self {
398 self.max_token_age = duration;
399 self
400 }
401}
402
403fn extract_bearer_token(headers: &HeaderMap) -> Result<String, FirebaseAuthError> {
408 let auth_header = headers.get("authorization").ok_or_else(|| {
409 FirebaseAuthError::InvalidTokenFormat("Missing Authorization header".to_string())
410 })?;
411
412 let auth_str = auth_header.to_str().map_err(|_| {
413 FirebaseAuthError::InvalidTokenFormat("Invalid Authorization header encoding".to_string())
414 })?;
415
416 if !auth_str.starts_with("Bearer ") {
417 return Err(FirebaseAuthError::InvalidTokenFormat(
418 "Authorization header must use Bearer scheme".to_string(),
419 ));
420 }
421
422 let token = &auth_str[7..];
423
424 if token.is_empty() {
425 return Err(FirebaseAuthError::InvalidTokenFormat(
426 "Empty token".to_string(),
427 ));
428 }
429
430 if token.len() > MAX_TOKEN_LENGTH {
431 return Err(FirebaseAuthError::InvalidTokenFormat(
432 "Token too long".to_string(),
433 ));
434 }
435
436 let parts: Vec<&str> = token.split('.').collect();
437 if parts.len() != 3 {
438 return Err(FirebaseAuthError::InvalidTokenFormat(
439 "Invalid JWT format".to_string(),
440 ));
441 }
442
443 if token.contains('\0') || token.contains('\n') || token.contains('\r') {
444 return Err(FirebaseAuthError::InvalidTokenFormat(
445 "Token contains invalid characters".to_string(),
446 ));
447 }
448
449 Ok(token.to_string())
450}
451
452async fn validate_firebase_token(
460 token: &str,
461 config: &FirebaseConfig,
462) -> Result<FirebaseClaims, FirebaseAuthError> {
463 let header = decode_header(token).map_err(|e| {
464 FirebaseAuthError::InvalidTokenFormat(format!("Invalid token header: {}", e))
465 })?;
466
467 let algorithm = header.alg;
468 if !config.allowed_algorithms.contains(&algorithm) {
469 return Err(FirebaseAuthError::ValidationFailed(format!(
470 "Algorithm {:?} not allowed",
471 algorithm
472 )));
473 }
474
475 let kid = header.kid.ok_or_else(|| {
476 FirebaseAuthError::InvalidTokenFormat("Missing key ID in token header".to_string())
477 })?;
478
479 let decoding_key = config.key_cache.get_key(&kid).await?;
480
481 let mut validation = Validation::new(algorithm);
482 validation.set_audience(&[&config.project_id]);
483 validation.set_issuer(&[&format!(
484 "https://securetoken.google.com/{}",
485 config.project_id
486 )]);
487 validation.validate_exp = true;
488 validation.validate_nbf = false;
489 validation.validate_aud = true;
490 validation.leeway = 60;
491 validation.reject_tokens_expiring_in_less_than = 0;
492
493 let token_data = decode::<FirebaseClaims>(token, &decoding_key, &validation).map_err(|e| {
494 FirebaseAuthError::ValidationFailed(format!("Token validation failed: {}", e))
495 })?;
496
497 let claims = token_data.claims;
498
499 if claims.sub.is_empty() || claims.sub.len() > 128 {
500 return Err(FirebaseAuthError::MissingClaims);
501 }
502
503 if claims.sub != claims.user_id {
504 return Err(FirebaseAuthError::ValidationFailed(
505 "Subject and user_id mismatch".to_string(),
506 ));
507 }
508
509 let now = SystemTime::now()
510 .duration_since(UNIX_EPOCH)
511 .unwrap()
512 .as_secs() as i64;
513
514 if claims.auth_time > now + 60 {
515 return Err(FirebaseAuthError::InvalidTiming);
516 }
517
518 let token_age = Duration::from_secs((now - claims.iat) as u64);
519 if token_age > config.max_token_age {
520 return Err(FirebaseAuthError::InvalidTiming);
521 }
522
523 let expected_issuer = format!("https://securetoken.google.com/{}", config.project_id);
524 if claims.iss != expected_issuer {
525 return Err(FirebaseAuthError::ValidationFailed(
526 "Invalid issuer".to_string(),
527 ));
528 }
529
530 if claims.aud != config.project_id {
531 return Err(FirebaseAuthError::ValidationFailed(
532 "Invalid audience".to_string(),
533 ));
534 }
535
536 let auth_age = Duration::from_secs((now - claims.auth_time) as u64);
537 if auth_age > Duration::from_secs(7 * 24 * 3600) {
538 return Err(FirebaseAuthError::InvalidTiming);
539 }
540
541 debug!(
542 "Successfully validated Firebase token for user: {}",
543 claims.user_id
544 );
545 Ok(claims)
546}
547
548pub async fn firebase_auth_middleware(
584 State(config): State<FirebaseConfig>,
585 mut request: Request,
586 next: Next,
587) -> Result<Response, StatusCode> {
588 if let Some(content_length) = request.headers().get("content-length") {
589 if let Ok(length_str) = content_length.to_str() {
590 if let Ok(length) = length_str.parse::<usize>() {
591 if length > 10_485_760 {
592 warn!("Request body too large: {} bytes", length);
593 return Err(StatusCode::PAYLOAD_TOO_LARGE);
594 }
595 }
596 }
597 }
598
599 let token = match extract_bearer_token(request.headers()) {
600 Ok(token) => token,
601 Err(e) => {
602 warn!("Token extraction failed: {}", e);
603 return Err(StatusCode::UNAUTHORIZED);
604 }
605 };
606
607 let claims = match validate_firebase_token(&token, &config).await {
608 Ok(claims) => {
609 debug!("Successfully authenticated user: {}", claims.user_id);
610 claims
611 }
612 Err(e) => match e {
613 FirebaseAuthError::InvalidTokenFormat(_) | FirebaseAuthError::MissingClaims => {
614 warn!("Invalid token format: {}", e);
615 return Err(StatusCode::UNAUTHORIZED);
616 }
617 FirebaseAuthError::ValidationFailed(_) | FirebaseAuthError::InvalidTiming => {
618 warn!("Token validation failed: {}", e);
619 return Err(StatusCode::UNAUTHORIZED);
620 }
621 FirebaseAuthError::KeyFetchFailed(_) => {
622 error!("Key fetch failed: {}", e);
623 return Err(StatusCode::SERVICE_UNAVAILABLE);
624 }
625 FirebaseAuthError::RateLimited => {
626 warn!("Rate limit exceeded");
627 return Err(StatusCode::TOO_MANY_REQUESTS);
628 }
629 FirebaseAuthError::ConfigError(_) => {
630 error!("Configuration error: {}", e);
631 return Err(StatusCode::INTERNAL_SERVER_ERROR);
632 }
633 },
634 };
635
636 request.extensions_mut().insert(claims);
637
638 if request.extensions().get::<String>().is_none() {
639 let request_id = Uuid::new_v4().to_string();
640 request.extensions_mut().insert(request_id);
641 }
642
643 Ok(next.run(request).await)
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use axum::body::Body;
650 use axum::extract::FromRef;
651 use axum::http::Request;
652 use axum::middleware::from_fn_with_state;
653 use axum::response::IntoResponse;
654 use axum::routing::get;
655 use axum::{Extension, Json, Router};
656 use tower::ServiceExt;
657
658 #[derive(Clone, FromRef)]
659 struct AppStateMock {
660 fb: FirebaseConfig,
661 }
662
663 async fn health_check(
664 State(config): State<FirebaseConfig>,
665 ) -> Result<Json<serde_json::Value>, StatusCode> {
666 match config.key_cache.fetch_keys().await {
667 Ok(()) => Ok(Json(serde_json::json!({
668 "status": "healthy",
669 "firebase_keys": "accessible",
670 "timestamp": SystemTime::now()
671 .duration_since(UNIX_EPOCH)
672 .unwrap()
673 .as_secs()
674 }))),
675 Err(_) => Err(StatusCode::SERVICE_UNAVAILABLE),
676 }
677 }
678
679 async fn protected_handler(Extension(claims): Extension<FirebaseClaims>) -> impl IntoResponse {
680 Json(serde_json::json!({
681 "message": "Successfully authenticated",
682 "user_id": claims.user_id,
683 "email": claims.email
684 }))
685 }
686
687 async fn create_route() -> Router {
688 let app_state = AppStateMock {
689 fb: FirebaseConfig::new("test-project-id".to_string()).unwrap(),
690 };
691
692 Router::new()
693 .route("/health", get(health_check))
694 .nest(
695 "/api/v1",
696 Router::new()
697 .route("/protected", get(protected_handler))
698 .route_layer(from_fn_with_state(
699 app_state.fb.clone(),
700 firebase_auth_middleware,
701 )),
702 )
703 .with_state(app_state)
704 }
705
706 #[test]
707 fn test_extract_bearer_token() {
708 let mut headers = HeaderMap::new();
709 headers.insert("authorization", "Bearer test.token.123".parse().unwrap());
710 let token = extract_bearer_token(&headers).unwrap();
711 assert_eq!(token, "test.token.123");
712
713 let mut headers = HeaderMap::new();
714 headers.insert("authorization", "Basic invalid".parse().unwrap());
715 let result = extract_bearer_token(&headers);
716 assert!(result.is_err());
717 assert!(matches!(
718 result.unwrap_err(),
719 FirebaseAuthError::InvalidTokenFormat(_)
720 ));
721
722 let headers = HeaderMap::new();
723 let result = extract_bearer_token(&headers);
724 assert!(result.is_err());
725 assert!(matches!(
726 result.unwrap_err(),
727 FirebaseAuthError::InvalidTokenFormat(_)
728 ));
729
730 let mut headers = HeaderMap::new();
731 headers.insert("authorization", "Bearer ".parse().unwrap());
732 let result = extract_bearer_token(&headers);
733 assert!(result.is_err());
734 assert!(matches!(
735 result.unwrap_err(),
736 FirebaseAuthError::InvalidTokenFormat(_)
737 ));
738
739 let mut headers = HeaderMap::new();
740 let long_token = "Bearer ".to_string() + &"a".repeat(MAX_TOKEN_LENGTH + 1);
741 headers.insert("authorization", long_token.parse().unwrap());
742 let result = extract_bearer_token(&headers);
743 assert!(result.is_err());
744 assert!(matches!(
745 result.unwrap_err(),
746 FirebaseAuthError::InvalidTokenFormat(_)
747 ));
748
749 let mut headers = HeaderMap::new();
750 headers.insert("authorization", "Bearer part1.part2".parse().unwrap());
751 let result = extract_bearer_token(&headers);
752 assert!(result.is_err());
753 assert!(matches!(
754 result.unwrap_err(),
755 FirebaseAuthError::InvalidTokenFormat(_)
756 ));
757
758 let mut headers = HeaderMap::new();
759 headers.insert("authorization", "Bearer part1.part2.part3".parse().unwrap());
760 let result = extract_bearer_token(&headers);
761 assert!(result.is_ok());
762 assert_eq!(result.unwrap(), "part1.part2.part3");
763
764 let invalid_chars = ["token\0null", "token\nline", "token\rreturn"];
765 for invalid_token in invalid_chars {
766 let has_invalid_chars = invalid_token.contains('\0')
767 || invalid_token.contains('\n')
768 || invalid_token.contains('\r');
769 assert!(has_invalid_chars, "Token should contain invalid characters");
770 }
771 }
772
773 #[tokio::test]
774 async fn test_public_key_cache_creation() {
775 let cache = PublicKeyCache::new(3600);
776 assert!(cache.is_ok());
777 assert!(cache.unwrap().keys.read().await.is_empty());
778 }
779
780 #[tokio::test]
781 async fn test_firebase_config_creation() {
782 let project_id = "test-project-id".to_string();
783 let config = FirebaseConfig::new(project_id.clone());
784 assert!(config.is_ok());
785
786 let config = config.unwrap();
787 assert_eq!(config.project_id, project_id);
788 assert_eq!(config.allowed_algorithms, vec![Algorithm::RS256]);
789 }
790
791 #[tokio::test]
792 async fn test_firebase_auth_middleware_no_token() {
793 let app = create_route().await;
794
795 let request = Request::builder()
796 .uri("/api/v1/protected")
797 .body(Body::empty())
798 .unwrap();
799
800 let response = app.oneshot(request).await.unwrap();
801 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
802 }
803
804 #[tokio::test]
805 async fn test_firebase_auth_middleware_invalid_token() {
806 let app = create_route().await;
807
808 let request = Request::builder()
809 .uri("/api/v1/protected")
810 .header("Authorization", "Bearer invalid.token.format")
811 .body(Body::empty())
812 .unwrap();
813
814 let response = app.oneshot(request).await.unwrap();
815 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
816 }
817
818 #[tokio::test]
819 async fn test_firebase_auth_without_middleware() {
820 let app = create_route().await;
821
822 let request = Request::builder()
823 .uri("/health")
824 .body(Body::empty())
825 .unwrap();
826
827 let response = app.oneshot(request).await.unwrap();
828 assert_eq!(response.status(), StatusCode::OK);
829 }
830}