ones-oidc 0.3.7

ONES OpenID Connect client for Rust
Documentation
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use jsonwebtoken::jwk::JwkSet;
use jsonwebtoken::EncodingKey;
use std::collections::HashMap;
use std::str::FromStr;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;
use super::oidc_types::{JwtPayload, LoginHint, LoginHintKind};
use super::errors::{DeviceError, OidcRequirementsError};
use super::http_client::default_client;
use super::{ClientId, CoreProviderMetadata, CoreTokenType, EmptyExtraTokenFields, StandardTokenResponse};
use openidconnect::OAuth2TokenResponse;

/// Cached token entry with expiration tracking
#[derive(Clone)]
struct TokenCache {
    token_response: StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>,
    expires_at: Instant,
}

use std::sync::LazyLock;

/// Global token cache - key format: "client_id:issuer_url"
static TOKEN_CACHE: LazyLock<RwLock<HashMap<String, TokenCache>>> = 
    LazyLock::new(|| RwLock::new(HashMap::new()));

/// Safety buffer before token expiry (100 seconds)
const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(100);

/// Configuration for JWT payload creation
#[derive(Debug, Clone)]
pub struct JwtPayloadConfig {
    pub scope: Option<String>,
    pub binding_message: Option<String>,
    pub login_hint: Option<LoginHint>,
    pub resource: Option<String>,
    pub qr_session_id: Option<String>,
    pub jwt_expiration_secs: u64,
}

impl Default for JwtPayloadConfig {
    fn default() -> Self {
        Self {
            scope: None,
            binding_message: None,
            login_hint: None,
            resource: None,
            qr_session_id: None,
            jwt_expiration_secs: 300, // 5 minutes default
        }
    }
}

