use std::{collections::BTreeMap, hash::Hash, marker::PhantomData};
use base64ct::Encoding;
use serde::{
de,
ser::{self, SerializeMap},
Deserialize, Serialize,
};
use signature::Keypair;
#[derive(Debug, thiserror::Error)]
pub enum JsonWebKeyError {
#[error(transparent)]
JSON(#[from] serde_json::Error),
#[error("key type mismatch: expected {expected}, got {got}")]
KeyType {
expected: String,
got: String,
},
#[error("missing expected jwk parameter {0}")]
MissingParameter(&'static str),
#[error("invalid key for algorithm {0}: {1}")]
InvalidKey(
&'static str,
#[source] Box<dyn std::error::Error + Send + Sync>,
),
}
pub trait JWKeyType {
const KEY_TYPE: &'static str;
}
impl<T> JWKeyType for &T
where
T: JWKeyType,
{
const KEY_TYPE: &'static str = T::KEY_TYPE;
}
pub trait DynJwkKeyType {
fn key_type(&self) -> &'static str;
}
impl<T> DynJwkKeyType for T
where
T: JWKeyType,
{
fn key_type(&self) -> &'static str {
T::KEY_TYPE
}
}
pub trait SerializeJWK: DynJwkKeyType {
fn parameters(&self) -> Vec<(String, serde_json::Value)>;
}
pub trait SerializePublicJWK: DynJwkKeyType {
fn public_parameters(&self) -> Vec<(String, serde_json::Value)>;
}
impl<K> SerializePublicJWK for K
where
K: Keypair + DynJwkKeyType,
K::VerifyingKey: SerializeJWK,
{
fn public_parameters(&self) -> Vec<(String, serde_json::Value)> {
self.verifying_key().parameters()
}
}
pub trait DeserializeJWK: DynJwkKeyType + Sized {
fn build(parameters: BTreeMap<String, serde_json::Value>) -> Result<Self, JsonWebKeyError>;
fn from_jwk(jwk: &JsonWebKey) -> Result<Self, JsonWebKeyError> {
let mut parameters = jwk.parameters().clone();
parameters.insert("kty".into(), jwk.key_type().into());
match Self::build(parameters) {
Ok(key) => {
if key.key_type() == jwk.key_type() {
Ok(key)
} else {
Err(JsonWebKeyError::KeyType {
expected: key.key_type().into(),
got: jwk.key_type().into(),
})
}
}
Err(e) => Err(e),
}
}
fn from_value(value: serde_json::Value) -> Result<Self, JsonWebKeyError> {
let jwk: JsonWebKey = serde_json::from_value(value)?;
Self::from_jwk(&jwk)
}
fn from_str(s: &str) -> Result<Self, JsonWebKeyError> {
let value: JsonWebKey = serde_json::from_str(s)?;
Self::from_jwk(&value)
}
}
pub trait BuildFromKey<Key: ?Sized> {
fn derive_from_key(key: &Key) -> Result<Self, JsonWebKeyError>
where
Self: Sized;
}
impl<Key> BuildFromKey<Key> for JsonWebKey
where
Key: SerializePublicJWK + ?Sized,
{
fn derive_from_key(key: &Key) -> Result<Self, JsonWebKeyError> {
Ok(JsonWebKey {
key_type: key.key_type().into(),
parameters: key.public_parameters().into_iter().collect(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct JsonWebKey {
#[serde(rename = "kty")]
key_type: String,
#[serde(flatten)]
parameters: BTreeMap<String, serde_json::Value>,
}
impl JsonWebKey {
pub fn new(key_type: String, parameters: BTreeMap<String, serde_json::Value>) -> Self {
Self {
key_type,
parameters,
}
}
pub fn build_public<K: SerializePublicJWK + ?Sized>(key: &K) -> Self {
JsonWebKey {
key_type: key.key_type().into(),
parameters: key.public_parameters().into_iter().collect(),
}
}
pub fn build<K: SerializeJWK + ?Sized>(key: &K) -> Self {
JsonWebKey {
key_type: key.key_type().into(),
parameters: key.parameters().into_iter().collect(),
}
}
pub fn key_type(&self) -> &str {
&self.key_type
}
pub fn parameters(&self) -> &BTreeMap<String, serde_json::Value> {
&self.parameters
}
pub fn deserialize_key<K: DeserializeJWK>(&self) -> Result<K, JsonWebKeyError> {
match K::build(self.parameters.clone()) {
Ok(key) => {
if key.key_type() == self.key_type {
Ok(key)
} else {
Err(JsonWebKeyError::KeyType {
expected: self.key_type.clone(),
got: key.key_type().into(),
})
}
}
Err(e) => Err(e),
}
}
}
impl Serialize for JsonWebKey {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut entries = self
.parameters
.iter()
.map(|(key, value)| (key.as_str(), value))
.collect::<BTreeMap<_, _>>();
let kty = serde_json::Value::String(self.key_type.clone());
entries.insert("kty", &kty);
let mut map = serializer.serialize_map(Some(entries.len()))?;
for (key, value) in entries {
map.serialize_entry(key, value)?;
}
map.end()
}
}
#[derive(Debug, zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
pub struct Thumbprint<Digest> {
thumbprint: String,
digest: PhantomData<Digest>,
}
impl<Digest> Clone for Thumbprint<Digest> {
fn clone(&self) -> Self {
Self {
thumbprint: self.thumbprint.clone(),
digest: PhantomData,
}
}
}
impl<Digest> Hash for Thumbprint<Digest> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.thumbprint.hash(state);
}
}
impl<Digest> PartialEq for Thumbprint<Digest> {
fn eq(&self, other: &Self) -> bool {
self.thumbprint == other.thumbprint && self.digest == other.digest
}
}
impl<Digest> Eq for Thumbprint<Digest> {}
impl<Digest> Thumbprint<Digest>
where
Digest: digest::Digest,
{
pub fn new(thumbprint: String) -> Self {
Self {
thumbprint,
digest: PhantomData,
}
}
pub fn from_jwk(jwk: &JsonWebKey) -> Result<Self, JsonWebKeyError> {
let thumb = serde_json::to_vec(&jwk)?;
let mut hasher = Digest::new();
hasher.update(&thumb);
let digest = hasher.finalize();
Ok(Self::new(base64ct::Base64UrlUnpadded::encode_string(
&digest,
)))
}
}
impl<Digest, Key> BuildFromKey<Key> for Thumbprint<Digest>
where
Key: SerializePublicJWK + ?Sized,
Digest: digest::Digest,
{
fn derive_from_key(key: &Key) -> Result<Thumbprint<Digest>, JsonWebKeyError> {
let jwk = JsonWebKey::derive_from_key(key)?;
Thumbprint::from_jwk(&jwk)
}
}
impl<Digest> ser::Serialize for Thumbprint<Digest> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.thumbprint.as_str())
}
}
struct ThumbprintVisitor<D>(PhantomData<D>);
impl<'de, D> de::Visitor<'de> for ThumbprintVisitor<D>
where
D: digest::Digest,
{
type Value = Thumbprint<D>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("thumbprint digest as base64url string")
}
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Thumbprint::new(v.to_owned()))
}
}
impl<'de, Digest> de::Deserialize<'de> for Thumbprint<Digest>
where
Digest: digest::Digest,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(ThumbprintVisitor(PhantomData))
}
}
impl<Digest> std::fmt::Display for Thumbprint<Digest> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.thumbprint)
}
}
impl<Digest> std::ops::Deref for Thumbprint<Digest> {
type Target = str;
fn deref(&self) -> &str {
&self.thumbprint
}
}
#[cfg(test)]
mod test {
use super::*;
use static_assertions as sa;
sa::assert_obj_safe!(SerializeJWK);
#[cfg(feature = "rsa")]
mod rsa {
use super::super::*;
use serde_json::json;
#[test]
fn rfc7639_example() {
let key = rsa::RsaPublicKey::from_value(json!({
"kty": "RSA",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAt
VT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn6
4tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FD
W2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n9
1CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINH
aQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB",
"alg": "RS256",
"kid": "2011-04-29"
}
))
.unwrap();
let thumb: Thumbprint<sha2::Sha256> =
Thumbprint::from_jwk(&JsonWebKey::build(&key)).unwrap();
assert_eq!(&*thumb, "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs");
}
}
}