brokerage_api/schwab/
schwab_api.rs

1use std::{fmt, sync::Arc};
2
3use chrono::{DateTime, Utc};
4use reqwest::{
5    Client,
6    header::{HeaderMap, HeaderValue},
7};
8use serde_json::Value;
9use urlencoding::encode;
10
11use crate::{schwab::schwab_auth::StoredTokenInfo, util::dedup_ordered};
12
13/// Represents the type of contract for an options chain.
14pub enum ContractType {
15    /// Call options.
16    Call,
17    /// Put options.
18    Put,
19    /// All options (both call and put).
20    All,
21}
22
23impl fmt::Display for ContractType {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            ContractType::All => write!(f, "ALL"),
27            ContractType::Call => write!(f, "CALL"),
28            ContractType::Put => write!(f, "PUT"),
29        }
30    }
31}
32
33/// Represents the fields to be returned in a quote.
34#[derive(Eq, PartialEq, Hash, Clone)]
35pub enum QuoteFields {
36    /// Quote data.
37    Quote,
38    /// Fundamental data.
39    Fundamental,
40    /// Extended data.
41    Extended,
42    /// Reference data.
43    Reference,
44    /// Regular data.
45    Regular,
46}
47
48impl fmt::Display for QuoteFields {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            QuoteFields::Quote => write!(f, "quote"),
52            QuoteFields::Fundamental => write!(f, "fundamental"),
53            QuoteFields::Extended => write!(f, "extended"),
54            QuoteFields::Reference => write!(f, "reference"),
55            QuoteFields::Regular => write!(f, "regular"),
56        }
57    }
58}
59
60/// Represents the period type for price history.
61#[derive(Eq, PartialEq, Hash, Clone)]
62pub enum PeriodType {
63    /// Day period type.
64    Day,
65    /// Month period type.
66    Month,
67    /// Year period type.
68    Year,
69    /// Year to date period type.
70    Ytd,
71}
72
73impl fmt::Display for PeriodType {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        match self {
76            PeriodType::Day => write!(f, "day"),
77            PeriodType::Month => write!(f, "month"),
78            PeriodType::Year => write!(f, "year"),
79            PeriodType::Ytd => write!(f, "ytd"),
80        }
81    }
82}
83
84/// Represents the frequency type for price history.
85#[derive(Eq, PartialEq, Hash, Clone)]
86pub enum FrequencyType {
87    /// Minute frequency type.
88    Minute,
89    /// Daily frequency type.
90    Daily,
91    /// Weekly frequency type.
92    Weekly,
93    /// Monthly frequency type.
94    Monthly,
95}
96
97impl fmt::Display for FrequencyType {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        match self {
100            FrequencyType::Minute => write!(f, "minute"),
101            FrequencyType::Daily => write!(f, "daily"),
102            FrequencyType::Weekly => write!(f, "weekly"),
103            FrequencyType::Monthly => write!(f, "monthly"),
104        }
105    }
106}
107
108/// Represents the sort order for movers.
109#[derive(Eq, PartialEq, Hash, Clone)]
110pub enum Sort {
111    /// Sort by volume.
112    Volume,
113    /// Sort by trades.
114    Trades,
115    /// Sort by percent change up.
116    PercentChangeUp,
117    /// Sort by percent change down.
118    PercentChangeDown,
119}
120
121impl fmt::Display for Sort {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        match self {
124            Sort::Volume => write!(f, "VOLUME"),
125            Sort::Trades => write!(f, "TRADES"),
126            Sort::PercentChangeUp => write!(f, "PERCENT_CHANGE_UP"),
127            Sort::PercentChangeDown => write!(f, "PERCENT_CHANGE_DOWN"),
128        }
129    }
130}
131
132/// Represents the projection type for instruments.
133#[derive(Eq, PartialEq, Hash, Clone)]
134pub enum Projection {
135    /// Symbol search projection.
136    SymbolSearch,
137    /// Symbol regex projection.
138    SymbolRegex,
139    /// Description search projection.
140    DescSearch,
141    /// Description regex projection.
142    DescRegex,
143    /// Search projection.
144    Search,
145    /// Fundamental projection.
146    Fundamental,
147}
148
149impl fmt::Display for Projection {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        match self {
152            Projection::SymbolSearch => write!(f, "symbol-search"),
153            Projection::SymbolRegex => write!(f, "symbol-regex"),
154            Projection::DescSearch => write!(f, "desc-search"),
155            Projection::DescRegex => write!(f, "desc-regex"),
156            Projection::Search => write!(f, "search"),
157            Projection::Fundamental => write!(f, "fundamental"),
158        }
159    }
160}
161
162/// Represents the market symbols for market hours.
163#[derive(Eq, PartialEq, Hash, Clone)]
164pub enum MarketSymbol {
165    /// Equity market.
166    Equity,
167    /// Option market.
168    Option,
169    /// Bond market.
170    Bond,
171    /// Future market.
172    Future,
173    /// Forex market.
174    Forex,
175}
176
177impl fmt::Display for MarketSymbol {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        match self {
180            MarketSymbol::Equity => write!(f, "equity"),
181            MarketSymbol::Option => write!(f, "option"),
182            MarketSymbol::Bond => write!(f, "bond"),
183            MarketSymbol::Future => write!(f, "future"),
184            MarketSymbol::Forex => write!(f, "forex"),
185        }
186    }
187}
188
189/// A client for interacting with the Schwab API.
190pub struct SchwabApi {
191    reqwest_client: Arc<Client>,
192    base_url: String,
193    tokens_file_path: String,
194}
195
196impl SchwabApi {
197    /// Creates a new `SchwabApi` instance.
198    ///
199    /// # Arguments
200    ///
201    /// * `reqwest_client` - An `Arc` wrapped `reqwest::Client` to be used for making HTTP requests.
202    /// * `base_url` - The base URL for the Schwab API.
203    /// * `tokens_file_path` - The path to the file where tokens are stored.
204    ///
205    /// # Returns
206    ///
207    /// A new `SchwabApi` instance.
208    pub fn new(reqwest_client: Arc<Client>, base_url: String, tokens_file_path: String) -> Self {
209        Self {
210            reqwest_client,
211            base_url,
212            tokens_file_path,
213        }
214    }
215
216    /// Parses optional parameters into a `Vec` of `(String, String)` tuples.
217    /// Filters out `None` values and converts `Some` values to strings.
218    fn parse_params<T: ToString>(params: Vec<(&str, Option<T>)>) -> Vec<(String, String)> {
219        params
220            .into_iter()
221            .filter_map(|(key, value)| value.map(|v| (key.to_string(), v.to_string())))
222            .collect()
223    }
224
225    /// Converts a `DateTime<Utc>` or `String` to an epoch timestamp in milliseconds.
226    fn time_to_epoch_ms(date: Option<DateTime<Utc>>) -> Option<String> {
227        date.map(|d| d.timestamp_millis().to_string())
228    }
229
230    /// Converts a `DateTime<Utc>` to a "YYYY-MM-DD" string.
231    fn time_to_yyyymmdd(date: Option<DateTime<Utc>>) -> Option<String> {
232        date.map(|d| d.format("%Y-%m-%d").to_string())
233    }
234
235    /// Gets quotes for a list of symbols.
236    ///
237    /// # Arguments
238    ///
239    /// * `symbols` - A `Vec` of `String`s representing the symbols to get quotes for.
240    /// * `fields` - An `Option`al `Vec` of `QuoteFields` to be returned in the quote.
241    /// * `indicative` - An `Option`al `bool` indicating whether to return indicative quotes.
242    ///
243    /// # Returns
244    ///
245    /// An empty `Result` indicating success or failure.
246    pub async fn get_quotes(
247        &self,
248        symbols: Vec<String>,
249        fields: Option<Vec<QuoteFields>>,
250        indicative: Option<bool>,
251    ) -> anyhow::Result<Value, anyhow::Error> {
252        let symbols_string = symbols.join(",");
253        let fields_string = match fields {
254            Some(v) => dedup_ordered(v)
255                .iter()
256                .map(|f| f.to_string())
257                .collect::<Vec<String>>()
258                .join(","),
259            None => "".to_owned(),
260        };
261        let indicative_string = match indicative {
262            Some(v) => v.to_string().to_lowercase(),
263            None => "".to_owned(),
264        };
265
266        let headers = self.construct_request_headers().await?;
267
268        let request_url = format!(
269            "{}/quotes?symbols={}&fields={}&indicative={}",
270            self.base_url, symbols_string, fields_string, indicative_string
271        );
272        let response = self
273            .reqwest_client
274            .get(request_url)
275            .headers(headers)
276            .send()
277            .await?;
278
279        let response_json = serde_json::from_str(response.text().await?.as_str())?;
280        Ok(response_json)
281    }
282
283    /// Gets an options chain for a symbol.
284    ///
285    /// # Arguments
286    ///
287    /// * `symbol` - The symbol to get the options chain for.
288    /// * `contract_type` - The type of contract to get.
289    /// * `strike_count` - The number of strikes to return.
290    /// * `include_underlying_quote` - Whether to include the underlying quote in the response.
291    ///
292    /// # Returns
293    ///
294    /// An empty `Result` indicating success or failure.
295    pub async fn get_chains(
296        &self,
297        symbol: String,
298        contract_type: ContractType,
299        strike_count: u64,
300        include_underlying_quote: bool,
301    ) -> anyhow::Result<Value, anyhow::Error> {
302        let headers = self.construct_request_headers().await?;
303
304        let request_url = format!(
305            "{}/chains?symbol={}&contractType={}&strikeCount={}&includeUnderlyingQuote={}",
306            self.base_url,
307            symbol,
308            contract_type.to_string(),
309            strike_count.to_string(),
310            include_underlying_quote.to_string()
311        );
312        let response = self
313            .reqwest_client
314            .get(request_url)
315            .headers(headers)
316            .send()
317            .await?;
318
319        let response_json = serde_json::from_str(response.text().await?.as_str())?;
320        Ok(response_json)
321    }
322
323    /// Get quote for a single symbol.
324    ///
325    /// # Arguments
326    ///
327    /// * `symbol_id` - Ticker symbol.
328    /// * `fields` - Fields to get ("all", "quote", "fundamental").
329    ///
330    /// # Returns
331    ///
332    /// `Result` containing quote for a single symbol or an error.
333    pub async fn quote(
334        &self,
335        symbol_id: String,
336        fields: Option<Vec<QuoteFields>>,
337    ) -> anyhow::Result<Value, anyhow::Error> {
338        let headers = self.construct_request_headers().await?;
339
340        let fields_string = match fields {
341            Some(v) => dedup_ordered(v)
342                .iter()
343                .map(|f| f.to_string())
344                .collect::<Vec<String>>()
345                .join(","),
346            None => "".to_owned(),
347        };
348
349        let params = SchwabApi::parse_params(vec![("fields", Some(fields_string))]);
350
351        let request_url = format!("{}/{}/quotes", self.base_url, encode(&symbol_id));
352        let response = self
353            .reqwest_client
354            .get(request_url)
355            .headers(headers)
356            .query(&params)
357            .send()
358            .await?;
359
360        let response_json = serde_json::from_str(response.text().await?.as_str())?;
361        Ok(response_json)
362    }
363
364    /// Get an option expiration chain for a ticker.
365    ///
366    /// # Arguments
367    ///
368    /// * `symbol` - Ticker symbol.
369    ///
370    /// # Returns
371    ///
372    /// `Result` containing option expiration chain or an error.
373    pub async fn option_expiration_chain(
374        &self,
375        symbol: String,
376    ) -> anyhow::Result<Value, anyhow::Error> {
377        let headers = self.construct_request_headers().await?;
378
379        let params = SchwabApi::parse_params(vec![("symbol", Some(symbol))]);
380
381        let request_url = format!("{}/expirationchain", self.base_url);
382        let response = self
383            .reqwest_client
384            .get(request_url)
385            .headers(headers)
386            .query(&params)
387            .send()
388            .await?;
389
390        let response_json = serde_json::from_str(response.text().await?.as_str())?;
391        Ok(response_json)
392    }
393
394    /// Get price history for a ticker.
395    ///
396    /// # Arguments
397    ///
398    /// * `symbol` - Ticker symbol.
399    /// * `period_type` - Period type ("day"|"month"|"year"|"ytd").
400    /// * `period` - Period.
401    /// * `frequency_type` - Frequency type ("minute"|"daily"|"weekly"|"monthly").
402    /// * `frequency` - Frequency (frequencyType: options), (minute: 1, 5, 10, 15, 30), (daily: 1), (weekly: 1), (monthly: 1).
403    /// * `start_date` - Start date.
404    /// * `end_date` - End date.
405    /// * `need_extended_hours_data` - Need extended hours data (True|False).
406    /// * `need_previous_close` - Need previous close (True|False).
407    ///
408    /// # Returns
409    ///
410    /// `Result` containing candle history or an error.
411    pub async fn price_history(
412        &self,
413        symbol: String,
414        period_type: Option<PeriodType>,
415        period: Option<u64>,
416        frequency_type: Option<FrequencyType>,
417        frequency: Option<u64>,
418        start_date: Option<DateTime<Utc>>,
419        end_date: Option<DateTime<Utc>>,
420        need_extended_hours_data: Option<bool>,
421        need_previous_close: Option<bool>,
422    ) -> anyhow::Result<Value, anyhow::Error> {
423        let headers = self.construct_request_headers().await?;
424
425        let params = SchwabApi::parse_params(vec![
426            ("symbol", Some(symbol)),
427            ("periodType", period_type.map(|p| p.to_string())),
428            ("period", period.map(|p| p.to_string())),
429            ("frequencyType", frequency_type.map(|f| f.to_string())),
430            ("frequency", frequency.map(|f| f.to_string())),
431            ("startDate", SchwabApi::time_to_epoch_ms(start_date)),
432            ("endDate", SchwabApi::time_to_epoch_ms(end_date)),
433            (
434                "needExtendedHoursData",
435                need_extended_hours_data.map(|b| b.to_string()),
436            ),
437            (
438                "needPreviousClose",
439                need_previous_close.map(|b| b.to_string()),
440            ),
441        ]);
442
443        let request_url = format!("{}/pricehistory", self.base_url);
444        let response = self
445            .reqwest_client
446            .get(request_url)
447            .headers(headers)
448            .query(&params)
449            .send()
450            .await?;
451
452        let response_json = serde_json::from_str(response.text().await?.as_str())?;
453        Ok(response_json)
454    }
455
456    /// Get movers in a specific index and direction.
457    ///
458    /// # Arguments
459    ///
460    /// * `symbol` - Symbol ("$DJI"|"$COMPX"|"$SPX"|"NYSE"|"NASDAQ"|"OTCBB"|"INDEX_ALL"|"EQUITY_ALL"|"OPTION_ALL"|"OPTION_PUT"|"OPTION_CALL").
461    /// * `sort` - Sort ("VOLUME"|"TRADES"|"PERCENT_CHANGE_UP"|"PERCENT_CHANGE_DOWN").
462    /// * `frequency` - Frequency (0|1|5|10|30|60).
463    ///
464    /// # Notes
465    ///
466    /// Must be called within market hours (there aren't really movers outside of market hours).
467    ///
468    /// # Returns
469    ///
470    /// `Result` containing movers or an error.
471    pub async fn movers(
472        &self,
473        symbol: String,
474        sort: Option<Sort>,
475        frequency: Option<u64>,
476    ) -> anyhow::Result<Value, anyhow::Error> {
477        let headers = self.construct_request_headers().await?;
478
479        let params = SchwabApi::parse_params(vec![
480            ("sort", sort.map(|s| s.to_string())),
481            ("frequency", frequency.map(|f| f.to_string())),
482        ]);
483
484        let request_url = format!("{}/movers/{}", self.base_url, encode(&symbol));
485        let response = self
486            .reqwest_client
487            .get(request_url)
488            .headers(headers)
489            .query(&params)
490            .send()
491            .await?;
492
493        let response_json = serde_json::from_str(response.text().await?.as_str())?;
494        Ok(response_json)
495    }
496
497    /// Get Market Hours for dates in the future across different markets.
498    ///
499    /// # Arguments
500    ///
501    /// * `symbols` - List of market symbols ("equity", "option", "bond", "future", "forex").
502    /// * `date` - Date.
503    ///
504    /// # Returns
505    ///
506    /// `Result` containing market hours or an error.
507    pub async fn market_hours(
508        &self,
509        symbols: Vec<MarketSymbol>,
510        date: Option<DateTime<Utc>>,
511    ) -> anyhow::Result<Value, anyhow::Error> {
512        let headers = self.construct_request_headers().await?;
513
514        let symbols_string = symbols
515            .iter()
516            .map(|s| s.to_string())
517            .collect::<Vec<String>>()
518            .join(",");
519
520        let params = SchwabApi::parse_params(vec![
521            ("markets", Some(symbols_string)),
522            ("date", SchwabApi::time_to_yyyymmdd(date)),
523        ]);
524
525        let request_url = format!("{}/markets", self.base_url);
526        let response = self
527            .reqwest_client
528            .get(request_url)
529            .headers(headers)
530            .query(&params)
531            .send()
532            .await?;
533
534        let response_json = serde_json::from_str(response.text().await?.as_str())?;
535        Ok(response_json)
536    }
537
538    /// Get Market Hours for dates in the future for a single market.
539    ///
540    /// # Arguments
541    ///
542    /// * `market_id` - Market id ("equity"|"option"|"bond"|"future"|"forex").
543    /// * `date` - Date.
544    ///
545    /// # Returns
546    ///
547    /// `Result` containing market hours or an error.
548    pub async fn market_hour(
549        &self,
550        market_id: MarketSymbol,
551        date: Option<DateTime<Utc>>,
552    ) -> anyhow::Result<Value, anyhow::Error> {
553        let headers = self.construct_request_headers().await?;
554
555        let params = SchwabApi::parse_params(vec![("date", SchwabApi::time_to_yyyymmdd(date))]);
556
557        let request_url = format!("{}/markets/{}", self.base_url, market_id.to_string());
558        let response = self
559            .reqwest_client
560            .get(request_url)
561            .headers(headers)
562            .query(&params)
563            .send()
564            .await?;
565
566        let response_json = serde_json::from_str(response.text().await?.as_str())?;
567        Ok(response_json)
568    }
569
570    /// Get instruments for a list of symbols.
571    ///
572    /// # Arguments
573    ///
574    /// * `symbol` - Symbol.
575    /// * `projection` - Projection ("symbol-search"|"symbol-regex"|"desc-search"|"desc-regex"|"search"|"fundamental").
576    ///
577    /// # Returns
578    ///
579    /// `Result` containing instruments or an error.
580    pub async fn instruments(
581        &self,
582        symbol: String,
583        projection: Projection,
584    ) -> anyhow::Result<Value, anyhow::Error> {
585        let headers = self.construct_request_headers().await?;
586
587        let params = SchwabApi::parse_params(vec![
588            ("symbol", Some(symbol)),
589            ("projection", Some(projection.to_string())),
590        ]);
591
592        let request_url = format!("{}/instruments", self.base_url);
593        let response = self
594            .reqwest_client
595            .get(request_url)
596            .headers(headers)
597            .query(&params)
598            .send()
599            .await?;
600
601        let response_json = serde_json::from_str(response.text().await?.as_str())?;
602        Ok(response_json)
603    }
604
605    /// Get instrument for a single cusip.
606    ///
607    /// # Arguments
608    ///
609    /// * `cusip_id` - Cusip id.
610    ///
611    /// # Returns
612    ///
613    /// `Result` containing instrument or an error.
614    pub async fn instrument_cusip(&self, cusip_id: String) -> anyhow::Result<Value, anyhow::Error> {
615        let headers = self.construct_request_headers().await?;
616
617        let request_url = format!("{}/instruments/{}", self.base_url, encode(&cusip_id));
618        let response = self
619            .reqwest_client
620            .get(request_url)
621            .headers(headers)
622            .send()
623            .await?;
624
625        let response_json = serde_json::from_str(response.text().await?.as_str())?;
626        Ok(response_json)
627    }
628
629    /// Constructs the request headers for a Schwab API request.
630    ///
631    /// # Returns
632    ///
633    /// A `HeaderMap` containing the required headers for a Schwab API request.
634    async fn construct_request_headers(&self) -> anyhow::Result<HeaderMap, anyhow::Error> {
635        let mut headers = HeaderMap::new();
636
637        let json_string = tokio::fs::read_to_string(&self.tokens_file_path).await?;
638        let data: StoredTokenInfo = serde_json::from_str(&json_string)?;
639        let auth_header = format!("Bearer {}", data.access_token.as_str());
640
641        headers.append("Accept", HeaderValue::from_str("application/json")?);
642        headers.append(
643            "Authorization",
644            HeaderValue::from_str(auth_header.as_str())?,
645        );
646
647        Ok(headers)
648    }
649}