Skip to main content

oidc_verifier/
lib.rs

1use core::fmt;
2use std::{
3    sync::Arc,
4    time::{Duration, Instant},
5};
6
7use jsonwebtoken::{DecodingKey, Validation, decode, decode_header};
8use parking_lot::Mutex;
9use serde::Deserialize;
10use serde_json::Value;
11use tokio::task::JoinHandle;
12use tracing::{debug, error};
13
14#[derive(thiserror::Error, Debug)]
15pub enum Error {
16    #[error("{0}")]
17    AccessDenied(String),
18    #[error("{0}")]
19    InvalidData(String),
20    #[error("{0}")]
21    NotReady(String),
22    #[error("{0}")]
23    Failed(String),
24    #[error("{0}")]
25    Unsupported(String),
26    #[error("{0}")]
27    Io(String),
28    #[error("Timed out")]
29    Timeout,
30}
31
32impl Error {
33    pub fn access<T: fmt::Display>(msg: T) -> Self {
34        Error::AccessDenied(msg.to_string())
35    }
36    pub fn invalid_data<T: fmt::Display>(msg: T) -> Self {
37        Error::InvalidData(msg.to_string())
38    }
39    pub fn not_ready<T: fmt::Display>(msg: T) -> Self {
40        Error::NotReady(msg.to_string())
41    }
42    pub fn failed<T: fmt::Display>(msg: T) -> Self {
43        Error::Failed(msg.to_string())
44    }
45    pub fn unsupported<T: fmt::Display>(msg: T) -> Self {
46        Error::Unsupported(msg.to_string())
47    }
48}
49
50impl From<std::io::Error> for Error {
51    fn from(err: std::io::Error) -> Self {
52        Error::Io(err.to_string())
53    }
54}
55
56impl From<tokio::time::error::Elapsed> for Error {
57    fn from(_: tokio::time::error::Elapsed) -> Self {
58        Error::Timeout
59    }
60}
61
62pub type Result<T> = std::result::Result<T, Error>;
63
64async fn fetch_jwks(path: &str) -> Result<JWKeys> {
65    if path.starts_with("http://") || path.starts_with("https://") {
66        let res = reqwest::get(path).await.map_err(Error::access)?;
67        let status = res.status();
68        if !status.is_success() {
69            return Err(Error::access(format!(
70                "Failed to fetch JWKs: HTTP {}",
71                status
72            )));
73        }
74        let data = res.text().await.map_err(Error::access)?;
75        let jwks: JWKeys = serde_json::from_str(&data).map_err(Error::invalid_data)?;
76        return Ok(jwks);
77    }
78    let data = tokio::fs::read_to_string(path).await?;
79    let jwks: JWKeys = serde_json::from_str(&data).map_err(Error::invalid_data)?;
80    Ok(jwks)
81}
82
83async fn safe_fetch_jwks(path: &str, timeout: Duration) -> Result<JWKeys> {
84    tokio::time::timeout(timeout, fetch_jwks(path)).await?
85}
86
87async fn fetcher(
88    jwks: Arc<Mutex<Option<JWKeys>>>,
89    path: String,
90    timeout: Duration,
91    refresh_interval: Duration,
92    retry_delay: Duration,
93    failed_after: Duration,
94) {
95    let mut int = tokio::time::interval(refresh_interval);
96    int.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
97    loop {
98        int.tick().await;
99        let ts = Instant::now();
100        loop {
101            match safe_fetch_jwks(&path, timeout).await {
102                Ok(v) => {
103                    let mut guard = jwks.lock();
104                    *guard = Some(v);
105                    debug!(path, "Successfully updated JWKs");
106                    break;
107                }
108                Err(e) => {
109                    if failed_after > Duration::ZERO && ts.elapsed() > failed_after {
110                        jwks.lock().take();
111                    }
112                    error!(path, error=%e, "Failed to fetch JWKs");
113                    tokio::time::sleep(retry_delay).await;
114                }
115            }
116        }
117    }
118}
119
120pub struct Verifier {
121    sub_field: String,
122    jwk: Arc<Mutex<Option<JWKeys>>>,
123    fetcher: JoinHandle<()>,
124}
125
126impl Verifier {
127    pub fn create(
128        path: &str,
129        timeout: Duration,
130        refresh_interval: Duration,
131        retry_delay: Duration,
132        failed_after: Duration,
133        sub_field: &str,
134    ) -> Self {
135        let jwk = Arc::new(Mutex::new(None));
136        let handle = tokio::spawn({
137            let jwk = jwk.clone();
138            let path = path.to_string();
139            async move {
140                fetcher(
141                    jwk,
142                    path,
143                    timeout,
144                    refresh_interval,
145                    retry_delay,
146                    failed_after,
147                )
148                .await
149            }
150        });
151        Self {
152            sub_field: sub_field.to_string(),
153            jwk,
154            fetcher: handle,
155        }
156    }
157    pub fn verify(&self, token: &str) -> Result<String> {
158        let guard = self.jwk.lock();
159        let jwk = guard
160            .as_ref()
161            .ok_or_else(|| Error::not_ready("JWKs not loaded"))?;
162        check(jwk, token, &self.sub_field).ok_or_else(|| Error::access("JWT verification failed"))
163    }
164}
165
166impl Drop for Verifier {
167    fn drop(&mut self) {
168        self.fetcher.abort();
169    }
170}
171
172#[derive(Deserialize)]
173#[serde(untagged)]
174enum JWKeys {
175    Multiple { keys: Vec<Value> },
176    Single(Value),
177}
178
179fn check(keys: &JWKeys, token: &str, sub_field: &str) -> Option<String> {
180    match keys {
181        JWKeys::Multiple { keys } => {
182            for jwk in keys {
183                match check_jwt(jwk, token, sub_field) {
184                    Ok(sub) => return Some(sub),
185                    Err(e) => {
186                        debug!("JWK {} did not validate token: {}", jwk, e);
187                    }
188                }
189            }
190            None
191        }
192        JWKeys::Single(jwk) => match check_jwt(jwk, token, sub_field) {
193            Ok(sub) => Some(sub),
194            Err(e) => {
195                debug!("JWK did not validate token: {}", e);
196                None
197            }
198        },
199    }
200}
201
202fn check_jwt(jwk: &Value, token: &str, sub_field: &str) -> Result<String> {
203    let header = decode_header(token).map_err(Error::invalid_data)?;
204    let alg = header.alg;
205
206    let decoding_key = match jwk["kty"].as_str() {
207        Some("RSA") => {
208            let n = jwk["n"]
209                .as_str()
210                .ok_or_else(|| Error::invalid_data("JWK missing n"))?;
211            let e = jwk["e"]
212                .as_str()
213                .ok_or_else(|| Error::invalid_data("JWK missing e"))?;
214            DecodingKey::from_rsa_components(n, e).map_err(Error::failed)?
215        }
216        Some("EC") => {
217            let crv = jwk["crv"]
218                .as_str()
219                .ok_or_else(|| Error::invalid_data("JWK missing crv"))?;
220            let x = jwk["x"]
221                .as_str()
222                .ok_or_else(|| Error::invalid_data("JWK missing x"))?;
223            let y = jwk["y"]
224                .as_str()
225                .ok_or_else(|| Error::invalid_data("JWK missing y"))?;
226            match crv {
227                "P-256" | "P-384" => {
228                    DecodingKey::from_ec_components(x, y).map_err(Error::failed)?
229                }
230                _ => return Err(Error::failed(format!("Unsupported EC curve: {}", crv))),
231            }
232        }
233        Some("oct") => {
234            let k = jwk["k"]
235                .as_str()
236                .ok_or_else(|| Error::invalid_data("JWK missing k"))?;
237            DecodingKey::from_base64_secret(k).map_err(Error::failed)?
238        }
239        Some("OKP") => {
240            let crv = jwk["crv"]
241                .as_str()
242                .ok_or_else(|| Error::invalid_data("JWK missing crv"))?;
243            let x = jwk["x"]
244                .as_str()
245                .ok_or_else(|| Error::invalid_data("JWK missing x"))?;
246            match crv {
247                "Ed25519" => DecodingKey::from_ed_der(
248                    base64_url::decode(x)
249                        .map_err(Error::invalid_data)?
250                        .as_slice(),
251                ),
252                _ => {
253                    return Err(Error::invalid_data(format!(
254                        "Unsupported OKP curve: {}",
255                        crv
256                    )));
257                }
258            }
259        }
260        other => {
261            return Err(Error::unsupported(format!(
262                "Unsupported key type: {:?}",
263                other
264            )));
265        }
266    };
267
268    let validation = Validation::new(alg);
269
270    let claims = decode::<Value>(token, &decoding_key, &validation)
271        .map_err(Error::access)?
272        .claims;
273    let Value::Object(obj) = claims else {
274        return Err(Error::invalid_data("Claims is not an object"));
275    };
276    let sub_value = obj
277        .get(sub_field)
278        .ok_or_else(|| Error::invalid_data(format!("Missing field: {}", sub_field)))?;
279    let sub = String::deserialize(sub_value)
280        .map_err(|_| Error::invalid_data(format!("Field {} is not a string", sub_field)))?;
281    Ok(sub)
282}