use serde::{Deserialize, Deserializer, Serialize, Serializer};
use sha2::digest::{Digest, Output};
use core::fmt;
use crate::{
alg::SecretBytes,
alloc::{Cow, String, ToString, Vec},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum KeyType {
Rsa,
EllipticCurve,
Symmetric,
KeyPair,
}
impl fmt::Display for KeyType {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(match self {
Self::Rsa => "RSA",
Self::EllipticCurve => "EC",
Self::Symmetric => "oct",
Self::KeyPair => "OKP",
})
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum JwkError {
NoField(String),
UnexpectedKeyType {
expected: KeyType,
actual: KeyType,
},
UnexpectedValue {
field: String,
expected: String,
actual: String,
},
UnexpectedLen {
field: String,
expected: usize,
actual: usize,
},
MismatchedKeys,
Custom(anyhow::Error),
}
impl fmt::Display for JwkError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnexpectedKeyType { expected, actual } => {
write!(
formatter,
"unexpected key type: {} (expected {})",
actual, expected
)
}
Self::NoField(field) => write!(formatter, "field `{}` is absent from JWK", field),
Self::UnexpectedValue {
field,
expected,
actual,
} => {
write!(
formatter,
"field `{}` has unexpected value (expected: {}, got: {})",
field, expected, actual
)
}
Self::UnexpectedLen {
field,
expected,
actual,
} => {
write!(
formatter,
"field `{}` has unexpected length (expected: {}, got: {})",
field, expected, actual
)
}
Self::MismatchedKeys => {
formatter.write_str("private and public keys encoded in JWK do not match")
}
Self::Custom(err) => fmt::Display::fmt(err, formatter),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for JwkError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Custom(err) => Some(err.as_ref()),
_ => None,
}
}
}
impl JwkError {
pub fn custom(err: impl Into<anyhow::Error>) -> Self {
Self::Custom(err.into())
}
pub(crate) fn key_type(jwk: &JsonWebKey<'_>, expected: KeyType) -> Self {
let actual = jwk.key_type();
debug_assert_ne!(actual, expected);
Self::UnexpectedKeyType { actual, expected }
}
}
impl Serialize for SecretBytes<'_> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
base64url::serialize(self.as_ref(), serializer)
}
}
impl<'de> Deserialize<'de> for SecretBytes<'_> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
base64url::deserialize(deserializer).map(SecretBytes::new)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kty")]
#[non_exhaustive]
pub enum JsonWebKey<'a> {
#[serde(rename = "RSA")]
Rsa {
#[serde(rename = "n", with = "base64url")]
modulus: Cow<'a, [u8]>,
#[serde(rename = "e", with = "base64url")]
public_exponent: Cow<'a, [u8]>,
#[serde(flatten)]
private_parts: Option<RsaPrivateParts<'a>>,
},
#[serde(rename = "EC")]
EllipticCurve {
#[serde(rename = "crv")]
curve: Cow<'a, str>,
#[serde(with = "base64url")]
x: Cow<'a, [u8]>,
#[serde(with = "base64url")]
y: Cow<'a, [u8]>,
#[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
secret: Option<SecretBytes<'a>>,
},
#[serde(rename = "oct")]
Symmetric {
#[serde(rename = "k")]
secret: SecretBytes<'a>,
},
#[serde(rename = "OKP")]
KeyPair {
#[serde(rename = "crv")]
curve: Cow<'a, str>,
#[serde(with = "base64url")]
x: Cow<'a, [u8]>,
#[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
secret: Option<SecretBytes<'a>>,
},
}
impl JsonWebKey<'_> {
pub fn key_type(&self) -> KeyType {
match self {
Self::Rsa { .. } => KeyType::Rsa,
Self::EllipticCurve { .. } => KeyType::EllipticCurve,
Self::Symmetric { .. } => KeyType::Symmetric,
Self::KeyPair { .. } => KeyType::KeyPair,
}
}
pub fn is_signing_key(&self) -> bool {
match self {
Self::Rsa { private_parts, .. } => private_parts.is_some(),
Self::EllipticCurve { secret, .. } | Self::KeyPair { secret, .. } => secret.is_some(),
Self::Symmetric { .. } => true,
}
}
pub fn to_verifying_key(&self) -> Self {
match self {
Self::Rsa {
modulus,
public_exponent,
..
} => Self::Rsa {
modulus: modulus.clone(),
public_exponent: public_exponent.clone(),
private_parts: None,
},
Self::EllipticCurve { curve, x, y, .. } => Self::EllipticCurve {
curve: curve.clone(),
x: x.clone(),
y: y.clone(),
secret: None,
},
Self::Symmetric { secret } => Self::Symmetric {
secret: secret.clone(),
},
Self::KeyPair { curve, x, .. } => Self::KeyPair {
curve: curve.clone(),
x: x.clone(),
secret: None,
},
}
}
pub fn thumbprint<D: Digest>(&self) -> Output<D> {
let hashed_key = if self.is_signing_key() {
Cow::Owned(self.to_verifying_key())
} else {
Cow::Borrowed(self)
};
D::digest(hashed_key.to_string().as_bytes())
}
}
impl fmt::Display for JsonWebKey<'_> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
let json_value = serde_json::to_value(self).expect("Cannot convert JsonWebKey to JSON");
let json_value = json_value.as_object().unwrap();
let mut json_entries: Vec<_> = json_value.iter().collect();
json_entries.sort_unstable_by(|(x, _), (y, _)| x.cmp(y));
formatter.write_str("{")?;
let field_count = json_entries.len();
for (i, (name, value)) in json_entries.into_iter().enumerate() {
write!(formatter, "\"{name}\":{value}", name = name, value = value)?;
if i + 1 < field_count {
formatter.write_str(",")?;
}
}
formatter.write_str("}")
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RsaPrivateParts<'a> {
#[serde(rename = "d")]
pub private_exponent: SecretBytes<'a>,
#[serde(rename = "p")]
pub prime_factor_p: SecretBytes<'a>,
#[serde(rename = "q")]
pub prime_factor_q: SecretBytes<'a>,
#[serde(rename = "dp", default, skip_serializing_if = "Option::is_none")]
pub p_crt_exponent: Option<SecretBytes<'a>>,
#[serde(rename = "dq", default, skip_serializing_if = "Option::is_none")]
pub q_crt_exponent: Option<SecretBytes<'a>>,
#[serde(rename = "qi", default, skip_serializing_if = "Option::is_none")]
pub q_crt_coefficient: Option<SecretBytes<'a>>,
#[serde(rename = "oth", default, skip_serializing_if = "Vec::is_empty")]
pub other_prime_factors: Vec<RsaPrimeFactor<'a>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RsaPrimeFactor<'a> {
#[serde(rename = "r")]
pub factor: SecretBytes<'a>,
#[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
pub crt_exponent: Option<SecretBytes<'a>>,
#[serde(rename = "t", default, skip_serializing_if = "Option::is_none")]
pub crt_coefficient: Option<SecretBytes<'a>>,
}
#[cfg(any(
feature = "es256k",
feature = "k256",
feature = "exonum-crypto",
feature = "ed25519-dalek",
feature = "ed25519-compact"
))]
mod helpers {
use super::{JsonWebKey, JwkError};
use crate::{alg::SigningKey, alloc::ToOwned, Algorithm};
use core::convert::TryFrom;
impl JsonWebKey<'_> {
pub(crate) fn ensure_curve(curve: &str, expected: &str) -> Result<(), JwkError> {
if curve == expected {
Ok(())
} else {
Err(JwkError::UnexpectedValue {
field: "crv".to_owned(),
expected: expected.to_owned(),
actual: curve.to_owned(),
})
}
}
pub(crate) fn ensure_len(
field: &str,
bytes: &[u8],
expected_len: usize,
) -> Result<(), JwkError> {
if bytes.len() == expected_len {
Ok(())
} else {
Err(JwkError::UnexpectedLen {
field: field.to_owned(),
expected: expected_len,
actual: bytes.len(),
})
}
}
pub(crate) fn ensure_key_match<Alg, K>(&self, signing_key: K) -> Result<K, JwkError>
where
Alg: Algorithm<SigningKey = K>,
K: SigningKey<Alg>,
Alg::VerifyingKey: for<'jwk> TryFrom<&'jwk Self, Error = JwkError> + PartialEq,
{
let verifying_key = <Alg::VerifyingKey>::try_from(self)?;
if verifying_key == signing_key.to_verifying_key() {
Ok(signing_key)
} else {
Err(JwkError::MismatchedKeys)
}
}
}
}
mod base64url {
use base64ct::{Base64UrlUnpadded, Encoding};
use serde::{
de::{Error as DeError, Unexpected, Visitor},
Deserializer, Serializer,
};
use core::fmt;
use crate::alloc::{Cow, Vec};
pub fn serialize<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if serializer.is_human_readable() {
serializer.serialize_str(&Base64UrlUnpadded::encode_string(value))
} else {
serializer.serialize_bytes(value)
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Cow<'static, [u8]>, D::Error>
where
D: Deserializer<'de>,
{
struct Base64Visitor;
impl Visitor<'_> for Base64Visitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("base64url-encoded data")
}
fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
Base64UrlUnpadded::decode_vec(value)
.map_err(|_| E::invalid_value(Unexpected::Str(value), &self))
}
fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
Ok(value.to_vec())
}
fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
Ok(value)
}
}
struct BytesVisitor;
impl<'de> Visitor<'de> for BytesVisitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("byte buffer")
}
fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
Ok(value.to_vec())
}
fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
Ok(value)
}
}
let maybe_bytes = if deserializer.is_human_readable() {
deserializer.deserialize_str(Base64Visitor)
} else {
deserializer.deserialize_bytes(BytesVisitor)
};
maybe_bytes.map(Cow::Owned)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alg::Hs256Key;
use assert_matches::assert_matches;
use core::convert::TryFrom;
fn create_jwk() -> JsonWebKey<'static> {
JsonWebKey::KeyPair {
curve: Cow::Borrowed("Ed25519"),
x: Cow::Borrowed(b"test"),
secret: None,
}
}
#[test]
fn serializing_jwk() {
let jwk = create_jwk();
let json = serde_json::to_value(&jwk).unwrap();
assert_eq!(
json,
serde_json::json!({ "crv": "Ed25519", "kty": "OKP", "x": "dGVzdA" })
);
let restored: JsonWebKey<'_> = serde_json::from_value(json).unwrap();
assert_eq!(restored, jwk);
}
#[test]
fn jwk_deserialization_errors() {
let missing_field_json = r#"{"crv":"Ed25519"}"#;
let missing_field_err = serde_json::from_str::<JsonWebKey<'_>>(missing_field_json)
.unwrap_err()
.to_string();
assert!(
missing_field_err.contains("missing field `kty`"),
"{}",
missing_field_err
);
let base64_json = r#"{"crv":"Ed25519","kty":"OKP","x":"??"}"#;
let base64_err = serde_json::from_str::<JsonWebKey<'_>>(base64_json)
.unwrap_err()
.to_string();
assert!(
base64_err.contains("invalid value: string \"??\""),
"{}",
base64_err
);
assert!(
base64_err.contains("base64url-encoded data"),
"{}",
base64_err
);
}
#[test]
fn extra_jwk_fields() {
#[derive(Debug, Serialize, Deserialize)]
struct ExtendedJsonWebKey<'a, T> {
#[serde(flatten)]
base: JsonWebKey<'a>,
#[serde(flatten)]
extra: T,
}
#[derive(Debug, Deserialize)]
struct Extra {
#[serde(rename = "kid")]
key_id: String,
#[serde(rename = "use")]
key_use: KeyUse,
}
#[derive(Debug, Deserialize, PartialEq)]
enum KeyUse {
#[serde(rename = "sig")]
Signature,
#[serde(rename = "enc")]
Encryption,
}
let json_str = r#"
{ "kty": "oct", "kid": "my-unique-key", "k": "dGVzdA", "use": "sig" }
"#;
let jwk: ExtendedJsonWebKey<'_, Extra> = serde_json::from_str(json_str).unwrap();
assert_matches!(&jwk.base, JsonWebKey::Symmetric { secret } if secret.as_ref() == b"test");
assert_eq!(jwk.extra.key_id, "my-unique-key");
assert_eq!(jwk.extra.key_use, KeyUse::Signature);
let key = Hs256Key::try_from(&jwk.base).unwrap();
let jwk_from_key = JsonWebKey::from(&key);
assert_matches!(
jwk_from_key,
JsonWebKey::Symmetric { secret } if secret.as_ref() == b"test"
);
}
#[test]
#[cfg(feature = "serde_cbor")]
fn jwk_with_cbor() {
let key = JsonWebKey::KeyPair {
curve: Cow::Borrowed("Ed25519"),
x: Cow::Borrowed(b"public"),
secret: Some(SecretBytes::borrowed(b"private")),
};
let bytes = serde_cbor::to_vec(&key).unwrap();
assert!(bytes.windows(6).any(|window| window == b"public"));
assert!(bytes.windows(7).any(|window| window == b"private"));
let restored: JsonWebKey<'_> = serde_cbor::from_slice(&bytes).unwrap();
assert_eq!(restored, key);
}
}