/// Request a device access token from the IDP token endpoint
pub async fn request_device_access_token(
    provider_metadata: &CoreProviderMetadata,
    device_client_id: &ClientId,
    device_jwt: &str,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>, DeviceError> {
    let client = default_client()?;

    let url = provider_metadata
        .token_endpoint()
        .ok_or(OidcRequirementsError::MissingTokenEndpoint)?;

    let body = [
        ("grant_type", "client_credentials"),
        ("client_id", &device_client_id.to_string()),
        (
            "client_assertion_type",
            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
        ),
        ("client_assertion", device_jwt),
    ];

    let response = client
        .post(url.to_string())
        .header("content-type", "application/x-www-form-urlencoded")
        .form(&body)
        .send()
        .await?
        .error_for_status()?
        .json::<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>>()
        .await?;

    Ok(response)
}

/// Request a device access token from the IDP token endpoint
/// CACHES THE RESULT - refreshes 100s before expiry
pub async fn device_access_token(
    provider_metadata: &CoreProviderMetadata,
    device_client_id: &ClientId,
    device_jwt: &str,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>, DeviceError>
{
    let cache_key = format!("{}:{}", device_client_id.as_str(), provider_metadata.issuer().as_str());
    let now = Instant::now();

    // Check cache first (read lock)
    {
        let cache = TOKEN_CACHE.read().await;
        if let Some(cached) = cache.get(&cache_key) {
            if cached.expires_at > now {
                // Token is still valid, return cached version
                return Ok(cached.token_response.clone());
            }
        }
    }

    // Token not cached or expired, request new one (write lock)
    let mut cache = TOKEN_CACHE.write().await;
    
    // Double-check pattern: another thread might have updated cache
    if let Some(cached) = cache.get(&cache_key) {
        if cached.expires_at > now {
            return Ok(cached.token_response.clone());
        }
    }

    // Request new token
    let response = request_device_access_token(provider_metadata, device_client_id, device_jwt).await?;
    
    // Calculate expiration time with safety buffer
    let expires_at = if let Some(expires_in) = response.expires_in() {
        // Subtract buffer time to refresh before expiry
        let expires_duration = if expires_in > TOKEN_REFRESH_BUFFER {
            expires_in - TOKEN_REFRESH_BUFFER
        } else {
            // If token expires too soon, use half the time
            Duration::from_secs(expires_in.as_secs() / 2)
        };
        now + expires_duration
    } else {
        // No expiration info, cache for 1 hour with buffer
        now + Duration::from_secs(3600) - TOKEN_REFRESH_BUFFER
    };

    // Cache the new token
    cache.insert(cache_key, TokenCache {
        token_response: response.clone(),
        expires_at,
    });

    Ok(response)
}

/// JWT payload builder with optional CIBA-specific fields
fn create_jwt_payload_base(
    device_client_id: &ClientId,
    provider_metadata: &CoreProviderMetadata,
    config: &JwtPayloadConfig,
) -> Result<JwtPayload, DeviceError> {
    let now = chrono::Utc::now().timestamp() as u64;
    
    let (login_hint_value, login_hint_token_value) = if let Some(hint) = &config.login_hint {
        match hint.kind {
            LoginHintKind::LoginHint => (Some(hint.value.clone()), None),
            LoginHintKind::LoginHintToken => (None, Some(hint.value.clone())),
        }
    } else {
        (None, None)
    };
    
    Ok(JwtPayload {
        iss: device_client_id.to_string(),
        sub: Uuid::from_str(device_client_id.as_str())?,
        aud: Some(provider_metadata.issuer().to_string()),
        jti: Some(Uuid::new_v4().to_string()),
        iat: now,
        exp: now + config.jwt_expiration_secs,
        nbf: Some(now),
        scope: config.scope.clone(),
        binding_message: config.binding_message.clone(),
        login_hint: login_hint_value,
        login_hint_token: login_hint_token_value,
        resource: config.resource.clone(),
        client_id: None,
        username: None,
        user_client_id: None,
        idp_role: None,
        qr_session_id: config.qr_session_id.clone(),
    })
}

/// Device JWT content generic (using default expiration)
pub fn make_device_jwt_base(
    device_client_id: &ClientId,
    provider_metadata: &CoreProviderMetadata,
) -> Result<JwtPayload, DeviceError> {
    let config = JwtPayloadConfig::default();
    create_jwt_payload_base(device_client_id, provider_metadata, &config)
}

/// Device JWT content generic with configurable expiration
#[allow(dead_code)]
pub fn make_device_jwt_base_with_config(
    device_client_id: &ClientId,
    provider_metadata: &CoreProviderMetadata,
    jwt_expiration_secs: u64,
) -> Result<JwtPayload, DeviceError> {
    let config = JwtPayloadConfig {
        jwt_expiration_secs,
        ..JwtPayloadConfig::default()
    };
    create_jwt_payload_base(device_client_id, provider_metadata, &config)
}

/// Device JWT generic (CIBA status and others)
pub fn make_device_jwt(
    device_client_id: &ClientId,
    provider_metadata: &CoreProviderMetadata,
    private_key: &EncodingKey,
) -> Result<String, DeviceError> {
    let jwt_payload = make_device_jwt_base(device_client_id, provider_metadata)?;
    let jwt = jsonwebtoken::encode(
        &jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
        &jwt_payload,
        private_key,
    )?;
    Ok(jwt)
}

/// Device JWT content for CIBA request (using default expiration)
pub fn make_device_jwt_ciba_base(
    device_client_id: &ClientId,
    provider_metadata: &CoreProviderMetadata,
    login_hint: &LoginHint,
    scope: &str,
    binding_message: &str,
    resource: Option<String>,
    // QR Session Request Origin
    qr_session_id: Option<String>,
) -> Result<JwtPayload, DeviceError> {
    let config = JwtPayloadConfig {
        scope: Some(scope.to_string()),
        binding_message: Some(binding_message.to_string()),
        login_hint: Some(login_hint.clone()),
        resource,
        qr_session_id,
        ..JwtPayloadConfig::default()
    };
    create_jwt_payload_base(device_client_id, provider_metadata, &config)
}

/// Device JWT content for CIBA request with configurable expiration
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
pub fn make_device_jwt_ciba_base_with_config(
    device_client_id: &ClientId,
    provider_metadata: &CoreProviderMetadata,
    login_hint: &LoginHint,
    scope: &str,
    binding_message: &str,
    resource: Option<String>,
    qr_session_id: Option<String>,
    jwt_expiration_secs: u64,
) -> Result<JwtPayload, DeviceError> {
    let config = JwtPayloadConfig {
        scope: Some(scope.to_string()),
        binding_message: Some(binding_message.to_string()),
        login_hint: Some(login_hint.clone()),
        resource,
        qr_session_id,
        jwt_expiration_secs,
    };
    create_jwt_payload_base(device_client_id, provider_metadata, &config)
}

/// Device JWT for CIBA request
#[allow(clippy::too_many_arguments)]
pub fn make_device_jwt_ciba(
    device_client_id: &ClientId,
    provider_metadata: &CoreProviderMetadata,
    login_hint: &LoginHint,
    scope: &str,
    binding_message: &str,
    resource: Option<String>,
    private_key: &EncodingKey,
    // QR Session Request Origin
    qr_session_id: Option<String>,
) -> Result<String, DeviceError> {
    let jwt_payload = make_device_jwt_ciba_base(
        device_client_id,
        provider_metadata,
        login_hint,
        scope,
        binding_message,
        resource,
        qr_session_id,
    )?;
    let jwt = jsonwebtoken::encode(
        &jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
        &jwt_payload,
        private_key,
    )?;
    Ok(jwt)
}

/// Get the IDP JWKS
pub async fn get_jwks(provider_metadata: &CoreProviderMetadata) -> Result<JwkSet, reqwest::Error> {
    let client = default_client()?;
    let response = client
        .get(provider_metadata.jwks_uri().to_string())
        .send()
        .await?
        .json()
        .await?;
    Ok(response)
}

#[allow(dead_code)]
pub fn sign_jwt(jwt_payload: JwtPayload, private_key: &EncodingKey) -> Result<String, DeviceError> {
    let jwt = jsonwebtoken::encode(
        &jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
        &jwt_payload,
        private_key,
    )?;
    Ok(jwt)
}

/**
 * Same as sign_jwt but uses the device identity to sign the JWT
 *
 * USE THIS APPROACH TO REPLACE jsonwebtoken::encode
 * 1. base 64 encode jwt_payload
 * 2. base 64 encode header { "alg": "RS256", "typ": "JWT" }
 * 3. generate message = encoded_1 + . + encoded_2
 * 4. Post to signing service { deviceId, message } => { signature }
 * 5. return message + . + signature
 */
#[allow(dead_code)]
pub async fn sign_jwt_device_identity(
    jwt_payload: JwtPayload,
    device_id: String,
    signing_service_url: &str,
) -> Result<String, DeviceError> {
    // 1. Create and base64 encode the header
    let header = serde_json::json!({
        "alg": "RS256",
        "typ": "JWT"
    });
    let encoded_header = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header)?.as_bytes());

    // 2. Base64 encode the payload
    let encoded_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&jwt_payload)?.as_bytes());

    // 3. Combine header and payload with a dot
    let message = format!("{}.{}", encoded_header, encoded_payload);

    // 4. Send to signing service
    let client = default_client().map_err(DeviceError::TokenRequest)?;
    // let device_id = jwt_payload.sub.to_string();

    let signature_request = serde_json::json!({
        "deviceId": device_id,
        "message": message
    });

    let signature_response = client
        .post(signing_service_url)
        .json(&signature_request)
        .send()
        .await?
        .error_for_status()?
        .json::<serde_json::Value>()
        .await?;

    let signature = signature_response["signature"]
        .as_str()
        .ok_or(DeviceError::InvalidSignatureResponse)?;

    // 5. Combine everything into final JWT
    Ok(format!("{}.{}", message, signature))
}