use reqwest::{Request, Response, header};
use reqwest_middleware::{Middleware, Next};
use tokio::sync::OnceCell;
use yup_oauth2::{ServiceAccountAuthenticator, ServiceAccountKey};
use yup_oauth2::authenticator::Authenticator;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use http::Extensions;
use std::sync::Arc;
type AuthType = Authenticator<HttpsConnector<HttpConnector>>;
#[derive(Clone)]
pub struct AuthMiddleware {
pub key: ServiceAccountKey,
authenticator: Arc<OnceCell<AuthType>>,
tenant_id: Option<String>,
}
impl AuthMiddleware {
pub fn new(key: ServiceAccountKey) -> Self {
Self {
key,
authenticator: Arc::new(OnceCell::new()),
tenant_id: None,
}
}
pub fn with_tenant(&self, tenant_id: &str) -> Self {
Self {
key: self.key.clone(),
authenticator: self.authenticator.clone(),
tenant_id: Some(tenant_id.to_string()),
}
}
pub fn tenant_id(&self) -> Option<String> {
self.tenant_id.clone()
}
async fn get_token(&self) -> Result<String, anyhow::Error> {
let key = self.key.clone();
let auth = self.authenticator.get_or_try_init(|| async move {
ServiceAccountAuthenticator::builder(key)
.build()
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
}).await?;
let scopes = &["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/firebase"];
let token = auth.token(scopes).await?;
Ok(token.token().ok_or_else(|| anyhow::anyhow!("No token found"))?.to_string())
}
}
#[async_trait::async_trait]
impl Middleware for AuthMiddleware {
async fn handle(
&self,
mut req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
let token = self.get_token().await.map_err(|e| {
reqwest_middleware::Error::Middleware(anyhow::anyhow!("Failed to get auth token: {}", e))
})?;
req.headers_mut().insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&format!("Bearer {}", token)).unwrap(),
);
next.run(req, extensions).await
}
}