use crate::errors::*;
use openssl::ecdsa::EcdsaSig;
use std::fmt::{self};
use std::sync::RwLock;
use std::time::{Duration, SystemTime};
use url::Url;
extern crate base64;
extern crate openssl;
extern crate serde;
extern crate serde_json;
use openssl::bn::BigNum;
use openssl::ec::{EcGroup, EcKey};
use openssl::hash::{hash, MessageDigest};
use openssl::nid::Nid;
use openssl::pkey::Id;
use openssl::pkey::{PKey, Public};
use openssl::rsa::Rsa;
use openssl::sign::Verifier;
pub const RELOAD_INTERVAL_FACTOR: f64 = 0.75;
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct BBKey {
key: PKey<Public>,
kid: Option<String>,
kty: KeyType,
crv: Option<EcCurve>,
alg: KeyAlgorithm,
}
#[derive(Clone, Debug, Deserialize)]
#[allow(non_camel_case_types)]
pub enum KeyType {
RSA,
EC,
OKP,
#[serde(other)]
Unsupported,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
pub enum KeyAlgorithm {
RS256,
RS384,
RS512,
ES256,
ES384,
ES512,
EdDSA,
#[serde(other)]
Other,
}
#[derive(Clone, Debug, Deserialize)]
pub enum EcCurve {
#[serde(rename = "P-256")]
P256,
#[serde(rename = "secp256k1")]
SECP256K1,
#[serde(rename = "P-384")]
P384,
#[serde(rename = "P-521")]
P521,
Ed25519,
Ed448,
}
#[allow(dead_code)]
#[derive(Clone, Debug, Deserialize)]
pub struct JWK {
pub kty: KeyType,
pub alg: Option<KeyAlgorithm>,
pub kid: Option<String>,
pub n: Option<String>,
pub e: Option<String>,
pub crv: Option<EcCurve>,
pub x: Option<String>,
pub y: Option<String>,
}
#[derive(Clone, Debug, Deserialize)]
pub struct JWKS {
pub keys: Vec<JWK>,
}
#[derive(Debug)]
pub struct KeyStore {
keyset: RwLock<Vec<BBKey>>,
url: Option<String>,
load_time: Option<SystemTime>,
reload_factor: f64,
reload_time: Option<SystemTime>,
}
impl JWKS {
pub fn new() -> Self {
JWKS { keys: vec![] }
}
}
impl Default for JWKS {
fn default() -> Self {
Self::new()
}
}
impl KeyAlgorithm {
pub fn message_digest(&self) -> Option<MessageDigest> {
match *self {
KeyAlgorithm::RS256 | KeyAlgorithm::ES256 => Some(MessageDigest::sha256()),
KeyAlgorithm::RS384 | KeyAlgorithm::ES384 => Some(MessageDigest::sha384()),
KeyAlgorithm::RS512 | KeyAlgorithm::ES512 => Some(MessageDigest::sha512()),
_ => None,
}
}
pub fn signature_length(&self) -> usize {
match *self {
KeyAlgorithm::ES256 => 64,
KeyAlgorithm::ES384 => 96,
KeyAlgorithm::ES512 => 132,
_ => 0,
}
}
}
impl Default for KeyAlgorithm {
fn default() -> Self {
KeyAlgorithm::RS256
}
}
impl EcCurve {
pub fn message_digest(&self) -> Option<MessageDigest> {
match *self {
EcCurve::P256 => Some(MessageDigest::sha256()),
EcCurve::P384 => Some(MessageDigest::sha384()),
EcCurve::P521 => Some(MessageDigest::sha512()),
_ => None,
}
}
pub fn nid(&self) -> Option<Nid> {
match *self {
EcCurve::SECP256K1 => Some(Nid::SECP256K1),
EcCurve::P256 => Some(Nid::X9_62_PRIME256V1),
EcCurve::P384 => Some(Nid::SECP384R1),
EcCurve::P521 => Some(Nid::SECP521R1),
_ => None,
}
}
}
impl fmt::Display for BBKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let kid = self.kid.clone().unwrap_or_else(|| "<no_kid>".to_string());
write!(f, "{}", kid)
}
}
impl BBKey {
pub fn verify_signature(&self, payload: &[u8], signature: &[u8]) -> BBResult<bool> {
match self.alg {
KeyAlgorithm::RS256 | KeyAlgorithm::RS384 | KeyAlgorithm::RS512 => {
let mut verifier = self.verifier()?;
verifier
.update(payload)
.map_err(|e| BBError::DecodeError(format!("{:?}", e)))?;
match verifier
.verify(signature)
.map_err(|e| BBError::Other(format!("Failed to check RSA signature: {:?}", e)))?
{
true => Ok(true),
false => Err(BBError::SignatureInvalid()),
}
}
KeyAlgorithm::ES256 | KeyAlgorithm::ES384 | KeyAlgorithm::ES512 => {
let ec_key = self.key.ec_key().map_err(|e| {
BBError::Other(format!("Failed to extract EC key from public key: {:?}", e))
})?;
let sig_len = signature.len();
if sig_len != self.alg.signature_length() {
return Err(BBError::SignatureInvalid());
}
let m = signature.len() / 2;
let r = BigNum::from_slice(&signature[..m])
.map_err(|e| BBError::Other(format!("Bignum error: {}", e)))?;
let s = BigNum::from_slice(&signature[m..sig_len])
.map_err(|e| BBError::Other(format!("Bignum error: {}", e)))?;
let sig = EcdsaSig::from_private_components(r, s)
.map_err(|e| BBError::Other(format!("Could not create Ecdsa Signature: {}", e)))?;
let digest = self
.alg
.message_digest()
.ok_or_else(|| BBError::Other("Unknown algorithm digest".to_string()))?;
let hash = hash(digest, payload)
.map_err(|e| BBError::Other(format!("Failed to hash payload: {}", e)))?;
Ok(
sig
.verify(&hash, &ec_key)
.map_err(|e| BBError::Other(format!("Failed to verify EC signature: {}", e)))?,
)
}
KeyAlgorithm::EdDSA => {
let mut verifier = Verifier::new_without_digest(&self.key)
.map_err(|e| BBError::Other(format!("Cannot get verifier for EdDSA: {}", e)))?;
Ok(
verifier
.verify_oneshot(signature, payload)
.map_err(|e| BBError::Other(format!("Failed to verify EdDSA signature: {}", e)))?,
)
}
_ => Err(BBError::Other(format!(
"Unsupported key algorithm for key '{}'",
*self
))),
}
}
pub fn verifier(&self) -> BBResult<Verifier> {
let verifier = match self.kty {
KeyType::RSA => {
let message_digest = self.alg.message_digest().ok_or_else(|| {
BBError::Other(format!("Failed to get message digest for key '{}'.", &self))
})?;
Verifier::new(message_digest, &self.key).map_err(|e| {
BBError::Other(format!(
"Failed to create verifier for RSA key '{}': {:?}",
&self, e
))
})?
}
KeyType::EC => {
return Err(BBError::Other("EC key has no verifier".to_string()));
}
KeyType::OKP => {
Verifier::new_without_digest(&self.key).map_err(|e| {
BBError::Other(format!(
"Failed to create verifier for Ed key '{}': {:?}",
&self, e
))
})?
}
KeyType::Unsupported => {
return Err(BBError::Other(format!(
"Unsupported key type for key '{}'",
&self
)));
}
};
Ok(verifier)
}
}
pub fn base64_config() -> base64::Config {
base64::URL_SAFE_NO_PAD.decode_allow_trailing_bits(true)
}
fn bignum_from_base64(b64: &str, error_context: &str) -> BBResult<BigNum> {
let bytes = base64::decode_config(b64, base64_config())
.map_err(|e| BBError::DecodeError(format!("{error_context}: '{:?}'", e)))?;
BigNum::from_slice(&bytes).map_err(|e| {
BBError::DecodeError(format!(
"Failed to create number from b64 string ({error_context}): {}",
e
))
})
}
fn pubkey_from_jwk(jwk: &JWK) -> BBResult<BBKey> {
let kid = if jwk.kid.is_some() {
jwk.kid.as_ref().unwrap()
} else {
"<no_kid>"
};
let key = match jwk.kty {
KeyType::EC => {
let nid = if jwk.crv.is_some() {
jwk.crv.as_ref().unwrap().nid()
} else {
None
}
.ok_or_else(|| {
BBError::JWKInvalid(format!(
"Missing or unsupported 'crv' field for EC key '{kid}'"
))
})?;
let group = EcGroup::from_curve_name(nid).map_err(|e| {
BBError::JWKInvalid(format!(
"Cannot create EcGroup from nid {:?} for key {kid}: {}",
nid, e
))
})?;
if jwk.x.is_none() || jwk.y.is_none() {
return Err(BBError::JWKInvalid(format!(
"Missing x or y for EC key '{kid}'"
)));
}
let x = bignum_from_base64(jwk.x.as_ref().unwrap(), "EC x")?;
let y = bignum_from_base64(jwk.y.as_ref().unwrap(), "EC y")?;
let ec_key = EcKey::from_public_key_affine_coordinates(&group, &x, &y)
.map_err(|e| BBError::JWKInvalid(format!("Failed to create EcKey for {kid}': {}", e)))?;
PKey::from_ec_key(ec_key)
.map_err(|e| BBError::JWKInvalid(format!("Failed to create PKey/EC for {kid}': {}", e)))?
}
KeyType::RSA => {
if jwk.n.is_none() || jwk.e.is_none() {
return Err(BBError::JWKInvalid(format!(
"Missing n or e for RSA key '{kid}'"
)));
}
let n = bignum_from_base64(jwk.n.as_ref().unwrap(), "RSA n")?;
let e = bignum_from_base64(jwk.e.as_ref().unwrap(), "RSA e")?;
let rsa_key = Rsa::from_public_components(n, e)
.map_err(|e| BBError::JWKInvalid(format!("Failed to create RSA key from {kid}: {}", e)))?;
PKey::from_rsa(rsa_key)
.map_err(|e| BBError::JWKInvalid(format!("Failed to create PKey/RSA from {kid}: {}", e)))?
}
KeyType::OKP => {
if jwk.x.is_none() {
return Err(BBError::JWKInvalid(format!(
"Missing x for OKP key '{kid}'"
)));
}
let bytes = base64::decode_config(jwk.x.as_ref().unwrap(), base64_config())
.map_err(|e| BBError::DecodeError(format!("Failed to decode x for {kid}: {}", e)))?;
let curve_id = match jwk.crv {
Some(EcCurve::Ed25519) => Id::ED25519,
Some(EcCurve::Ed448) => Id::ED448,
None => Id::ED25519,
_ => {
return Err(BBError::JWKInvalid(format!(
"Invalid curve for OKP key {kid}"
)));
}
};
PKey::public_key_from_raw_bytes(&bytes, curve_id)
.map_err(|e| BBError::JWKInvalid(format!("Failed to read EdDSA key for {kid}: {}", e)))?
}
_ => {
return Err(BBError::JWKInvalid(format!(
"Unsupported keytype for {kid}"
)));
}
};
Ok(BBKey {
kid: jwk.kid.clone(),
key,
kty: jwk.kty.clone(),
crv: jwk.crv.clone(),
alg: jwk.alg.clone().unwrap_or_default(),
})
}
#[allow(dead_code)]
impl KeyStore {
pub async fn new() -> BBResult<Self> {
Ok(KeyStore {
keyset: RwLock::new(Vec::new()),
url: None,
load_time: None,
reload_factor: RELOAD_INTERVAL_FACTOR,
reload_time: None,
})
}
pub async fn new_from_url(surl: &str) -> BBResult<Self> {
let url = Url::parse(surl)
.map_err(|e| BBError::URLInvalid(format!("Invalid keyset URL '{surl}: {:?}", e)))?;
let host = url
.host_str()
.ok_or_else(|| BBError::URLInvalid(format!("No host in keyset URL '{surl}")))?;
if !["localhost", "127.0.0.1"].contains(&host) && url.scheme() != "https" {
return Err(BBError::URLInvalid(
"Public keysets must be loaded via https.".to_string(),
));
}
let mut ks = KeyStore {
keyset: RwLock::new(Vec::new()),
url: Some(url.to_string()),
load_time: None,
reload_factor: RELOAD_INTERVAL_FACTOR,
reload_time: None,
};
ks.load_keys().await?;
Ok(ks)
}
pub fn keyset(&self) -> BBResult<Vec<BBKey>> {
if let Ok(keyset) = self.keyset.read() {
Ok(keyset.clone())
} else {
Err(BBError::Fatal("Keyset lock is poisoned".to_string()))
}
}
pub fn keys_len(&self) -> usize {
if let Ok(keyset) = self.keyset.read() {
keyset.len()
} else {
0
}
}
pub fn add_key(&mut self, key_json: &str) -> BBResult<()> {
let key: JWK = serde_json::from_str(key_json)
.map_err(|e| BBError::Other(format!("Failed to parse key JSON: {:?}", e)))?;
let mut keyset = self
.keyset
.write()
.map_err(|e| BBError::Other(format!("Failed to get write lock on keyset: {:?}", e)))?;
keyset.push(pubkey_from_jwk(&key)?);
Ok(())
}
pub fn add_rsa_pem_key(&self, pem: &str, kid: Option<&str>, alg: KeyAlgorithm) -> BBResult<()> {
let rsa = openssl::rsa::Rsa::public_key_from_pem(pem.as_bytes())
.map_err(|e| BBError::Other(format!("Could not read RSA pem: {:?}", e)))?;
let bbkey = BBKey {
kid: kid.map(|v| v.to_string()),
key: PKey::from_rsa(rsa)
.map_err(|e| BBError::JWKInvalid(format!("Failed to create PKey/RSA from PEM: {}", e)))?,
kty: KeyType::RSA,
crv: None,
alg,
};
let mut keyset = self
.keyset
.write()
.map_err(|e| BBError::Other(format!("Failed to get write lock on keyset: {:?}", e)))?;
keyset.push(bbkey);
Ok(())
}
pub fn add_ec_pem_key(
&self,
pem: &str,
kid: Option<&str>,
curve: EcCurve,
alg: KeyAlgorithm,
) -> BBResult<()> {
let key = PKey::public_key_from_pem(pem.as_bytes())
.map_err(|e| BBError::Other(format!("Failed to read PEM EdDSA pub key: {}", e)))?;
let kty = match alg {
KeyAlgorithm::ES256 | KeyAlgorithm::ES384 | KeyAlgorithm::ES512 => KeyType::EC,
KeyAlgorithm::EdDSA => KeyType::OKP,
_ => {
return Err(BBError::Other("Invalid algorithm for ec key".to_string()));
}
};
let bbkey = BBKey {
kid: kid.map(|v| v.to_string()),
key,
kty,
alg,
crv: Some(curve),
};
let mut keyset = self
.keyset
.write()
.map_err(|e| BBError::Other(format!("Failed to get write lock on keyset: {:?}", e)))?;
keyset.push(bbkey);
Ok(())
}
pub fn key_by_id(&self, kid: Option<&str>) -> BBResult<BBKey> {
let keyset = self
.keyset
.read()
.map_err(|_e| BBError::Fatal("The keyset lock is poisoned".to_string()))?;
let key = if let Some(kid) = kid {
let key = keyset.iter().find(|k: &&BBKey| {
if let Some(this_kid) = &k.kid {
this_kid.eq(kid)
} else {
false
}
});
key.ok_or_else(|| BBError::Other(format!("Could not find kid '{kid}' in keyset.")))?
} else {
keyset
.first()
.ok_or_else(|| BBError::Other("No keys in keyset".to_string()))?
};
Ok(key.clone())
}
pub fn set_reload_factor(&mut self, interval: f64) {
self.reload_factor = interval;
}
pub fn reload_factor(&self) -> f64 {
self.reload_factor
}
pub fn load_time(&self) -> Option<SystemTime> {
self.load_time
}
pub fn reload_time(&self) -> Option<SystemTime> {
self.reload_time
}
pub fn should_reload_time(&self, time: SystemTime) -> Option<bool> {
self.reload_time.map(|reload_time| reload_time <= time)
}
pub fn should_reload(&self) -> Option<bool> {
self.should_reload_time(SystemTime::now())
}
#[allow(clippy::await_holding_lock)]
pub async fn load_keys(&mut self) -> BBResult<()> {
let url = self
.url
.clone()
.ok_or_else(|| BBError::Other("No load URL for keyset provided.".to_string()))?;
let mut keys = self
.keyset
.write()
.map_err(|e| BBError::Fatal(format!("Keyset write lock is poisoned: {}", e)))?;
keys.clear();
drop(keys);
let mut response = reqwest::get(&url)
.await
.map_err(|e| BBError::Other(format!("Failed to load IdP keyset: {:?}", e)))?;
let lifetime = KeyStore::get_key_expiration_time(&mut response);
let json = response
.text()
.await
.map_err(|e| BBError::NetworkError(format!("Failed to load public key set: {:?}", e)))?;
let keyset: JWKS = serde_json::from_str(&json)
.map_err(|e| BBError::Other(format!("Failed to parse IdP public key set: {:?}", e)))?;
let mut keys = self
.keyset
.write()
.map_err(|e| BBError::Fatal(format!("Keyset write lock is poisoned: {}", e)))?;
for key in keyset.keys {
keys.push(pubkey_from_jwk(&key)?);
}
let load_time = SystemTime::now();
if let Ok(value) = lifetime {
let seconds: u64 = (value as f64 * self.reload_factor) as u64;
self.reload_time = Some(load_time + Duration::new(seconds, 0));
}
if self.load_time.is_none() {
self.load_time = Some(load_time);
}
Ok(())
}
fn get_key_expiration_time(response: &mut reqwest::Response) -> Result<u64, ()> {
let header = response.headers().get("cache-control").ok_or(())?;
let cache_control = header.to_str().map_err(|_| ())?.to_lowercase();
assigned_header_value(&cache_control, "max-age")
}
pub async fn idp_certs_url(idp_discovery_url: &str) -> BBResult<String> {
let info_json = reqwest::get(idp_discovery_url)
.await
.map_err(|e| {
BBError::NetworkError(format!(
"Failed to load IdP discovery info JSON from {idp_discovery_url}: {:?}",
e
))
})?
.text()
.await
.map_err(|e| {
BBError::NetworkError(format!("Failed to get IdP discovery info JSON: {:?}", e))
})?;
let info: serde_json::Value = serde_json::from_str(&info_json).map_err(|e| {
BBError::Other(format!(
"Invalid JSON from IdP discovery info url '{idp_discovery_url}': {:?}",
e
))
})?;
if let serde_json::Value::String(jwks_uri) = &info["jwks_uri"] {
Ok(jwks_uri.to_string())
} else {
Err(BBError::Other(
"No jwks_uri in IdP discovery info found".to_string(),
))
}
}
pub fn keycloak_discovery_url(host: &str, realm: &str) -> BBResult<String> {
let mut info_url = Url::parse(host).map_err(|e| {
BBError::Other(format!(
"Invalid base URL for Keycloak discovery endpoint: {:?}",
e
))
})?;
info_url
.path_segments_mut()
.map_err(|_| BBError::Other(format!("Invalid IdP URL '{host}'")))?
.push("realms")
.push(realm)
.push(".well-known")
.push("openid-configuration");
Ok(info_url.to_string())
}
}
fn assigned_header_value(hdr_value: &str, name: &str) -> Result<u64, ()> {
let mut p = hdr_value.find(name).ok_or(())?;
p += name.len();
let mut num = String::with_capacity(22); let mut got_ass = false;
let chars = hdr_value.get(p..).unwrap().chars();
for c in chars {
match c {
'=' => {
got_ass = true;
}
c => {
if !got_ass {
continue;
}
if c.is_numeric() {
num.push(c);
} else if !num.is_empty() {
break;
}
}
}
}
if num.is_empty() {
return Err(());
}
let value: u64 = num.parse().map_err(|_| ())?;
Ok(value)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::seq::SliceRandom;
use std::env;
use std::fs::File;
use std::io::Read;
use std::path::Path;
pub fn path_to_asset_file(asset_name: &str) -> String {
let path = Path::new(
env::var("CARGO_MANIFEST_DIR")
.expect("CARGO_MANIFEST_DIR not set")
.as_str(),
)
.join(format!("tests/assets/{asset_name}"));
String::from(path.to_str().unwrap())
}
#[test]
fn test_keycloak_discovery_url() {
let url = KeyStore::keycloak_discovery_url("https://host.tld", "testing");
assert_eq!(
url.unwrap(),
"https://host.tld/realms/testing/.well-known/openid-configuration"
)
}
#[test]
fn test_header_value_parser() {
let test_strings = vec![
"oriuehgueohgeor depp = 3485975dd",
"depp=1,fellow",
"depp = 22-dude",
"r depp=12345678",
"xu depp=666",
];
let results: Vec<u64> = vec![3485975, 1, 22, 12345678, 666];
for i in 0..test_strings.len() {
assert!(assigned_header_value(test_strings[i], "depp").unwrap() == results[i]);
}
assert!(assigned_header_value("orihgeorgohoho", "name").is_err());
}
#[tokio::test]
async fn test_keystore_local() {
let mut ks = KeyStore::new()
.await
.expect("Failed to create empty keystore");
let key_json_file = path_to_asset_file("pubkey.json");
let mut file = File::open(key_json_file).expect("Failed to open pubkey.json");
let mut data = String::new();
file.read_to_string(&mut data).unwrap();
for i in 1..21 {
ks.add_key(&data.replace(
"nOo3ZDrODXEK1jKWhXslHR_KXEg",
format!("bbjwt-test-{i}").as_str(),
))
.expect("Failed to add key to keystore");
}
assert_eq!(ks.keys_len(), 20);
let key1 = ks.key_by_id(None).expect("Failed to get key just added");
assert!(key1.kid.unwrap() == "bbjwt-test-1");
let k = ks
.key_by_id(Some("bbjwt-test-17"))
.expect("Failed to get key by ID");
assert_eq!(k.kid.unwrap(), "bbjwt-test-17");
}
#[tokio::test]
async fn insecure_keyset_load() {
let ret = KeyStore::new_from_url(
"http://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration",
)
.await;
assert!(format!("{:?}", ret).contains("https"));
}
#[tokio::test]
async fn test_load_keys() {
let url = "https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration";
let ks_url = KeyStore::idp_certs_url(url)
.await
.expect("Failed to get keyset URL");
let ks = KeyStore::new_from_url(&ks_url)
.await
.expect("Failed to load keystore");
assert!(ks.load_time.is_some());
assert!(ks.reload_time.is_some());
assert!(ks.reload_time.unwrap() > ks.load_time.unwrap());
println!("KeyStore: {:?}", ks);
assert!(ks.keys_len() > 0);
let keyset = ks.keyset().unwrap();
let key = keyset
.choose(&mut rand::thread_rng())
.expect("Failed to get random key from keyset");
let kid = key
.kid
.clone()
.expect("No kid in key; not an error, but spoils this test...");
let k = ks.key_by_id(Some(&kid)).expect("Failed to get key by id");
assert_eq!(k.kid.expect("Missing kid"), kid);
let k1 = ks
.keyset()
.unwrap()
.first()
.expect("Failed to get first key")
.clone();
let k = ks.key_by_id(None).expect("No key returned without kid");
assert_eq!(k.kid.unwrap().as_str(), k1.kid.unwrap().as_str());
}
}