1use jsonwebtoken::jwk::JwkSet;
2use log::{debug, info};
3use reqwest::header::HeaderValue;
4use std::{
5 collections::HashMap,
6 sync::{
7 atomic::{AtomicBool, AtomicU64, Ordering},
8 Arc,
9 },
10 time::{Duration, Instant},
11};
12
13use crate::{
14 util::{current_time, decode_jwk},
15 DecodingInfo, JwkSetFetch, ValidationSettings,
16};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct Settings {
22 pub max_age: Duration,
25 pub stale_while_revalidate: Option<Duration>,
27 pub stale_if_error: Option<Duration>,
30}
31
32impl Settings {
33 pub fn from_header_val(value: Option<&HeaderValue>) -> Self {
34 let mut config = Self::default();
36
37 if let Some(value) = value {
38 if let Ok(value) = value.to_str() {
39 config.parse_str(value);
40 }
41 }
42 config
43 }
44
45 fn parse_str(&mut self, value: &str) {
46 for token in value.split(',') {
48 let (key, val) = {
50 let mut split = token.split('=').map(str::trim);
51 (split.next(), split.next())
52 };
53 match (key, val) {
56 (Some("max-age"), Some(val)) => {
57 if let Ok(secs) = val.parse::<u64>() {
58 self.max_age = Duration::from_secs(secs);
59 }
60 }
61 (Some("stale-while-revalidate"), Some(val)) => {
62 if let Ok(secs) = val.parse::<u64>() {
63 self.stale_while_revalidate = Some(Duration::from_secs(secs));
64 }
65 }
66 (Some("stale-if-error"), Some(val)) => {
67 if let Ok(secs) = val.parse::<u64>() {
68 self.stale_if_error = Some(Duration::from_secs(secs));
69 }
70 }
71 _ => continue,
72 };
73 }
74 }
75}
76
77impl Default for Settings {
78 fn default() -> Self {
79 Self {
80 max_age: Duration::from_secs(1),
81 stale_while_revalidate: Some(Duration::from_secs(1)),
82 stale_if_error: Some(Duration::from_secs(60)),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum Strategy {
90 Automatic,
95 Manual(Settings),
99}
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum UpdateAction {
102 NoUpdate,
104 JwksUpdate,
106 CacheUpdate(Settings),
108 JwksAndCacheUpdate(Settings),
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum Error {
113 MissingKid,
114 DecodeError,
115}
116#[derive(Debug)]
119pub(crate) struct State {
120 last_update: AtomicU64,
121 is_revalidating: AtomicBool,
122 is_error: AtomicBool,
123}
124
125impl State {
126 pub fn new() -> Self {
127 Self {
128 last_update: AtomicU64::new(current_time()),
129 is_revalidating: AtomicBool::new(false),
130 is_error: AtomicBool::new(false),
131 }
132 }
133 pub fn is_error(&self) -> bool {
134 self.is_error.load(Ordering::SeqCst)
135 }
136 pub fn set_is_error(&self, value: bool) {
137 self.is_error.store(value, Ordering::SeqCst);
138 }
139
140 pub fn last_update(&self) -> u64 {
141 self.last_update.load(Ordering::SeqCst)
142 }
143 pub fn set_last_update(&self, timestamp: u64) {
144 self.last_update.store(timestamp, Ordering::SeqCst);
145 }
146
147 pub fn is_revalidating(&self) -> bool {
148 self.is_revalidating.load(Ordering::SeqCst)
149 }
150
151 pub fn set_is_revalidating(&self, value: bool) {
152 self.is_revalidating.store(value, Ordering::SeqCst);
153 }
154}
155
156impl Default for State {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162pub struct JwkSetStore {
164 pub jwks: JwkSet,
165 decoding_map: HashMap<String, Arc<DecodingInfo>>,
166 pub cache_policy: Settings,
167 validation: ValidationSettings,
168}
169
170impl JwkSetStore {
171 pub fn new(jwks: JwkSet, cache_config: Settings, validation: ValidationSettings) -> Self {
172 Self {
173 jwks,
174 decoding_map: HashMap::new(),
175 cache_policy: cache_config,
176 validation,
177 }
178 }
179
180 fn update_jwks(&mut self, new_jwks: JwkSet) {
181 self.jwks = new_jwks;
182 let keys = self
183 .jwks
184 .keys
185 .iter()
186 .filter_map(|i| decode_jwk(i, &self.validation).ok());
187 self.decoding_map.clear();
189 for key in keys {
191 self.decoding_map.insert(key.0, Arc::new(key.1));
192 }
193 }
194
195 pub fn get_key(&self, kid: &str) -> Option<Arc<DecodingInfo>> {
196 self.decoding_map.get(kid).cloned()
197 }
198
199 pub(crate) fn update_fetch(&mut self, fetch: JwkSetFetch) -> UpdateAction {
200 debug!("Decoding JWKS");
201 let time = Instant::now();
202 let new_jwks = fetch.jwks;
203 let cache_policy = fetch.cache_policy.unwrap_or(self.cache_policy);
206 let result = match (self.jwks == new_jwks, self.cache_policy == cache_policy) {
207 (true, true) => {
209 debug!("JWKS Content has not changed since last update");
210 UpdateAction::NoUpdate
211 }
212 (false, true) => {
214 info!("JWKS Content has changed since last update");
215 self.update_jwks(new_jwks);
216 UpdateAction::JwksUpdate
217 }
218 (true, false) => {
220 self.cache_policy = cache_policy;
221 UpdateAction::CacheUpdate(cache_policy)
222 }
223 (false, false) => {
225 info!("cache-control header and JWKS content has changed since last update");
226 self.update_jwks(new_jwks);
227 self.cache_policy = cache_policy;
228 UpdateAction::JwksAndCacheUpdate(cache_policy)
229 }
230 };
231 let elapsed = time.elapsed();
232 debug!("Decoded and parsed JWKS in {:#?}", elapsed);
233 result
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 #[test]
240 fn validate_headers() {
241 let _input = vec![
242 "max-age=604800",
243 "no-cache",
244 "max-age=604800, must-revalidate",
245 "no-store",
246 "public, max-age=604800, immutable",
247 "max-age=604800, stale-while-revalidate=86400",
248 "max-age=604800, stale-if-error=86400",
249 ];
250 }
251}