use std::collections::{BTreeMap, HashMap};
use crate::{
ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey, EcdsaPublicKey},
eddsa::{Ed25519PrivateKey, Ed25519PublicKey},
rsa::{RsaAlgorithm, RsaPrivateKey, RsaPublicKey},
some::SomePublicKey,
url_safe_trailing_bits, verify, verify_only, Error, Header, HeaderAndClaims, PublicKeyToJwk,
Result, SigningKey, SomePrivateKey, VerificationKey,
};
use openssl::{
bn::BigNum,
hash::{hash, MessageDigest},
pkey::PKey,
rsa::{Rsa, RsaPrivateKeyBuilder},
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
#[non_exhaustive]
#[derive(Debug, Deserialize, Serialize, Default)]
pub struct Jwk {
pub kty: String,
#[serde(rename = "use", skip_serializing_if = "Option::is_none")]
pub use_: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub key_ops: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub alg: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub crv: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub kid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub e: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub d: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub p: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub q: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dp: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dq: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub qi: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub oth: Vec<Value>,
}
impl Jwk {
pub fn to_verification_key(&self) -> Result<SomePublicKey> {
if !matches!(self.use_.as_deref(), None | Some("sig")) {
return Err(Error::UnsupportedOrInvalidKey);
}
if !(self.key_ops.is_empty() || self.key_ops.iter().any(|ops| ops == "verify")) {
return Err(Error::UnsupportedOrInvalidKey);
}
#[allow(clippy::single_match)]
match &*self.kty {
"RSA" => match (self.alg.as_deref(), &self.n, &self.e) {
(alg, Some(ref n), Some(ref e)) => {
let n = base64::decode_config(n, url_safe_trailing_bits())?;
let e = base64::decode_config(e, url_safe_trailing_bits())?;
let alg = if let Some(alg) = alg {
Some(RsaAlgorithm::from_name(alg)?)
} else {
None
};
return Ok(SomePublicKey::Rsa(RsaPublicKey::from_components(
&n, &e, alg,
)?));
}
_ => {}
},
"EC" => match (self.crv.as_deref(), &self.x, &self.y) {
(Some(crv), Some(ref x), Some(ref y)) => {
let x = base64::decode_config(x, url_safe_trailing_bits())?;
let y = base64::decode_config(y, url_safe_trailing_bits())?;
let alg = EcdsaAlgorithm::from_curve_name(crv)?;
return Ok(SomePublicKey::Ecdsa(EcdsaPublicKey::from_coordinates(
&x, &y, alg,
)?));
}
_ => {}
},
"OKP" => match (self.crv.as_deref(), &self.x) {
(Some(crv), Some(ref x)) => {
let x = base64::decode_config(x, url_safe_trailing_bits())?;
match crv {
"Ed25519" => {
return Ok(SomePublicKey::Ed25519(Ed25519PublicKey::from_bytes(&x)?));
}
_ => {}
}
}
_ => {}
},
_ => {}
}
Err(Error::UnsupportedOrInvalidKey)
}
#[allow(clippy::many_single_char_names)]
pub fn to_signing_key(&self, rsa_fallback_algorithm: RsaAlgorithm) -> Result<SomePrivateKey> {
match &*self.kty {
"RSA" => {
let alg = if let Some(ref alg) = self.alg {
RsaAlgorithm::from_name(alg)?
} else {
rsa_fallback_algorithm
};
match (self.d.as_deref(), self.n.as_deref(), self.e.as_deref()) {
(Some(d), Some(n), Some(e)) => {
fn decode(x: &str) -> Result<BigNum> {
Ok(BigNum::from_slice(&base64::decode_config(
x,
url_safe_trailing_bits(),
)?)?)
}
let d = decode(d)?;
let n = decode(n)?;
let e = decode(e)?;
match (
self.p.as_deref(),
self.q.as_deref(),
self.dp.as_deref(),
self.dq.as_deref(),
self.qi.as_deref(),
self.oth.is_empty(),
) {
(None, None, None, None, None, true) => {
let rsa = RsaPrivateKeyBuilder::new(n, e, d)?.build();
let pkey = PKey::from_rsa(rsa)?;
RsaPrivateKey::from_pkey_without_check(pkey, alg).map(Into::into)
}
(Some(p), Some(q), Some(dp), Some(dq), Some(qi), true) => {
let p = decode(p)?;
let q = decode(q)?;
let dp = decode(dp)?;
let dq = decode(dq)?;
let qi = decode(qi)?;
let rsa = Rsa::from_private_components(n, e, d, p, q, dp, dq, qi)?;
let pkey = PKey::from_rsa(rsa)?;
RsaPrivateKey::from_pkey(pkey, alg).map(Into::into)
}
_ => Err(Error::UnsupportedOrInvalidKey),
}
}
_ => Err(Error::UnsupportedOrInvalidKey),
}
}
"EC" => {
match (
self.crv.as_deref(),
self.d.as_deref(),
self.x.as_deref(),
self.y.as_deref(),
) {
(Some(crv), Some(d), Some(x), Some(y)) => {
let alg = EcdsaAlgorithm::from_curve_name(crv)?;
let d = base64::decode_config(d, url_safe_trailing_bits())?;
let x = base64::decode_config(x, url_safe_trailing_bits())?;
let y = base64::decode_config(y, url_safe_trailing_bits())?;
EcdsaPrivateKey::from_private_components(alg, &d, &x, &y).map(Into::into)
}
_ => Err(Error::UnsupportedOrInvalidKey),
}
}
"OKP" => match (self.crv.as_deref(), self.d.as_deref()) {
(Some("Ed25519"), Some(d)) => {
let d = base64::decode_config(d, url_safe_trailing_bits())?;
Ed25519PrivateKey::from_bytes(&d).map(Into::into)
}
_ => Err(Error::UnsupportedOrInvalidKey),
},
_ => Err(Error::UnsupportedOrInvalidKey),
}
}
pub fn get_thumbprint_sha256(&self) -> Result<[u8; 32]> {
let as_json = match &*self.kty {
"RSA" => {
let mut v = BTreeMap::new();
v.insert(
"e",
self.e.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
);
v.insert("kty", "RSA");
v.insert(
"n",
self.n.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
);
serde_json::to_string(&v)?
}
"EC" => {
let mut v = BTreeMap::new();
v.insert(
"crv",
self.crv.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
);
v.insert("kty", "EC");
v.insert(
"x",
self.x.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
);
v.insert(
"y",
self.y.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
);
serde_json::to_string(&v)?
}
"OKP" => {
let mut v = BTreeMap::new();
v.insert(
"crv",
self.crv.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
);
v.insert("kty", "OKP");
v.insert(
"x",
self.x.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
);
serde_json::to_string(&v)?
}
_ => return Err(Error::UnsupportedOrInvalidKey),
};
let hash = hash(MessageDigest::sha256(), as_json.as_bytes())?;
let mut out = [0u8; 32];
out.copy_from_slice(&hash[..]);
Ok(out)
}
pub fn get_thumbprint_sha256_base64(&self) -> Result<String> {
Ok(base64::encode_config(
self.get_thumbprint_sha256()?,
url_safe_trailing_bits(),
))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JwkSet {
pub keys: Vec<Jwk>,
}
impl JwkSet {
pub fn verifier(&self) -> JwkSetVerifier {
let mut prepared = JwkSetVerifier {
keys: HashMap::new(),
require_kid: true,
};
for k in self.keys.iter() {
if let Some(ref kid) = k.kid {
if let Ok(vk) = k.to_verification_key() {
prepared.keys.insert(kid.clone(), vk);
}
}
}
prepared
}
}
pub struct JwkSetVerifier {
keys: HashMap<String, SomePublicKey>,
require_kid: bool,
}
impl JwkSetVerifier {
pub fn set_require_kid(&mut self, required: bool) {
self.require_kid = required;
}
pub fn find(&self, kid: &str) -> Option<&SomePublicKey> {
if let Some(vk) = self.keys.get(kid) {
Some(vk)
} else {
None
}
}
pub fn verify<ExtraClaims: DeserializeOwned>(
&self,
token: &str,
) -> Result<HeaderAndClaims<ExtraClaims>> {
self.find_and_verify(token, verify)
}
pub fn verify_only<ExtraClaims: DeserializeOwned>(
&self,
token: &str,
) -> Result<HeaderAndClaims<ExtraClaims>> {
self.find_and_verify(token, verify_only)
}
fn find_and_verify<ExtraClaims: DeserializeOwned>(
&self,
token: &str,
verifier: fn(&str, &dyn VerificationKey) -> Result<HeaderAndClaims<ExtraClaims>>,
) -> Result<HeaderAndClaims<ExtraClaims>> {
let mut parts = token.split('.');
let mut header = parts.next().ok_or(Error::InvalidToken)?.as_bytes();
let header_r = base64::read::DecoderReader::new(&mut header, url_safe_trailing_bits());
let header: Header = serde_json::from_reader(header_r)?;
if let Some(kid) = header.kid {
let k = self.find(&kid).ok_or(Error::NoKey)?;
verifier(token, k)
} else if !self.require_kid {
if let Some(res) = self
.keys
.iter()
.map(|(_, key)| verifier(token, key))
.find_map(|res| res.ok())
{
Ok(res)
} else {
Err(Error::NoKey)
}
} else {
Err(Error::NoKey)
}
}
}
#[derive(Debug)]
pub struct WithKid<S> {
kid: String,
inner: S,
}
impl<S> WithKid<S> {
pub fn new(kid: String, inner: S) -> Self {
Self { kid, inner }
}
pub fn new_with_thumbprint_id(inner: S) -> Result<Self>
where
S: PublicKeyToJwk,
{
Ok(Self {
kid: inner.public_key_to_jwk()?.get_thumbprint_sha256_base64()?,
inner,
})
}
pub fn kid(&self) -> &str {
&self.kid
}
pub fn set_kid(&mut self, kid: impl Into<String>) {
self.kid = kid.into();
}
pub fn as_inner(&self) -> &S {
&self.inner
}
pub fn into_inner(self) -> S {
self.inner
}
pub fn as_inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
impl<S: SigningKey> SigningKey for WithKid<S> {
fn kid(&self) -> Option<&str> {
Some(&self.kid)
}
fn sign(&self, v: &[u8]) -> Result<smallvec::SmallVec<[u8; 64]>> {
self.inner.sign(v)
}
fn alg(&self) -> &'static str {
self.inner.alg()
}
}
impl<S: VerificationKey> VerificationKey for WithKid<S> {
fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> Result<()> {
self.inner.verify(v, sig, alg)
}
}
impl<K: PublicKeyToJwk> PublicKeyToJwk for WithKid<K> {
fn public_key_to_jwk(&self) -> Result<Jwk> {
let mut jwk = self.inner.public_key_to_jwk()?;
jwk.kid = Some(self.kid.clone());
Ok(jwk)
}
}
#[cfg(feature = "remote-jwks")]
struct JWKSCache {
jwks: JwkSetVerifier,
valid_until: std::time::Instant,
}
#[cfg(feature = "remote-jwks")]
pub struct RemoteJwksVerifier {
url: String,
client: reqwest::Client,
cache_duration: std::time::Duration,
cache: tokio::sync::RwLock<Option<JWKSCache>>,
require_kid: bool,
}
#[cfg(feature = "remote-jwks")]
impl RemoteJwksVerifier {
pub fn new(
url: String,
client: Option<reqwest::Client>,
cache_duration: std::time::Duration,
) -> Self {
Self {
url,
client: client.unwrap_or_default(),
cache_duration,
cache: tokio::sync::RwLock::new(None),
require_kid: true,
}
}
pub fn set_require_kid(&mut self, required: bool) {
self.require_kid = required;
if let Some(ref mut v) = self.cache.get_mut() {
v.jwks.require_kid = required;
}
}
async fn get_verifier(&self) -> Result<tokio::sync::RwLockReadGuard<'_, JwkSetVerifier>> {
let cache = self.cache.read().await;
if let Some(c) = &*cache {
if c.valid_until
.checked_duration_since(std::time::Instant::now())
.is_some()
{
return Ok(tokio::sync::RwLockReadGuard::map(cache, |c| {
&c.as_ref().unwrap().jwks
}));
}
}
drop(cache);
let mut cache = self.cache.write().await;
if let Some(c) = &*cache {
if c.valid_until
.checked_duration_since(std::time::Instant::now())
.is_some()
{
return Ok(tokio::sync::RwLockReadGuard::map(cache.downgrade(), |c| {
&c.as_ref().unwrap().jwks
}));
}
}
let response = self
.client
.get(&self.url)
.header("accept", "application/json")
.send()
.await?;
let jwks: JwkSet = response.json().await?;
*cache = Some(JWKSCache {
jwks: {
let mut v = jwks.verifier();
v.require_kid = self.require_kid;
v
},
valid_until: std::time::Instant::now() + self.cache_duration,
});
Ok(tokio::sync::RwLockReadGuard::map(cache.downgrade(), |c| {
&c.as_ref().unwrap().jwks
}))
}
pub async fn verify<E: DeserializeOwned>(&self, token: &str) -> Result<HeaderAndClaims<E>> {
let v = self.get_verifier().await?;
v.verify(token)
}
pub async fn verify_only<E: DeserializeOwned>(
&self,
token: &str,
) -> Result<HeaderAndClaims<E>> {
let v = self.get_verifier().await?;
v.verify_only(token)
}
}
#[cfg(test)]
mod tests {
use crate::{
ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey},
eddsa::Ed25519PrivateKey,
rsa::RsaPrivateKey,
sign,
};
use super::*;
#[test]
fn test_jwk() -> Result<()> {
assert!(Jwk {
kty: "RSA".to_string(),
use_: Some("enc".into()),
..Default::default()
}
.to_verification_key()
.is_err());
assert!(Jwk {
kty: "RSA".to_string(),
key_ops: vec!["encryption".into()],
..Default::default()
}
.to_verification_key()
.is_err());
Ok(())
}
#[test]
fn test_thumbprint() -> Result<()> {
RsaPrivateKey::generate(2048, RsaAlgorithm::RS256)?
.public_key_to_jwk()?
.get_thumbprint_sha256_base64()?;
EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256)?
.public_key_to_jwk()?
.get_thumbprint_sha256_base64()?;
Ed25519PrivateKey::generate()?
.public_key_to_jwk()?
.get_thumbprint_sha256_base64()?;
Ok(())
}
#[derive(Serialize, Deserialize)]
struct MyClaim {
foo: String,
}
#[test]
fn test_jwks_verify() -> Result<()> {
let k = EcdsaPrivateKey::generate(EcdsaAlgorithm::ES512)?;
let kk = WithKid::new("my key".into(), k.clone());
let k_jwk = kk.public_key_to_jwk()?;
let jwks = JwkSet { keys: vec![k_jwk] };
let mut verifier = jwks.verifier();
{
let mut jwt = HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() });
jwt.set_kid("my key");
let token = sign(&mut jwt, &k)?;
verifier.verify_only::<MyClaim>(&token)?;
let verified = verifier.verify::<MyClaim>(&token)?;
assert_eq!(verified.claims.extra.foo, "bar");
}
{
let mut jwt = HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() });
jwt.set_kid("my key2");
let token = sign(&mut jwt, &k)?;
let res = verifier.verify_only::<MyClaim>(&token);
assert!(res.is_err());
}
{
let mut jwt = HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() });
jwt.set_kid("my key2");
let token = sign(&mut jwt, &kk)?;
verifier.verify_only::<MyClaim>(&token)?;
let verified = verifier.verify::<MyClaim>(&token)?;
assert_eq!(verified.claims.extra.foo, "bar");
}
{
let token = sign(
&mut HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() }),
&k,
)?;
let res = verifier.verify_only::<MyClaim>(&token);
assert!(res.is_err());
}
{
let token = sign(
&mut HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() }),
&k,
)?;
verifier.set_require_kid(false);
verifier.verify::<MyClaim>(&token)?;
let verified = verifier.verify_only::<MyClaim>(&token)?;
assert_eq!(verified.claims.extra.foo, "bar");
}
Ok(())
}
}