Skip to main content

questrade_client/
auth.rs

1//! OAuth token management for the Questrade API.
2
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use time::OffsetDateTime;
7use tokio::sync::RwLock;
8use tracing::{debug, info};
9
10use crate::error::{QuestradeError, Result};
11
12/// Token response from the Questrade OAuth endpoint.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TokenResponse {
15    /// Short-lived Bearer access token used to authenticate API requests.
16    pub access_token: String,
17    /// Token type; always `"Bearer"`.
18    pub token_type: String,
19    /// Token lifetime in seconds (typically 1800).
20    pub expires_in: u64,
21    /// Single-use refresh token. **Must be persisted** after every refresh —
22    /// using an old token will result in an authentication failure.
23    pub refresh_token: String,
24    /// Base URL for API requests (e.g. `"https://api01.iq.questrade.com/"`).
25    /// May change between refreshes; always use the most recently received value.
26    pub api_server: String,
27}
28
29/// Callback invoked whenever a token refresh completes successfully.
30/// Receives the full `TokenResponse`; the caller is responsible for persisting
31/// the new `refresh_token` (Questrade refresh tokens are single-use).
32pub type OnTokenRefresh = Arc<dyn Fn(TokenResponse) + Send + Sync>;
33
34/// Pre-existing token state that can be passed to [`TokenManager::new`] to
35/// skip the initial token refresh when a valid cached token is available.
36pub struct CachedToken {
37    /// Bearer access token from a previous session.
38    pub access_token: String,
39    /// API server URL that was returned alongside this access token.
40    pub api_server: String,
41    /// When this access token expires.
42    pub expires_at: OffsetDateTime,
43}
44
45/// Manages Questrade OAuth tokens with auto-refresh.
46#[derive(Clone)]
47pub struct TokenManager {
48    inner: Arc<RwLock<TokenState>>,
49    login_url: String,
50    on_token_refresh: OnTokenRefresh,
51}
52
53struct TokenState {
54    access_token: String,
55    api_server: String,
56    refresh_token: String,
57    expires_at: OffsetDateTime,
58}
59
60impl TokenManager {
61    /// Create a new TokenManager with the given refresh token.
62    ///
63    /// `on_token_refresh` is called whenever the token is refreshed; pass `None`
64    /// for a no-op (e.g. in tests that don't need persistence).
65    ///
66    /// If `cached_token` is provided and still valid, the initial token refresh
67    /// is skipped and the cached credentials are used directly.
68    pub async fn new(
69        refresh_token: String,
70        practice: bool,
71        on_token_refresh: Option<OnTokenRefresh>,
72        cached_token: Option<CachedToken>,
73    ) -> Result<Self> {
74        let login_url = if practice {
75            "https://practicelogin.questrade.com".to_string()
76        } else {
77            "https://login.questrade.com".to_string()
78        };
79        Self::new_with_login_url(refresh_token, on_token_refresh, login_url, cached_token).await
80    }
81
82    /// Like [`Self::new`] but accepts an explicit login URL.
83    /// Used internally and in tests (e.g. to point at a wiremock server).
84    pub async fn new_with_login_url(
85        refresh_token: String,
86        on_token_refresh: Option<OnTokenRefresh>,
87        login_url: String,
88        cached_token: Option<CachedToken>,
89    ) -> Result<Self> {
90        let cb: OnTokenRefresh = on_token_refresh.unwrap_or_else(|| Arc::new(|_| {}));
91
92        // Use cached token if provided and still valid, otherwise start expired.
93        let (access_token, api_server, expires_at) =
94            if let Some(ct) = cached_token.filter(|ct| OffsetDateTime::now_utc() < ct.expires_at) {
95                info!("reusing cached Questrade access token");
96                (ct.access_token, ct.api_server, ct.expires_at)
97            } else {
98                (String::new(), String::new(), OffsetDateTime::UNIX_EPOCH)
99            };
100
101        let manager = Self {
102            inner: Arc::new(RwLock::new(TokenState {
103                access_token,
104                api_server,
105                refresh_token,
106                expires_at,
107            })),
108            login_url,
109            on_token_refresh: cb,
110        };
111
112        // Only refresh if we don't have a valid token.
113        if manager.inner.read().await.access_token.is_empty() {
114            manager.refresh().await?;
115        }
116
117        Ok(manager)
118    }
119
120    /// Get a valid access token and API server URL, refreshing if needed.
121    pub async fn get_token(&self) -> Result<(String, String)> {
122        {
123            let state = self.inner.read().await;
124            if OffsetDateTime::now_utc() < state.expires_at {
125                return Ok((state.access_token.clone(), state.api_server.clone()));
126            }
127        }
128        // Token expired, refresh
129        self.refresh().await
130    }
131
132    /// Force a token refresh even if the current token has not expired.
133    ///
134    /// Used when the server returns 401 Unauthorized, indicating the access
135    /// token was revoked server-side before its stated expiry.
136    pub async fn force_refresh(&self) -> Result<(String, String)> {
137        {
138            let mut state = self.inner.write().await;
139            state.expires_at = OffsetDateTime::UNIX_EPOCH;
140            state.access_token.clear();
141        }
142        self.refresh().await
143    }
144
145    async fn refresh(&self) -> Result<(String, String)> {
146        let mut state = self.inner.write().await;
147
148        // Double-check after acquiring write lock
149        if OffsetDateTime::now_utc() < state.expires_at && !state.access_token.is_empty() {
150            return Ok((state.access_token.clone(), state.api_server.clone()));
151        }
152
153        info!("refreshing Questrade access token");
154
155        let client = reqwest::Client::builder()
156            .connect_timeout(std::time::Duration::from_secs(10))
157            .timeout(std::time::Duration::from_secs(30))
158            .build()
159            .unwrap_or_default();
160        let url = format!("{}/oauth2/token", self.login_url);
161
162        let resp = client
163            .get(&url)
164            .query(&[
165                ("grant_type", "refresh_token"),
166                ("refresh_token", state.refresh_token.as_str()),
167            ])
168            .send()
169            .await?;
170
171        if !resp.status().is_success() {
172            let status = resp.status();
173            let body = resp.text().await.unwrap_or_default();
174            return Err(QuestradeError::TokenRefresh { status, body });
175        }
176
177        let token_resp: TokenResponse = resp.json().await?;
178
179        debug!(api_server = %token_resp.api_server, "new API server");
180
181        let expires_at =
182            OffsetDateTime::now_utc() + time::Duration::seconds(token_resp.expires_in as i64 - 30); // 30s buffer
183
184        state.access_token = token_resp.access_token.clone();
185        state.api_server = token_resp.api_server.clone();
186        state.refresh_token = token_resp.refresh_token.clone();
187        state.expires_at = expires_at;
188
189        let result = (state.access_token.clone(), state.api_server.clone());
190        drop(state); // release lock before invoking callback to prevent deadlock
191
192        // Notify caller — token persistence is the caller's responsibility.
193        (self.on_token_refresh)(token_resp);
194
195        Ok(result)
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use std::sync::Mutex;
203    use wiremock::matchers::{method, path, query_param};
204    use wiremock::{Mock, MockServer, ResponseTemplate};
205
206    fn mock_token_body(refresh: &str) -> serde_json::Value {
207        serde_json::json!({
208            "access_token": "acc_123",
209            "token_type": "Bearer",
210            "expires_in": 1800,
211            "refresh_token": refresh,
212            "api_server": "https://api01.iq.questrade.com/"
213        })
214    }
215
216    #[tokio::test]
217    async fn callback_invoked_with_new_token_on_refresh() {
218        let server = MockServer::start().await;
219        Mock::given(method("GET"))
220            .and(path("/oauth2/token"))
221            .and(query_param("grant_type", "refresh_token"))
222            .and(query_param("refresh_token", "seed_token"))
223            .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("rotated")))
224            .mount(&server)
225            .await;
226
227        let seen: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(vec![]));
228        let seen_clone = seen.clone();
229        let cb: OnTokenRefresh = Arc::new(move |t: TokenResponse| {
230            seen_clone.lock().unwrap().push(t.refresh_token.clone());
231        });
232
233        TokenManager::new_with_login_url("seed_token".to_string(), Some(cb), server.uri(), None)
234            .await
235            .unwrap();
236
237        assert_eq!(*seen.lock().unwrap(), vec!["rotated"]);
238    }
239
240    #[tokio::test]
241    async fn token_with_reserved_url_characters_is_encoded() {
242        // Tokens containing '+', '=', '&' must be percent-encoded so they are
243        // not misinterpreted as query-string delimiters.
244        let tricky_token = "abc+def==&ghi";
245        let server = MockServer::start().await;
246        Mock::given(method("GET"))
247            .and(path("/oauth2/token"))
248            .and(query_param("grant_type", "refresh_token"))
249            .and(query_param("refresh_token", tricky_token))
250            .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("rotated")))
251            .mount(&server)
252            .await;
253
254        let result =
255            TokenManager::new_with_login_url(tricky_token.to_string(), None, server.uri(), None)
256                .await;
257        assert!(result.is_ok(), "token with reserved chars should succeed");
258    }
259
260    #[tokio::test]
261    async fn no_callback_constructs_successfully() {
262        let server = MockServer::start().await;
263        Mock::given(method("GET"))
264            .and(path("/oauth2/token"))
265            .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("tok")))
266            .mount(&server)
267            .await;
268
269        let result =
270            TokenManager::new_with_login_url("any".to_string(), None, server.uri(), None).await;
271        assert!(result.is_ok());
272    }
273
274    #[tokio::test]
275    async fn cached_token_skips_initial_refresh() {
276        // No mock server needed — if a refresh were attempted it would fail
277        // because there's nothing to connect to.
278        let cached = CachedToken {
279            access_token: "cached_acc".to_string(),
280            api_server: "https://api05.iq.questrade.com/".to_string(),
281            expires_at: OffsetDateTime::now_utc() + time::Duration::minutes(25),
282        };
283
284        let manager = TokenManager::new_with_login_url(
285            "unused_refresh".to_string(),
286            None,
287            "http://127.0.0.1:1".to_string(), // unreachable — proves no refresh happens
288            Some(cached),
289        )
290        .await
291        .unwrap();
292
293        let (token, server) = manager.get_token().await.unwrap();
294        assert_eq!(token, "cached_acc");
295        assert_eq!(server, "https://api05.iq.questrade.com/");
296    }
297
298    #[tokio::test]
299    async fn expired_cached_token_triggers_refresh() {
300        let server = MockServer::start().await;
301        Mock::given(method("GET"))
302            .and(path("/oauth2/token"))
303            .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("fresh")))
304            .expect(1)
305            .mount(&server)
306            .await;
307
308        let expired = CachedToken {
309            access_token: "stale".to_string(),
310            api_server: "https://old.example.com/".to_string(),
311            expires_at: OffsetDateTime::now_utc() - time::Duration::seconds(1),
312        };
313
314        let manager =
315            TokenManager::new_with_login_url("rt".to_string(), None, server.uri(), Some(expired))
316                .await
317                .unwrap();
318
319        let (token, _) = manager.get_token().await.unwrap();
320        assert_eq!(token, "acc_123");
321    }
322
323    #[tokio::test]
324    async fn force_refresh_bypasses_valid_cached_token() {
325        let server = MockServer::start().await;
326        Mock::given(method("GET"))
327            .and(path("/oauth2/token"))
328            .respond_with(ResponseTemplate::new(200).set_body_json(mock_token_body("refreshed")))
329            .expect(1) // exactly one refresh call expected
330            .mount(&server)
331            .await;
332
333        // Start with a valid cached token — normally no refresh would happen.
334        let cached = CachedToken {
335            access_token: "old_acc".to_string(),
336            api_server: "https://api01.iq.questrade.com/".to_string(),
337            expires_at: OffsetDateTime::now_utc() + time::Duration::minutes(25),
338        };
339
340        let manager =
341            TokenManager::new_with_login_url("rt".to_string(), None, server.uri(), Some(cached))
342                .await
343                .unwrap();
344
345        // Confirm cached token is being used.
346        let (token, _) = manager.get_token().await.unwrap();
347        assert_eq!(token, "old_acc");
348
349        // Force refresh should bypass the valid cache and hit the OAuth endpoint.
350        let (token, _) = manager.force_refresh().await.unwrap();
351        assert_eq!(token, "acc_123"); // from mock_token_body
352    }
353}