use std::collections::HashMap;
use std::collections::HashSet;
use std::mem;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use futures::future::Either;
use futures::future::join_all;
use futures::future::select;
use futures::pin_mut;
use futures::stream::repeat;
use futures::stream::select_all;
use http::header::ACCEPT;
use jsonwebtoken::Algorithm;
use jsonwebtoken::jwk::Jwk;
use jsonwebtoken::jwk::JwkSet;
use mime::APPLICATION_JSON;
use serde_json::Value;
use tokio::fs::read_to_string;
use tokio::sync::oneshot;
use tower::BoxError;
use tracing_futures::Instrument;
use url::Url;
use super::CLIENT;
use super::DEFAULT_AUTHENTICATION_NETWORK_TIMEOUT;
use super::Header;
#[derive(Clone)]
pub(super) struct JwksManager {
list: Vec<JwksConfig>,
jwks_map: Arc<RwLock<HashMap<Url, JwkSet>>>,
_drop_signal: Arc<oneshot::Sender<()>>,
}
#[derive(Clone)]
pub(super) struct JwksConfig {
pub(super) url: Url,
pub(super) issuer: Option<String>,
pub(super) algorithms: Option<HashSet<Algorithm>>,
pub(super) poll_interval: Duration,
pub(super) headers: Vec<Header>,
}
#[derive(Clone)]
pub(super) struct JwkSetInfo {
pub(super) jwks: JwkSet,
pub(super) issuer: Option<String>,
pub(super) algorithms: Option<HashSet<Algorithm>>,
}
impl JwksManager {
pub(super) async fn new(list: Vec<JwksConfig>) -> Result<Self, BoxError> {
use futures::FutureExt;
let downloads = list
.iter()
.cloned()
.map(|JwksConfig { url, headers, .. }| {
let span = tracing::info_span!("fetch jwks", url = %url);
get_jwks(url.clone(), headers.clone())
.map(|opt_jwks| opt_jwks.map(|jwks| (url, jwks)))
.instrument(span)
})
.collect::<Vec<_>>();
let jwks_map: HashMap<_, _> = join_all(downloads).await.into_iter().flatten().collect();
let jwks_map = Arc::new(RwLock::new(jwks_map));
let (_drop_signal, drop_receiver) = oneshot::channel::<()>();
tokio::task::spawn(poll(list.clone(), jwks_map.clone(), drop_receiver));
Ok(JwksManager {
list,
jwks_map,
_drop_signal: Arc::new(_drop_signal),
})
}
#[cfg(test)]
pub(super) fn new_test(list: Vec<JwksConfig>, jwks: HashMap<Url, JwkSet>) -> Self {
let (_drop_signal, _) = oneshot::channel::<()>();
JwksManager {
list,
jwks_map: Arc::new(RwLock::new(jwks)),
_drop_signal: Arc::new(_drop_signal),
}
}
pub(super) fn iter_jwks(&self) -> Iter {
Iter {
list: self.list.clone(),
manager: self,
}
}
}
async fn poll(
list: Vec<JwksConfig>,
jwks_map: Arc<RwLock<HashMap<Url, JwkSet>>>,
drop_receiver: oneshot::Receiver<()>,
) {
use futures::stream::StreamExt;
let mut streams = select_all(list.into_iter().map(move |config| {
let jwks_map = jwks_map.clone();
Box::pin(
repeat((config, jwks_map)).then(|(config, jwks_map)| async move {
tokio::time::sleep(config.poll_interval).await;
if let Some(jwks) = get_jwks(config.url.clone(), config.headers.clone()).await {
if let Ok(mut map) = jwks_map.write() {
map.insert(config.url, jwks);
}
}
}),
)
}));
pin_mut!(drop_receiver);
loop {
let next = streams.next();
pin_mut!(next);
match select(drop_receiver, next).await {
Either::Left((_res, _)) => return,
Either::Right((Some(()), receiver)) => {
drop_receiver = receiver;
}
Either::Right((None, _)) => return,
};
}
}
pub(super) async fn get_jwks(url: Url, headers: Vec<Header>) -> Option<JwkSet> {
let data = if url.scheme() == "file" {
let path = url
.to_file_path()
.inspect_err(|_| {
tracing::error!("url cannot be converted to filesystem path");
})
.ok()?;
read_to_string(path)
.await
.inspect_err(|e| {
tracing::error!(%e, "could not read JWKS path");
})
.ok()?
} else {
let my_client = CLIENT
.as_ref()
.inspect_err(|e| {
tracing::error!(%e, "could not activate authentication feature");
})
.ok()?
.clone();
let mut builder = my_client
.get(url)
.header(ACCEPT, APPLICATION_JSON.essence_str());
for header in headers.into_iter() {
builder = builder.header(header.name, header.value);
}
builder
.timeout(DEFAULT_AUTHENTICATION_NETWORK_TIMEOUT)
.send()
.await
.map_err(|e| {
tracing::error!(%e, "could not get url");
e
})
.ok()?
.text()
.await
.map_err(|e| {
tracing::error!(%e, "could not process url content");
e
})
.ok()?
};
let jwks = parse_jwks(&data)?;
Some(jwks)
}
pub(crate) fn parse_jwks(data: &str) -> Option<JwkSet> {
tracing::debug!(data, "parsing JWKS");
let mut raw_json: Value = serde_json::from_str(data)
.map_err(|e| {
tracing::error!(%e, "could not create JSON Value from url content, enable debug logs to see content");
e
})
.ok()?;
raw_json.get_mut("keys").and_then(|keys| {
keys.as_array_mut().map(|array| {
*array = mem::take(array).into_iter().enumerate().filter(|(index, key)| {
if let Err(err) = serde_json::from_value::<Jwk>(key.clone()) {
let alg = key.get("alg").and_then(|alg|alg.as_str()).unwrap_or("<unknown>");
tracing::warn!(%err, alg, index, "ignoring a key since it is not valid, enable debug logs to full content");
return false;
}
true
}).map(|(_, key)| key).collect();
})
});
let jwks: JwkSet = serde_json::from_value(raw_json)
.map_err(|e| {
tracing::error!(%e, "could not create JWKS from url content, enable debug logs to see content");
e
})
.ok()?;
Some(jwks)
}
pub(super) struct Iter<'a> {
manager: &'a JwksManager,
list: Vec<JwksConfig>,
}
impl Iterator for Iter<'_> {
type Item = JwkSetInfo;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.list.pop() {
None => return None,
Some(config) => {
if let Ok(map) = self.manager.jwks_map.read() {
if let Some(jwks) = map.get(&config.url) {
return Some(JwkSetInfo {
jwks: jwks.clone(),
issuer: config.issuer.clone(),
algorithms: config.algorithms.clone(),
});
}
} else {
return None;
}
}
}
}
}
}