arche 4.0.0

An opinionated backend foundation for Axum applications, providing batteries-included integrations for cloud services, databases, authentication, middleware, and logging.
Documentation
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};

use jsonwebtoken::{Algorithm, EncodingKey, Header};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;

use crate::error::AppError;

const DEFAULT_TOKEN_URI: &str = "https://oauth2.googleapis.com/token";
const JWT_LIFETIME_SECS: u64 = 3600;
const EXPIRY_SAFETY_MARGIN: Duration = Duration::from_secs(60);
const TOKEN_FETCH_TIMEOUT: Duration = Duration::from_secs(15);
const TOKEN_FETCH_MAX_ATTEMPTS: u32 = 2;
const TOKEN_FETCH_RETRY_DELAY: Duration = Duration::from_millis(200);

#[derive(Clone, Deserialize)]
pub struct ServiceAccountKey {
    client_email: String,
    private_key: String,
    #[serde(default)]
    private_key_id: Option<String>,
    #[serde(default)]
    token_uri: Option<String>,
}

impl std::fmt::Debug for ServiceAccountKey {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ServiceAccountKey")
            .field("client_email", &self.client_email)
            .field("private_key_id", &self.private_key_id)
            .field("token_uri", &self.token_uri)
            .field("private_key", &"<redacted>")
            .finish()
    }
}

impl ServiceAccountKey {
    pub fn new(client_email: impl Into<String>, private_key: impl Into<String>) -> Self {
        Self {
            client_email: client_email.into(),
            private_key: normalize_private_key(&private_key.into()),
            private_key_id: None,
            token_uri: None,
        }
    }

    pub async fn from_path(path: &str) -> Result<Self, AppError> {
        let json = tokio::fs::read_to_string(path).await.map_err(|e| {
            AppError::internal_error(
                format!("Failed to read service account key file at {path}: {e}"),
                None,
            )
        })?;
        Self::from_json(&json)
    }

    fn from_json(json: &str) -> Result<Self, AppError> {
        let mut key: Self = serde_json::from_str(json).map_err(|e| {
            AppError::internal_error(format!("Failed to parse service account JSON: {e}"), None)
        })?;
        key.private_key = normalize_private_key(&key.private_key);
        Ok(key)
    }

    fn token_uri(&self) -> &str {
        self.token_uri.as_deref().unwrap_or(DEFAULT_TOKEN_URI)
    }
}

// `.env` stores newlines as `\\n`; PEM parsers need real newlines.
fn normalize_private_key(raw: &str) -> String {
    raw.replace("\\n", "\n")
}

#[derive(Serialize)]
struct JwtClaims<'a> {
    iss: &'a str,
    scope: String,
    aud: &'a str,
    exp: u64,
    iat: u64,
}

#[derive(Deserialize)]
struct TokenResponse {
    access_token: String,
    expires_in: u64,
}

#[derive(Deserialize)]
struct TokenErrorResponse {
    error: String,
    #[serde(default)]
    error_description: Option<String>,
}

#[derive(Clone)]
struct CachedToken {
    value: String,
    expires_at: Instant,
}

type CacheKey = (String, Vec<String>);

pub struct TokenSource {
    http: reqwest::Client,
    key: ServiceAccountKey,
    cache: Mutex<HashMap<CacheKey, CachedToken>>,
    locks: Mutex<HashMap<CacheKey, Arc<Mutex<()>>>>,
}

impl TokenSource {
    pub fn new(http: reqwest::Client, key: ServiceAccountKey) -> Self {
        Self {
            http,
            key,
            cache: Mutex::new(HashMap::new()),
            locks: Mutex::new(HashMap::new()),
        }
    }

    pub async fn access_token(&self, scopes: &[&str]) -> Result<String, AppError> {
        // Sort so the same scope set in different orders shares a cache slot.
        let mut sorted_scopes: Vec<String> = scopes.iter().map(|s| s.to_string()).collect();
        sorted_scopes.sort();
        let cache_key: CacheKey = (self.key.client_email.clone(), sorted_scopes);

        if let Some(token) = self.lookup_cached(&cache_key).await {
            return Ok(token);
        }

        let lock = {
            let mut locks = self.locks.lock().await;
            locks
                .entry(cache_key.clone())
                .or_insert_with(|| Arc::new(Mutex::new(())))
                .clone()
        };
        let _guard = lock.lock().await;

        if let Some(token) = self.lookup_cached(&cache_key).await {
            return Ok(token);
        }

        let fetched = self.fetch_token(scopes).await?;
        let mut cache = self.cache.lock().await;
        cache.insert(cache_key, fetched.clone());
        Ok(fetched.value)
    }

