1use std::collections::HashMap;
16use std::time::{Duration, SystemTime};
17
18use jwt::{Header, PKeyWithDigest, RegisteredClaims, Token as JwtToken, VerifyWithKey};
19use openssl::bn::BigNum;
20use openssl::hash::MessageDigest;
21use openssl::pkey::PKey;
22use openssl::rsa::Rsa;
23use serde::{Deserialize, Serialize};
24use serde_json;
25use serde_with::{
26 base64::{Base64, UrlSafe},
27 serde_as,
28};
29
30use crate::token::Token;
31
32type Error = Box<dyn std::error::Error>;
33type Claims = HashMap<String, serde_json::value::Value>;
34
35const EXPIRES_LEEWAY: Duration = Duration::from_secs(5);
36
37pub fn validate_token(token: &Token, endpoint: &str) -> Result<Claims, Error> {
38 let jwt: JwtToken<Header, HashMap<String, serde_json::value::Value>, _> =
39 JwtToken::parse_unverified(&token.access_token).expect("Unable to parse given token");
40 let key_id = jwt
41 .header()
42 .key_id
43 .as_ref()
44 .expect("Token has no signing Key ID!");
45
46 let keys = retrieve_keys(endpoint)?.keys;
48 let key = keys
49 .iter()
50 .find(|&k| k.kid == *key_id)
51 .expect("No signing key found for token key id");
52
53 let rsa_key = Rsa::from_public_components(
54 BigNum::from_slice(&key.n).unwrap(),
55 BigNum::from_slice(&key.e).unwrap(),
56 )
57 .unwrap();
58
59 let rs256_verifier = PKeyWithDigest {
61 digest: MessageDigest::sha256(),
62 key: PKey::from_rsa(rsa_key).unwrap(),
63 };
64
65 let reg_claims: RegisteredClaims = token.access_token.verify_with_key(&rs256_verifier)?;
67
68 _validate_iss(®_claims)?;
70 _validate_iat(®_claims)?;
71 _validate_exp(®_claims, EXPIRES_LEEWAY)?;
72
73 Ok(jwt.claims().clone())
75}
76
77#[derive(Debug, Clone)]
78pub struct InvalidTokenError {
79 message: String,
80}
81
82impl std::fmt::Display for InvalidTokenError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.message)
85 }
86}
87
88impl std::error::Error for InvalidTokenError {}
89
90fn _validate_iss(claims: &RegisteredClaims) -> Result<(), Error> {
91 let er = InvalidTokenError {
93 message: "Issuer must start with 'https://iam'".to_string(),
94 };
95 let iss = claims.issuer.as_ref().ok_or(er.clone())?;
96 if !iss.starts_with("https://iam") {
97 return Err(er.into());
98 }
99 Ok(())
100}
101
102fn unix_now() -> u64 {
103 SystemTime::now()
104 .duration_since(SystemTime::UNIX_EPOCH)
105 .unwrap()
106 .as_secs()
107}
108
109fn _validate_iat(claims: &RegisteredClaims) -> Result<(), Error> {
110 let er = InvalidTokenError {
112 message: "Issued At is None or in the future".to_string(),
113 };
114 let iat = claims.issued_at.ok_or(er.clone())?;
115
116 if iat > unix_now() {
117 return Err(er.into());
118 }
119
120 Ok(())
121}
122
123fn _validate_exp(claims: &RegisteredClaims, leeway: std::time::Duration) -> Result<(), Error> {
124 let er = InvalidTokenError {
126 message: "Expiration is None or in the past".to_string(),
127 };
128 let exp = claims.expiration.ok_or(er.clone())?;
129
130 if (exp + leeway.as_secs()) < unix_now() {
131 return Err(er.into());
132 }
133
134 Ok(())
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138struct KeysResponse {
139 keys: Vec<Key>,
140}
141
142#[serde_as]
143#[derive(Debug, Clone, Serialize, Deserialize)]
144struct Key {
145 kty: String,
146 kid: String,
147 alg: String,
148 #[serde_as(as = "Base64<UrlSafe>")]
149 n: Vec<u8>,
150 #[serde_as(as = "Base64<UrlSafe>")]
151 e: Vec<u8>,
152}
153
154fn retrieve_keys(endpoint: &str) -> Result<KeysResponse, Error> {
155 let c = reqwest::blocking::Client::new();
156
157 let resp = c
158 .get(format!("{}/identity/keys", endpoint))
159 .header("Accept", "application/json")
160 .send()
161 .expect("Retrieving IAM public keys failed");
162
163 let text = resp.text().expect("Getting body text failed");
164 Ok(serde_json::from_str(&text)?)
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_validate_iss() {
173 let mut claims = RegisteredClaims::default();
174 claims.issuer = None;
175 assert!(_validate_iss(&claims).is_err());
176
177 claims.issuer = Some("https://notiam".into());
178 assert!(_validate_iss(&claims).is_err());
179
180 claims.issuer = Some("https://iam.test.cloud.ibm.com".into());
181 assert!(_validate_iss(&claims).is_ok());
182 }
183
184 #[test]
185 fn test_validate_iat() {
186 let mut claims = RegisteredClaims::default();
187 claims.issued_at = None;
188 assert!(_validate_iat(&claims).is_err());
189
190 let unix_now = SystemTime::now()
191 .duration_since(SystemTime::UNIX_EPOCH)
192 .unwrap();
193 claims.issued_at = Some((unix_now + Duration::from_secs(15)).as_secs());
195 assert!(_validate_iat(&claims).is_err());
196
197 claims.issued_at = Some((unix_now - Duration::from_secs(15)).as_secs());
198 assert!(_validate_iat(&claims).is_ok());
199 }
200
201 #[test]
202 fn test_validate_exp() {
203 let mut claims = RegisteredClaims::default();
204 claims.expiration = None;
205 assert!(_validate_exp(&claims, EXPIRES_LEEWAY).is_err());
206
207 let unix_now = SystemTime::now()
208 .duration_since(SystemTime::UNIX_EPOCH)
209 .unwrap();
210 claims.expiration = Some((unix_now - Duration::from_secs(15)).as_secs());
211 assert!(_validate_exp(&claims, EXPIRES_LEEWAY).is_err());
212 }
213
214 #[test]
215 fn test_validate_exp_expired_but_within_leeway() {
216 let mut claims = RegisteredClaims::default();
217 let unix_now = SystemTime::now()
218 .duration_since(SystemTime::UNIX_EPOCH)
219 .unwrap();
220 claims.expiration = Some((unix_now - Duration::from_secs(15)).as_secs());
221 assert!(_validate_exp(&claims, Duration::from_secs(20)).is_ok());
222 }
223
224 #[test]
225 fn test_validate_exp_token_not_expired() {
226 let mut claims = RegisteredClaims::default();
227 let unix_now = SystemTime::now()
228 .duration_since(SystemTime::UNIX_EPOCH)
229 .unwrap();
230 claims.expiration = Some((unix_now + Duration::from_secs(15)).as_secs());
231 assert!(_validate_exp(&claims, EXPIRES_LEEWAY).is_ok());
232 }
233}