brokerage_api/schwab/
schwab_api.rs

1use std::{env, fmt, sync::Arc};
2
3use chrono::{DateTime, Utc};
4use reqwest::{header::HeaderMap, Client, RequestBuilder, Response, StatusCode};
5use tokio::sync::Mutex;
6use tracing::info;
7use urlencoding::encode;
8
9use crate::{
10    schwab::{
11        common::{SCHWAB_MARKET_DATA_API_URL, SCHWAB_TRADER_API_URL, TOKENS_FILE},
12        models::{
13            market_data::{
14                ChainsResponse, ExpirationChainResponse, InstrumentsResponse, MarketHours,
15                MarketHoursResponse, MoversResponse, PriceHistoryResponse, QuotesResponse,
16            },
17            trader::UserPreferencesResponse,
18        },
19        schwab_auth::{SchwabAuth, StoredTokenInfo},
20    },
21    util::{dedup_ordered, parse_params, time_to_epoch_ms, time_to_yyyymmdd},
22};
23
24/// Represents the type of contract for an options chain.
25pub enum ContractType {
26    /// Call options.
27    Call,
28    /// Put options.
29    Put,
30    /// All options (both call and put).
31    All,
32}
33
34impl fmt::Display for ContractType {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        match self {
37            ContractType::All => write!(f, "ALL"),
38            ContractType::Call => write!(f, "CALL"),
39            ContractType::Put => write!(f, "PUT"),
40        }
41    }
42}
43
44/// Represents the fields to be returned in a quote.
45#[derive(Eq, PartialEq, Hash, Clone)]
46pub enum QuoteFields {
47    /// Quote data.
48    Quote,
49    /// Fundamental data.
50    Fundamental,
51    /// Extended data.
52    Extended,
53    /// Reference data.
54    Reference,
55    /// Regular data.
56    Regular,
57}
58
59impl fmt::Display for QuoteFields {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            QuoteFields::Quote => write!(f, "quote"),
63            QuoteFields::Fundamental => write!(f, "fundamental"),
64            QuoteFields::Extended => write!(f, "extended"),
65            QuoteFields::Reference => write!(f, "reference"),
66            QuoteFields::Regular => write!(f, "regular"),
67        }
68    }
69}
70
71/// Represents the period type for price history.
72#[derive(Eq, PartialEq, Hash, Clone)]
73pub enum PeriodType {
74    /// Day period type.
75    Day,
76    /// Month period type.
77    Month,
78    /// Year period type.
79    Year,
80    /// Year to date period type.
81    Ytd,
82}
83
84impl fmt::Display for PeriodType {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        match self {
87            PeriodType::Day => write!(f, "day"),
88            PeriodType::Month => write!(f, "month"),
89            PeriodType::Year => write!(f, "year"),
90            PeriodType::Ytd => write!(f, "ytd"),
91        }
92    }
93}
94
95/// Represents the frequency type for price history.
96#[derive(Eq, PartialEq, Hash, Clone)]
97pub enum FrequencyType {
98    /// Minute frequency type.
99    Minute,
100    /// Daily frequency type.
101    Daily,
102    /// Weekly frequency type.
103    Weekly,
104    /// Monthly frequency type.
105    Monthly,
106}
107
108impl fmt::Display for FrequencyType {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        match self {
111            FrequencyType::Minute => write!(f, "minute"),
112            FrequencyType::Daily => write!(f, "daily"),
113            FrequencyType::Weekly => write!(f, "weekly"),
114            FrequencyType::Monthly => write!(f, "monthly"),
115        }
116    }
117}
118
119/// Represents the sort order for movers.
120#[derive(Eq, PartialEq, Hash, Clone)]
121pub enum Sort {
122    /// Sort by volume.
123    Volume,
124    /// Sort by trades.
125    Trades,
126    /// Sort by percent change up.
127    PercentChangeUp,
128    /// Sort by percent change down.
129    PercentChangeDown,
130}
131
132impl fmt::Display for Sort {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        match self {
135            Sort::Volume => write!(f, "VOLUME"),
136            Sort::Trades => write!(f, "TRADES"),
137            Sort::PercentChangeUp => write!(f, "PERCENT_CHANGE_UP"),
138            Sort::PercentChangeDown => write!(f, "PERCENT_CHANGE_DOWN"),
139        }
140    }
141}
142
143/// Represents the projection type for instruments.
144#[derive(Eq, PartialEq, Hash, Clone)]
145pub enum Projection {
146    /// Symbol search projection.
147    SymbolSearch,
148    /// Symbol regex projection.
149    SymbolRegex,
150    /// Description search projection.
151    DescSearch,
152    /// Description regex projection.
153    DescRegex,
154    /// Search projection.
155    Search,
156    /// Fundamental projection.
157    Fundamental,
158}
159
160impl fmt::Display for Projection {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        match self {
163            Projection::SymbolSearch => write!(f, "symbol-search"),
164            Projection::SymbolRegex => write!(f, "symbol-regex"),
165            Projection::DescSearch => write!(f, "desc-search"),
166            Projection::DescRegex => write!(f, "desc-regex"),
167            Projection::Search => write!(f, "search"),
168            Projection::Fundamental => write!(f, "fundamental"),
169        }
170    }
171}
172
173/// Represents the market symbols for market hours.
174#[derive(Eq, PartialEq, Hash, Clone)]
175pub enum MarketSymbol {
176    /// Equity market.
177    Equity,
178    /// Option market.
179    Option,
180    /// Bond market.
181    Bond,
182    /// Future market.
183    Future,
184    /// Forex market.
185    Forex,
186}
187
188impl fmt::Display for MarketSymbol {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        match self {
191            MarketSymbol::Equity => write!(f, "equity"),
192            MarketSymbol::Option => write!(f, "option"),
193            MarketSymbol::Bond => write!(f, "bond"),
194            MarketSymbol::Future => write!(f, "future"),
195            MarketSymbol::Forex => write!(f, "forex"),
196        }
197    }
198}
199
200/// A client for interacting with the Schwab API, with automatic token refreshing.
201#[derive(Debug, Clone)]
202pub struct SchwabApi {
203    reqwest_client: Arc<Client>,
204    app_key: String,
205    app_secret: String,
206    tokens_file_path: String,
207    auth: SchwabAuth,
208    token_info: Arc<Mutex<StoredTokenInfo>>,
209}
210
211impl SchwabApi {
212    /// Creates a new `SchwabApi` instance.
213    ///
214    /// # Arguments
215    /// * `app_key` - Your Schwab application key (Client ID).
216    /// * `app_secret` - Your Schwab application secret (Client Secret).
217    /// * `tokens_file_path` - The path to where the `tokens.json` file is stored.
218    pub async fn new(
219        app_key: String,
220        app_secret: String,
221        tokens_file_path: String,
222    ) -> anyhow::Result<Self> {
223        let reqwest_client = Arc::new(Client::new());
224        let auth = SchwabAuth::new(reqwest_client.clone(), tokens_file_path.clone());
225
226        let json_string = tokio::fs::read_to_string(&tokens_file_path).await?;
227        let token_info: StoredTokenInfo = serde_json::from_str(&json_string)?;
228
229        Ok(Self {
230            reqwest_client,
231            app_key,
232            app_secret,
233            tokens_file_path,
234            auth,
235            token_info: Arc::new(Mutex::new(token_info)),
236        })
237    }
238
239    /// Creates a new `SchwabApi` instance with default settings, loading credentials from environment variables.
240    /// It expects `SCHWAB_APP_KEY` and `SCHWAB_APP_SECRET` to be set.
241    pub async fn default() -> anyhow::Result<Self> {
242        let app_key = env::var("SCHWAB_APP_KEY")
243            .map_err(|_| anyhow::anyhow!("SCHWAB_APP_KEY environment variable not set"))?;
244        let app_secret = env::var("SCHWAB_APP_SECRET")
245            .map_err(|_| anyhow::anyhow!("SCHWAB_APP_SECRET environment variable not set"))?;
246        Self::new(app_key, app_secret, TOKENS_FILE.to_owned()).await
247    }
248
249    /// Centralized request sender that handles authentication and token refreshing.
250    async fn send_request(&self, mut builder: RequestBuilder) -> anyhow::Result<Response> {
251        // Sign the request with the current access token from memory
252        let headers = self.construct_request_headers().await?;
253        builder = builder.headers(headers);
254
255        // Clone the request before sending, so we can retry it if the token is expired
256        let retry_builder = builder
257            .try_clone()
258            .ok_or_else(|| anyhow::anyhow!("Failed to clone request for potential retry"))?;
259
260        // Send the initial request
261        let response = builder.send().await?;
262
263        // Check if the token expired (401 Unauthorized)
264        if response.status() == StatusCode::UNAUTHORIZED {
265            info!("Token expired. Attempting to refresh...");
266            self.refresh_and_store_token().await?;
267
268            // Re-sign the cloned request with the new token
269            let retry_headers = self.construct_request_headers().await?;
270            let response = retry_builder.headers(retry_headers).send().await?;
271
272            info!("Request successful after token refresh.");
273            return Ok(response);
274        }
275
276        Ok(response)
277    }
278
279    /// Refreshes the token, updates the in-memory store, and writes the new token to the file.
280    async fn refresh_and_store_token(&self) -> anyhow::Result<()> {
281        let refresh_token = {
282            let token_data = self.token_info.lock().await;
283            token_data.refresh_token.clone()
284        };
285
286        let new_token_info = self
287            .auth
288            .refresh_tokens(&self.app_key, &self.app_secret, &refresh_token)
289            .await?;
290
291        // Update the in-memory token
292        {
293            let mut token_data = self.token_info.lock().await;
294            *token_data = new_token_info.clone();
295        }
296
297        // Persist the new token to the file for future sessions
298        let json_string = serde_json::to_string_pretty(&new_token_info)?;
299        tokio::fs::write(&self.tokens_file_path, json_string).await?;
300        info!("Successfully refreshed and stored new token.");
301
302        Ok(())
303    }
304
305    /// Constructs the request headers from the in-memory token.
306    async fn construct_request_headers(&self) -> anyhow::Result<HeaderMap> {
307        let mut headers = HeaderMap::new();
308
309        let token_data = self.token_info.lock().await;
310        let auth_header = format!("Bearer {}", token_data.access_token);
311        headers.insert("Authorization", auth_header.parse()?);
312
313        Ok(headers)
314    }
315
316    pub async fn get_preferences(&self) -> anyhow::Result<UserPreferencesResponse> {
317        let builder = self
318            .reqwest_client
319            .get(format!("{SCHWAB_TRADER_API_URL}/userPreference"));
320
321        let response = self.send_request(builder).await?;
322        response.json().await.map_err(Into::into)
323    }
324
325    pub async fn get_quotes(
326        &self,
327        symbols: Vec<String>,
328        fields: Option<Vec<QuoteFields>>,
329        indicative: Option<bool>,
330    ) -> anyhow::Result<QuotesResponse> {
331        let url = format!("{}/quotes", SCHWAB_MARKET_DATA_API_URL);
332
333        let params = parse_params(vec![
334            ("symbols", Some(symbols.join(","))),
335            (
336                "fields",
337                fields.map(|v| {
338                    dedup_ordered(v)
339                        .iter()
340                        .map(|f| f.to_string())
341                        .collect::<Vec<String>>()
342                        .join(",")
343                }),
344            ),
345            ("indicative", indicative.map(|v| v.to_string().to_lowercase())),
346        ]);
347
348        let builder = self.reqwest_client.get(url).query(&params);
349        let response = self.send_request(builder).await?;
350        response.json().await.map_err(Into::into)
351    }
352
353    pub async fn get_chains(
354        &self,
355        symbol: String,
356        contract_type: ContractType,
357        strike_count: u64,
358        include_underlying_quote: bool,
359    ) -> anyhow::Result<ChainsResponse> {
360        let url = format!("{}/chains", SCHWAB_MARKET_DATA_API_URL);
361
362        let params = parse_params(vec![
363            ("symbol", Some(symbol)),
364            ("contractType", Some(contract_type.to_string())),
365            ("strikeCount", Some(strike_count.to_string())),
366            (
367                "includeUnderlyingQuote",
368                Some(include_underlying_quote.to_string()),
369            ),
370        ]);
371
372        let builder = self.reqwest_client.get(url).query(&params);
373        let response = self.send_request(builder).await?;
374        response.json().await.map_err(Into::into)
375    }
376
377    pub async fn quote(
378        &self,
379        symbol_id: String,
380        fields: Option<Vec<QuoteFields>>,
381    ) -> anyhow::Result<QuotesResponse> {
382        let url = format!(
383            "{}/{}/quotes",
384            SCHWAB_MARKET_DATA_API_URL,
385            encode(&symbol_id)
386        );
387
388        let params = parse_params(vec![(
389            "fields",
390            fields.map(|v| {
391                dedup_ordered(v)
392                    .iter()
393                    .map(|f| f.to_string())
394                    .collect::<Vec<String>>()
395                    .join(",")
396            }),
397        )]);
398
399        let builder = self.reqwest_client.get(url).query(&params);
400        let response = self.send_request(builder).await?;
401        response.json().await.map_err(Into::into)
402    }
403
404    pub async fn option_expiration_chain(
405        &self,
406        symbol: String,
407    ) -> anyhow::Result<ExpirationChainResponse> {
408        let url = format!("{}/expirationchain", SCHWAB_MARKET_DATA_API_URL);
409        let params = parse_params(vec![("symbol", Some(symbol))]);
410
411        let builder = self.reqwest_client.get(url).query(&params);
412        let response = self.send_request(builder).await?;
413        response.json().await.map_err(Into::into)
414    }
415
416    pub async fn price_history(
417        &self,
418        symbol: String,
419        period_type: Option<PeriodType>,
420        period: Option<u64>,
421        frequency_type: Option<FrequencyType>,
422        frequency: Option<u64>,
423        start_date: Option<DateTime<Utc>>,
424        end_date: Option<DateTime<Utc>>,
425        need_extended_hours_data: Option<bool>,
426        need_previous_close: Option<bool>,
427    ) -> anyhow::Result<PriceHistoryResponse> {
428        let url = format!("{}/pricehistory", SCHWAB_MARKET_DATA_API_URL);
429
430        let params = parse_params(vec![
431            ("symbol", Some(symbol)),
432            ("periodType", period_type.map(|p| p.to_string())),
433            ("period", period.map(|p| p.to_string())),
434            ("frequencyType", frequency_type.map(|f| f.to_string())),
435            ("frequency", frequency.map(|f| f.to_string())),
436            ("startDate", time_to_epoch_ms(start_date)),
437            ("endDate", time_to_epoch_ms(end_date)),
438            (
439                "needExtendedHoursData",
440                need_extended_hours_data.map(|b| b.to_string()),
441            ),
442            (
443                "needPreviousClose",
444                need_previous_close.map(|b| b.to_string()),
445            ),
446        ]);
447
448        let builder = self.reqwest_client.get(url).query(&params);
449        let response = self.send_request(builder).await?;
450        response.json().await.map_err(Into::into)
451    }
452
453    pub async fn movers(
454        &self,
455        symbol: String,
456        sort: Option<Sort>,
457        frequency: Option<u64>,
458    ) -> anyhow::Result<MoversResponse> {
459        let url = format!("{}/movers/{}", SCHWAB_MARKET_DATA_API_URL, encode(&symbol));
460        let params = parse_params(vec![
461            ("sort", sort.map(|s| s.to_string())),
462            ("frequency", frequency.map(|f| f.to_string())),
463        ]);
464
465        let builder = self.reqwest_client.get(url).query(&params);
466        let response = self.send_request(builder).await?;
467        response.json().await.map_err(Into::into)
468    }
469
470    pub async fn market_hours(
471        &self,
472        symbols: Vec<MarketSymbol>,
473        date: Option<DateTime<Utc>>,
474    ) -> anyhow::Result<MarketHoursResponse> {
475        let url = format!("{}/markets", SCHWAB_MARKET_DATA_API_URL);
476
477        let symbols_string = symbols
478            .iter()
479            .map(|s| s.to_string())
480            .collect::<Vec<String>>()
481            .join(",");
482
483        let params = parse_params(vec![
484            ("markets", Some(symbols_string)),
485            ("date", time_to_yyyymmdd(date)),
486        ]);
487
488        let builder = self.reqwest_client.get(url).query(&params);
489        let response = self.send_request(builder).await?;
490        response.json().await.map_err(Into::into)
491    }
492
493    pub async fn market_hour(
494        &self,
495        market_id: MarketSymbol,
496        date: Option<DateTime<Utc>>,
497    ) -> anyhow::Result<MarketHours> {
498        let url = format!(
499            "{}/markets/{}",
500            SCHWAB_MARKET_DATA_API_URL,
501            market_id.to_string()
502        );
503
504        let params = parse_params(vec![("date", time_to_yyyymmdd(date))]);
505
506        let builder = self.reqwest_client.get(url).query(&params);
507        let response = self.send_request(builder).await?;
508
509        // The API wraps the single response in a map with the market name as the key.
510        // We find the first value in the map and return it.
511        let mut response_map: MarketHoursResponse = response.json().await?;
512        let market_hours = response_map
513            .into_values()
514            .next()
515            .ok_or_else(|| anyhow::anyhow!("Market hours response was empty"))?;
516        Ok(market_hours)
517    }
518
519    pub async fn instruments(
520        &self,
521        symbol: String,
522        projection: Projection,
523    ) -> anyhow::Result<InstrumentsResponse> {
524        let url = format!("{}/instruments", SCHWAB_MARKET_DATA_API_URL);
525
526        let params = parse_params(vec![
527            ("symbol", Some(symbol)),
528            ("projection", Some(projection.to_string())),
529        ]);
530
531        let builder = self.reqwest_client.get(url).query(&params);
532        let response = self.send_request(builder).await?;
533        response.json().await.map_err(Into::into)
534    }
535
536    pub async fn instrument_cusip(&self, cusip_id: String) -> anyhow::Result<InstrumentsResponse> {
537        let url = format!(
538            "{}/instruments/{}",
539            SCHWAB_MARKET_DATA_API_URL,
540            encode(&cusip_id)
541        );
542
543        let builder = self.reqwest_client.get(url);
544        let response = self.send_request(builder).await?;
545        response.json().await.map_err(Into::into)
546    }
547
548    pub(crate) async fn token_info(&self) -> StoredTokenInfo {
549        self.token_info.lock().await.clone()
550    }
551}