use anyhow::Result;
use jsonwebtoken::DecodingKey;
use serde::Deserialize;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Jwk {
kid: String,
n: String,
e: String,
alg: String,
kty: String,
#[serde(rename = "use")]
use_: String,
}
#[derive(Debug, Deserialize)]
struct JwkSet {
keys: Vec<Jwk>,
}
#[derive(Clone)]
pub struct JwksStore {
url: String,
keys: Arc<RwLock<HashMap<String, DecodingKey>>>,
}
impl JwksStore {
pub async fn new(url: String) -> Result<Self> {
let keys = Self::fetch_keys(&url).await?;
Ok(Self {
url,
keys: Arc::new(RwLock::new(keys)),
})
}
pub async fn refresh(&self) -> Result<()> {
let keys = Self::fetch_keys(&self.url).await?;
let mut w = self.keys.write().await;
*w = keys;
Ok(())
}
async fn fetch_keys(url: &str) -> Result<HashMap<String, DecodingKey>> {
let body = reqwest::get(url).await?.text().await?;
let jwks: JwkSet = serde_json::from_str(&body)?;
let mut map = HashMap::new();
for key in jwks.keys {
if key.kty == "RSA" && key.use_ == "sig" {
let decoding_key = DecodingKey::from_rsa_components(&key.n, &key.e)?;
map.insert(key.kid.clone(), decoding_key);
}
}
Ok(map)
}
pub async fn get_key(&self, kid: &str) -> Option<DecodingKey> {
let r = self.keys.read().await;
r.get(kid).cloned()
}
}