oidc_jwt_validator/
cache.rs

1use jsonwebtoken::jwk::JwkSet;
2use log::{debug, info};
3use reqwest::header::HeaderValue;
4use std::{
5    collections::HashMap,
6    sync::{
7        atomic::{AtomicBool, AtomicU64, Ordering},
8        Arc,
9    },
10    time::{Duration, Instant},
11};
12
13use crate::{
14    util::{current_time, decode_jwk},
15    DecodingInfo, JwkSetFetch, ValidationSettings,
16};
17
18/// Determines settings about updating the cached JWKS data.
19/// The JWKS will be lazily revalidated every time [validate](crate::Validator) validates a token.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct Settings {
22    /// Time in Seconds to refresh the JWKS from the OIDC Provider
23    /// Default/Minimum value: 1 Second
24    pub max_age: Duration,
25    /// The amount of time a s
26    pub stale_while_revalidate: Option<Duration>,
27    /// The amount of time the stale JWKS data should be valid for if we are unable to re-validate it from the URL.
28    /// Minimum Value: 60 Seconds
29    pub stale_if_error: Option<Duration>,
30}
31
32impl Settings {
33    pub fn from_header_val(value: Option<&HeaderValue>) -> Self {
34        // Initalize the default config of polling every second
35        let mut config = Self::default();
36
37        if let Some(value) = value {
38            if let Ok(value) = value.to_str() {
39                config.parse_str(value);
40            }
41        }
42        config
43    }
44
45    fn parse_str(&mut self, value: &str) {
46        // Iterate over every token in the header value
47        for token in value.split(',') {
48            // split them into whitespace trimmed pairs
49            let (key, val) = {
50                let mut split = token.split('=').map(str::trim);
51                (split.next(), split.next())
52            };
53            //Modify the default config based on the values that matter
54            //Any values here would be more permisssive than the default behavior
55            match (key, val) {
56                (Some("max-age"), Some(val)) => {
57                    if let Ok(secs) = val.parse::<u64>() {
58                        self.max_age = Duration::from_secs(secs);
59                    }
60                }
61                (Some("stale-while-revalidate"), Some(val)) => {
62                    if let Ok(secs) = val.parse::<u64>() {
63                        self.stale_while_revalidate = Some(Duration::from_secs(secs));
64                    }
65                }
66                (Some("stale-if-error"), Some(val)) => {
67                    if let Ok(secs) = val.parse::<u64>() {
68                        self.stale_if_error = Some(Duration::from_secs(secs));
69                    }
70                }
71                _ => continue,
72            };
73        }
74    }
75}
76
77impl Default for Settings {
78    fn default() -> Self {
79        Self {
80            max_age: Duration::from_secs(1),
81            stale_while_revalidate: Some(Duration::from_secs(1)),
82            stale_if_error: Some(Duration::from_secs(60)),
83        }
84    }
85}
86
87/// Determines the JWKS Caching behavior of the validator
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum Strategy {
90    /// The Reccomended Option.
91    /// Determines [Settings] from the cache-control header on a per request basis.
92    /// Allows for dynamic updating of the cache duration during run time.
93    /// If no cache-control headers are present, a lazy 1 second polling interval on the JWKS will be used.
94    Automatic,
95    /// Use a static [Settings] for the lifetime of the program. Ignores cache-control directives
96    /// Not recommended unless you are *really* sure that you know this will be the correct option
97    /// This option could potentially introduce a security vulnerability if the JWKS has changed, and the value was set too high.
98    Manual(Settings),
99}
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum UpdateAction {
102    /// We checked the JWKS uri and it was the same as the last time we refreshed it so no action was taken
103    NoUpdate,
104    /// We checked the JWKS uri and it was different so we updated our local cache
105    JwksUpdate,
106    /// The JWKS Uri responded with a different cache-control header
107    CacheUpdate(Settings),
108    JwksAndCacheUpdate(Settings),
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum Error {
113    MissingKid,
114    DecodeError,
115}
116/// Helper struct for determining when our cache needs to be re-validated
117/// Utilizes atomics to prevent write-locking as much as possible
118#[derive(Debug)]
119pub(crate) struct State {
120    last_update: AtomicU64,
121    is_revalidating: AtomicBool,
122    is_error: AtomicBool,
123}
124
125impl State {
126    pub fn new() -> Self {
127        Self {
128            last_update: AtomicU64::new(current_time()),
129            is_revalidating: AtomicBool::new(false),
130            is_error: AtomicBool::new(false),
131        }
132    }
133    pub fn is_error(&self) -> bool {
134        self.is_error.load(Ordering::SeqCst)
135    }
136    pub fn set_is_error(&self, value: bool) {
137        self.is_error.store(value, Ordering::SeqCst);
138    }
139
140    pub fn last_update(&self) -> u64 {
141        self.last_update.load(Ordering::SeqCst)
142    }
143    pub fn set_last_update(&self, timestamp: u64) {
144        self.last_update.store(timestamp, Ordering::SeqCst);
145    }
146
147    pub fn is_revalidating(&self) -> bool {
148        self.is_revalidating.load(Ordering::SeqCst)
149    }
150
151    pub fn set_is_revalidating(&self, value: bool) {
152        self.is_revalidating.store(value, Ordering::SeqCst);
153    }
154}
155
156impl Default for State {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162/// Helper Struct for storing
163pub struct JwkSetStore {
164    pub jwks: JwkSet,
165    decoding_map: HashMap<String, Arc<DecodingInfo>>,
166    pub cache_policy: Settings,
167    validation: ValidationSettings,
168}
169
170impl JwkSetStore {
171    pub fn new(jwks: JwkSet, cache_config: Settings, validation: ValidationSettings) -> Self {
172        Self {
173            jwks,
174            decoding_map: HashMap::new(),
175            cache_policy: cache_config,
176            validation,
177        }
178    }
179
180    fn update_jwks(&mut self, new_jwks: JwkSet) {
181        self.jwks = new_jwks;
182        let keys = self
183            .jwks
184            .keys
185            .iter()
186            .filter_map(|i| decode_jwk(i, &self.validation).ok());
187        // Clear our cache of decoding keys
188        self.decoding_map.clear();
189        // Load the keys back into our hashmap cache.
190        for key in keys {
191            self.decoding_map.insert(key.0, Arc::new(key.1));
192        }
193    }
194
195    pub fn get_key(&self, kid: &str) -> Option<Arc<DecodingInfo>> {
196        self.decoding_map.get(kid).cloned()
197    }
198
199    pub(crate) fn update_fetch(&mut self, fetch: JwkSetFetch) -> UpdateAction {
200        debug!("Decoding JWKS");
201        let time = Instant::now();
202        let new_jwks = fetch.jwks;
203        // If we didn't parse out a cache policy from the last request
204        // Assume that it's the same as the last
205        let cache_policy = fetch.cache_policy.unwrap_or(self.cache_policy);
206        let result = match (self.jwks == new_jwks, self.cache_policy == cache_policy) {
207            // Everything is the same
208            (true, true) => {
209                debug!("JWKS Content has not changed since last update");
210                UpdateAction::NoUpdate
211            }
212            // The JWKS changed but the cache policy hasn't
213            (false, true) => {
214                info!("JWKS Content has changed since last update");
215                self.update_jwks(new_jwks);
216                UpdateAction::JwksUpdate
217            }
218            // The cache policy changed, but the JWKS hasn't
219            (true, false) => {
220                self.cache_policy = cache_policy;
221                UpdateAction::CacheUpdate(cache_policy)
222            }
223            // Both the cache and the JWKS have changed
224            (false, false) => {
225                info!("cache-control header and JWKS content has changed since last update");
226                self.update_jwks(new_jwks);
227                self.cache_policy = cache_policy;
228                UpdateAction::JwksAndCacheUpdate(cache_policy)
229            }
230        };
231        let elapsed = time.elapsed();
232        debug!("Decoded and parsed JWKS in {:#?}", elapsed);
233        result
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    #[test]
240    fn validate_headers() {
241        let _input = vec![
242            "max-age=604800",
243            "no-cache",
244            "max-age=604800, must-revalidate",
245            "no-store",
246            "public, max-age=604800, immutable",
247            "max-age=604800, stale-while-revalidate=86400",
248            "max-age=604800, stale-if-error=86400",
249        ];
250    }
251}