ibmcloud_iam/
jwt.rs

1// Copyright 2022 Mathew Odden <mathewrodden@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::time::{Duration, SystemTime};
17
18use jwt::{Header, PKeyWithDigest, RegisteredClaims, Token as JwtToken, VerifyWithKey};
19use openssl::bn::BigNum;
20use openssl::hash::MessageDigest;
21use openssl::pkey::PKey;
22use openssl::rsa::Rsa;
23use serde::{Deserialize, Serialize};
24use serde_json;
25use serde_with::{
26    base64::{Base64, UrlSafe},
27    serde_as,
28};
29
30use crate::token::Token;
31
32type Error = Box<dyn std::error::Error>;
33type Claims = HashMap<String, serde_json::value::Value>;
34
35const EXPIRES_LEEWAY: Duration = Duration::from_secs(5);
36
37pub fn validate_token(token: &Token, endpoint: &str) -> Result<Claims, Error> {
38    let jwt: JwtToken<Header, HashMap<String, serde_json::value::Value>, _> =
39        JwtToken::parse_unverified(&token.access_token).expect("Unable to parse given token");
40    let key_id = jwt
41        .header()
42        .key_id
43        .as_ref()
44        .expect("Token has no signing Key ID!");
45
46    // get public key from IAM
47    let keys = retrieve_keys(endpoint)?.keys;
48    let key = keys
49        .iter()
50        .find(|&k| k.kid == *key_id)
51        .expect("No signing key found for token key id");
52
53    let rsa_key = Rsa::from_public_components(
54        BigNum::from_slice(&key.n).unwrap(),
55        BigNum::from_slice(&key.e).unwrap(),
56    )
57    .unwrap();
58
59    // create verifier
60    let rs256_verifier = PKeyWithDigest {
61        digest: MessageDigest::sha256(),
62        key: PKey::from_rsa(rsa_key).unwrap(),
63    };
64
65    // verify token
66    let reg_claims: RegisteredClaims = token.access_token.verify_with_key(&rs256_verifier)?;
67
68    // verify claims
69    _validate_iss(&reg_claims)?;
70    _validate_iat(&reg_claims)?;
71    _validate_exp(&reg_claims, EXPIRES_LEEWAY)?;
72
73    // return claims
74    Ok(jwt.claims().clone())
75}
76
77#[derive(Debug, Clone)]
78pub struct InvalidTokenError {
79    message: String,
80}
81
82impl std::fmt::Display for InvalidTokenError {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        write!(f, "{}", self.message)
85    }
86}
87
88impl std::error::Error for InvalidTokenError {}
89
90fn _validate_iss(claims: &RegisteredClaims) -> Result<(), Error> {
91    // assert issuer is IAM
92    let er = InvalidTokenError {
93        message: "Issuer must start with 'https://iam'".to_string(),
94    };
95    let iss = claims.issuer.as_ref().ok_or(er.clone())?;
96    if !iss.starts_with("https://iam") {
97        return Err(er.into());
98    }
99    Ok(())
100}
101
102fn unix_now() -> u64 {
103    SystemTime::now()
104        .duration_since(SystemTime::UNIX_EPOCH)
105        .unwrap()
106        .as_secs()
107}
108
109fn _validate_iat(claims: &RegisteredClaims) -> Result<(), Error> {
110    // assert issued-at not in future
111    let er = InvalidTokenError {
112        message: "Issued At is None or in the future".to_string(),
113    };
114    let iat = claims.issued_at.ok_or(er.clone())?;
115
116    if iat > unix_now() {
117        return Err(er.into());
118    }
119
120    Ok(())
121}
122
123fn _validate_exp(claims: &RegisteredClaims, leeway: std::time::Duration) -> Result<(), Error> {
124    // assert not expired with leeway
125    let er = InvalidTokenError {
126        message: "Expiration is None or in the past".to_string(),
127    };
128    let exp = claims.expiration.ok_or(er.clone())?;
129
130    if (exp + leeway.as_secs()) < unix_now() {
131        return Err(er.into());
132    }
133
134    Ok(())
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138struct KeysResponse {
139    keys: Vec<Key>,
140}
141
142#[serde_as]
143#[derive(Debug, Clone, Serialize, Deserialize)]
144struct Key {
145    kty: String,
146    kid: String,
147    alg: String,
148    #[serde_as(as = "Base64<UrlSafe>")]
149    n: Vec<u8>,
150    #[serde_as(as = "Base64<UrlSafe>")]
151    e: Vec<u8>,
152}
153
154fn retrieve_keys(endpoint: &str) -> Result<KeysResponse, Error> {
155    let c = reqwest::blocking::Client::new();
156
157    let resp = c
158        .get(format!("{}/identity/keys", endpoint))
159        .header("Accept", "application/json")
160        .send()
161        .expect("Retrieving IAM public keys failed");
162
163    let text = resp.text().expect("Getting body text failed");
164    Ok(serde_json::from_str(&text)?)
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_validate_iss() {
173        let mut claims = RegisteredClaims::default();
174        claims.issuer = None;
175        assert!(_validate_iss(&claims).is_err());
176
177        claims.issuer = Some("https://notiam".into());
178        assert!(_validate_iss(&claims).is_err());
179
180        claims.issuer = Some("https://iam.test.cloud.ibm.com".into());
181        assert!(_validate_iss(&claims).is_ok());
182    }
183
184    #[test]
185    fn test_validate_iat() {
186        let mut claims = RegisteredClaims::default();
187        claims.issued_at = None;
188        assert!(_validate_iat(&claims).is_err());
189
190        let unix_now = SystemTime::now()
191            .duration_since(SystemTime::UNIX_EPOCH)
192            .unwrap();
193        // issued_at in future
194        claims.issued_at = Some((unix_now + Duration::from_secs(15)).as_secs());
195        assert!(_validate_iat(&claims).is_err());
196
197        claims.issued_at = Some((unix_now - Duration::from_secs(15)).as_secs());
198        assert!(_validate_iat(&claims).is_ok());
199    }
200
201    #[test]
202    fn test_validate_exp() {
203        let mut claims = RegisteredClaims::default();
204        claims.expiration = None;
205        assert!(_validate_exp(&claims, EXPIRES_LEEWAY).is_err());
206
207        let unix_now = SystemTime::now()
208            .duration_since(SystemTime::UNIX_EPOCH)
209            .unwrap();
210        claims.expiration = Some((unix_now - Duration::from_secs(15)).as_secs());
211        assert!(_validate_exp(&claims, EXPIRES_LEEWAY).is_err());
212    }
213
214    #[test]
215    fn test_validate_exp_expired_but_within_leeway() {
216        let mut claims = RegisteredClaims::default();
217        let unix_now = SystemTime::now()
218            .duration_since(SystemTime::UNIX_EPOCH)
219            .unwrap();
220        claims.expiration = Some((unix_now - Duration::from_secs(15)).as_secs());
221        assert!(_validate_exp(&claims, Duration::from_secs(20)).is_ok());
222    }
223
224    #[test]
225    fn test_validate_exp_token_not_expired() {
226        let mut claims = RegisteredClaims::default();
227        let unix_now = SystemTime::now()
228            .duration_since(SystemTime::UNIX_EPOCH)
229            .unwrap();
230        claims.expiration = Some((unix_now + Duration::from_secs(15)).as_secs());
231        assert!(_validate_exp(&claims, EXPIRES_LEEWAY).is_ok());
232    }
233}