1use core::fmt;
2use std::{
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use jsonwebtoken::{DecodingKey, Validation, decode, decode_header};
8use parking_lot::Mutex;
9use serde::Deserialize;
10use serde_json::Value;
11use tokio::task::JoinHandle;
12use tracing::{debug, error};
13
14#[derive(thiserror::Error, Debug)]
15pub enum Error {
16 #[error("{0}")]
17 AccessDenied(String),
18 #[error("{0}")]
19 InvalidData(String),
20 #[error("{0}")]
21 NotReady(String),
22 #[error("{0}")]
23 Failed(String),
24 #[error("{0}")]
25 Unsupported(String),
26 #[error("{0}")]
27 Io(String),
28 #[error("Timed out")]
29 Timeout,
30}
31
32impl Error {
33 pub fn access<T: fmt::Display>(msg: T) -> Self {
34 Error::AccessDenied(msg.to_string())
35 }
36 pub fn invalid_data<T: fmt::Display>(msg: T) -> Self {
37 Error::InvalidData(msg.to_string())
38 }
39 pub fn not_ready<T: fmt::Display>(msg: T) -> Self {
40 Error::NotReady(msg.to_string())
41 }
42 pub fn failed<T: fmt::Display>(msg: T) -> Self {
43 Error::Failed(msg.to_string())
44 }
45 pub fn unsupported<T: fmt::Display>(msg: T) -> Self {
46 Error::Unsupported(msg.to_string())
47 }
48}
49
50impl From<std::io::Error> for Error {
51 fn from(err: std::io::Error) -> Self {
52 Error::Io(err.to_string())
53 }
54}
55
56impl From<tokio::time::error::Elapsed> for Error {
57 fn from(_: tokio::time::error::Elapsed) -> Self {
58 Error::Timeout
59 }
60}
61
62pub type Result<T> = std::result::Result<T, Error>;
63
64async fn fetch_jwks(path: &str) -> Result<JWKeys> {
65 if path.starts_with("http://") || path.starts_with("https://") {
66 let res = reqwest::get(path).await.map_err(Error::access)?;
67 let status = res.status();
68 if !status.is_success() {
69 return Err(Error::access(format!(
70 "Failed to fetch JWKs: HTTP {}",
71 status
72 )));
73 }
74 let data = res.text().await.map_err(Error::access)?;
75 let jwks: JWKeys = serde_json::from_str(&data).map_err(Error::invalid_data)?;
76 return Ok(jwks);
77 }
78 let data = tokio::fs::read_to_string(path).await?;
79 let jwks: JWKeys = serde_json::from_str(&data).map_err(Error::invalid_data)?;
80 Ok(jwks)
81}
82
83async fn safe_fetch_jwks(path: &str, timeout: Duration) -> Result<JWKeys> {
84 tokio::time::timeout(timeout, fetch_jwks(path)).await?
85}
86
87async fn fetcher(
88 jwks: Arc<Mutex<Option<JWKeys>>>,
89 path: String,
90 timeout: Duration,
91 refresh_interval: Duration,
92 retry_delay: Duration,
93 failed_after: Duration,
94) {
95 let mut int = tokio::time::interval(refresh_interval);
96 int.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
97 loop {
98 int.tick().await;
99 let ts = Instant::now();
100 loop {
101 match safe_fetch_jwks(&path, timeout).await {
102 Ok(v) => {
103 let mut guard = jwks.lock();
104 *guard = Some(v);
105 debug!(path, "Successfully updated JWKs");
106 break;
107 }
108 Err(e) => {
109 if failed_after > Duration::ZERO && ts.elapsed() > failed_after {
110 jwks.lock().take();
111 }
112 error!(path, error=%e, "Failed to fetch JWKs");
113 tokio::time::sleep(retry_delay).await;
114 }
115 }
116 }
117 }
118}
119
120pub struct Verifier {
121 sub_field: String,
122 jwk: Arc<Mutex<Option<JWKeys>>>,
123 fetcher: JoinHandle<()>,
124}
125
126impl Verifier {
127 pub fn create(
128 path: &str,
129 timeout: Duration,
130 refresh_interval: Duration,
131 retry_delay: Duration,
132 failed_after: Duration,
133 sub_field: &str,
134 ) -> Self {
135 let jwk = Arc::new(Mutex::new(None));
136 let handle = tokio::spawn({
137 let jwk = jwk.clone();
138 let path = path.to_string();
139 async move {
140 fetcher(
141 jwk,
142 path,
143 timeout,
144 refresh_interval,
145 retry_delay,
146 failed_after,
147 )
148 .await
149 }
150 });
151 Self {
152 sub_field: sub_field.to_string(),
153 jwk,
154 fetcher: handle,
155 }
156 }
157 pub fn verify(&self, token: &str) -> Result<String> {
158 let guard = self.jwk.lock();
159 let jwk = guard
160 .as_ref()
161 .ok_or_else(|| Error::not_ready("JWKs not loaded"))?;
162 check(jwk, token, &self.sub_field).ok_or_else(|| Error::access("JWT verification failed"))
163 }
164}
165
166impl Drop for Verifier {
167 fn drop(&mut self) {
168 self.fetcher.abort();
169 }
170}
171
172#[derive(Deserialize)]
173#[serde(untagged)]
174enum JWKeys {
175 Multiple { keys: Vec<Value> },
176 Single(Value),
177}
178
179fn check(keys: &JWKeys, token: &str, sub_field: &str) -> Option<String> {
180 match keys {
181 JWKeys::Multiple { keys } => {
182 for jwk in keys {
183 match check_jwt(jwk, token, sub_field) {
184 Ok(sub) => return Some(sub),
185 Err(e) => {
186 debug!("JWK {} did not validate token: {}", jwk, e);
187 }
188 }
189 }
190 None
191 }
192 JWKeys::Single(jwk) => match check_jwt(jwk, token, sub_field) {
193 Ok(sub) => Some(sub),
194 Err(e) => {
195 debug!("JWK did not validate token: {}", e);
196 None
197 }
198 },
199 }
200}
201
202fn check_jwt(jwk: &Value, token: &str, sub_field: &str) -> Result<String> {
203 let header = decode_header(token).map_err(Error::invalid_data)?;
204 let alg = header.alg;
205
206 let decoding_key = match jwk["kty"].as_str() {
207 Some("RSA") => {
208 let n = jwk["n"]
209 .as_str()
210 .ok_or_else(|| Error::invalid_data("JWK missing n"))?;
211 let e = jwk["e"]
212 .as_str()
213 .ok_or_else(|| Error::invalid_data("JWK missing e"))?;
214 DecodingKey::from_rsa_components(n, e).map_err(Error::failed)?
215 }
216 Some("EC") => {
217 let crv = jwk["crv"]
218 .as_str()
219 .ok_or_else(|| Error::invalid_data("JWK missing crv"))?;
220 let x = jwk["x"]
221 .as_str()
222 .ok_or_else(|| Error::invalid_data("JWK missing x"))?;
223 let y = jwk["y"]
224 .as_str()
225 .ok_or_else(|| Error::invalid_data("JWK missing y"))?;
226 match crv {
227 "P-256" | "P-384" => {
228 DecodingKey::from_ec_components(x, y).map_err(Error::failed)?
229 }
230 _ => return Err(Error::failed(format!("Unsupported EC curve: {}", crv))),
231 }
232 }
233 Some("oct") => {
234 let k = jwk["k"]
235 .as_str()
236 .ok_or_else(|| Error::invalid_data("JWK missing k"))?;
237 DecodingKey::from_base64_secret(k).map_err(Error::failed)?
238 }
239 Some("OKP") => {
240 let crv = jwk["crv"]
241 .as_str()
242 .ok_or_else(|| Error::invalid_data("JWK missing crv"))?;
243 let x = jwk["x"]
244 .as_str()
245 .ok_or_else(|| Error::invalid_data("JWK missing x"))?;
246 match crv {
247 "Ed25519" => DecodingKey::from_ed_der(
248 base64_url::decode(x)
249 .map_err(Error::invalid_data)?
250 .as_slice(),
251 ),
252 _ => {
253 return Err(Error::invalid_data(format!(
254 "Unsupported OKP curve: {}",
255 crv
256 )));
257 }
258 }
259 }
260 other => {
261 return Err(Error::unsupported(format!(
262 "Unsupported key type: {:?}",
263 other
264 )));
265 }
266 };
267
268 let validation = Validation::new(alg);
269
270 let claims = decode::<Value>(token, &decoding_key, &validation)
271 .map_err(Error::access)?
272 .claims;
273 let Value::Object(obj) = claims else {
274 return Err(Error::invalid_data("Claims is not an object"));
275 };
276 let sub_value = obj
277 .get(sub_field)
278 .ok_or_else(|| Error::invalid_data(format!("Missing field: {}", sub_field)))?;
279 let sub = String::deserialize(sub_value)
280 .map_err(|_| Error::invalid_data(format!("Field {} is not a string", sub_field)))?;
281 Ok(sub)
282}