1use std::{env, fmt, sync::Arc};
2
3use chrono::{DateTime, Utc};
4use reqwest::{header::HeaderMap, Client, RequestBuilder, Response, StatusCode};
5use tokio::sync::Mutex;
6use tracing::info;
7use urlencoding::encode;
8
9use crate::{
10 schwab::{
11 common::{SCHWAB_MARKET_DATA_API_URL, SCHWAB_TRADER_API_URL, TOKENS_FILE},
12 models::{
13 market_data::{
14 ChainsResponse, ExpirationChainResponse, InstrumentsResponse, MarketHours,
15 MarketHoursResponse, MoversResponse, PriceHistoryResponse, QuotesResponse,
16 },
17 trader::UserPreferencesResponse,
18 },
19 schwab_auth::{SchwabAuth, StoredTokenInfo},
20 },
21 util::{dedup_ordered, parse_params, time_to_epoch_ms, time_to_yyyymmdd},
22};
23
24pub enum ContractType {
26 Call,
28 Put,
30 All,
32}
33
34impl fmt::Display for ContractType {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 ContractType::All => write!(f, "ALL"),
38 ContractType::Call => write!(f, "CALL"),
39 ContractType::Put => write!(f, "PUT"),
40 }
41 }
42}
43
44#[derive(Eq, PartialEq, Hash, Clone)]
46pub enum QuoteFields {
47 Quote,
49 Fundamental,
51 Extended,
53 Reference,
55 Regular,
57}
58
59impl fmt::Display for QuoteFields {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self {
62 QuoteFields::Quote => write!(f, "quote"),
63 QuoteFields::Fundamental => write!(f, "fundamental"),
64 QuoteFields::Extended => write!(f, "extended"),
65 QuoteFields::Reference => write!(f, "reference"),
66 QuoteFields::Regular => write!(f, "regular"),
67 }
68 }
69}
70
71#[derive(Eq, PartialEq, Hash, Clone)]
73pub enum PeriodType {
74 Day,
76 Month,
78 Year,
80 Ytd,
82}
83
84impl fmt::Display for PeriodType {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 match self {
87 PeriodType::Day => write!(f, "day"),
88 PeriodType::Month => write!(f, "month"),
89 PeriodType::Year => write!(f, "year"),
90 PeriodType::Ytd => write!(f, "ytd"),
91 }
92 }
93}
94
95#[derive(Eq, PartialEq, Hash, Clone)]
97pub enum FrequencyType {
98 Minute,
100 Daily,
102 Weekly,
104 Monthly,
106}
107
108impl fmt::Display for FrequencyType {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 match self {
111 FrequencyType::Minute => write!(f, "minute"),
112 FrequencyType::Daily => write!(f, "daily"),
113 FrequencyType::Weekly => write!(f, "weekly"),
114 FrequencyType::Monthly => write!(f, "monthly"),
115 }
116 }
117}
118
119#[derive(Eq, PartialEq, Hash, Clone)]
121pub enum Sort {
122 Volume,
124 Trades,
126 PercentChangeUp,
128 PercentChangeDown,
130}
131
132impl fmt::Display for Sort {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 match self {
135 Sort::Volume => write!(f, "VOLUME"),
136 Sort::Trades => write!(f, "TRADES"),
137 Sort::PercentChangeUp => write!(f, "PERCENT_CHANGE_UP"),
138 Sort::PercentChangeDown => write!(f, "PERCENT_CHANGE_DOWN"),
139 }
140 }
141}
142
143#[derive(Eq, PartialEq, Hash, Clone)]
145pub enum Projection {
146 SymbolSearch,
148 SymbolRegex,
150 DescSearch,
152 DescRegex,
154 Search,
156 Fundamental,
158}
159
160impl fmt::Display for Projection {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 match self {
163 Projection::SymbolSearch => write!(f, "symbol-search"),
164 Projection::SymbolRegex => write!(f, "symbol-regex"),
165 Projection::DescSearch => write!(f, "desc-search"),
166 Projection::DescRegex => write!(f, "desc-regex"),
167 Projection::Search => write!(f, "search"),
168 Projection::Fundamental => write!(f, "fundamental"),
169 }
170 }
171}
172
173#[derive(Eq, PartialEq, Hash, Clone)]
175pub enum MarketSymbol {
176 Equity,
178 Option,
180 Bond,
182 Future,
184 Forex,
186}
187
188impl fmt::Display for MarketSymbol {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 match self {
191 MarketSymbol::Equity => write!(f, "equity"),
192 MarketSymbol::Option => write!(f, "option"),
193 MarketSymbol::Bond => write!(f, "bond"),
194 MarketSymbol::Future => write!(f, "future"),
195 MarketSymbol::Forex => write!(f, "forex"),
196 }
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct SchwabApi {
203 reqwest_client: Arc<Client>,
204 app_key: String,
205 app_secret: String,
206 tokens_file_path: String,
207 auth: SchwabAuth,
208 token_info: Arc<Mutex<StoredTokenInfo>>,
209}
210
211impl SchwabApi {
212 pub async fn new(
219 app_key: String,
220 app_secret: String,
221 tokens_file_path: String,
222 ) -> anyhow::Result<Self> {
223 let reqwest_client = Arc::new(Client::new());
224 let auth = SchwabAuth::new(reqwest_client.clone(), tokens_file_path.clone());
225
226 let json_string = tokio::fs::read_to_string(&tokens_file_path).await?;
227 let token_info: StoredTokenInfo = serde_json::from_str(&json_string)?;
228
229 Ok(Self {
230 reqwest_client,
231 app_key,
232 app_secret,
233 tokens_file_path,
234 auth,
235 token_info: Arc::new(Mutex::new(token_info)),
236 })
237 }
238
239 pub async fn default() -> anyhow::Result<Self> {
242 let app_key = env::var("SCHWAB_APP_KEY")
243 .map_err(|_| anyhow::anyhow!("SCHWAB_APP_KEY environment variable not set"))?;
244 let app_secret = env::var("SCHWAB_APP_SECRET")
245 .map_err(|_| anyhow::anyhow!("SCHWAB_APP_SECRET environment variable not set"))?;
246 Self::new(app_key, app_secret, TOKENS_FILE.to_owned()).await
247 }
248
249 async fn send_request(&self, mut builder: RequestBuilder) -> anyhow::Result<Response> {
251 let headers = self.construct_request_headers().await?;
253 builder = builder.headers(headers);
254
255 let retry_builder = builder
257 .try_clone()
258 .ok_or_else(|| anyhow::anyhow!("Failed to clone request for potential retry"))?;
259
260 let response = builder.send().await?;
262
263 if response.status() == StatusCode::UNAUTHORIZED {
265 info!("Token expired. Attempting to refresh...");
266 self.refresh_and_store_token().await?;
267
268 let retry_headers = self.construct_request_headers().await?;
270 let response = retry_builder.headers(retry_headers).send().await?;
271
272 info!("Request successful after token refresh.");
273 return Ok(response);
274 }
275
276 Ok(response)
277 }
278
279 async fn refresh_and_store_token(&self) -> anyhow::Result<()> {
281 let refresh_token = {
282 let token_data = self.token_info.lock().await;
283 token_data.refresh_token.clone()
284 };
285
286 let new_token_info = self
287 .auth
288 .refresh_tokens(&self.app_key, &self.app_secret, &refresh_token)
289 .await?;
290
291 {
293 let mut token_data = self.token_info.lock().await;
294 *token_data = new_token_info.clone();
295 }
296
297 let json_string = serde_json::to_string_pretty(&new_token_info)?;
299 tokio::fs::write(&self.tokens_file_path, json_string).await?;
300 info!("Successfully refreshed and stored new token.");
301
302 Ok(())
303 }
304
305 async fn construct_request_headers(&self) -> anyhow::Result<HeaderMap> {
307 let mut headers = HeaderMap::new();
308
309 let token_data = self.token_info.lock().await;
310 let auth_header = format!("Bearer {}", token_data.access_token);
311 headers.insert("Authorization", auth_header.parse()?);
312
313 Ok(headers)
314 }
315
316 pub async fn get_preferences(&self) -> anyhow::Result<UserPreferencesResponse> {
317 let builder = self
318 .reqwest_client
319 .get(format!("{SCHWAB_TRADER_API_URL}/userPreference"));
320
321 let response = self.send_request(builder).await?;
322 response.json().await.map_err(Into::into)
323 }
324
325 pub async fn get_quotes(
326 &self,
327 symbols: Vec<String>,
328 fields: Option<Vec<QuoteFields>>,
329 indicative: Option<bool>,
330 ) -> anyhow::Result<QuotesResponse> {
331 let url = format!("{}/quotes", SCHWAB_MARKET_DATA_API_URL);
332
333 let params = parse_params(vec![
334 ("symbols", Some(symbols.join(","))),
335 (
336 "fields",
337 fields.map(|v| {
338 dedup_ordered(v)
339 .iter()
340 .map(|f| f.to_string())
341 .collect::<Vec<String>>()
342 .join(",")
343 }),
344 ),
345 ("indicative", indicative.map(|v| v.to_string().to_lowercase())),
346 ]);
347
348 let builder = self.reqwest_client.get(url).query(¶ms);
349 let response = self.send_request(builder).await?;
350 response.json().await.map_err(Into::into)
351 }
352
353 pub async fn get_chains(
354 &self,
355 symbol: String,
356 contract_type: ContractType,
357 strike_count: u64,
358 include_underlying_quote: bool,
359 ) -> anyhow::Result<ChainsResponse> {
360 let url = format!("{}/chains", SCHWAB_MARKET_DATA_API_URL);
361
362 let params = parse_params(vec![
363 ("symbol", Some(symbol)),
364 ("contractType", Some(contract_type.to_string())),
365 ("strikeCount", Some(strike_count.to_string())),
366 (
367 "includeUnderlyingQuote",
368 Some(include_underlying_quote.to_string()),
369 ),
370 ]);
371
372 let builder = self.reqwest_client.get(url).query(¶ms);
373 let response = self.send_request(builder).await?;
374 response.json().await.map_err(Into::into)
375 }
376
377 pub async fn quote(
378 &self,
379 symbol_id: String,
380 fields: Option<Vec<QuoteFields>>,
381 ) -> anyhow::Result<QuotesResponse> {
382 let url = format!(
383 "{}/{}/quotes",
384 SCHWAB_MARKET_DATA_API_URL,
385 encode(&symbol_id)
386 );
387
388 let params = parse_params(vec![(
389 "fields",
390 fields.map(|v| {
391 dedup_ordered(v)
392 .iter()
393 .map(|f| f.to_string())
394 .collect::<Vec<String>>()
395 .join(",")
396 }),
397 )]);
398
399 let builder = self.reqwest_client.get(url).query(¶ms);
400 let response = self.send_request(builder).await?;
401 response.json().await.map_err(Into::into)
402 }
403
404 pub async fn option_expiration_chain(
405 &self,
406 symbol: String,
407 ) -> anyhow::Result<ExpirationChainResponse> {
408 let url = format!("{}/expirationchain", SCHWAB_MARKET_DATA_API_URL);
409 let params = parse_params(vec![("symbol", Some(symbol))]);
410
411 let builder = self.reqwest_client.get(url).query(¶ms);
412 let response = self.send_request(builder).await?;
413 response.json().await.map_err(Into::into)
414 }
415
416 pub async fn price_history(
417 &self,
418 symbol: String,
419 period_type: Option<PeriodType>,
420 period: Option<u64>,
421 frequency_type: Option<FrequencyType>,
422 frequency: Option<u64>,
423 start_date: Option<DateTime<Utc>>,
424 end_date: Option<DateTime<Utc>>,
425 need_extended_hours_data: Option<bool>,
426 need_previous_close: Option<bool>,
427 ) -> anyhow::Result<PriceHistoryResponse> {
428 let url = format!("{}/pricehistory", SCHWAB_MARKET_DATA_API_URL);
429
430 let params = parse_params(vec![
431 ("symbol", Some(symbol)),
432 ("periodType", period_type.map(|p| p.to_string())),
433 ("period", period.map(|p| p.to_string())),
434 ("frequencyType", frequency_type.map(|f| f.to_string())),
435 ("frequency", frequency.map(|f| f.to_string())),
436 ("startDate", time_to_epoch_ms(start_date)),
437 ("endDate", time_to_epoch_ms(end_date)),
438 (
439 "needExtendedHoursData",
440 need_extended_hours_data.map(|b| b.to_string()),
441 ),
442 (
443 "needPreviousClose",
444 need_previous_close.map(|b| b.to_string()),
445 ),
446 ]);
447
448 let builder = self.reqwest_client.get(url).query(¶ms);
449 let response = self.send_request(builder).await?;
450 response.json().await.map_err(Into::into)
451 }
452
453 pub async fn movers(
454 &self,
455 symbol: String,
456 sort: Option<Sort>,
457 frequency: Option<u64>,
458 ) -> anyhow::Result<MoversResponse> {
459 let url = format!("{}/movers/{}", SCHWAB_MARKET_DATA_API_URL, encode(&symbol));
460 let params = parse_params(vec![
461 ("sort", sort.map(|s| s.to_string())),
462 ("frequency", frequency.map(|f| f.to_string())),
463 ]);
464
465 let builder = self.reqwest_client.get(url).query(¶ms);
466 let response = self.send_request(builder).await?;
467 response.json().await.map_err(Into::into)
468 }
469
470 pub async fn market_hours(
471 &self,
472 symbols: Vec<MarketSymbol>,
473 date: Option<DateTime<Utc>>,
474 ) -> anyhow::Result<MarketHoursResponse> {
475 let url = format!("{}/markets", SCHWAB_MARKET_DATA_API_URL);
476
477 let symbols_string = symbols
478 .iter()
479 .map(|s| s.to_string())
480 .collect::<Vec<String>>()
481 .join(",");
482
483 let params = parse_params(vec![
484 ("markets", Some(symbols_string)),
485 ("date", time_to_yyyymmdd(date)),
486 ]);
487
488 let builder = self.reqwest_client.get(url).query(¶ms);
489 let response = self.send_request(builder).await?;
490 response.json().await.map_err(Into::into)
491 }
492
493 pub async fn market_hour(
494 &self,
495 market_id: MarketSymbol,
496 date: Option<DateTime<Utc>>,
497 ) -> anyhow::Result<MarketHours> {
498 let url = format!(
499 "{}/markets/{}",
500 SCHWAB_MARKET_DATA_API_URL,
501 market_id.to_string()
502 );
503
504 let params = parse_params(vec![("date", time_to_yyyymmdd(date))]);
505
506 let builder = self.reqwest_client.get(url).query(¶ms);
507 let response = self.send_request(builder).await?;
508
509 let mut response_map: MarketHoursResponse = response.json().await?;
512 let market_hours = response_map
513 .into_values()
514 .next()
515 .ok_or_else(|| anyhow::anyhow!("Market hours response was empty"))?;
516 Ok(market_hours)
517 }
518
519 pub async fn instruments(
520 &self,
521 symbol: String,
522 projection: Projection,
523 ) -> anyhow::Result<InstrumentsResponse> {
524 let url = format!("{}/instruments", SCHWAB_MARKET_DATA_API_URL);
525
526 let params = parse_params(vec![
527 ("symbol", Some(symbol)),
528 ("projection", Some(projection.to_string())),
529 ]);
530
531 let builder = self.reqwest_client.get(url).query(¶ms);
532 let response = self.send_request(builder).await?;
533 response.json().await.map_err(Into::into)
534 }
535
536 pub async fn instrument_cusip(&self, cusip_id: String) -> anyhow::Result<InstrumentsResponse> {
537 let url = format!(
538 "{}/instruments/{}",
539 SCHWAB_MARKET_DATA_API_URL,
540 encode(&cusip_id)
541 );
542
543 let builder = self.reqwest_client.get(url);
544 let response = self.send_request(builder).await?;
545 response.json().await.map_err(Into::into)
546 }
547
548 pub(crate) async fn token_info(&self) -> StoredTokenInfo {
549 self.token_info.lock().await.clone()
550 }
551}