#![allow(missing_docs)]
use crate::{
errors::{self, Error, ErrorKind},
Algorithm,
};
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use std::{fmt, str::FromStr};
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum PublicKeyUse {
Signature,
Encryption,
Other(String),
}
impl Serialize for PublicKeyUse {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let string = match self {
PublicKeyUse::Signature => "sig",
PublicKeyUse::Encryption => "enc",
PublicKeyUse::Other(other) => other,
};
serializer.serialize_str(string)
}
}
impl<'de> Deserialize<'de> for PublicKeyUse {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct PublicKeyUseVisitor;
impl<'de> de::Visitor<'de> for PublicKeyUseVisitor {
type Value = PublicKeyUse;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(match v {
"sig" => PublicKeyUse::Signature,
"enc" => PublicKeyUse::Encryption,
other => PublicKeyUse::Other(other.to_string()),
})
}
}
deserializer.deserialize_string(PublicKeyUseVisitor)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum KeyOperations {
Sign,
Verify,
Encrypt,
Decrypt,
WrapKey,
UnwrapKey,
DeriveKey,
DeriveBits,
Other(String),
}
impl Serialize for KeyOperations {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let string = match self {
KeyOperations::Sign => "sign",
KeyOperations::Verify => "verify",
KeyOperations::Encrypt => "encrypt",
KeyOperations::Decrypt => "decrypt",
KeyOperations::WrapKey => "wrapKey",
KeyOperations::UnwrapKey => "unwrapKey",
KeyOperations::DeriveKey => "deriveKey",
KeyOperations::DeriveBits => "deriveBits",
KeyOperations::Other(other) => other,
};
serializer.serialize_str(string)
}
}
impl<'de> Deserialize<'de> for KeyOperations {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct KeyOperationsVisitor;
impl<'de> de::Visitor<'de> for KeyOperationsVisitor {
type Value = KeyOperations;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(match v {
"sign" => KeyOperations::Sign,
"verify" => KeyOperations::Verify,
"encrypt" => KeyOperations::Encrypt,
"decrypt" => KeyOperations::Decrypt,
"wrapKey" => KeyOperations::WrapKey,
"unwrapKey" => KeyOperations::UnwrapKey,
"deriveKey" => KeyOperations::DeriveKey,
"deriveBits" => KeyOperations::DeriveBits,
other => KeyOperations::Other(other.to_string()),
})
}
}
deserializer.deserialize_string(KeyOperationsVisitor)
}
}
#[allow(non_camel_case_types, clippy::upper_case_acronyms)]
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize)]
pub enum KeyAlgorithm {
HS256,
HS384,
HS512,
ES256,
ES384,
RS256,
RS384,
RS512,
PS256,
PS384,
PS512,
EdDSA,
RSA1_5,
#[serde(rename = "RSA-OAEP")]
RSA_OAEP,
#[serde(rename = "RSA-OAEP-256")]
RSA_OAEP_256,
}
impl FromStr for KeyAlgorithm {
type Err = Error;
fn from_str(s: &str) -> errors::Result<Self> {
match s {
"HS256" => Ok(KeyAlgorithm::HS256),
"HS384" => Ok(KeyAlgorithm::HS384),
"HS512" => Ok(KeyAlgorithm::HS512),
"ES256" => Ok(KeyAlgorithm::ES256),
"ES384" => Ok(KeyAlgorithm::ES384),
"RS256" => Ok(KeyAlgorithm::RS256),
"RS384" => Ok(KeyAlgorithm::RS384),
"PS256" => Ok(KeyAlgorithm::PS256),
"PS384" => Ok(KeyAlgorithm::PS384),
"PS512" => Ok(KeyAlgorithm::PS512),
"RS512" => Ok(KeyAlgorithm::RS512),
"EdDSA" => Ok(KeyAlgorithm::EdDSA),
"RSA1_5" => Ok(KeyAlgorithm::RSA1_5),
"RSA-OAEP" => Ok(KeyAlgorithm::RSA_OAEP),
"RSA-OAEP-256" => Ok(KeyAlgorithm::RSA_OAEP_256),
_ => Err(ErrorKind::InvalidAlgorithmName.into()),
}
}
}
impl fmt::Display for KeyAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl KeyAlgorithm {
fn to_algorithm(self) -> errors::Result<Algorithm> {
Algorithm::from_str(self.to_string().as_str())
}
}
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
pub struct CommonParameters {
#[serde(rename = "use", skip_serializing_if = "Option::is_none", default)]
pub public_key_use: Option<PublicKeyUse>,
#[serde(rename = "key_ops", skip_serializing_if = "Option::is_none", default)]
pub key_operations: Option<Vec<KeyOperations>>,
#[serde(rename = "alg", skip_serializing_if = "Option::is_none", default)]
pub key_algorithm: Option<KeyAlgorithm>,
#[serde(rename = "kid", skip_serializing_if = "Option::is_none", default)]
pub key_id: Option<String>,
#[serde(rename = "x5u", skip_serializing_if = "Option::is_none")]
pub x509_url: Option<String>,
#[serde(rename = "x5c", skip_serializing_if = "Option::is_none")]
pub x509_chain: Option<Vec<String>>,
#[serde(rename = "x5t", skip_serializing_if = "Option::is_none")]
pub x509_sha1_fingerprint: Option<String>,
#[serde(rename = "x5t#S256", skip_serializing_if = "Option::is_none")]
pub x509_sha256_fingerprint: Option<String>,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub enum EllipticCurveKeyType {
#[default]
EC,
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub enum EllipticCurve {
#[serde(rename = "P-256")]
#[default]
P256,
#[serde(rename = "P-384")]
P384,
#[serde(rename = "P-521")]
P521,
#[serde(rename = "Ed25519")]
Ed25519,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
pub struct EllipticCurveKeyParameters {
#[serde(rename = "kty")]
pub key_type: EllipticCurveKeyType,
#[serde(rename = "crv")]
pub curve: EllipticCurve,
pub x: String,
pub y: String,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub enum RSAKeyType {
#[default]
RSA,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
pub struct RSAKeyParameters {
#[serde(rename = "kty")]
pub key_type: RSAKeyType,
pub n: String,
pub e: String,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub enum OctetKeyType {
#[serde(rename = "oct")]
#[default]
Octet,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
pub struct OctetKeyParameters {
#[serde(rename = "kty")]
pub key_type: OctetKeyType,
#[serde(rename = "k")]
pub value: String,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub enum OctetKeyPairType {
#[serde(rename = "OKP")]
#[default]
OctetKeyPair,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
pub struct OctetKeyPairParameters {
#[serde(rename = "kty")]
pub key_type: OctetKeyPairType,
#[serde(rename = "crv")]
pub curve: EllipticCurve,
pub x: String,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
#[serde(untagged)]
pub enum AlgorithmParameters {
EllipticCurve(EllipticCurveKeyParameters),
RSA(RSAKeyParameters),
OctetKey(OctetKeyParameters),
OctetKeyPair(OctetKeyPairParameters),
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub struct Jwk {
#[serde(flatten)]
pub common: CommonParameters,
#[serde(flatten)]
pub algorithm: AlgorithmParameters,
}
impl Jwk {
pub fn is_supported(&self) -> bool {
self.common.key_algorithm.unwrap().to_algorithm().is_ok()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct JwkSet {
pub keys: Vec<Jwk>,
}
impl JwkSet {
pub fn find(&self, kid: &str) -> Option<&Jwk> {
self.keys
.iter()
.find(|jwk| jwk.common.key_id.is_some() && jwk.common.key_id.as_ref().unwrap() == kid)
}
}
#[cfg(test)]
mod tests {
use crate::jwk::{AlgorithmParameters, JwkSet, OctetKeyType};
use crate::serialization::b64_encode;
use crate::Algorithm;
use serde_json::json;
use wasm_bindgen_test::wasm_bindgen_test;
#[test]
#[wasm_bindgen_test]
fn check_hs256() {
let key = b64_encode("abcdefghijklmnopqrstuvwxyz012345");
let jwks_json = json!({
"keys": [
{
"kty": "oct",
"alg": "HS256",
"kid": "abc123",
"k": key
}
]
});
let set: JwkSet = serde_json::from_value(jwks_json).expect("Failed HS256 check");
assert_eq!(set.keys.len(), 1);
let key = &set.keys[0];
assert_eq!(key.common.key_id, Some("abc123".to_string()));
let algorithm = key.common.key_algorithm.unwrap().to_algorithm().unwrap();
assert_eq!(algorithm, Algorithm::HS256);
match &key.algorithm {
AlgorithmParameters::OctetKey(key) => {
assert_eq!(key.key_type, OctetKeyType::Octet);
assert_eq!(key.value, key.value)
}
_ => panic!("Unexpected key algorithm"),
}
}
}