1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
use std::sync::Arc;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use std::collections::HashMap;

use serde::{Deserialize};
use serde_json::value::Value;
use serde_json;

use reqwest::{self, Response};

use jsonwebtokens as jwt;
use jwt::{Algorithm, AlgorithmID, Verifier, VerifierBuilder};

mod error;
pub use error::{Error, ErrorDetails};

#[derive(Debug, Deserialize, Clone)]
struct RSAKey {
    kid: String,
    alg: String,
    n: String,
    e: String,
}

#[derive(Debug, Deserialize)]
struct JwkSet {
    keys: Vec<RSAKey>,
}

#[derive(Debug, Clone)]
struct Cache {
    last_jwks_get_time: Option<Instant>,
    algorithms: HashMap<String, Arc<Algorithm>>,
}

/// Abstracts a remote Amazon Cognito JWKS key set
///
/// The key set represents the public key information for one or more RSA keys that
/// Amazon Cognito uses to sign tokens. To verify a token from Cognito the token's
/// `kid` is used to look up the corresponding public key from this set which can
/// be used to verify the token's signature.
///
/// Building on top of the [Verifier](https://docs.rs/jsonwebtokens/1.0.0-alpha.8/jsonwebtokens/struct.Verifier.html)
/// API from [jsonwebtokens](https://crates.io/crates/jsonwebtokens), a KeySet provides some
/// helpers for building a [Verifier](https://docs.rs/jsonwebtokens/1.0.0-alpha.8/jsonwebtokens/struct.Verifier.html)
/// for Cognito Access token claims or ID token claims - referencing the region and
/// pool details used to construct the keyset.
///
/// Example:
/// ```no_run
/// # use jsonwebtokens_cognito::KeySet;
/// # use async_std::prelude::*;
/// # #[async_std::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let keyset = KeySet::new("eu-west-1", "my-user-pool-id")?;
/// let verifier = keyset.new_id_token_verifier(&["client-id-0", "client-id-1"])
///     .string_equals("custom_claim0", "value")
///     .string_equals("custom_claim1", "value")
///     .build()?;
/// # let token = "header.payload.signature";
/// let claims = keyset.verify(token, &verifier).await?;
/// # Ok(())
/// # }
/// ```
///
/// Internally a KeySet holds a cache of Algorithm structs (see the jsonwebtokens
/// API for further details) where each Algorithm represents one RSA public key.
///
/// Although `keyset.verify()` can be very convenient, if you need to avoid network
/// I/O when verifying tokens it's also possible to prefetch the remote JWKS key
/// set ahead of time and `try_verify()` can be used to verify a token without any
/// network I/O. This can be useful if you don't have an async context when
/// verifying tokens.
///
/// ```no_run
/// # use jsonwebtokens_cognito::KeySet;
/// # use async_std::prelude::*;
/// # #[async_std::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let keyset = KeySet::new("eu-west-1", "my-user-pool-id")?;
/// keyset.prefetch_jwks().await?;
/// let verifier = keyset.new_id_token_verifier(&["client-id-0", "client-id-1"])
///     .string_equals("custom_claim0", "value")
///     .string_equals("custom_claim1", "value")
///     .build()?;
/// # let token = "header.payload.signature";
/// let claims = keyset.try_verify(token, &verifier)?;
/// # Ok(())
/// # }
/// ```
///
/// It's also possible to perform cache lookups directly to access an Algorithm if
/// you need to use the jsonwebtokens API directly:
/// ```no_run
/// # use jsonwebtokens_cognito::KeySet;
/// # use jsonwebtokens as jwt;
/// # use serde_json::value::Value;
/// # use async_std::prelude::*;
/// # #[async_std::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let token = "header.payload.signature";
/// let keyset = KeySet::new("eu-west-1", "my-user-pool-id")?;
/// keyset.prefetch_jwks().await?;
///
/// let verifier = keyset.new_id_token_verifier(&["client-id-0", "client-id-1"])
///     .string_equals("custom_claim0", "value")
///     .string_equals("custom_claim1", "value")
///     .build()?;
///
/// let header = jwt::raw::decode_header_only(token)?;
/// if let Some(Value::String(kid)) = header.get("kid") {
///     let alg = keyset.try_cache_lookup_algorithm(kid)?;
///     let claims = verifier.verify(token, &alg)?;
///
///     // Whoop!
/// } else {
///     Err(jwt::error::Error::MalformedToken(jwt::error::ErrorDetails::new("Missing kid")))?;
/// };
/// # Ok(())
/// # }
/// ```

///
#[derive(Debug, Clone)]
pub struct KeySet {
    region: String,
    pool_id: String,
    jwks_url: String,
    iss: String,
    cache: Arc<RwLock<Cache>>,
    min_jwks_fetch_interval: Duration,
}

impl KeySet {

    /// Constructs a key set that corresponds to a remote Json Web Key Set published
    /// by Amazon for a given region and Cognito User Pool ID.
    pub fn new(region: impl Into<String>,
               pool_id: impl Into<String>
    ) -> Result<Self, Error> {

        let region_str = region.into();
        let pool_id_str = pool_id.into();
        let jwks_url = format!("https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json",
                                       region_str, pool_id_str).into();
        let iss = format!("https://cognito-idp.{}.amazonaws.com/{}", region_str, pool_id_str);

        Ok(KeySet {
            region: region_str,
            pool_id: pool_id_str,
            jwks_url: jwks_url,
            iss: iss,
            cache: Arc::new(RwLock::new(Cache {
                last_jwks_get_time: None,
                algorithms: HashMap::new()
            })),
            min_jwks_fetch_interval: Duration::from_secs(60),
        })
    }

