use super::oidc_types::{JwtPayload, LoginHint, LoginHintKind};
use crate::errors::{DeviceError, OidcRequirementsError};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use jsonwebtoken::jwk::JwkSet;
use jsonwebtoken::EncodingKey;
use openidconnect::{
core::{CoreProviderMetadata, CoreTokenType},
ClientId, EmptyExtraTokenFields, StandardTokenResponse,
};
use std::str::FromStr;
use uuid::Uuid;
pub async fn request_device_access_token(
provider_metadata: &CoreProviderMetadata,
device_client_id: &ClientId,
device_jwt: &str,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>, DeviceError> {
let client = reqwest::Client::new();
let url = provider_metadata
.token_endpoint()
.ok_or(OidcRequirementsError::MissingTokenEndpoint)?;
let body = [
("grant_type", "client_credentials"),
("client_id", &device_client_id.to_string()),
(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
),
("client_assertion", device_jwt),
];
let response = client
.post(url.to_string())
.header("content-type", "application/x-www-form-urlencoded")
.timeout(std::time::Duration::from_secs(3))
.form(&body)
.send()
.await?
.error_for_status()?
.json::<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>>()
.await?;
Ok(response)
}
pub async fn device_access_token(
provider_metadata: &CoreProviderMetadata,
device_client_id: &ClientId,
device_jwt: &str,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, CoreTokenType>, Box<dyn std::error::Error>>
{
let response =
request_device_access_token(provider_metadata, device_client_id, device_jwt).await?;
Ok(response)
}
pub fn make_device_jwt_base(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
) -> JwtPayload {
let now = chrono::Utc::now().timestamp() as u64;
JwtPayload {
iss: device_client_id.to_string(),
sub: Uuid::from_str(device_client_id.as_str()).unwrap(),
aud: Some(provider_metadata.issuer().to_string()),
jti: Some(Uuid::new_v4().to_string()),
iat: now,
exp: now + 300,
nbf: Some(now),
scope: None,
binding_message: None,
login_hint: None,
login_hint_token: None,
resource: None,
client_id: None,
username: None,
user_client_id: None,
idp_role: None,
qr_session_id: None,
}
}
pub fn make_device_jwt(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
private_key: &EncodingKey,
) -> String {
let jwt_payload = make_device_jwt_base(device_client_id, provider_metadata);
let jwt = jsonwebtoken::encode(
&jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
&jwt_payload,
&private_key,
)
.unwrap();
jwt
}
pub fn make_device_jwt_ciba_base(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
login_hint: &LoginHint,
scope: &str,
binding_message: &str,
resource: Option<String>,
qr_session_id: Option<String>,
) -> JwtPayload {
let now = chrono::Utc::now().timestamp() as u64;
JwtPayload {
iss: device_client_id.to_string(),
sub: Uuid::from_str(device_client_id.as_str()).unwrap(),
aud: Some(provider_metadata.issuer().to_string()),
jti: Some(Uuid::new_v4().to_string()),
iat: now,
exp: now + 300,
nbf: Some(now),
scope: Some(scope.to_string()),
binding_message: Some(binding_message.to_string()),
login_hint: match login_hint.kind {
LoginHintKind::LoginHint => Some(login_hint.value.clone()),
_ => None,
},
login_hint_token: match login_hint.kind {
LoginHintKind::LoginHintToken => Some(login_hint.value.clone()),
_ => None,
},
resource,
client_id: None,
username: None,
user_client_id: None,
idp_role: None,
qr_session_id,
}
}
pub fn make_device_jwt_ciba(
device_client_id: &ClientId,
provider_metadata: &CoreProviderMetadata,
login_hint: &LoginHint,
scope: &str,
binding_message: &str,
resource: Option<String>,
private_key: &EncodingKey,
qr_session_id: Option<String>,
) -> String {
let jwt_payload = make_device_jwt_ciba_base(
device_client_id,
provider_metadata,
login_hint,
scope,
binding_message,
resource,
qr_session_id,
);
let jwt = jsonwebtoken::encode(
&jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
&jwt_payload,
&private_key,
)
.unwrap();
jwt
}
pub async fn get_jwks(provider_metadata: &CoreProviderMetadata) -> Result<JwkSet, reqwest::Error> {
let client = reqwest::Client::new();
let response = client
.get(provider_metadata.jwks_uri().to_string())
.send()
.await?
.json()
.await?;
Ok(response)
}
pub fn sign_jwt(jwt_payload: JwtPayload, private_key: &EncodingKey) -> String {
let jwt = jsonwebtoken::encode(
&jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
&jwt_payload,
&private_key,
)
.unwrap();
jwt
}
pub async fn sign_jwt_device_identity(
jwt_payload: JwtPayload,
device_id: String,
) -> Result<String, DeviceError> {
let header = serde_json::json!({
"alg": "RS256",
"typ": "JWT"
});
let encoded_header = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header)?.as_bytes());
let encoded_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&jwt_payload)?.as_bytes());
let message = format!("{}.{}", encoded_header, encoded_payload);
let client = reqwest::Client::new();
let signature_request = serde_json::json!({
"deviceId": device_id,
"message": message
});
let signature_response = client
.post("http://localhost:8000/signature")
.json(&signature_request)
.send()
.await?
.error_for_status()?
.json::<serde_json::Value>()
.await?;
let signature = signature_response["signature"]
.as_str()
.ok_or(DeviceError::InvalidSignatureResponse)?;
Ok(format!("{}.{}", message, signature))
}