1use std::sync::{Arc, Mutex};
16use std::time::{Duration, Instant};
17
18use serde::{Deserialize, Serialize};
19use serde_json;
20use tracing::error;
21use url::form_urlencoded;
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct Token {
25 pub access_token: String,
26 pub token_type: String,
27 pub refresh_token: String,
28 pub expiry: Instant,
29}
30
31impl std::fmt::Display for Token {
32 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
33 write!(f, "{:?}", self)
34 }
35}
36
37impl From<TokenResponse> for Token {
38 fn from(tr: TokenResponse) -> Self {
39 Token {
40 access_token: tr.access_token,
41 token_type: tr.token_type,
42 refresh_token: tr.refresh_token.unwrap_or_else(|| "".to_string()),
43 expiry: Instant::now() + Duration::from_secs(tr.expires_in.unwrap_or_else(|| 1200)),
44 }
45 }
46}
47
48impl Token {
49 pub fn valid(&self) -> bool {
50 Instant::now().checked_duration_since(self.expiry).is_none()
51 }
52}
53
54#[derive(Debug, Clone, Deserialize, Serialize)]
55struct TokenResponse {
56 access_token: String,
57 token_type: String,
58 refresh_token: Option<String>,
59 expires_in: Option<u64>,
60}
61
62pub struct TokenManager {
63 api_key: String,
64 token: Arc<Mutex<Option<Token>>>,
65 endpoint: String,
66}
67
68impl TokenManager {
69 pub fn new(api_key: &str, endpoint: &str) -> Self {
70 Self {
71 api_key: api_key.to_string(),
72 token: Arc::new(Mutex::new(None)),
73 endpoint: endpoint.to_string(),
74 }
75 }
76
77 pub fn token(&self) -> Result<Token, Box<dyn std::error::Error>> {
78 let mut token = self.token.lock().unwrap();
79
80 if let Some(t) = token.clone() {
81 if t.valid() {
82 return Ok(t);
83 }
84 }
85
86 *token = Some(self.request_token());
87
88 Ok(token.as_ref().unwrap().clone())
89 }
90
91 fn request_token(&self) -> Token {
92 let encoded: String = form_urlencoded::Serializer::new(String::new())
93 .append_pair("grant_type", "urn:ibm:params:oauth:grant-type:apikey")
94 .append_pair("apikey", &self.api_key)
95 .finish();
96
97 let c = reqwest::blocking::Client::new();
98
99 let path = format!("{}/identity/token", self.endpoint);
100
101 let resp = c
102 .post(path)
103 .header("Authorization", "Basic Yng6Yng=")
104 .header("Accept", "application/json")
105 .header("Content-Type", "application/x-www-form-urlencoded")
106 .body(encoded)
107 .send()
108 .expect("Get token failed");
109
110 let text = resp.text().expect("Getting body text failed");
111
112 let token_resp = match serde_json::from_str::<TokenResponse>(&text) {
113 Ok(v) => v,
114 Err(err) => {
115 error!("Error deserializing from response: body={}", text);
116 panic!("{}", err);
117 }
118 };
119
120 token_resp.into()
121 }
122}
123
124pub const DEFAULT_IAM_ENDPOINT: &str = "https://iam.cloud.ibm.com";
125
126impl Default for TokenManager {
127 fn default() -> Self {
128 let env_key = "IBMCLOUD_API_KEY";
129 let api_key = match std::env::var(env_key) {
130 Ok(k) => k,
131 _ => {
132 panic!("'IBMCLOUD_API_KEY' not set or invalid");
133 }
134 };
135
136 Self::new(&api_key, &DEFAULT_IAM_ENDPOINT)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 use std::thread;
145
146 fn get_test_token() -> Token {
147 let access_token = String::from("");
148 let refresh_token = String::from("");
149 let token_type = String::from("test");
150
151 Token {
152 access_token,
153 refresh_token,
154 token_type,
155 expiry: Instant::now() + Duration::from_secs(1200),
156 }
157 }
158
159 #[test]
160 fn token_expiry() {
161 let mut token = get_test_token();
162 token.expiry = Instant::now() + Duration::from_secs(10);
163 assert!(token.valid());
164
165 token.expiry = Instant::now() - Duration::from_secs(10);
166 assert!(!token.valid());
167 }
168
169 #[test]
170 fn token_caching() {
171 let iam = TokenManager::new("".into(), &DEFAULT_IAM_ENDPOINT);
172 *iam.token.lock().unwrap() = Some(get_test_token());
173
174 let token = iam.token().unwrap();
175 let token2 = iam.token().unwrap();
176 assert_eq!(token, token2);
177 }
178
179 #[test]
180 fn threadsafe_cache() {
181 let iam = TokenManager::new("".into(), &DEFAULT_IAM_ENDPOINT);
182 *iam.token.lock().unwrap() = Some(get_test_token());
183
184 let c = Arc::new(iam);
185 let c1 = c.clone();
186 let c2 = c.clone();
187
188 let t1 = thread::spawn(move || c1.token().unwrap());
189
190 let t2 = thread::spawn(move || c2.token().unwrap());
191
192 let res1 = t1.join().unwrap();
193 let res2 = t2.join().unwrap();
194
195 assert_eq!(res1, res2);
196 }
197}