#![deny(rust_2018_idioms, unreachable_pub)]
#![forbid(unsafe_code)]
mod byte_array;
mod byte_vec;
mod key_ops;
#[cfg(test)]
mod tests;
mod utils;
use std::{borrow::Cow, fmt};
use generic_array::typenum::U32;
use serde::{Deserialize, Serialize};
pub use byte_array::ByteArray;
pub use byte_vec::ByteVec;
pub use key_ops::KeyOps;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct JsonWebKey {
#[serde(flatten)]
pub key: Box<Key>,
#[serde(default, rename = "use", skip_serializing_if = "Option::is_none")]
pub key_use: Option<KeyUse>,
#[serde(default, skip_serializing_if = "KeyOps::is_empty")]
pub key_ops: KeyOps,
#[serde(default, rename = "kid", skip_serializing_if = "Option::is_none")]
pub key_id: Option<String>,
#[serde(default, rename = "alg", skip_serializing_if = "Option::is_none")]
pub algorithm: Option<Algorithm>,
#[serde(default, flatten, skip_serializing_if = "X509Params::is_empty")]
pub x5: X509Params,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct X509Params {
#[serde(default, rename = "x5u", skip_serializing_if = "Option::is_none")]
url: Option<String>,
#[serde(default, rename = "x5c", skip_serializing_if = "Option::is_none")]
cert_chain: Option<Vec<String>>,
#[serde(default, rename = "x5t", skip_serializing_if = "Option::is_none")]
thumbprint: Option<String>,
#[serde(default, rename = "x5t#S256", skip_serializing_if = "Option::is_none")]
thumbprint_sha256: Option<String>,
}
impl X509Params {
fn is_empty(&self) -> bool {
matches!(
self,
X509Params {
url: None,
cert_chain: None,
thumbprint: None,
thumbprint_sha256: None,
}
)
}
}
impl JsonWebKey {
pub fn new(key: Key) -> Self {
Self {
key: Box::new(key),
key_use: None,
key_ops: KeyOps::empty(),
key_id: None,
algorithm: None,
x5: Default::default(),
}
}
pub fn set_algorithm(&mut self, alg: Algorithm) -> Result<(), Error> {
Self::validate_algorithm(alg, &*self.key)?;
self.algorithm = Some(alg);
Ok(())
}
pub fn from_slice(bytes: impl AsRef<[u8]>) -> Result<Self, Error> {
Ok(serde_json::from_slice(bytes.as_ref())?)
}
fn validate_algorithm(alg: Algorithm, key: &Key) -> Result<(), Error> {
use Algorithm::*;
use Key::*;
match (alg, key) {
(
ES256,
EC {
curve: Curve::P256 { .. },
},
)
| (RS256, RSA { .. })
| (HS256, Symmetric { .. }) => Ok(()),
_ => Err(Error::MismatchedAlgorithm),
}
}
}
impl std::str::FromStr for JsonWebKey {
type Err = Error;
fn from_str(json: &str) -> Result<Self, Self::Err> {
let jwk = Self::from_slice(json.as_bytes())?;
let alg = match jwk.algorithm {
Some(alg) => alg,
None => return Ok(jwk),
};
Self::validate_algorithm(alg, &*jwk.key).map(|_| jwk)
}
}
impl std::fmt::Display for JsonWebKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if f.alternate() {
write!(f, "{}", serde_json::to_string_pretty(self).unwrap())
} else {
write!(f, "{}", serde_json::to_string(self).unwrap())
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kty")]
#[allow(clippy::upper_case_acronyms)]
pub enum Key {
EC {
#[serde(flatten)]
curve: Curve,
},
RSA {
#[serde(flatten)]
public: RsaPublic,
#[serde(flatten, default, skip_serializing_if = "Option::is_none")]
private: Option<RsaPrivate>,
},
#[serde(rename = "oct")]
Symmetric {
#[serde(rename = "k")]
key: ByteVec,
},
}
impl Key {
pub fn is_private(&self) -> bool {
matches!(
self,
Self::Symmetric { .. }
| Self::EC {
curve: Curve::P256 { d: Some(_), .. },
..
}
| Self::RSA {
private: Some(_),
..
}
)
}
pub fn to_public(&self) -> Option<Cow<'_, Self>> {
if !self.is_private() {
return Some(Cow::Borrowed(self));
}
Some(Cow::Owned(match self {
Self::Symmetric { .. } => return None,
Self::EC {
curve: Curve::P256 { x, y, .. },
} => Self::EC {
curve: Curve::P256 {
x: x.clone(),
y: y.clone(),
d: None,
},
},
Self::RSA { public, .. } => Self::RSA {
public: public.clone(),
private: None,
},
}))
}
#[cfg(feature = "pkcs-convert")]
pub fn try_to_der(&self) -> Result<Vec<u8>, ConversionError> {
use num_bigint::BigUint;
use yasna::{models::ObjectIdentifier, DERWriter, DERWriterSeq, Tag};
use crate::utils::pkcs8;
if let Self::Symmetric { .. } = self {
return Err(ConversionError::NotAsymmetric);
}
Ok(match self {
Self::EC {
curve: Curve::P256 { d, x, y },
} => {
let ec_public_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 2, 1]);
let prime256v1_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 3, 1, 7]);
let oids = &[Some(&ec_public_oid), Some(&prime256v1_oid)];
let write_public = |writer: DERWriter<'_>| {
let public_bytes: Vec<u8> = [0x04 ]
.iter()
.chain(x.iter())
.chain(y.iter())
.copied()
.collect();
writer.write_bitvec_bytes(&public_bytes, 8 * (32 * 2 + 1));
};
match d {
Some(private_point) => {
pkcs8::write_private(oids, |writer: &mut DERWriterSeq<'_>| {
writer.next().write_i8(1); writer.next().write_bytes(&**private_point);
writer.next().write_tagged(Tag::context(1), write_public);
})
}
None => pkcs8::write_public(oids, write_public),
}
}
Self::RSA { public, private } => {
let rsa_encryption_oid = ObjectIdentifier::from_slice(&[
1, 2, 840, 113549, 1, 1, 1, ]);
let oids = &[Some(&rsa_encryption_oid), None];
let write_bytevec = |writer: DERWriter<'_>, vec: &ByteVec| {
let bigint = BigUint::from_bytes_be(vec);
writer.write_biguint(&bigint);
};
let write_public = |writer: &mut DERWriterSeq<'_>| {
write_bytevec(writer.next(), &public.n);
writer.next().write_u32(PUBLIC_EXPONENT);
};
let write_private = |writer: &mut DERWriterSeq<'_>, private: &RsaPrivate| {
writer.next().write_i8(0); write_public(writer);
write_bytevec(writer.next(), &private.d);
macro_rules! write_opt_bytevecs {
($($param:ident),+) => {{
$(write_bytevec(writer.next(), private.$param.as_ref().unwrap());)+
}};
}
write_opt_bytevecs!(p, q, dp, dq, qi);
};
match private {
Some(
private @ RsaPrivate {
d: _,
p: Some(_),
q: Some(_),
dp: Some(_),
dq: Some(_),
qi: Some(_),
},
) => pkcs8::write_private(oids, |writer| write_private(writer, private)),
Some(_) => return Err(ConversionError::MissingRsaParams),
None => pkcs8::write_public(oids, |writer| {
let body =
yasna::construct_der(|writer| writer.write_sequence(write_public));
writer.write_bitvec_bytes(&body, body.len() * 8);
}),
}
}
Self::Symmetric { .. } => unreachable!("checked above"),
})
}
#[cfg(feature = "pkcs-convert")]
pub fn to_der(&self) -> Vec<u8> {
self.try_to_der().unwrap()
}
#[cfg(feature = "pkcs-convert")]
pub fn try_to_pem(&self) -> Result<String, ConversionError> {
use std::fmt::Write;
let der_b64 = base64::encode(self.try_to_der()?);
let key_ty = if self.is_private() {
"PRIVATE"
} else {
"PUBLIC"
};
let mut pem = String::new();
writeln!(&mut pem, "-----BEGIN {} KEY-----", key_ty).unwrap();
const MAX_LINE_LEN: usize = 64;
for i in (0..der_b64.len()).step_by(MAX_LINE_LEN) {
writeln!(
&mut pem,
"{}",
&der_b64[i..std::cmp::min(i + MAX_LINE_LEN, der_b64.len())]
)
.unwrap();
}
writeln!(&mut pem, "-----END {} KEY-----", key_ty).unwrap();
Ok(pem)
}
#[cfg(feature = "pkcs-convert")]
pub fn to_pem(&self) -> String {
self.try_to_pem().unwrap()
}
#[cfg(feature = "generate")]
pub fn generate_symmetric(num_bits: usize) -> Self {
use rand::RngCore;
let mut bytes = vec![0; num_bits / 8];
rand::thread_rng().fill_bytes(&mut bytes);
Self::Symmetric { key: bytes.into() }
}
#[cfg(feature = "generate")]
pub fn generate_p256() -> Self {
use p256::elliptic_curve::{self as elliptic_curve, sec1::ToEncodedPoint};
let sk = elliptic_curve::SecretKey::random(&mut rand::thread_rng());
let sk_scalar = p256::Scalar::from(&sk);
let pk = p256::ProjectivePoint::generator() * sk_scalar;
let pk_bytes = &pk
.to_affine()
.to_encoded_point(false )
.to_bytes()[1..];
let (x_bytes, y_bytes) = pk_bytes.split_at(32);
Self::EC {
curve: Curve::P256 {
d: Some(sk_scalar.to_bytes().into()),
x: ByteArray::from_slice(x_bytes),
y: ByteArray::from_slice(y_bytes),
},
}
}
}
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "crv")]
pub enum Curve {
#[serde(rename = "P-256")]
P256 {
#[serde(skip_serializing_if = "Option::is_none")]
d: Option<ByteArray<U32>>,
x: ByteArray<U32>,
y: ByteArray<U32>,
},
}
impl fmt::Debug for Curve {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::P256 { x, y, .. } => f
.debug_struct("Curve:P256")
.field("x", x)
.field("y", y)
.finish(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct RsaPublic {
pub e: PublicExponent,
pub n: ByteVec,
}
const PUBLIC_EXPONENT: u32 = 65537;
const PUBLIC_EXPONENT_B64: &str = "AQAB"; const PUBLIC_EXPONENT_B64_PADDED: &str = "AQABAA==";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct PublicExponent;
impl Serialize for PublicExponent {
fn serialize<S: serde::ser::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
PUBLIC_EXPONENT_B64.serialize(s)
}
}
impl<'de> Deserialize<'de> for PublicExponent {
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let e = String::deserialize(d)?;
if e == PUBLIC_EXPONENT_B64 || e == PUBLIC_EXPONENT_B64_PADDED {
Ok(Self)
} else {
Err(serde::de::Error::custom(&format!(
"public exponent must be {}",
PUBLIC_EXPONENT
)))
}
}
}
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RsaPrivate {
pub d: ByteVec,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub p: Option<ByteVec>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub q: Option<ByteVec>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp: Option<ByteVec>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dq: Option<ByteVec>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub qi: Option<ByteVec>,
}
impl fmt::Debug for RsaPrivate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("RsaPrivate")
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum KeyUse {
#[serde(rename = "sig")]
Signing,
#[serde(rename = "enc")]
Encryption,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[allow(clippy::upper_case_acronyms)]
pub enum Algorithm {
HS256,
RS256,
ES256,
}
#[cfg(feature = "jwt-convert")]
const _IMPL_JWT_CONVERSIONS: () = {
use jsonwebtoken as jwt;
impl From<Algorithm> for jwt::Algorithm {
fn from(alg: Algorithm) -> Self {
match alg {
Algorithm::HS256 => Self::HS256,
Algorithm::ES256 => Self::ES256,
Algorithm::RS256 => Self::RS256,
}
}
}
impl Key {
pub fn try_to_encoding_key(&self) -> Result<jwt::EncodingKey, ConversionError> {
if !self.is_private() {
return Err(ConversionError::NotPrivate);
}
Ok(match self {
Self::Symmetric { key } => jwt::EncodingKey::from_secret(key),
Self::EC { .. } => {
jwt::EncodingKey::from_ec_pem(self.try_to_pem()?.as_bytes()).unwrap()
}
Self::RSA { .. } => {
jwt::EncodingKey::from_rsa_pem(self.try_to_pem()?.as_bytes()).unwrap()
}
})
}
pub fn to_encoding_key(&self) -> jwt::EncodingKey {
self.try_to_encoding_key().unwrap()
}
pub fn to_decoding_key(&self) -> jwt::DecodingKey {
match self {
Self::Symmetric { key } => jwt::DecodingKey::from_secret(key),
Self::EC { .. } => {
jwt::DecodingKey::from_ec_pem(self.to_public().unwrap().to_pem().as_bytes())
.unwrap()
}
Self::RSA { .. } => {
jwt::DecodingKey::from_rsa_pem(self.to_pem().as_bytes()).unwrap()
}
}
}
}
};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Serde(#[from] serde_json::Error),
#[error(transparent)]
Base64Decode(#[from] base64::DecodeError),
#[error("mismatched algorithm for key type")]
MismatchedAlgorithm,
}
#[derive(Debug, thiserror::Error)]
pub enum ConversionError {
#[error("encoding RSA JWK as PKCS#8 requires specifing all of p, q, dp, dq, qi")]
MissingRsaParams,
#[error("a symmetric key can not be encoded using PKCS#8")]
NotAsymmetric,
#[cfg(feature = "jwt-convert")]
#[error("a public key cannot be converted to a `jsonwebtoken::EncodingKey`")]
NotPrivate,
}