1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
use jsonwebtoken::jwk::JwkSet;
use log::{debug, info};
use reqwest::header::HeaderValue;
use std::{
    collections::HashMap,
    sync::{
        atomic::{AtomicBool, AtomicU64, Ordering},
        Arc,
    },
    time::{Duration, Instant},
};

use crate::{
    util::{current_time, decode_jwk},
    DecodingInfo, JwkSetFetch, ValidationSettings,
};

/// Determines settings about updating the cached JWKS data.
/// The JWKS will be lazily revalidated every time [validate](crate::Validator) validates a token.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Settings {
    /// Time in Seconds to refresh the JWKS from the OIDC Provider
    /// Default/Minimum value: 1 Second
    pub max_age: Duration,
    /// The amount of time a s
    pub stale_while_revalidate: Option<Duration>,
    /// The amount of time the stale JWKS data should be valid for if we are unable to re-validate it from the URL.
    /// Minimum Value: 60 Seconds
    pub stale_if_error: Option<Duration>,
}

impl Settings {
    pub fn from_header_val(value: Option<&HeaderValue>) -> Self {
        // Initalize the default config of polling every second
        let mut config = Self::default();

        if let Some(value) = value {
            if let Ok(value) = value.to_str() {
                config.parse_str(value);
            }
        }
        config
    }

    fn parse_str(&mut self, value: &str) {
        // Iterate over every token in the header value
        for token in value.split(',') {
            // split them into whitespace trimmed pairs
            let (key, val) = {
                let mut split = token.split('=').map(str::trim);
                (split.next(), split.next())
            };
            //Modify the default config based on the values that matter
            //Any values here would be more permisssive than the default behavior
            match (key, val) {
                (Some("max-age"), Some(val)) => {
                    if let Ok(secs) = val.parse::<u64>() {
                        self.max_age = Duration::from_secs(secs);
                    }
                }
                (Some("stale-while-revalidate"), Some(val)) => {
                    if let Ok(secs) = val.parse::<u64>() {
                        self.stale_while_revalidate = Some(Duration::from_secs(secs));
                    }
                }
                (Some("stale-if-error"), Some(val)) => {
                    if let Ok(secs) = val.parse::<u64>() {
                        self.stale_if_error = Some(Duration::from_secs(secs));
                    }
                }
                _ => continue,
            };
        }
    }
}

impl Default for Settings {
    fn default() -> Self {
        Self {
            max_age: Duration::from_secs(1),
            stale_while_revalidate: Some(Duration::from_secs(1)),
            stale_if_error: Some(Duration::from_secs(60)),
        }
    }
}