    /// Returns a `VerifierBuilder` that has been pre-configured to validate an
    /// AWS Cognito ID token. This can be further configured for verifying other
    /// custom claims before calling `.build()` to create a `Verifier`
    pub fn new_id_token_verifier(&self, client_ids: &[&str]) -> VerifierBuilder {
        let mut builder = Verifier::create();

        builder
            .string_equals("iss", &self.iss)
            .string_equals_one_of("aud", client_ids)
            .string_equals("token_use", "id");

        builder
    }

    /// Set's the minimum time between attempts to fetch the remote JWKS key set
    ///
    /// By default this is one minute, to throttle requests in case there is a
    /// transient network problem
    pub fn set_min_jwks_fetch_interval(&mut self, interval: Duration) {
        self.min_jwks_fetch_interval = interval;
    }

    /// Get's the minimum time between attempts to fetch the remote JWKS key set
    pub fn min_jwks_fetch_interval(&mut self) -> Duration {
        self.min_jwks_fetch_interval
    }

    /// Returns a `VerifierBuilder` that has been pre-configured to validate an
    /// AWS Cognito access token. This can be further configured for verifying other
    /// custom claims before calling `.build()` to create a `Verifier`
    pub fn new_access_token_verifier(&self, client_ids: &[&str]) -> VerifierBuilder {
        let mut builder = Verifier::create();

        builder
            .string_equals("iss", &self.iss)
            .string_equals_one_of("client_id", client_ids)
            .string_equals("token_use", "access");

        builder
    }

    /// Looks for a cached Algorithm based on the given JWT token's key ID ('kid')
    ///
    /// This is a lower-level API in case you need to use the jsonwebtokens
    /// Algorithm API directly.
    ///
    /// Returns an `Arc<Algorithm>` corresponding to the give key ID (`kid`) or returns
    /// a `CacheMiss` error if the Algorithm / key is not cached.
    pub fn try_cache_lookup_algorithm(&self, kid: &str) -> Result<Arc<Algorithm>, Error> {

        // We unwrap, because poisoning would imply something else had gone
        // badly wrong (there should be nothing that can cause a panic while
        // holding the cache's lock)
        let readable_cache = self.cache.read().unwrap();

        let a = readable_cache.algorithms.get(kid);
        if let Some(alg) = a {
            return Ok(alg.clone());
        } else {
            return Err(Error::CacheMiss(readable_cache.last_jwks_get_time));
        }
    }

    /// Verify a token's signature and its claims
    pub async fn verify(
        &self,
        token: &str,
        verifier: &Verifier
    ) -> Result<serde_json::value::Value, Error> {

        let header = jwt::raw::decode_header_only(token)?;

        let kid = match header.get("kid") {
            Some(Value::String(kid)) => kid,
            _ => return Err(Error::NoKeyID()),
        };

        let algorithm = match self.try_cache_lookup_algorithm(kid) {
            Err(Error::CacheMiss(last_update_time)) => {
                let duration = match last_update_time {
                    Some(last_jwks_get_time) => Instant::now().duration_since(last_jwks_get_time),
                    None => self.min_jwks_fetch_interval
                };

                if duration < self.min_jwks_fetch_interval {
                    return Err(Error::NetworkError(ErrorDetails::new("Key set is currently unreachable (throttled)")))
                }

                self.prefetch_jwks().await?;
                self.try_cache_lookup_algorithm(kid)?
            },
            Err(e) => {
                // try_cache_lookup_algorithm shouldn't return any other kind of error...
                unreachable!("Unexpected error looking up JWT Algorithm for key ID: {:?}", e);
            }
            Ok(alg) => alg
        };

        let claims = verifier.verify(token, &algorithm)?;
        Ok(claims)
    }

    /// Try and verify a token's signature and claims without performing any network I/O
    ///
    /// To be able to verify a token in a synchronous context (but without blocking) this
    /// API lets you try and verify a token, and if the required Algorithm / key has not
    /// been cached yet then it will return a `CacheMiss` error.
    pub fn try_verify(
        &self,
        token: &str,
        verifier: &Verifier
    ) -> Result<serde_json::value::Value, Error> {

        let header = jwt::raw::decode_header_only(token)?;

        let kid = match header.get("kid") {
            Some(Value::String(kid)) => kid,
            _ => return Err(Error::NoKeyID()),
        };

        let alg = self.try_cache_lookup_algorithm(kid)?;
        let claims = verifier.verify(token, &alg)?;
        Ok(claims)
    }

    /// Ensure the remote Json Web Key Set is downloaded and cached
    pub async fn prefetch_jwks(&self) -> Result<(), Error> {
        let resp: Response = reqwest::get(&self.jwks_url).await?;
        let jwks: JwkSet = resp.json().await?;

        // We unwrap, because poisoning would imply something else had gone
        // badly wrong (there should be nothing that can cause a panic while
        // holding the cache's lock)
        let mut writeable_cache = self.cache.write().unwrap();

        writeable_cache.last_jwks_get_time = Some(Instant::now());

        for key in jwks.keys.into_iter() {
            // For now we assume AWS Cognito only ever uses RS256 keys
            if key.alg != "RS256" {
                continue;
            }
            let mut algorithm = Algorithm::new_rsa_n_e_b64_verifier(AlgorithmID::RS256, &key.n, &key.e)?;
            // By associating a kid here we will essentially be double checking
            // that we only verify a token with the key matching its associated kid
            // (once by us and jsonwebtokens will also check too)
            algorithm.set_kid(&key.kid);
            writeable_cache.algorithms.insert(key.kid.clone(), Arc::new(algorithm));
        }

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    // TODO
}