modelmux 1.0.0

ModelMux - high-performance Rust gateway that translates OpenAI-compatible API requests to Vertex AI (Claude), with streaming, tool calling, and production-grade reliability.
Documentation
//!
//! Authentication for LLM backends (Vertex GCP OAuth2, Bearer token for other providers).
//!
//! [RequestAuth] is the unified type used by the server; it is built from the
//! provider's [crate::provider::AuthStrategy].
//!
//! Authors:
//!   Jaro <yarenty@gmail.com>
//!
//! Copyright (c) 2026 SkyCorp

/* --- uses ------------------------------------------------------------------------------------ */

use std::sync::Arc;

use hyper_util::client::legacy::connect::HttpConnector;
use tokio::sync::Mutex;
use yup_oauth2::authenticator::Authenticator;
use yup_oauth2::{ServiceAccountAuthenticator, ServiceAccountKey as OAuthKey, hyper_rustls};

use crate::config::ServiceAccountKey;
use crate::error::{ProxyError, Result};
use crate::provider::AuthStrategy;

/* --- request auth (provider-agnostic) -------------------------------------------------------- */

///
/// Unified auth for outgoing LLM requests: GCP OAuth2 or static Bearer token.
///
/// Built from [AuthStrategy]; the server uses this to attach the correct header.
pub enum RequestAuth {
    /// Google Cloud OAuth2 (Vertex AI).
    Gcp(Arc<GcpAuthProvider>),
    /// Static Bearer token (e.g. OpenAI-compatible, Mistral).
    Bearer(String),
}

impl RequestAuth {
    ///
    /// Build [RequestAuth] from the provider's auth strategy.
    pub async fn from_strategy(strategy: &AuthStrategy) -> Result<Self> {
        match strategy {
            AuthStrategy::GcpOAuth2(key) => {
                let provider = GcpAuthProvider::new(key).await?;
                Ok(Self::Gcp(Arc::new(provider)))
            }
            AuthStrategy::BearerToken(token) => Ok(Self::Bearer(token.clone())),
        }
    }

    ///
    /// Return the value for the `Authorization` header (e.g. `Bearer <token>`).
    pub async fn authorization_header_value(&self) -> Result<String> {
        match self {
            Self::Gcp(gcp) => {
                let token = gcp.get_access_token().await?;
                Ok(format!("Bearer {}", token))
            }
            Self::Bearer(t) => Ok(format!("Bearer {}", t)),
        }
    }
}

/* --- GCP auth provider ----------------------------------------------------------------------- */

///
/// Google Cloud Platform authentication provider.
///
/// Manages OAuth2 authentication flow for accessing Vertex AI services using
/// service account credentials. Handles token generation and refresh automatically.
pub struct GcpAuthProvider {
    /** the OAuth2 authenticator instance for token management */
    authenticator: Arc<Mutex<ServiceAccountAuth>>,
}

/* --- constants ------------------------------------------------------------------------------ */

/** Google Cloud Platform scope for accessing cloud services */
const CLOUD_PLATFORM_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";

/* --- start of code -------------------------------------------------------------------------- */

// Type alias for the authenticator type returned by ServiceAccountAuthenticator::builder().build()
type ServiceAccountAuth = Authenticator<hyper_rustls::HttpsConnector<HttpConnector>>;

impl GcpAuthProvider {
    ///
    /// Create a new GCP authentication provider.
    ///
    /// Initializes the OAuth2 authenticator with the provided service account
    /// credentials. The authenticator will automatically handle token refresh
    /// when needed.
    ///
    /// # Arguments
    ///  * `service_account_key` - Google Cloud service account credentials
    ///
    /// # Returns
    ///  * New authentication provider instance
    ///  * `ProxyError::Auth` if authenticator creation fails
    pub async fn new(service_account_key: &ServiceAccountKey) -> Result<Self> {
        let oauth_key = Self::convert_service_account_key(service_account_key);
        let authenticator = Self::create_authenticator(oauth_key).await?;

        Ok(Self { authenticator: Arc::new(Mutex::new(authenticator)) })
    }

    ///
    /// Get a valid access token for Google Cloud Platform.
    ///
    /// Retrieves a fresh access token, automatically refreshing if the current
    /// token has expired. The token can be used for authenticating requests
    /// to Vertex AI services.
    ///
    /// # Returns
    ///  * Valid access token string
    ///  * `ProxyError::Auth` if token retrieval fails
    pub async fn get_access_token(&self) -> Result<String> {
        let scopes = &[CLOUD_PLATFORM_SCOPE];
        let guard = self.authenticator.lock().await;

        let token = guard
            .token(scopes)
            .await
            .map_err(|e| ProxyError::Auth(format!("Failed to get access token: {}", e)))?;

        // AccessToken has a token() method that returns Option<&str>
        token
            .token()
            .ok_or_else(|| ProxyError::Auth("Access token is missing from response".to_string()))
            .map(|s| s.to_string())
    }

    ///
    /// Convert internal service account key to OAuth2 library format.
    ///
    /// Transforms our configuration structure into the format expected by
    /// the yup-oauth2 library for service account authentication.
    ///
    /// # Arguments
    ///  * `service_account_key` - internal service account key structure
    ///
    /// # Returns
    ///  * OAuth2 library service account key structure
    fn convert_service_account_key(service_account_key: &ServiceAccountKey) -> OAuthKey {
        OAuthKey {
            key_type: Some("service_account".to_string()),
            project_id: Some(service_account_key.project_id.clone()),
            private_key_id: Some(service_account_key.private_key_id.clone()),
            private_key: service_account_key.private_key.clone(),
            client_email: service_account_key.client_email.clone(),
            client_id: Some(service_account_key.client_id.clone()),
            auth_uri: Some(service_account_key.auth_uri.clone()),
            token_uri: service_account_key.token_uri.clone(),
            auth_provider_x509_cert_url: Some(
                service_account_key.auth_provider_x509_cert_url.clone(),
            ),
            client_x509_cert_url: Some(service_account_key.client_x509_cert_url.clone()),
        }
    }

    ///
    /// Create the OAuth2 authenticator instance.
    ///
    /// Builds and configures the authenticator using the provided OAuth2 key.
    /// This handles the low-level OAuth2 flow setup.
    ///
    /// # Arguments
    ///  * `oauth_key` - OAuth2 service account key
    ///
    /// # Returns
    ///  * Configured authenticator instance
    ///  * `ProxyError::Auth` if authenticator creation fails
    async fn create_authenticator(oauth_key: OAuthKey) -> Result<ServiceAccountAuth> {
        ServiceAccountAuthenticator::builder(oauth_key)
            .build()
            .await
            .map_err(|e| ProxyError::Auth(format!("Failed to create authenticator: {}", e)))
    }
}