/// Determines the JWKS Caching behavior of the validator
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Strategy {
    /// The Reccomended Option.
    /// Determines [Settings] from the cache-control header on a per request basis.
    /// Allows for dynamic updating of the cache duration during run time.
    /// If no cache-control headers are present, a lazy 1 second polling interval on the JWKS will be used.
    Automatic,
    /// Use a static [Settings] for the lifetime of the program. Ignores cache-control directives
    /// Not recommended unless you are *really* sure that you know this will be the correct option
    /// This option could potentially introduce a security vulnerability if the JWKS has changed, and the value was set too high.
    Manual(Settings),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UpdateAction {
    /// We checked the JWKS uri and it was the same as the last time we refreshed it so no action was taken
    NoUpdate,
    /// We checked the JWKS uri and it was different so we updated our local cache
    JwksUpdate,
    /// The JWKS Uri responded with a different cache-control header
    CacheUpdate(Settings),
    JwksAndCacheUpdate(Settings),
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
    MissingKid,
    DecodeError,
}
/// Helper struct for determining when our cache needs to be re-validated
/// Utilizes atomics to prevent write-locking as much as possible
#[derive(Debug)]
pub(crate) struct State {
    last_update: AtomicU64,
    is_revalidating: AtomicBool,
    is_error: AtomicBool,
}

impl State {
    pub fn new() -> Self {
        Self {
            last_update: AtomicU64::new(current_time()),
            is_revalidating: AtomicBool::new(false),
            is_error: AtomicBool::new(false),
        }
    }
    pub fn is_error(&self) -> bool {
        self.is_error.load(Ordering::SeqCst)
    }
    pub fn set_is_error(&self, value: bool) {
        self.is_error.store(value, Ordering::SeqCst);
    }

    pub fn last_update(&self) -> u64 {
        self.last_update.load(Ordering::SeqCst)
    }
    pub fn set_last_update(&self, timestamp: u64) {
        self.last_update.store(timestamp, Ordering::SeqCst);
    }

    pub fn is_revalidating(&self) -> bool {
        self.is_revalidating.load(Ordering::SeqCst)
    }

    pub fn set_is_revalidating(&self, value: bool) {
        self.is_revalidating.store(value, Ordering::SeqCst);
    }
}

impl Default for State {
    fn default() -> Self {
        Self::new()
    }
}

/// Helper Struct for storing
pub struct JwkSetStore {
    pub jwks: JwkSet,
    decoding_map: HashMap<String, Arc<DecodingInfo>>,
    pub cache_policy: Settings,
    validation: ValidationSettings,
}

impl JwkSetStore {
    pub fn new(jwks: JwkSet, cache_config: Settings, validation: ValidationSettings) -> Self {
        Self {
            jwks,
            decoding_map: HashMap::new(),
            cache_policy: cache_config,
            validation,
        }
    }

    fn update_jwks(&mut self, new_jwks: JwkSet) {
        self.jwks = new_jwks;
        let keys = self
            .jwks
            .keys
            .iter()
            .filter_map(|i| decode_jwk(i, &self.validation).ok());
        // Clear our cache of decoding keys
        self.decoding_map.clear();
        // Load the keys back into our hashmap cache.
        for key in keys {
            self.decoding_map.insert(key.0, Arc::new(key.1));
        }
    }

    pub fn get_key(&self, kid: &str) -> Option<Arc<DecodingInfo>> {
        self.decoding_map.get(kid).cloned()
    }

    pub(crate) fn update_fetch(&mut self, fetch: JwkSetFetch) -> UpdateAction {
        debug!("Decoding JWKS");
        let time = Instant::now();
        let new_jwks = fetch.jwks;
        // If we didn't parse out a cache policy from the last request
        // Assume that it's the same as the last
        let cache_policy = fetch.cache_policy.unwrap_or(self.cache_policy);
        let result = match (self.jwks == new_jwks, self.cache_policy == cache_policy) {
            // Everything is the same
            (true, true) => {
                debug!("JWKS Content has not changed since last update");
                UpdateAction::NoUpdate
            }
            // The JWKS changed but the cache policy hasn't
            (false, true) => {
                info!("JWKS Content has changed since last update");
                self.update_jwks(new_jwks);
                UpdateAction::JwksUpdate
            }
            // The cache policy changed, but the JWKS hasn't
            (true, false) => {
                self.cache_policy = cache_policy;
                UpdateAction::CacheUpdate(cache_policy)
            }
            // Both the cache and the JWKS have changed
            (false, false) => {
                info!("cache-control header and JWKS content has changed since last update");
                self.update_jwks(new_jwks);
                self.cache_policy = cache_policy;
                UpdateAction::JwksAndCacheUpdate(cache_policy)
            }
        };
        let elapsed = time.elapsed();
        debug!("Decoded and parsed JWKS in {:#?}", elapsed);
        result
    }
}

#[cfg(test)]
mod tests {
    #[test]
    fn validate_headers() {
        let _input = vec![
            "max-age=604800",
            "no-cache",
            "max-age=604800, must-revalidate",
            "no-store",
            "public, max-age=604800, immutable",
            "max-age=604800, stale-while-revalidate=86400",
            "max-age=604800, stale-if-error=86400",
        ];
    }
}