use base64ct::{Base64UrlUnpadded, Encoding};
use serde::{de::DeserializeOwned, Serialize};
use core::num::NonZeroUsize;
use crate::{
alloc::{Cow, String, ToOwned, Vec},
token::CompleteHeader,
Claims, CreationError, Header, SignedToken, Token, UntrustedToken, ValidationError,
};
pub trait AlgorithmSignature: Sized {
const LENGTH: Option<NonZeroUsize> = None;
fn try_from_slice(slice: &[u8]) -> anyhow::Result<Self>;
fn as_bytes(&self) -> Cow<'_, [u8]>;
}
pub trait Algorithm {
type SigningKey;
type VerifyingKey;
type Signature: AlgorithmSignature;
fn name(&self) -> Cow<'static, str>;
fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature;
fn verify_signature(
&self,
signature: &Self::Signature,
verifying_key: &Self::VerifyingKey,
message: &[u8],
) -> bool;
}
#[derive(Debug, Clone, Copy)]
pub struct Renamed<A> {
inner: A,
name: &'static str,
}
impl<A: Algorithm> Renamed<A> {
pub fn new(algorithm: A, new_name: &'static str) -> Self {
Self {
inner: algorithm,
name: new_name,
}
}
}
impl<A: Algorithm> Algorithm for Renamed<A> {
type SigningKey = A::SigningKey;
type VerifyingKey = A::VerifyingKey;
type Signature = A::Signature;
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed(self.name)
}
fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
self.inner.sign(signing_key, message)
}
fn verify_signature(
&self,
signature: &Self::Signature,
verifying_key: &Self::VerifyingKey,
message: &[u8],
) -> bool {
self.inner
.verify_signature(signature, verifying_key, message)
}
}
pub trait AlgorithmExt: Algorithm {
fn token<T>(
&self,
header: Header,
claims: &Claims<T>,
signing_key: &Self::SigningKey,
) -> Result<String, CreationError>
where
T: Serialize;
#[cfg(feature = "serde_cbor")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde_cbor")))]
fn compact_token<T>(
&self,
header: Header,
claims: &Claims<T>,
signing_key: &Self::SigningKey,
) -> Result<String, CreationError>
where
T: Serialize;
fn validate_integrity<T>(
&self,
token: &UntrustedToken<'_>,
verifying_key: &Self::VerifyingKey,
) -> Result<Token<T>, ValidationError>
where
T: DeserializeOwned;
fn validate_for_signed_token<T>(
&self,
token: &UntrustedToken<'_>,
verifying_key: &Self::VerifyingKey,
) -> Result<SignedToken<Self, T>, ValidationError>
where
T: DeserializeOwned;
}
impl<A: Algorithm> AlgorithmExt for A {
fn token<T>(
&self,
header: Header,
claims: &Claims<T>,
signing_key: &Self::SigningKey,
) -> Result<String, CreationError>
where
T: Serialize,
{
let complete_header = CompleteHeader {
algorithm: self.name(),
content_type: None,
inner: header,
};
let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
let mut buffer = Vec::new();
encode_base64_buf(&header, &mut buffer);
let claims = serde_json::to_string(claims).map_err(CreationError::Claims)?;
buffer.push(b'.');
encode_base64_buf(&claims, &mut buffer);
let signature = self.sign(signing_key, &buffer);
buffer.push(b'.');
encode_base64_buf(signature.as_bytes(), &mut buffer);
Ok(unsafe { String::from_utf8_unchecked(buffer) })
}
#[cfg(feature = "serde_cbor")]
fn compact_token<T>(
&self,
header: Header,
claims: &Claims<T>,
signing_key: &Self::SigningKey,
) -> Result<String, CreationError>
where
T: Serialize,
{
let complete_header = CompleteHeader {
algorithm: self.name(),
content_type: Some("CBOR".to_owned()),
inner: header,
};
let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
let mut buffer = Vec::new();
encode_base64_buf(&header, &mut buffer);
let claims = serde_cbor::to_vec(claims).map_err(CreationError::CborClaims)?;
buffer.push(b'.');
encode_base64_buf(&claims, &mut buffer);
let signature = self.sign(signing_key, &buffer);
buffer.push(b'.');
encode_base64_buf(signature.as_bytes(), &mut buffer);
Ok(unsafe { String::from_utf8_unchecked(buffer) })
}
fn validate_integrity<T>(
&self,
token: &UntrustedToken<'_>,
verifying_key: &Self::VerifyingKey,
) -> Result<Token<T>, ValidationError>
where
T: DeserializeOwned,
{
self.validate_for_signed_token(token, verifying_key)
.map(|wrapper| wrapper.token)
}
fn validate_for_signed_token<T>(
&self,
token: &UntrustedToken<'_>,
verifying_key: &Self::VerifyingKey,
) -> Result<SignedToken<Self, T>, ValidationError>
where
T: DeserializeOwned,
{
let expected_alg = self.name();
if expected_alg != token.algorithm() {
return Err(ValidationError::AlgorithmMismatch {
expected: expected_alg.into_owned(),
actual: token.algorithm().to_owned(),
});
}
let signature = token.signature_bytes();
if let Some(expected_len) = Self::Signature::LENGTH {
if signature.len() != expected_len.get() {
return Err(ValidationError::InvalidSignatureLen {
expected: expected_len.get(),
actual: signature.len(),
});
}
}
let signature = Self::Signature::try_from_slice(signature)
.map_err(ValidationError::MalformedSignature)?;
let claims = token.deserialize_claims_unchecked::<T>()?;
if !self.verify_signature(&signature, verifying_key, &*token.signed_data) {
return Err(ValidationError::InvalidSignature);
}
Ok(SignedToken {
signature,
token: Token::new(token.header().clone(), claims),
})
}
}
fn encode_base64_buf(source: impl AsRef<[u8]>, buffer: &mut Vec<u8>) {
let source = source.as_ref();
let previous_len = buffer.len();
let claims_len = Base64UrlUnpadded::encoded_len(source);
buffer.resize(previous_len + claims_len, 0);
Base64UrlUnpadded::encode(source, &mut buffer[previous_len..])
.expect("miscalculated base64-encoded length; this should never happen");
}