1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
use jsonwebtoken::jwk::JwkSet;
use log::{debug, info};
use reqwest::header::HeaderValue;
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use crate::{
util::{current_time, decode_jwk},
DecodingInfo, JwkSetFetch, ValidationSettings,
};
/// Determines settings about updating the cached JWKS data.
/// The JWKS will be lazily revalidated every time [validate](crate::Validator) validates a token.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Settings {
/// Time in Seconds to refresh the JWKS from the OIDC Provider
/// Default/Minimum value: 1 Second
pub max_age: Duration,
/// The amount of time a s
pub stale_while_revalidate: Option<Duration>,
/// The amount of time the stale JWKS data should be valid for if we are unable to re-validate it from the URL.
/// Minimum Value: 60 Seconds
pub stale_if_error: Option<Duration>,
}
impl Settings {
pub fn from_header_val(value: Option<&HeaderValue>) -> Self {
// Initalize the default config of polling every second
let mut config = Self::default();
if let Some(value) = value {
if let Ok(value) = value.to_str() {
config.parse_str(value);
}
}
config
}
fn parse_str(&mut self, value: &str) {
// Iterate over every token in the header value
for token in value.split(',') {
// split them into whitespace trimmed pairs
let (key, val) = {
let mut split = token.split('=').map(str::trim);
(split.next(), split.next())
};
//Modify the default config based on the values that matter
//Any values here would be more permisssive than the default behavior
match (key, val) {
(Some("max-age"), Some(val)) => {
if let Ok(secs) = val.parse::<u64>() {
self.max_age = Duration::from_secs(secs);
}
}
(Some("stale-while-revalidate"), Some(val)) => {
if let Ok(secs) = val.parse::<u64>() {
self.stale_while_revalidate = Some(Duration::from_secs(secs));
}
}
(Some("stale-if-error"), Some(val)) => {
if let Ok(secs) = val.parse::<u64>() {
self.stale_if_error = Some(Duration::from_secs(secs));
}
}
_ => continue,
};
}
}
}
impl Default for Settings {
fn default() -> Self {
Self {
max_age: Duration::from_secs(1),
stale_while_revalidate: Some(Duration::from_secs(1)),
stale_if_error: Some(Duration::from_secs(60)),
}
}
}
/// Determines the JWKS Caching behavior of the validator
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Strategy {
/// The Reccomended Option.
/// Determines [Settings] from the cache-control header on a per request basis.
/// Allows for dynamic updating of the cache duration during run time.
/// If no cache-control headers are present, a lazy 1 second polling interval on the JWKS will be used.
Automatic,
/// Use a static [Settings] for the lifetime of the program. Ignores cache-control directives
/// Not recommended unless you are *really* sure that you know this will be the correct option
/// This option could potentially introduce a security vulnerability if the JWKS has changed, and the value was set too high.
Manual(Settings),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UpdateAction {
/// We checked the JWKS uri and it was the same as the last time we refreshed it so no action was taken
NoUpdate,
/// We checked the JWKS uri and it was different so we updated our local cache
JwksUpdate,
/// The JWKS Uri responded with a different cache-control header
CacheUpdate(Settings),
JwksAndCacheUpdate(Settings),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
MissingKid,
DecodeError,
}
/// Helper struct for determining when our cache needs to be re-validated
/// Utilizes atomics to prevent write-locking as much as possible
#[derive(Debug)]
pub(crate) struct State {
last_update: AtomicU64,
is_revalidating: AtomicBool,
is_error: AtomicBool,
}
impl State {
pub fn new() -> Self {
Self {
last_update: AtomicU64::new(current_time()),
is_revalidating: AtomicBool::new(false),
is_error: AtomicBool::new(false),
}
}
pub fn is_error(&self) -> bool {
self.is_error.load(Ordering::SeqCst)
}
pub fn set_is_error(&self, value: bool) {
self.is_error.store(value, Ordering::SeqCst);
}
pub fn last_update(&self) -> u64 {
self.last_update.load(Ordering::SeqCst)
}
pub fn set_last_update(&self, timestamp: u64) {
self.last_update.store(timestamp, Ordering::SeqCst);
}
pub fn is_revalidating(&self) -> bool {
self.is_revalidating.load(Ordering::SeqCst)
}
pub fn set_is_revalidating(&self, value: bool) {
self.is_revalidating.store(value, Ordering::SeqCst);
}
}
impl Default for State {
fn default() -> Self {
Self::new()
}
}
/// Helper Struct for storing
pub struct JwkSetStore {
pub jwks: JwkSet,
decoding_map: HashMap<String, Arc<DecodingInfo>>,
pub cache_policy: Settings,
validation: ValidationSettings,
}
impl JwkSetStore {
pub fn new(jwks: JwkSet, cache_config: Settings, validation: ValidationSettings) -> Self {
Self {
jwks,
decoding_map: HashMap::new(),
cache_policy: cache_config,
validation,
}
}
fn update_jwks(&mut self, new_jwks: JwkSet) {
self.jwks = new_jwks;
let keys = self
.jwks
.keys
.iter()
.filter_map(|i| decode_jwk(i, &self.validation).ok());
// Clear our cache of decoding keys
self.decoding_map.clear();
// Load the keys back into our hashmap cache.
for key in keys {
self.decoding_map.insert(key.0, Arc::new(key.1));
}
}
pub fn get_key(&self, kid: &str) -> Option<Arc<DecodingInfo>> {
self.decoding_map.get(kid).cloned()
}
pub(crate) fn update_fetch(&mut self, fetch: JwkSetFetch) -> UpdateAction {
debug!("Decoding JWKS");
let time = Instant::now();
let new_jwks = fetch.jwks;
// If we didn't parse out a cache policy from the last request
// Assume that it's the same as the last
let cache_policy = fetch.cache_policy.unwrap_or(self.cache_policy);
let result = match (self.jwks == new_jwks, self.cache_policy == cache_policy) {
// Everything is the same
(true, true) => {
debug!("JWKS Content has not changed since last update");
UpdateAction::NoUpdate
}
// The JWKS changed but the cache policy hasn't
(false, true) => {
info!("JWKS Content has changed since last update");
self.update_jwks(new_jwks);
UpdateAction::JwksUpdate
}
// The cache policy changed, but the JWKS hasn't
(true, false) => {
self.cache_policy = cache_policy;
UpdateAction::CacheUpdate(cache_policy)
}
// Both the cache and the JWKS have changed
(false, false) => {
info!("cache-control header and JWKS content has changed since last update");
self.update_jwks(new_jwks);
self.cache_policy = cache_policy;
UpdateAction::JwksAndCacheUpdate(cache_policy)
}
};
let elapsed = time.elapsed();
debug!("Decoded and parsed JWKS in {:#?}", elapsed);
result
}
}
#[cfg(test)]
mod tests {
#[test]
fn validate_headers() {
let _input = vec![
"max-age=604800",
"no-cache",
"max-age=604800, must-revalidate",
"no-store",
"public, max-age=604800, immutable",
"max-age=604800, stale-while-revalidate=86400",
"max-age=604800, stale-if-error=86400",
];
}
}