use crate::Algorithm;
use crate::Error;
use crate::Header;
#[cfg(feature = "ring")]
use jsonwebtoken::{
encode as jwt_encode, Algorithm as JwtAlgorithm, EncodingKey, Header as JwtHeader,
};
#[cfg(feature = "noring")]
use jsonwebtoken_rustcrypto::{
encode as jwt_encode,
Algorithm as JwtAlgorithm,
EncodingKey,
Header as JwtHeader,
};
#[cfg(feature = "noring")]
use rsa::{pkcs8::DecodePrivateKey, RsaPrivateKey};
use serde::Serialize;
#[derive(Clone)]
pub struct KeyForEncoding {
key: EncodingKey,
}
impl KeyForEncoding {
#[cfg(feature = "ring")]
pub fn from_secret(secret: &[u8]) -> Self {
KeyForEncoding {
key: EncodingKey::from_secret(secret),
}
}
#[cfg(feature = "ring")]
pub fn from_base64_secret(secret: &str) -> Result<Self, Error> {
Ok(KeyForEncoding {
key: EncodingKey::from_base64_secret(secret)?,
})
}
#[cfg(feature = "ring")]
pub fn from_rsa_pem(key: &[u8]) -> Result<Self, Error> {
Ok(KeyForEncoding {
key: EncodingKey::from_rsa_pem(key)?,
})
}
#[cfg(feature = "noring")]
pub fn from_rsa_pem(key: &[u8]) -> Result<Self, Error> {
let rsa_key = RsaPrivateKey::from_pkcs8_pem(std::str::from_utf8(key)?)?;
Ok(KeyForEncoding {
key: EncodingKey::from_rsa(rsa_key)?,
})
}
#[cfg(feature = "ring")]
pub fn from_ec_pem(key: &[u8]) -> Result<Self, Error> {
Ok(KeyForEncoding {
key: EncodingKey::from_ec_pem(key)?,
})
}
#[cfg(feature = "ring")]
pub fn from_ed_pem(key: &[u8]) -> Result<Self, Error> {
Ok(KeyForEncoding {
key: EncodingKey::from_ed_pem(key)?,
})
}
#[cfg(feature = "ring")]
pub fn from_rsa_der(der: &[u8]) -> Self {
KeyForEncoding {
key: EncodingKey::from_rsa_der(der),
}
}
#[cfg(feature = "ring")]
pub fn from_ec_der(der: &[u8]) -> Self {
KeyForEncoding {
key: EncodingKey::from_ec_der(der),
}
}
#[cfg(feature = "ring")]
pub fn from_ed_der(der: &[u8]) -> Self {
KeyForEncoding {
key: EncodingKey::from_ed_der(der),
}
}
}
#[cfg(feature = "ring")]
fn build_header(header: &Header) -> Result<JwtHeader, Error> {
let jwk = match &header.jwk {
Some(jwk) => Some(serde_json::from_value(jwk.clone())?),
None => None,
};
Ok(JwtHeader {
typ: header.typ.clone(),
alg: match header.alg {
Algorithm::HS256 => JwtAlgorithm::HS256,
Algorithm::HS384 => JwtAlgorithm::HS384,
Algorithm::HS512 => JwtAlgorithm::HS512,
Algorithm::RS256 => JwtAlgorithm::RS256,
Algorithm::RS384 => JwtAlgorithm::RS384,
Algorithm::RS512 => JwtAlgorithm::RS512,
Algorithm::ES256 => JwtAlgorithm::ES256,
Algorithm::ES384 => JwtAlgorithm::ES384,
Algorithm::PS256 => JwtAlgorithm::PS256,
Algorithm::PS384 => JwtAlgorithm::PS384,
Algorithm::PS512 => JwtAlgorithm::PS512,
Algorithm::EdDSA => JwtAlgorithm::EdDSA,
},
cty: header.cty.clone(),
jku: header.jku.clone(),
jwk,
kid: header.kid.clone(),
x5u: header.x5u.clone(),
x5c: header.x5c.clone(),
x5t: header.x5t.clone(),
x5t_s256: header.x5t_s256.clone(),
})
}
#[cfg(feature = "noring")]
fn build_header(header: &Header) -> Result<JwtHeader, Error> {
let alg = match header.alg {
Algorithm::HS256 => JwtAlgorithm::HS256,
Algorithm::HS384 => JwtAlgorithm::HS384,
Algorithm::HS512 => JwtAlgorithm::HS512,
Algorithm::RS256 => JwtAlgorithm::RS256,
Algorithm::RS384 => JwtAlgorithm::RS384,
Algorithm::RS512 => JwtAlgorithm::RS512,
Algorithm::ES256 => JwtAlgorithm::ES256,
Algorithm::ES384 => JwtAlgorithm::ES384,
Algorithm::PS256 => JwtAlgorithm::PS256,
Algorithm::PS384 => JwtAlgorithm::PS384,
Algorithm::PS512 => JwtAlgorithm::PS512,
_ => unimplemented!(),
};
let mut jwt_header = JwtHeader::new(alg);
jwt_header.typ = header.typ.clone();
jwt_header.jku = header.jku.clone();
jwt_header.kid = header.kid.clone();
jwt_header.cty = header.cty.clone();
Ok(jwt_header)
}
pub fn encode<T: Serialize>(
header: &Header,
claims: &T,
key: &KeyForEncoding,
) -> Result<String, Error> {
Ok(jwt_encode(&build_header(header)?, claims, &key.key)?)
}