use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use jsonwebtoken::jwk::JwkSet;
use jsonwebtoken::EncodingKey;
use std::collections::HashMap;
use std::str::FromStr;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;
use super::oidc_types::{JwtPayload, LoginHint, LoginHintKind};
use super::errors::{DeviceError, OidcRequirementsError};
use super::http_client::default_client;
use super::{ClientId, CoreProviderMetadata, CoreTokenType, EmptyExtraTokenFields, StandardTokenResponse};
use openidconnect::OAuth2TokenResponse;
#[derive(Clone)]
struct TokenCache {
token_response: StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>,
expires_at: Instant,
}
use std::sync::LazyLock;
static TOKEN_CACHE: LazyLock<RwLock<HashMap<String, TokenCache>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(100);
#[derive(Debug, Clone)]
pub struct JwtPayloadConfig {
pub scope: Option<String>,
pub binding_message: Option<String>,
pub login_hint: Option<LoginHint>,
pub resource: Option<String>,
pub qr_session_id: Option<String>,
pub jwt_expiration_secs: u64,
}
impl Default for JwtPayloadConfig {
fn default() -> Self {
Self {
scope: None,
binding_message: None,
login_hint: None,
resource: None,
qr_session_id: None,
jwt_expiration_secs: 300, }
}
}
pub async fn request_device_access_token(
provider_metadata: &CoreProviderMetadata,
device_client_id: &ClientId,
device_jwt: &str,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>, DeviceError> {
let client = default_client()?;
let url = provider_metadata
.token_endpoint()
.ok_or(OidcRequirementsError::MissingTokenEndpoint)?;
let body = [
("grant_type", "client_credentials"),
("client_id", &device_client_id.to_string()),
(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
),
("client_assertion", device_jwt),
];
let response = client
.post(url.to_string())
.header("content-type", "application/x-www-form-urlencoded")
.form(&body)
.send()
.await?
.error_for_status()?
.json::<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>>()
.await?;
Ok(response)
}
pub async fn device_access_token(
provider_metadata: &CoreProviderMetadata,
device_client_id: &ClientId,
device_jwt: &str,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>, DeviceError>
{
let cache_key = format!("{}:{}", device_client_id.as_str(), provider_metadata.issuer().as_str());
let now = Instant::now();
{
let cache = TOKEN_CACHE.read().await;
if let Some(cached) = cache.get(&cache_key) {
if cached.expires_at > now {
return Ok(cached.token_response.clone());
}
}
}
let mut cache = TOKEN_CACHE.write().await;
if let Some(cached) = cache.get(&cache_key) {
if cached.expires_at > now {
return Ok(cached.token_response.clone());
}
}
let response = request_device_access_token(provider_metadata, device_client_id, device_jwt).await?;
let expires_at = if let Some(expires_in) = response.expires_in() {
let expires_duration = if expires_in > TOKEN_REFRESH_BUFFER {
expires_in - TOKEN_REFRESH_BUFFER
} else {
Duration::from_secs(expires_in.as_secs() / 2)
};
now + expires_duration
} else {
now + Duration::from_secs(3600) - TOKEN_REFRESH_BUFFER
};
cache.insert(cache_key, TokenCache {
token_response: response.clone(),
expires_at,
});
Ok(response)
}
fn create_jwt_payload_base(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
config: &JwtPayloadConfig,
) -> Result<JwtPayload, DeviceError> {
let now = chrono::Utc::now().timestamp() as u64;
let (login_hint_value, login_hint_token_value) = if let Some(hint) = &config.login_hint {
match hint.kind {
LoginHintKind::LoginHint => (Some(hint.value.clone()), None),
LoginHintKind::LoginHintToken => (None, Some(hint.value.clone())),
}
} else {
(None, None)
};
Ok(JwtPayload {
iss: device_client_id.to_string(),
sub: Uuid::from_str(device_client_id.as_str())?,
aud: Some(provider_metadata.issuer().to_string()),
jti: Some(Uuid::new_v4().to_string()),
iat: now,
exp: now + config.jwt_expiration_secs,
nbf: Some(now),
scope: config.scope.clone(),
binding_message: config.binding_message.clone(),
login_hint: login_hint_value,
login_hint_token: login_hint_token_value,
resource: config.resource.clone(),
client_id: None,
username: None,
user_client_id: None,
idp_role: None,
qr_session_id: config.qr_session_id.clone(),
})
}
pub fn make_device_jwt_base(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
) -> Result<JwtPayload, DeviceError> {
let config = JwtPayloadConfig::default();
create_jwt_payload_base(device_client_id, provider_metadata, &config)
}
#[allow(dead_code)]
pub fn make_device_jwt_base_with_config(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
jwt_expiration_secs: u64,
) -> Result<JwtPayload, DeviceError> {
let config = JwtPayloadConfig {
jwt_expiration_secs,
..JwtPayloadConfig::default()
};
create_jwt_payload_base(device_client_id, provider_metadata, &config)
}
pub fn make_device_jwt(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
private_key: &EncodingKey,
) -> Result<String, DeviceError> {
let jwt_payload = make_device_jwt_base(device_client_id, provider_metadata)?;
let jwt = jsonwebtoken::encode(
&jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
&jwt_payload,
private_key,
)?;
Ok(jwt)
}
pub fn make_device_jwt_ciba_base(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
login_hint: &LoginHint,
scope: &str,
binding_message: &str,
resource: Option<String>,
qr_session_id: Option<String>,
) -> Result<JwtPayload, DeviceError> {
let config = JwtPayloadConfig {
scope: Some(scope.to_string()),
binding_message: Some(binding_message.to_string()),
login_hint: Some(login_hint.clone()),
resource,
qr_session_id,
..JwtPayloadConfig::default()
};
create_jwt_payload_base(device_client_id, provider_metadata, &config)
}
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
pub fn make_device_jwt_ciba_base_with_config(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
login_hint: &LoginHint,
scope: &str,
binding_message: &str,
resource: Option<String>,
qr_session_id: Option<String>,
jwt_expiration_secs: u64,
) -> Result<JwtPayload, DeviceError> {
let config = JwtPayloadConfig {
scope: Some(scope.to_string()),
binding_message: Some(binding_message.to_string()),
login_hint: Some(login_hint.clone()),
resource,
qr_session_id,
jwt_expiration_secs,
};
create_jwt_payload_base(device_client_id, provider_metadata, &config)
}
#[allow(clippy::too_many_arguments)]
pub fn make_device_jwt_ciba(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
login_hint: &LoginHint,
scope: &str,
binding_message: &str,
resource: Option<String>,
private_key: &EncodingKey,
qr_session_id: Option<String>,
) -> Result<String, DeviceError> {
let jwt_payload = make_device_jwt_ciba_base(
device_client_id,
provider_metadata,
login_hint,
scope,
binding_message,
resource,
qr_session_id,
)?;
let jwt = jsonwebtoken::encode(
&jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
&jwt_payload,
private_key,
)?;
Ok(jwt)
}
pub async fn get_jwks(provider_metadata: &CoreProviderMetadata) -> Result<JwkSet, reqwest::Error> {
let client = default_client()?;
let response = client
.get(provider_metadata.jwks_uri().to_string())
.send()
.await?
.json()
.await?;
Ok(response)
}
#[allow(dead_code)]
pub fn sign_jwt(jwt_payload: JwtPayload, private_key: &EncodingKey) -> Result<String, DeviceError> {
let jwt = jsonwebtoken::encode(
&jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
&jwt_payload,
private_key,
)?;
Ok(jwt)
}
#[allow(dead_code)]
pub async fn sign_jwt_device_identity(
jwt_payload: JwtPayload,
device_id: String,
signing_service_url: &str,
) -> Result<String, DeviceError> {
let header = serde_json::json!({
"alg": "RS256",
"typ": "JWT"
});
let encoded_header = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header)?.as_bytes());
let encoded_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&jwt_payload)?.as_bytes());
let message = format!("{}.{}", encoded_header, encoded_payload);
let client = default_client().map_err(DeviceError::TokenRequest)?;
let signature_request = serde_json::json!({
"deviceId": device_id,
"message": message
});
let signature_response = client
.post(signing_service_url)
.json(&signature_request)
.send()
.await?
.error_for_status()?
.json::<serde_json::Value>()
.await?;
let signature = signature_response["signature"]
.as_str()
.ok_or(DeviceError::InvalidSignatureResponse)?;
Ok(format!("{}.{}", message, signature))
}