use std::fmt::{Debug, Formatter};
use base64::{
Engine,
engine::general_purpose::{STANDARD, URL_SAFE},
};
use serde::ser::Serialize;
use crate::algorithms::AlgorithmFamily;
use crate::crypto::CryptoProvider;
use crate::errors::{ErrorKind, Result, new_error};
use crate::header::Header;
#[cfg(feature = "use_pem")]
use crate::pem::decoder::PemEncodedKey;
use crate::serialization::{b64_encode, b64_encode_part};
#[derive(Clone)]
pub struct EncodingKey {
family: AlgorithmFamily,
content: Vec<u8>,
}
impl EncodingKey {
pub fn family(&self) -> AlgorithmFamily {
self.family
}
pub fn from_secret(secret: &[u8]) -> Self {
EncodingKey { family: AlgorithmFamily::Hmac, content: secret.to_vec() }
}
pub fn from_base64_secret(secret: &str) -> Result<Self> {
let out = STANDARD.decode(secret)?;
Ok(EncodingKey { family: AlgorithmFamily::Hmac, content: out })
}
pub fn from_urlsafe_base64_secret(secret: &str) -> Result<Self> {
let out = URL_SAFE.decode(secret)?;
Ok(EncodingKey { family: AlgorithmFamily::Hmac, content: out })
}
#[cfg(feature = "use_pem")]
pub fn from_rsa_pem(key: &[u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_rsa_key()?;
Ok(EncodingKey { family: AlgorithmFamily::Rsa, content: content.to_vec() })
}
#[cfg(feature = "use_pem")]
pub fn from_ec_pem(key: &[u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_ec_private_key()?;
Ok(EncodingKey { family: AlgorithmFamily::Ec, content: content.to_vec() })
}
#[cfg(feature = "use_pem")]
pub fn from_ed_pem(key: &[u8]) -> Result<Self> {
let pem_key = PemEncodedKey::new(key)?;
let content = pem_key.as_ed_private_key()?;
Ok(EncodingKey { family: AlgorithmFamily::Ed, content: content.to_vec() })
}
pub fn from_rsa_der(der: &[u8]) -> Self {
EncodingKey { family: AlgorithmFamily::Rsa, content: der.to_vec() }
}
pub fn from_ec_der(der: &[u8]) -> Self {
EncodingKey { family: AlgorithmFamily::Ec, content: der.to_vec() }
}
pub fn from_ed_der(der: &[u8]) -> Self {
EncodingKey { family: AlgorithmFamily::Ed, content: der.to_vec() }
}
pub fn inner(&self) -> &[u8] {
&self.content
}
pub fn try_get_hmac_secret(&self) -> Result<&[u8]> {
if self.family == AlgorithmFamily::Hmac {
Ok(self.inner())
} else {
Err(new_error(ErrorKind::InvalidKeyFormat))
}
}
}
impl Debug for EncodingKey {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncodingKey")
.field("family", &self.family)
.field("content", &"[redacted]")
.finish()
}
}
pub fn encode<T: Serialize>(header: &Header, claims: &T, key: &EncodingKey) -> Result<String> {
if key.family != header.alg.family() {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}
let signing_provider = (CryptoProvider::get_default().signer_factory)(&header.alg, key)?;
if signing_provider.algorithm() != header.alg {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}
let encoded_header = b64_encode_part(&header)?;
let encoded_claims = b64_encode_part(claims)?;
let message = [encoded_header, encoded_claims].join(".");
let signature = b64_encode(signing_provider.try_sign(message.as_bytes())?);
Ok([message, signature].join("."))
}