1use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use time::OffsetDateTime;
7use tokio::sync::RwLock;
8use tracing::{debug, info};
9
10use crate::error::{QuestradeError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TokenResponse {
15 pub access_token: String,
17 pub token_type: String,
19 pub expires_in: u64,
21 pub refresh_token: String,
24 pub api_server: String,
27}
28
29pub type OnTokenRefresh = Arc<dyn Fn(TokenResponse) + Send + Sync>;
33
34pub struct CachedToken {
37 pub access_token: String,
39 pub api_server: String,
41 pub expires_at: OffsetDateTime,
43}
44
45#[derive(Clone)]
47pub struct TokenManager {
48 inner: Arc<RwLock<TokenState>>,
49 login_url: String,
50 on_token_refresh: OnTokenRefresh,
51}
52
53struct TokenState {
54 access_token: String,
55 api_server: String,
56 refresh_token: String,
57 expires_at: OffsetDateTime,
58}
59
60impl TokenManager {
61 pub async fn new(
69 refresh_token: String,
70 practice: bool,
71 on_token_refresh: Option<OnTokenRefresh>,
72 cached_token: Option<CachedToken>,
73 ) -> Result<Self> {
74 let login_url = if practice {
75 "https://practicelogin.questrade.com".to_string()
76 } else {
77 "https://login.questrade.com".to_string()
78 };
79 Self::new_with_login_url(refresh_token, on_token_refresh, login_url, cached_token).await
80 }
81
82 pub async fn new_with_login_url(
85 refresh_token: String,
86 on_token_refresh: Option<OnTokenRefresh>,
87 login_url: String,
88 cached_token: Option<CachedToken>,
89 ) -> Result<Self> {
90 let cb: OnTokenRefresh = on_token_refresh.unwrap_or_else(|| Arc::new(|_| {}));
91
92 let (access_token, api_server, expires_at) =
94 if let Some(ct) = cached_token.filter(|ct| OffsetDateTime::now_utc() < ct.expires_at) {
95 info!("reusing cached Questrade access token");
96 (ct.access_token, ct.api_server, ct.expires_at)
97 } else {
98 (String::new(), String::new(), OffsetDateTime::UNIX_EPOCH)
99 };
100
101 let manager = Self {
102 inner: Arc::new(RwLock::new(TokenState {
103 access_token,
104 api_server,
105 refresh_token,
106 expires_at,
107 })),
108 login_url,
109 on_token_refresh: cb,
110 };
111
112 if manager.inner.read().await.access_token.is_empty() {
114 manager.refresh().await?;
115 }
116
117 Ok(manager)
118 }
119
120 pub async fn get_token(&self) -> Result<(String, String)> {
122 {
123 let state = self.inner.read().await;
124 if OffsetDateTime::now_utc() < state.expires_at {
125 return Ok((state.access_token.clone(), state.api_server.clone()));
126 }
127 }
128 self.refresh().await
130 }
131
132 pub async fn force_refresh(&self) -> Result<(String, String)> {
137 {
138 let mut state = self.inner.write().await;
139 state.expires_at = OffsetDateTime::UNIX_EPOCH;
140 state.access_token.clear();
141 }
142 self.refresh().await
143 }
144
145 async fn refresh(&self) -> Result<(String, String)> {
146 let mut state = self.inner.write().await;
147
148 if OffsetDateTime::now_utc() < state.expires_at && !state.access_token.is_empty() {
150 return Ok((state.access_token.clone(), state.api_server.clone()));
151 }
152
153 info!("refreshing Questrade access token");
154
155 let client = reqwest::Client::builder()
156 .connect_timeout(std::time::Duration::from_secs(10))
157 .timeout(std::time::Duration::from_secs(30))
158 .build()
159 .unwrap_or_default();
160 let url = format!("{}/oauth2/token", self.login_url);
161
162 let resp = client
163 .get(&url)
164 .query(&[
165 ("grant_type", "refresh_token"),
166 ("refresh_token", state.refresh_token.as_str()),
167 ])
168 .send()
169 .await?;
170
171 if !resp.status().is_success() {
172 let status = resp.status();
173 let body = resp.text().await.unwrap_or_default();
174 return Err(QuestradeError::TokenRefresh { status, body });
175 }
176
177 let token_resp: TokenResponse = resp.json().await?;
178
179 debug!(api_server = %token_resp.api_server, "new API server");
180
181 let expires_at =
182 OffsetDateTime::now_utc() + time::Duration::seconds(token_resp.expires_in as i64 - 30); state.access_token = token_resp.access_token.clone();
185 state.api_server = token_resp.api_server.clone();
186 state.refresh_token = token_resp.refresh_token.clone();
187 state.expires_at = expires_at;
188
189 let result = (state.access_token.clone(), state.api_server.clone());
190 drop(state); (self.on_token_refresh)(token_resp);
194
195 Ok(result)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use std::sync::Mutex;
203 use wiremock::matchers::{method, path, query_param};
204 use wiremock::{Mock, MockServer, ResponseTemplate};
205
206 fn mock_token_body(refresh: &str) -> serde_json::Value {
207 serde_json::json!({
208 "access_token": "acc_123",
209 "token_type": "Bearer",
210 "expires_in": 1800,
211 "refresh_token": refresh,
212 "api_server": "https://api01.iq.questrade.com/"
213 })
214 }
215
216 #[tokio::test]
217 async fn callback_invoked_with_new_token_on_refresh() {
218 let server = MockServer::start().await;
219 Mock::given(method("GET"))
220 .and(path("/oauth2/token"))
221 .and(query_param("grant_type", "refresh_token"))
222 .and(query_param("refresh_token", "seed_token"))
223 .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("rotated")))
224 .mount(&server)
225 .await;
226
227 let seen: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(vec![]));
228 let seen_clone = seen.clone();
229 let cb: OnTokenRefresh = Arc::new(move |t: TokenResponse| {
230 seen_clone.lock().unwrap().push(t.refresh_token.clone());
231 });
232
233 TokenManager::new_with_login_url("seed_token".to_string(), Some(cb), server.uri(), None)
234 .await
235 .unwrap();
236
237 assert_eq!(*seen.lock().unwrap(), vec!["rotated"]);
238 }
239
240 #[tokio::test]
241 async fn token_with_reserved_url_characters_is_encoded() {
242 let tricky_token = "abc+def==&ghi";
245 let server = MockServer::start().await;
246 Mock::given(method("GET"))
247 .and(path("/oauth2/token"))
248 .and(query_param("grant_type", "refresh_token"))
249 .and(query_param("refresh_token", tricky_token))
250 .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("rotated")))
251 .mount(&server)
252 .await;
253
254 let result =
255 TokenManager::new_with_login_url(tricky_token.to_string(), None, server.uri(), None)
256 .await;
257 assert!(result.is_ok(), "token with reserved chars should succeed");
258 }
259
260 #[tokio::test]
261 async fn no_callback_constructs_successfully() {
262 let server = MockServer::start().await;
263 Mock::given(method("GET"))
264 .and(path("/oauth2/token"))
265 .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("tok")))
266 .mount(&server)
267 .await;
268
269 let result =
270 TokenManager::new_with_login_url("any".to_string(), None, server.uri(), None).await;
271 assert!(result.is_ok());
272 }
273
274 #[tokio::test]
275 async fn cached_token_skips_initial_refresh() {
276 let cached = CachedToken {
279 access_token: "cached_acc".to_string(),
280 api_server: "https://api05.iq.questrade.com/".to_string(),
281 expires_at: OffsetDateTime::now_utc() + time::Duration::minutes(25),
282 };
283
284 let manager = TokenManager::new_with_login_url(
285 "unused_refresh".to_string(),
286 None,
287 "http://127.0.0.1:1".to_string(), Some(cached),
289 )
290 .await
291 .unwrap();
292
293 let (token, server) = manager.get_token().await.unwrap();
294 assert_eq!(token, "cached_acc");
295 assert_eq!(server, "https://api05.iq.questrade.com/");
296 }
297
298 #[tokio::test]
299 async fn expired_cached_token_triggers_refresh() {
300 let server = MockServer::start().await;
301 Mock::given(method("GET"))
302 .and(path("/oauth2/token"))
303 .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("fresh")))
304 .expect(1)
305 .mount(&server)
306 .await;
307
308 let expired = CachedToken {
309 access_token: "stale".to_string(),
310 api_server: "https://old.example.com/".to_string(),
311 expires_at: OffsetDateTime::now_utc() - time::Duration::seconds(1),
312 };
313
314 let manager =
315 TokenManager::new_with_login_url("rt".to_string(), None, server.uri(), Some(expired))
316 .await
317 .unwrap();
318
319 let (token, _) = manager.get_token().await.unwrap();
320 assert_eq!(token, "acc_123");
321 }
322
323 #[tokio::test]
324 async fn force_refresh_bypasses_valid_cached_token() {
325 let server = MockServer::start().await;
326 Mock::given(method("GET"))
327 .and(path("/oauth2/token"))
328 .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("refreshed")))
329 .expect(1) .mount(&server)
331 .await;
332
333 let cached = CachedToken {
335 access_token: "old_acc".to_string(),
336 api_server: "https://api01.iq.questrade.com/".to_string(),
337 expires_at: OffsetDateTime::now_utc() + time::Duration::minutes(25),
338 };
339
340 let manager =
341 TokenManager::new_with_login_url("rt".to_string(), None, server.uri(), Some(cached))
342 .await
343 .unwrap();
344
345 let (token, _) = manager.get_token().await.unwrap();
347 assert_eq!(token, "old_acc");
348
349 let (token, _) = manager.force_refresh().await.unwrap();
351 assert_eq!(token, "acc_123"); }
353}