use std::{borrow::Cow, convert::Infallible};
use base64::prelude::*;
use bon::Builder;
use serde::Serialize;
use snafu::prelude::*;
use crate::{
crypto::signer::JwsSigner,
jwk::PublicJwk,
jwt::{
builder::jwt_builder::{SetClaims, SetExtraHeaders},
structure::{JwtClaims, JwtHeader},
},
platform::{Duration, SystemTime, SystemTimeError},
secrets::SecretString,
};
#[derive(Debug, Clone, Builder)]
#[builder(
start_fn(vis = "", name = "builder_internal"),
generics(setters(name = "with_{}"))
)]
pub struct Jwt<'a, ExtraHeaders = (), Claims = ()>
where
ExtraHeaders: Serialize + Clone,
Claims: Serialize + Clone,
{
#[builder(default = "JWT", into)]
pub typ: Cow<'a, str>,
#[builder(into)]
pub issuer: Option<Cow<'a, str>>,
#[builder(into)]
pub subject: Option<Cow<'a, str>>,
#[builder(default, into)]
pub audiences: Vec<String>,
pub issued_at: Option<SystemTime>,
pub expiration: Option<SystemTime>,
pub not_before: Option<SystemTime>,
#[builder(required, into, default = crate::uuid::uuid_v7())]
pub jti: Option<String>,
pub jwk: Option<PublicJwk>,
#[builder(setters(vis = "", name = "extra_headers_internal"))]
pub extra_headers: Option<ExtraHeaders>,
#[builder(setters(vis = "", name = "claims_internal"))]
pub claims: Claims,
}
impl<'a> Jwt<'a, (), ()> {
pub fn builder() -> JwtBuilder<'a, (), ()> {
Jwt::<(), ()>::builder_internal()
}
}
impl<'a, ExtraHeaders, Claims, S: jwt_builder::State> JwtBuilder<'a, ExtraHeaders, Claims, S>
where
ExtraHeaders: Serialize + Clone,
Claims: Serialize + Clone,
{
pub fn audience(
self,
audience: impl Into<String>,
) -> JwtBuilder<'a, ExtraHeaders, Claims, jwt_builder::SetAudiences<S>>
where
S::Audiences: jwt_builder::IsUnset,
{
self.audiences(vec![audience.into()])
}
pub fn issued_now(self) -> JwtBuilder<'a, ExtraHeaders, Claims, jwt_builder::SetIssuedAt<S>>
where
S::IssuedAt: jwt_builder::IsUnset,
{
self.issued_at(crate::platform::SystemTime::now())
}
pub fn issued_now_expires_after(
self,
after: Duration,
) -> JwtBuilder<'a, ExtraHeaders, Claims, jwt_builder::SetExpiration<jwt_builder::SetIssuedAt<S>>>
where
S::IssuedAt: jwt_builder::IsUnset,
S::Expiration: jwt_builder::IsUnset,
{
let now = crate::platform::SystemTime::now();
self.issued_at(now).expiration(now + after)
}
pub fn claims<E2>(self, claims: E2) -> JwtBuilder<'a, ExtraHeaders, E2, SetClaims<S>>
where
E2: Serialize + Clone,
S::Claims: jwt_builder::IsUnset,
{
self.with_claims::<E2>().claims_internal(claims)
}
pub fn extra_headers<E2>(self, headers: E2) -> JwtBuilder<'a, E2, Claims, SetExtraHeaders<S>>
where
E2: Serialize + Clone,
S::ExtraHeaders: jwt_builder::IsUnset,
{
self.with_extra_headers::<E2>()
.extra_headers_internal(headers)
}
}
#[derive(Debug, Snafu)]
pub enum JwsSigningInputError {
EncodeClaims {
source: serde_json::Error,
},
EncodeHeader {
source: serde_json::Error,
},
Time {
source: SystemTimeError,
},
}
#[derive(Debug, Snafu)]
pub enum JwsSerializationError<SgnErr: crate::Error + 'static = Infallible> {
GenerateSigningInput {
source: JwsSigningInputError,
},
Sign {
source: SgnErr,
},
NormalizeUri {
source: http::Error,
},
NoThumbprint,
NoMatchingKeyForThumbprint,
}
impl<SgnErr: crate::Error> crate::Error for JwsSerializationError<SgnErr> {
fn is_retryable(&self) -> bool {
match self {
JwsSerializationError::GenerateSigningInput { .. }
| JwsSerializationError::NormalizeUri { .. }
| JwsSerializationError::NoMatchingKeyForThumbprint
| JwsSerializationError::NoThumbprint => false,
JwsSerializationError::Sign { source } => source.is_retryable(),
}
}
}
impl<ExtraHeaders, Claims> Jwt<'_, ExtraHeaders, Claims>
where
ExtraHeaders: Serialize + Clone,
Claims: Serialize + Clone,
{
pub async fn to_jws_compact<Sgn: JwsSigner>(
&self,
signer: &Sgn,
) -> Result<SecretString, JwsSerializationError<Sgn::Error>> {
let signing_input = self
.generate_jwt_signing_input(&signer.jws_algorithm(), signer.key_id().as_deref())
.context(GenerateSigningInputSnafu)?;
let signature = signer
.sign(signing_input.as_bytes())
.await
.context(SignSnafu)?;
let signature_b64 = BASE64_URL_SAFE_NO_PAD.encode(&signature);
let result = [signing_input, signature_b64].join(".");
Ok(SecretString::new(result))
}
fn generate_jwt_signing_input(
&self,
alg: &str,
kid: Option<&str>,
) -> Result<String, JwsSigningInputError> {
let jwt_header = JwtHeader {
alg: Cow::Borrowed(alg),
typ: Some(Cow::Borrowed(&self.typ)),
kid: kid.map(Cow::Borrowed),
crit: Vec::new(),
jwk: self.jwk.clone(),
extra_headers: self.extra_headers.as_ref().map(Cow::Borrowed),
};
let iat = self
.issued_at
.map(|iat| {
iat.duration_since(SystemTime::UNIX_EPOCH)
.map(|dur| dur.as_secs())
})
.transpose()
.context(TimeSnafu)?;
let exp = self
.expiration
.map(|exp| {
exp.duration_since(SystemTime::UNIX_EPOCH)
.map(|dur| dur.as_secs())
})
.transpose()
.context(TimeSnafu)?;
let nbf = self
.not_before
.map(|nbf| {
nbf.duration_since(SystemTime::UNIX_EPOCH)
.map(|dur| dur.as_secs())
})
.transpose()
.context(TimeSnafu)?;
let jwt_claims = JwtClaims {
iss: self.issuer.as_deref().map(Cow::Borrowed),
sub: self.subject.as_deref().map(Cow::Borrowed),
aud: self.audiences.clone(),
iat,
exp,
nbf,
jti: self.jti.as_deref().map(Cow::Borrowed),
cnf: None,
claims: Cow::Borrowed(&self.claims),
};
let jwt_header_json = serde_json::to_vec(&jwt_header).context(EncodeHeaderSnafu)?;
let jwt_header_b64 = BASE64_URL_SAFE_NO_PAD.encode(&jwt_header_json);
let jwt_claims_json = serde_json::to_vec(&jwt_claims).context(EncodeClaimsSnafu)?;
let jwt_claims_b64 = BASE64_URL_SAFE_NO_PAD.encode(&jwt_claims_json);
Ok([jwt_header_b64, jwt_claims_b64].join("."))
}
}