firebase_admin_sdk/core/
middleware.rs1use reqwest::{Request, Response, header};
2use reqwest_middleware::{Middleware, Next};
3use tokio::sync::OnceCell;
4use yup_oauth2::{ServiceAccountAuthenticator, ServiceAccountKey};
5use yup_oauth2::authenticator::Authenticator;
6use hyper_rustls::HttpsConnector;
7use hyper_util::client::legacy::connect::HttpConnector;
8use http::Extensions;
9use std::sync::Arc;
10
11type AuthType = Authenticator<HttpsConnector<HttpConnector>>;
13
14#[derive(Clone)]
24pub struct AuthMiddleware {
25 pub key: ServiceAccountKey,
27 authenticator: Arc<OnceCell<AuthType>>,
29 tenant_id: Option<String>,
31}
32
33impl AuthMiddleware {
34 pub fn new(key: ServiceAccountKey) -> Self {
40 Self {
41 key,
42 authenticator: Arc::new(OnceCell::new()),
43 tenant_id: None,
44 }
45 }
46
47 pub fn with_tenant(&self, tenant_id: &str) -> Self {
49 Self {
50 key: self.key.clone(),
51 authenticator: self.authenticator.clone(),
52 tenant_id: Some(tenant_id.to_string()),
53 }
54 }
55
56 pub fn tenant_id(&self) -> Option<String> {
58 self.tenant_id.clone()
59 }
60
61 async fn get_token(&self) -> Result<String, anyhow::Error> {
63 let key = self.key.clone();
64 let auth = self.authenticator.get_or_try_init(|| async move {
65 ServiceAccountAuthenticator::builder(key)
66 .build()
67 .await
68 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
69 }).await?;
70
71 let scopes = &["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/firebase"];
72
73 let token = auth.token(scopes).await?;
74
75 Ok(token.token().ok_or_else(|| anyhow::anyhow!("No token found"))?.to_string())
76 }
77}
78
79#[async_trait::async_trait]
80impl Middleware for AuthMiddleware {
81 async fn handle(
83 &self,
84 mut req: Request,
85 extensions: &mut Extensions,
86 next: Next<'_>,
87 ) -> reqwest_middleware::Result<Response> {
88
89 let token = self.get_token().await.map_err(|e| {
90 reqwest_middleware::Error::Middleware(anyhow::anyhow!("Failed to get auth token: {}", e))
91 })?;
92
93 req.headers_mut().insert(
94 header::AUTHORIZATION,
95 header::HeaderValue::from_str(&format!("Bearer {}", token)).unwrap(),
96 );
97
98 next.run(req, extensions).await
99 }
100}