use crate::control::security::util::base64_url_decode;
#[derive(Debug, Clone)]
pub struct VerificationKey {
pub kid: String,
pub algorithm: String,
pub key_type: KeyType,
}
#[derive(Debug, Clone)]
pub enum KeyType {
Rsa(Vec<u8>),
EcP256(Vec<u8>),
EcP384(Vec<u8>),
}
#[derive(Debug, serde::Deserialize)]
pub struct JwksResponse {
pub keys: Vec<JwkEntry>,
}
#[derive(Debug, serde::Deserialize)]
pub struct JwkEntry {
pub kty: String,
#[serde(default)]
pub kid: String,
#[serde(default)]
pub alg: String,
#[serde(default, rename = "use")]
pub key_use: String,
#[serde(default)]
pub n: String,
#[serde(default)]
pub e: String,
#[serde(default)]
pub crv: String,
#[serde(default)]
pub x: String,
#[serde(default)]
pub y: String,
}
pub fn parse_jwk(entry: &JwkEntry) -> Option<VerificationKey> {
if !entry.key_use.is_empty() && entry.key_use != "sig" {
return None;
}
match entry.kty.as_str() {
"RSA" => parse_rsa_jwk(entry),
"EC" => parse_ec_jwk(entry),
_ => None,
}
}
fn parse_rsa_jwk(entry: &JwkEntry) -> Option<VerificationKey> {
let n_bytes = base64_url_decode(&entry.n)?;
let e_bytes = base64_url_decode(&entry.e)?;
let n_der = to_der_integer(&n_bytes);
let e_der = to_der_integer(&e_bytes);
let mut seq_content = Vec::new();
seq_content.extend_from_slice(&n_der);
seq_content.extend_from_slice(&e_der);
let mut pkcs1_der = Vec::new();
pkcs1_der.push(0x30); encode_der_length(&mut pkcs1_der, seq_content.len());
pkcs1_der.extend_from_slice(&seq_content);
let alg = if entry.alg.is_empty() {
"RS256".to_string()
} else {
entry.alg.clone()
};
Some(VerificationKey {
kid: entry.kid.clone(),
algorithm: alg,
key_type: KeyType::Rsa(pkcs1_der),
})
}
fn parse_ec_jwk(entry: &JwkEntry) -> Option<VerificationKey> {
let x_bytes = base64_url_decode(&entry.x)?;
let y_bytes = base64_url_decode(&entry.y)?;
let mut point = Vec::with_capacity(1 + x_bytes.len() + y_bytes.len());
point.push(0x04);
point.extend_from_slice(&x_bytes);
point.extend_from_slice(&y_bytes);
let (key_type, default_alg) = match entry.crv.as_str() {
"P-256" => (KeyType::EcP256(point), "ES256"),
"P-384" => (KeyType::EcP384(point), "ES384"),
_ => return None,
};
let alg = if entry.alg.is_empty() {
default_alg.to_string()
} else {
entry.alg.clone()
};
Some(VerificationKey {
kid: entry.kid.clone(),
algorithm: alg,
key_type,
})
}
pub fn verify_signature(key: &VerificationKey, signing_input: &[u8], signature: &[u8]) -> bool {
match &key.key_type {
KeyType::Rsa(der) => verify_rsa_sha256(der, signing_input, signature),
KeyType::EcP256(point) => verify_ec_p256(point, signing_input, signature),
KeyType::EcP384(point) => verify_ec_p384(point, signing_input, signature),
}
}
fn verify_rsa_sha256(pkcs1_der: &[u8], message: &[u8], signature: &[u8]) -> bool {
use rsa::Pkcs1v15Sign;
let rsa_key = if let Ok(key) =
<rsa::RsaPublicKey as rsa::pkcs1::DecodeRsaPublicKey>::from_pkcs1_der(pkcs1_der)
{
key
} else if let Ok(key) =
<rsa::RsaPublicKey as rsa::pkcs8::DecodePublicKey>::from_public_key_der(pkcs1_der)
{
key
} else {
return false;
};
let digest = {
use sha2::Digest;
sha2::Sha256::digest(message)
};
let scheme = Pkcs1v15Sign::new::<sha2::Sha256>();
rsa_key.verify(scheme, &digest, signature).is_ok()
}
fn verify_ec_p256(sec1_point: &[u8], message: &[u8], signature: &[u8]) -> bool {
use p256::EncodedPoint;
use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
let point = match EncodedPoint::from_bytes(sec1_point) {
Ok(p) => p,
Err(_) => return false,
};
let vk = match VerifyingKey::from_encoded_point(&point) {
Ok(k) => k,
Err(_) => return false,
};
let sig = match Signature::from_slice(signature) {
Ok(s) => s,
Err(_) => return false,
};
vk.verify(message, &sig).is_ok()
}
fn verify_ec_p384(sec1_point: &[u8], message: &[u8], signature: &[u8]) -> bool {
use p384::EncodedPoint;
use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
let point = match EncodedPoint::from_bytes(sec1_point) {
Ok(p) => p,
Err(_) => return false,
};
let vk = match VerifyingKey::from_encoded_point(&point) {
Ok(k) => k,
Err(_) => return false,
};
let sig = match Signature::from_slice(signature) {
Ok(s) => s,
Err(_) => return false,
};
vk.verify(message, &sig).is_ok()
}
fn to_der_integer(bytes: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
out.push(0x02);
let needs_pad = !bytes.is_empty() && (bytes[0] & 0x80) != 0;
let len = bytes.len() + if needs_pad { 1 } else { 0 };
encode_der_length(&mut out, len);
if needs_pad {
out.push(0x00);
}
out.extend_from_slice(bytes);
out
}
fn encode_der_length(out: &mut Vec<u8>, len: usize) {
if len < 0x80 {
out.push(len as u8);
} else if len <= 0xFF {
out.push(0x81);
out.push(len as u8);
} else {
out.push(0x82);
out.push((len >> 8) as u8);
out.push(len as u8);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_rsa_jwk_entry() {
let entry = JwkEntry {
kty: "RSA".into(),
kid: "rsa-key-1".into(),
alg: "RS256".into(),
key_use: "sig".into(),
n: "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw".into(),
e: "AQAB".into(),
crv: String::new(),
x: String::new(),
y: String::new(),
};
let key = parse_jwk(&entry).unwrap();
assert_eq!(key.kid, "rsa-key-1");
assert_eq!(key.algorithm, "RS256");
assert!(matches!(key.key_type, KeyType::Rsa(_)));
}
#[test]
fn parse_ec_p256_jwk_entry() {
let entry = JwkEntry {
kty: "EC".into(),
kid: "ec-key-1".into(),
alg: "ES256".into(),
key_use: "sig".into(),
n: String::new(),
e: String::new(),
crv: "P-256".into(),
x: "f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU".into(),
y: "x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0".into(),
};
let key = parse_jwk(&entry).unwrap();
assert_eq!(key.kid, "ec-key-1");
assert_eq!(key.algorithm, "ES256");
assert!(matches!(key.key_type, KeyType::EcP256(_)));
}
#[test]
fn skip_encryption_keys() {
let entry = JwkEntry {
kty: "RSA".into(),
kid: "enc-key".into(),
alg: "RSA-OAEP".into(),
key_use: "enc".into(),
n: "AQAB".into(),
e: "AQAB".into(),
crv: String::new(),
x: String::new(),
y: String::new(),
};
assert!(parse_jwk(&entry).is_none());
}
#[test]
fn skip_unknown_curve() {
let entry = JwkEntry {
kty: "EC".into(),
kid: "ed-key".into(),
alg: "EdDSA".into(),
key_use: "sig".into(),
n: String::new(),
e: String::new(),
crv: "Ed25519".into(),
x: "abc".into(),
y: "def".into(),
};
assert!(parse_jwk(&entry).is_none());
}
}