    async fn lookup_cached(&self, key: &CacheKey) -> Option<String> {
        let cache = self.cache.lock().await;
        cache
            .get(key)
            .filter(|t| t.expires_at > Instant::now())
            .map(|t| t.value.clone())
    }

    async fn fetch_token(&self, scopes: &[&str]) -> Result<CachedToken, AppError> {
        let assertion = self.sign_assertion(scopes)?;

        let mut last_transient: Option<String> = None;
        for attempt in 1..=TOKEN_FETCH_MAX_ATTEMPTS {
            match self.try_fetch_token(&assertion).await {
                Ok(token) => return Ok(token),
                Err(TokenFetchError::Permanent(e)) => return Err(e),
                Err(TokenFetchError::Transient(detail)) => {
                    if attempt < TOKEN_FETCH_MAX_ATTEMPTS {
                        tracing::warn!(
                            attempt,
                            error = %detail,
                            "Transient GCP token fetch error, retrying"
                        );
                        tokio::time::sleep(TOKEN_FETCH_RETRY_DELAY).await;
                    }
                    last_transient = Some(detail);
                }
            }
        }

        Err(AppError::internal_error(
            format!(
                "Failed to fetch GCP access token after {TOKEN_FETCH_MAX_ATTEMPTS} attempts: {}",
                last_transient.unwrap_or_else(|| "unknown error".into())
            ),
            None,
        ))
    }

    fn sign_assertion(&self, scopes: &[&str]) -> Result<String, AppError> {
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .map_err(|e| AppError::internal_error(format!("Clock error: {e}"), None))?
            .as_secs();

        let claims = JwtClaims {
            iss: &self.key.client_email,
            scope: scopes.join(" "),
            aud: self.key.token_uri(),
            iat: now,
            exp: now + JWT_LIFETIME_SECS,
        };

        let mut header = Header::new(Algorithm::RS256);
        header.kid = self.key.private_key_id.clone();

        let encoding_key =
            EncodingKey::from_rsa_pem(self.key.private_key.as_bytes()).map_err(|e| {
                AppError::internal_error(
                    format!("Invalid GCP service account private key: {e}"),
                    None,
                )
            })?;

        jsonwebtoken::encode(&header, &claims, &encoding_key).map_err(|e| {
            AppError::internal_error(format!("Failed to sign JWT for GCP token: {e}"), None)
        })
    }

    async fn try_fetch_token(&self, assertion: &str) -> Result<CachedToken, TokenFetchError> {
        let send_fut = self
            .http
            .post(self.key.token_uri())
            .form(&[
                ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
                ("assertion", assertion),
            ])
            .send();

        let resp = match tokio::time::timeout(TOKEN_FETCH_TIMEOUT, send_fut).await {
            Ok(Ok(r)) => r,
            Ok(Err(e)) => return Err(TokenFetchError::Transient(format!("send error: {e}"))),
            Err(_) => {
                return Err(TokenFetchError::Transient(format!(
                    "timed out after {}s",
                    TOKEN_FETCH_TIMEOUT.as_secs()
                )));
            }
        };

        let status = resp.status();
        let body = resp
            .bytes()
            .await
            .map_err(|e| TokenFetchError::Transient(format!("read body error: {e}")))?;

        if status.is_server_error() {
            return Err(TokenFetchError::Transient(format!(
                "HTTP {status}: {}",
                parse_token_error(&body)
            )));
        }

        if !status.is_success() {
            return Err(TokenFetchError::Permanent(AppError::internal_error(
                format!(
                    "GCP token endpoint returned {status}: {}",
                    parse_token_error(&body)
                ),
                None,
            )));
        }

        let token: TokenResponse = serde_json::from_slice(&body).map_err(|e| {
            TokenFetchError::Permanent(AppError::internal_error(
                format!("Malformed GCP token response: {e}"),
                None,
            ))
        })?;

        let lifetime = Duration::from_secs(token.expires_in)
            .checked_sub(EXPIRY_SAFETY_MARGIN)
            .unwrap_or(Duration::ZERO);

        Ok(CachedToken {
            value: token.access_token,
            expires_at: Instant::now() + lifetime,
        })
    }
}

enum TokenFetchError {
    Transient(String),
    Permanent(AppError),
}

fn parse_token_error(body: &[u8]) -> String {
    serde_json::from_slice::<TokenErrorResponse>(body)
        .map(|e| match e.error_description {
            Some(desc) => format!("{}: {desc}", e.error),
            None => e.error,
        })
        .unwrap_or_else(|_| String::from_utf8_lossy(body).into_owned())
}