1use crate::broker::{
11 Account, BrokerClient, BrokerError, HealthStatus, OrderFilter, Position, PositionSide,
12};
13use crate::{OrderRequest, OrderResponse, OrderSide, OrderStatus, OrderType, Symbol, TimeInForce};
14use async_trait::async_trait;
15use chrono::{DateTime, Duration as ChronoDuration, Utc};
16use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
17use reqwest::{Client, Method, StatusCode};
18use rust_decimal::Decimal;
19use serde::{Deserialize, Serialize};
20use std::num::NonZeroU32;
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::RwLock;
24use tracing::{debug, error, info, warn};
25use uuid;
26
27#[derive(Debug, Clone)]
29pub struct QuestradeConfig {
30 pub refresh_token: String,
32 pub practice: bool,
34 pub timeout: Duration,
36}
37
38impl Default for QuestradeConfig {
39 fn default() -> Self {
40 Self {
41 refresh_token: String::new(),
42 practice: true,
43 timeout: Duration::from_secs(30),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50struct OAuthToken {
51 access_token: String,
52 token_type: String,
53 expires_in: i64,
54 refresh_token: String,
55 api_server: String,
56 #[serde(skip)]
57 expires_at: Option<DateTime<Utc>>,
58}
59
60pub struct QuestradeBroker {
62 client: Client,
63 config: QuestradeConfig,
64 token: Arc<RwLock<Option<OAuthToken>>>,
65 rate_limiter: DefaultDirectRateLimiter,
66 account_number: Arc<RwLock<Option<String>>>,
67}
68
69impl QuestradeBroker {
70 pub fn new(config: QuestradeConfig) -> Self {
72 let client = Client::builder()
73 .timeout(config.timeout)
74 .build()
75 .expect("Failed to create HTTP client");
76
77 let quota = Quota::per_second(NonZeroU32::new(1).unwrap());
79 let rate_limiter = RateLimiter::direct(quota);
80
81 Self {
82 client,
83 config,
84 token: Arc::new(RwLock::new(None)),
85 rate_limiter,
86 account_number: Arc::new(RwLock::new(None)),
87 }
88 }
89
90 pub async fn authenticate(&self) -> Result<(), BrokerError> {
92 let url = if self.config.practice {
93 "https://practice.login.questrade.com/oauth2/token"
94 } else {
95 "https://login.questrade.com/oauth2/token"
96 };
97
98 let params = [
99 ("grant_type", "refresh_token"),
100 ("refresh_token", &self.config.refresh_token),
101 ];
102
103 debug!("Authenticating with Questrade");
104
105 let response = self
106 .client
107 .post(url)
108 .form(¶ms)
109 .send()
110 .await?;
111
112 if response.status().is_success() {
113 let mut token: OAuthToken = response.json().await?;
114 token.expires_at = Some(Utc::now() + ChronoDuration::seconds(token.expires_in));
115
116 info!("Questrade authentication successful, expires at {:?}", token.expires_at);
117 *self.token.write().await = Some(token);
118
119 Box::pin(self.load_account_number()).await?;
121
122 Ok(())
123 } else {
124 let error_text = response.text().await.unwrap_or_default();
125 error!("Questrade authentication failed: {}", error_text);
126 Err(BrokerError::Auth(format!("Authentication failed: {}", error_text)))
127 }
128 }
129
130 async fn ensure_authenticated(&self) -> Result<(), BrokerError> {
132 let token_guard = self.token.read().await;
133 if let Some(token) = token_guard.as_ref() {
134 if let Some(expires_at) = token.expires_at {
135 if Utc::now() < expires_at - ChronoDuration::minutes(5) {
136 return Ok(());
137 }
138 }
139 }
140 drop(token_guard);
141
142 Box::pin(self.authenticate()).await
144 }
145
146 async fn api_server(&self) -> Result<String, BrokerError> {
148 let token_guard = self.token.read().await;
149 token_guard
150 .as_ref()
151 .map(|t| t.api_server.clone())
152 .ok_or_else(|| BrokerError::Auth("Not authenticated".to_string()))
153 }
154
155 async fn load_account_number(&self) -> Result<(), BrokerError> {
157 #[derive(Deserialize)]
158 struct AccountsResponse {
159 accounts: Vec<QuestradeAccount>,
160 }
161
162 #[derive(Deserialize)]
163 struct QuestradeAccount {
164 number: String,
165 #[serde(rename = "type")]
166 account_type: String,
167 status: String,
168 #[serde(rename = "isPrimary")]
169 is_primary: bool,
170 }
171
172 let response: AccountsResponse = self.request_internal(Method::GET, "/v1/accounts", None::<()>).await?;
174
175 let primary_account = response
176 .accounts
177 .into_iter()
178 .find(|acc| acc.is_primary && acc.status == "Active")
179 .ok_or_else(|| BrokerError::Other(anyhow::anyhow!("No active primary account found")))?;
180
181 *self.account_number.write().await = Some(primary_account.number.clone());
182 info!("Using Questrade account: {}", primary_account.number);
183
184 Ok(())
185 }
186
187 async fn request_internal<T: serde::de::DeserializeOwned>(
189 &self,
190 method: Method,
191 path: &str,
192 body: Option<impl Serialize>,
193 ) -> Result<T, BrokerError> {
194 self.rate_limiter.until_ready().await;
195
196 let api_server = self.api_server().await?;
197 let url = format!("{}{}", api_server, path);
198
199 let token_guard = self.token.read().await;
200 let access_token = token_guard
201 .as_ref()
202 .map(|t| t.access_token.clone())
203 .ok_or_else(|| BrokerError::Auth("No access token".to_string()))?;
204 drop(token_guard);
205
206 let mut req = self
207 .client
208 .request(method.clone(), &url)
209 .header("Authorization", format!("Bearer {}", access_token));
210
211 if let Some(body) = body {
212 req = req.json(&body);
213 }
214
215 debug!("Questrade API request: {} {}", method, path);
216
217 let response = req.send().await?;
218
219 match response.status() {
220 StatusCode::OK | StatusCode::CREATED => {
221 let result = response.json().await?;
222 Ok(result)
223 }
224 StatusCode::UNAUTHORIZED => {
225 warn!("Questrade token expired, re-authenticating");
227 self.authenticate().await?;
228 Err(BrokerError::Auth("Token expired, please retry".to_string()))
229 }
230 StatusCode::TOO_MANY_REQUESTS => Err(BrokerError::RateLimit),
231 status => {
232 let error_text = response.text().await.unwrap_or_default();
233 error!("Questrade API error {}: {}", status, error_text);
234 Err(BrokerError::Other(anyhow::anyhow!("HTTP {}: {}", status, error_text)))
235 }
236 }
237 }
238
239 async fn request<T: serde::de::DeserializeOwned>(
241 &self,
242 method: Method,
243 path: &str,
244 body: Option<impl Serialize>,
245 ) -> Result<T, BrokerError> {
246 self.ensure_authenticated().await?;
247 self.request_internal(method, path, body).await
248 }
249
250 async fn get_account_number(&self) -> Result<String, BrokerError> {
252 let account_guard = self.account_number.read().await;
253 account_guard
254 .as_ref()
255 .cloned()
256 .ok_or_else(|| BrokerError::Auth("Account number not loaded".to_string()))
257 }
258}
259
260#[async_trait]
261impl BrokerClient for QuestradeBroker {
262 async fn get_account(&self) -> Result<Account, BrokerError> {
263 let account_number = self.get_account_number().await?;
264
265 #[derive(Deserialize)]
266 struct BalancesResponse {
267 #[serde(rename = "perCurrencyBalances")]
268 per_currency_balances: Vec<CurrencyBalance>,
269 }
270
271 #[derive(Deserialize)]
272 struct CurrencyBalance {
273 currency: String,
274 cash: Decimal,
275 #[serde(rename = "marketValue")]
276 market_value: Decimal,
277 #[serde(rename = "totalEquity")]
278 total_equity: Decimal,
279 #[serde(rename = "buyingPower")]
280 buying_power: Decimal,
281 }
282
283 let response: BalancesResponse = self
284 .request(
285 Method::GET,
286 &format!("/v1/accounts/{}/balances", account_number),
287 None::<()>,
288 )
289 .await?;
290
291 let cad_balance = response
293 .per_currency_balances
294 .into_iter()
295 .find(|b| b.currency == "CAD")
296 .ok_or_else(|| BrokerError::Other(anyhow::anyhow!("No CAD balance found")))?;
297
298 Ok(Account {
299 account_id: account_number,
300 cash: cad_balance.cash,
301 portfolio_value: cad_balance.total_equity,
302 buying_power: cad_balance.buying_power,
303 equity: cad_balance.total_equity,
304 last_equity: cad_balance.total_equity,
305 multiplier: "1".to_string(),
306 currency: "CAD".to_string(),
307 shorting_enabled: false,
308 long_market_value: cad_balance.market_value,
309 short_market_value: Decimal::ZERO,
310 initial_margin: Decimal::ZERO,
311 maintenance_margin: Decimal::ZERO,
312 day_trading_buying_power: cad_balance.buying_power,
313 daytrade_count: 0,
314 })
315 }
316
317 async fn get_positions(&self) -> Result<Vec<Position>, BrokerError> {
318 let account_number = self.get_account_number().await?;
319
320 #[derive(Deserialize)]
321 struct PositionsResponse {
322 positions: Vec<QuestradePosition>,
323 }
324
325 #[derive(Deserialize)]
326 struct QuestradePosition {
327 symbol: String,
328 #[serde(rename = "symbolId")]
329 symbol_id: i64,
330 #[serde(rename = "openQuantity")]
331 open_quantity: i64,
332 #[serde(rename = "currentMarketValue")]
333 current_market_value: Decimal,
334 #[serde(rename = "currentPrice")]
335 current_price: Decimal,
336 #[serde(rename = "averageEntryPrice")]
337 average_entry_price: Decimal,
338 #[serde(rename = "openPnl")]
339 open_pnl: Decimal,
340 }
341
342 let response: PositionsResponse = self
343 .request(
344 Method::GET,
345 &format!("/v1/accounts/{}/positions", account_number),
346 None::<()>,
347 )
348 .await?;
349
350 Ok(response
351 .positions
352 .into_iter()
353 .map(|p| Position {
354 symbol: Symbol::new(p.symbol.as_str()).expect("Invalid symbol from Questrade"),
355 qty: p.open_quantity,
356 side: if p.open_quantity > 0 {
357 PositionSide::Long
358 } else {
359 PositionSide::Short
360 },
361 avg_entry_price: p.average_entry_price,
362 market_value: p.current_market_value,
363 cost_basis: p.average_entry_price * Decimal::from(p.open_quantity.abs()),
364 unrealized_pl: p.open_pnl,
365 unrealized_plpc: if p.current_market_value != Decimal::ZERO {
366 (p.open_pnl / p.current_market_value.abs()) * Decimal::from(100)
367 } else {
368 Decimal::ZERO
369 },
370 current_price: p.current_price,
371 lastday_price: p.current_price,
372 change_today: Decimal::ZERO,
373 })
374 .collect())
375 }
376
377 async fn place_order(&self, order: OrderRequest) -> Result<OrderResponse, BrokerError> {
378 let account_number = self.get_account_number().await?;
379
380 #[derive(Serialize)]
381 struct QuestradeOrderRequest {
382 #[serde(rename = "accountNumber")]
383 account_number: String,
384 #[serde(rename = "symbolId")]
385 symbol_id: i64,
386 quantity: i64,
387 #[serde(rename = "orderType")]
388 order_type: String,
389 #[serde(rename = "timeInForce")]
390 time_in_force: String,
391 #[serde(rename = "action")]
392 action: String,
393 #[serde(skip_serializing_if = "Option::is_none")]
394 #[serde(rename = "limitPrice")]
395 limit_price: Option<Decimal>,
396 #[serde(skip_serializing_if = "Option::is_none")]
397 #[serde(rename = "stopPrice")]
398 stop_price: Option<Decimal>,
399 }
400
401 let symbol_id = 0; let order_type = match order.order_type {
406 OrderType::Market => "Market",
407 OrderType::Limit => "Limit",
408 OrderType::StopLoss => "Stop",
409 OrderType::StopLimit => "StopLimit",
410 };
411
412 let action = match order.side {
413 OrderSide::Buy => "Buy",
414 OrderSide::Sell => "Sell",
415 };
416
417 let time_in_force = match order.time_in_force {
418 TimeInForce::Day => "Day",
419 TimeInForce::GTC => "GoodTillCanceled",
420 _ => "Day",
421 };
422
423 let req = QuestradeOrderRequest {
424 account_number: account_number.clone(),
425 symbol_id,
426 quantity: order.quantity as i64,
427 order_type: order_type.to_string(),
428 time_in_force: time_in_force.to_string(),
429 action: action.to_string(),
430 limit_price: order.limit_price,
431 stop_price: order.stop_price,
432 };
433
434 #[derive(Deserialize)]
435 struct OrderResponse {
436 #[serde(rename = "orderId")]
437 order_id: i64,
438 }
439
440 let response: OrderResponse = self
441 .request(
442 Method::POST,
443 &format!("/v1/accounts/{}/orders", account_number),
444 Some(req),
445 )
446 .await?;
447
448 Ok(crate::OrderResponse {
449 order_id: response.order_id.to_string(),
450 client_order_id: uuid::Uuid::new_v4().to_string(),
451 status: OrderStatus::Accepted,
452 filled_qty: 0,
453 filled_avg_price: None,
454 submitted_at: Utc::now(),
455 filled_at: None,
456 })
457 }
458
459 async fn cancel_order(&self, order_id: &str) -> Result<(), BrokerError> {
460 let account_number = self.get_account_number().await?;
461
462 let _: serde_json::Value = self
463 .request(
464 Method::DELETE,
465 &format!("/v1/accounts/{}/orders/{}", account_number, order_id),
466 None::<()>,
467 )
468 .await?;
469
470 Ok(())
471 }
472
473 async fn get_order(&self, order_id: &str) -> Result<OrderResponse, BrokerError> {
474 Err(BrokerError::Other(anyhow::anyhow!("Not implemented")))
475 }
476
477 async fn list_orders(&self, _filter: OrderFilter) -> Result<Vec<OrderResponse>, BrokerError> {
478 Ok(Vec::new())
479 }
480
481 async fn health_check(&self) -> Result<HealthStatus, BrokerError> {
482 match self.ensure_authenticated().await {
483 Ok(_) => Ok(HealthStatus::Healthy),
484 Err(_) => Ok(HealthStatus::Unhealthy),
485 }
486 }
487}