jwks_rs/
lib.rs

1// Mostly edit from https://github.com/cdriehuys/axum-jwks/blob/main/axum-jwks/src/jwks.rs
2
3use std::collections::HashMap;
4
5use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
6use jsonwebtoken::{
7    jwk::{self},
8    DecodingKey,
9};
10use serde::Deserialize;
11use thiserror::Error;
12
13/// A container for a set of JWT decoding keys.
14///
15/// The container can be used to validate any JWT that identifies a known key
16/// through the `kid` attribute in the token's header.
17#[derive(Clone)]
18#[allow(dead_code)]
19pub struct Jwks {
20    pub keys: HashMap<String, Jwk>,
21}
22
23#[derive(Deserialize)]
24struct OIDCConfig {
25    jwks_uri: String,
26}
27
28impl Jwks {
29    /// # Arguments
30    /// * `oidc_url` - The url with OpenID configuration, e.g. https://accounts.google.com/.well-known/openid-configuration
31    pub async fn from_oidc_url(oidc_url: impl Into<String>) -> Result<Self, JwksError> {
32        Self::from_oidc_url_with_client(&reqwest::Client::default(), oidc_url).await
33    }
34
35    /// A version of [`from_oidc_url`][Self::from_oidc_url] that allows for
36    /// passing in a custom [`Client`][reqwest::Client].
37    pub async fn from_oidc_url_with_client(
38        client: &reqwest::Client,
39        oidc_url: impl Into<String>,
40    ) -> Result<Self, JwksError> {
41        let oidc_config = client
42            .get(oidc_url.into())
43            .send()
44            .await?
45            .json::<OIDCConfig>()
46            .await?;
47        let jwks_uri = oidc_config.jwks_uri;
48
49        Self::from_jwks_url_with_client(&reqwest::Client::default(), &jwks_uri).await
50    }
51
52    /// # Arguments
53    /// * `jwks_url` - The url which JWKS info is pulled from, e.g. https://www.googleapis.com/oauth2/v3/certs
54    pub async fn from_jwks_url(jwks_url: impl Into<String>) -> Result<Self, JwksError> {
55        Self::from_jwks_url_with_client(&reqwest::Client::default(), jwks_url.into()).await
56    }
57
58    /// A version of [`from_jwks_url`][Self::from_jwks_url] that allows for
59    /// passing in a custom [`Client`][reqwest::Client].
60    pub async fn from_jwks_url_with_client(
61        client: &reqwest::Client,
62        jwks_url: impl Into<String>,
63    ) -> Result<Self, JwksError> {
64        let jwks: jwk::JwkSet = client.get(jwks_url.into()).send().await?.json().await?;
65
66        let mut keys = HashMap::new();
67        for jwk in jwks.keys {
68            let kid = jwk.common.key_id.ok_or(JwkError::MissingKeyId)?;
69
70            match &jwk.algorithm {
71                jwk::AlgorithmParameters::RSA(params) => {
72                    let decoding_key = DecodingKey::from_rsa_components(&params.n, &params.e)
73                        .map_err(|err| JwkError::DecodingError {
74                            key_id: kid.clone(),
75                            error: err,
76                        })?;
77
78                    keys.insert(
79                        kid,
80                        Jwk {
81                            decoding_key: decoding_key,
82                        },
83                    );
84                }
85                jwk::AlgorithmParameters::EllipticCurve(params) => {
86                    let decoding_key = DecodingKey::from_ec_components(&params.x, &params.y)
87                        .map_err(|err| JwkError::DecodingError {
88                            key_id: kid.clone(),
89                            error: err,
90                        })?;
91
92                    keys.insert(
93                        kid,
94                        Jwk {
95                            decoding_key: decoding_key,
96                        },
97                    );
98                }
99                jwk::AlgorithmParameters::OctetKeyPair(params) => {
100                    let decoding_key =
101                        DecodingKey::from_ed_components(&params.x).map_err(|err| {
102                            JwkError::DecodingError {
103                                key_id: kid.clone(),
104                                error: err,
105                            }
106                        })?;
107
108                    keys.insert(
109                        kid,
110                        Jwk {
111                            decoding_key: decoding_key,
112                        },
113                    );
114                }
115                jwk::AlgorithmParameters::OctetKey(params) => {
116                    // same as https://github.com/Keats/jsonwebtoken/blob/master/src/serialization.rs#L11
117                    let base64_decoded = URL_SAFE_NO_PAD.decode(&params.value).map_err(|err| {
118                        JwkError::DecodingError {
119                            key_id: kid.clone(),
120                            error: err.into(),
121                        }
122                    })?;
123                    let decoding_key = DecodingKey::from_secret(&base64_decoded);
124                    keys.insert(
125                        kid,
126                        Jwk {
127                            decoding_key: decoding_key,
128                        },
129                    );
130                }
131            }
132        }
133
134        Ok(Self { keys })
135    }
136}
137
138#[derive(Clone)]
139#[allow(dead_code)]
140pub struct Jwk {
141    pub decoding_key: DecodingKey,
142}
143
144/// An error with the overall set of JSON Web Keys.
145#[derive(Debug, Error)]
146pub enum JwksError {
147    /// There was an error fetching the OIDC or JWKS config from
148    /// the specified url.
149    #[error("could not fetch config from authority: {0}")]
150    FetchError(#[from] reqwest::Error),
151
152    /// An error with an individual key caused the processing of the JWKS to
153    /// fail.
154    #[error("there was an error with an individual key: {0}")]
155    KeyError(#[from] JwkError),
156}
157
158/// An error with a specific key from a JWKS.
159#[derive(Debug, Error)]
160pub enum JwkError {
161    /// There was an error constructing the decoding key from the RSA components
162    /// provided by the key.
163    #[error("could not construct a decoding key for {key_id:?}: {error:?}")]
164    DecodingError {
165        key_id: String,
166        error: jsonwebtoken::errors::Error,
167    },
168
169    /// The key does not specify an algorithm to use.
170    #[error("the key {key_id:?} does not specify an algorithm")]
171    MissingAlgorithm { key_id: String },
172
173    /// The key is missing the `kid` attribute.
174    #[error("the key is missing the `kid` attribute")]
175    MissingKeyId,
176}
177
178#[cfg(test)]
179mod tests {
180    use serde_json::json;
181
182    use super::*;
183    use httpmock::prelude::*;
184
185    #[tokio::test]
186    async fn can_fetch_and_parse_jwks_from_jwks_url() {
187        let server = MockServer::start();
188        let jwks_path = "/oauth2/v3/certs";
189
190        // from https://www.googleapis.com/oauth2/v3/certs
191        let jwks = json!({
192          "keys": [
193            {
194              "use": "sig",
195              "n": "jb1Ps3fdt0oPYPbQlfZqKkCXrM1qJ5EkfBHSMrPXPzh9QLwa43WCLEdrTcf5vI8cNwbgSxDlCDS2BzHQC0hYPwFkJaD6y6NIIcwdSMcKlQPwk4-sqJbz55_gyUWjifcpXXKbXDdnd2QzSE2YipareOPJaBs3Ybuvf_EePnYoKEhXNeGm_T3546A56uOV2mNEe6e-RaIa76i8kcx_8JP3FjqxZSWRrmGYwZJhTGbeY5pfOS6v_EYpA4Up1kZANWReeC3mgh3O78f5nKEDxwPf99bIQ22fIC2779HbfzO-ybqR_EJ0zv8LlqfT7dMjZs25LH8Jw5wGWjP_9efP8emTOw",
196              "kty": "RSA",
197              "alg": "RS256",
198              "e": "AQAB",
199              "kid": "91413cf4fa0cb92a3c3f5a054509132c47660937"
200            },
201            {
202              "n": "tgkwz0K80MycaI2Dz_jHkErJ_IHUPTlx4LR_6wltAHQW_ZwhMzINNH8vbWo8P5F2YLDiIbuslF9y7Q3izsPX3XWQyt6LI8ZT4gmGXQBumYMKx2VtbmTYIysKY8AY7x5UCDO-oaAcBuKQvWc5E31kXm6d6vfaEZjrMc_KT3DsFdN0LcAkB-Q9oYcVl7YEgAN849ROKUs6onf7eukj1PHwDzIBgA9AExJaKen0wITvxQv3H_BRXB7m6hFkLbK5Jo18gl3UxJ7Em29peEwi8Psn7MuI7CwhFNchKhjZM9eaMX27tpDPqR15-I6CA5Zf94rabUGWYph5cFXKWPPr8dskQQ",
203              "alg": "RS256",
204              "use": "sig",
205              "kid": "1f40f0a8ef3d880978dc82f25c3ec317c6a5b781",
206              "e": "AQAB",
207              "kty": "RSA"
208            }
209          ]
210        });
211
212        let _ = server.mock(|when, then| {
213            when.method(GET).path(jwks_path);
214            then.status(200)
215                .header("content-type", "application/json")
216                .body(jwks.to_string());
217        });
218
219        let jwks_url = server.url(jwks_path);
220        let jwks = Jwks::from_jwks_url(&jwks_url).await.unwrap();
221        assert_eq!(jwks.keys.len(), 2);
222
223        // get keys by key id (kid)
224        _ = &jwks
225            .keys
226            .get("91413cf4fa0cb92a3c3f5a054509132c47660937")
227            .expect("key one should be found");
228        _ = &jwks
229            .keys
230            .get("1f40f0a8ef3d880978dc82f25c3ec317c6a5b781")
231            .expect("key two should be found");
232    }
233
234    #[tokio::test]
235    async fn can_fetch_and_parse_jwks_from_oidc_config_url() {
236        let oidc_server = MockServer::start();
237
238        let oidc_config_path = "/.well-known/openid-configuration";
239        let jwks_path = "/oauth2/v3/certs";
240        let jwks_url = oidc_server.url(jwks_path);
241
242        // from https://accounts.google.com/.well-known/openid-configuration
243        let oidc_config = json!({
244         "issuer": "https://accounts.google.com",
245         "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth",
246         "device_authorization_endpoint": "https://oauth2.googleapis.com/device/code",
247         "token_endpoint": "https://oauth2.googleapis.com/token",
248         "userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
249         "revocation_endpoint": "https://oauth2.googleapis.com/revoke",
250         "jwks_uri": jwks_url,
251         "response_types_supported": [
252          "code",
253          "token",
254          "id_token",
255          "code token",
256          "code id_token",
257          "token id_token",
258          "code token id_token",
259          "none"
260         ],
261         "subject_types_supported": [
262          "public"
263         ],
264         "id_token_signing_alg_values_supported": [
265          "RS256"
266         ],
267         "scopes_supported": [
268          "openid",
269          "email",
270          "profile"
271         ],
272         "token_endpoint_auth_methods_supported": [
273          "client_secret_post",
274          "client_secret_basic"
275         ],
276         "claims_supported": [
277          "aud",
278          "email",
279          "email_verified",
280          "exp",
281          "family_name",
282          "given_name",
283          "iat",
284          "iss",
285          "locale",
286          "name",
287          "picture",
288          "sub"
289         ],
290         "code_challenge_methods_supported": [
291          "plain",
292          "S256"
293         ],
294         "grant_types_supported": [
295          "authorization_code",
296          "refresh_token",
297          "urn:ietf:params:oauth:grant-type:device_code",
298          "urn:ietf:params:oauth:grant-type:jwt-bearer"
299         ]
300        });
301
302        let _ = oidc_server.mock(|when, then| {
303            when.method(GET).path(oidc_config_path);
304            then.status(200)
305                .header("content-type", "application/json")
306                .body(oidc_config.to_string());
307        });
308
309        // from https://www.googleapis.com/oauth2/v3/certs
310        let jwks = json!({
311          "keys": [
312            {
313              "use": "sig",
314              "n": "jb1Ps3fdt0oPYPbQlfZqKkCXrM1qJ5EkfBHSMrPXPzh9QLwa43WCLEdrTcf5vI8cNwbgSxDlCDS2BzHQC0hYPwFkJaD6y6NIIcwdSMcKlQPwk4-sqJbz55_gyUWjifcpXXKbXDdnd2QzSE2YipareOPJaBs3Ybuvf_EePnYoKEhXNeGm_T3546A56uOV2mNEe6e-RaIa76i8kcx_8JP3FjqxZSWRrmGYwZJhTGbeY5pfOS6v_EYpA4Up1kZANWReeC3mgh3O78f5nKEDxwPf99bIQ22fIC2779HbfzO-ybqR_EJ0zv8LlqfT7dMjZs25LH8Jw5wGWjP_9efP8emTOw",
315              "kty": "RSA",
316              "alg": "RS256",
317              "e": "AQAB",
318              "kid": "91413cf4fa0cb92a3c3f5a054509132c47660937"
319            },
320            {
321              "n": "tgkwz0K80MycaI2Dz_jHkErJ_IHUPTlx4LR_6wltAHQW_ZwhMzINNH8vbWo8P5F2YLDiIbuslF9y7Q3izsPX3XWQyt6LI8ZT4gmGXQBumYMKx2VtbmTYIysKY8AY7x5UCDO-oaAcBuKQvWc5E31kXm6d6vfaEZjrMc_KT3DsFdN0LcAkB-Q9oYcVl7YEgAN849ROKUs6onf7eukj1PHwDzIBgA9AExJaKen0wITvxQv3H_BRXB7m6hFkLbK5Jo18gl3UxJ7Em29peEwi8Psn7MuI7CwhFNchKhjZM9eaMX27tpDPqR15-I6CA5Zf94rabUGWYph5cFXKWPPr8dskQQ",
322              "alg": "RS256",
323              "use": "sig",
324              "kid": "1f40f0a8ef3d880978dc82f25c3ec317c6a5b781",
325              "e": "AQAB",
326              "kty": "RSA"
327            }
328          ]
329        });
330
331        let _ = oidc_server.mock(|when, then| {
332            when.method(GET).path(jwks_path);
333            then.status(200)
334                .header("content-type", "application/json")
335                .body(jwks.to_string());
336        });
337
338        let oidc_config_url = oidc_server.url(oidc_config_path);
339        let jwks = Jwks::from_oidc_url(&oidc_config_url).await.unwrap();
340        assert_eq!(jwks.keys.len(), 2);
341
342        // get keys by key id (kid)
343        _ = &jwks
344            .keys
345            .get("91413cf4fa0cb92a3c3f5a054509132c47660937")
346            .expect("key one should be found");
347        _ = &jwks
348            .keys
349            .get("1f40f0a8ef3d880978dc82f25c3ec317c6a5b781")
350            .expect("key two should be found");
351    }
352}