apca/data/v2/
stream.rs

1// Copyright (C) 2021-2024 The apca Developers
2// SPDX-License-Identifier: GPL-3.0-or-later
3
4use std::borrow::Borrow as _;
5use std::borrow::Cow;
6use std::cmp::Ordering;
7use std::fmt::Debug;
8use std::marker::PhantomData;
9use std::ops::Deref;
10
11use async_trait::async_trait;
12
13use chrono::DateTime;
14use chrono::Utc;
15
16use futures::stream::Fuse;
17use futures::stream::FusedStream;
18use futures::stream::Map;
19use futures::stream::SplitSink;
20use futures::stream::SplitStream;
21use futures::Future;
22use futures::FutureExt as _;
23use futures::Sink;
24use futures::StreamExt as _;
25
26use num_decimal::Num;
27
28use serde::de::DeserializeOwned;
29use serde::de::Deserializer;
30use serde::ser::SerializeSeq as _;
31use serde::ser::Serializer;
32use serde::Deserialize;
33use serde::Serialize;
34use serde_json::from_slice as json_from_slice;
35use serde_json::from_str as json_from_str;
36use serde_json::to_string as to_json;
37use serde_json::Error as JsonError;
38
39use thiserror::Error as ThisError;
40
41use tokio::net::TcpStream;
42
43use tungstenite::MaybeTlsStream;
44use tungstenite::WebSocketStream;
45
46use url::Url;
47
48use websocket_util::subscribe;
49use websocket_util::subscribe::MessageStream;
50use websocket_util::tungstenite::Error as WebSocketError;
51use websocket_util::wrap;
52use websocket_util::wrap::Wrapper;
53
54use super::unfold::Unfold;
55
56use crate::subscribable::Subscribable;
57use crate::websocket::connect;
58use crate::websocket::MessageResult;
59use crate::ApiInfo;
60use crate::Error;
61use crate::Str;
62
63
64type UserMessage<B, Q, T> = <ParsedMessage<B, Q, T> as subscribe::Message>::UserMessage;
65
66/// Helper function to drive a [`Subscription`] related future to
67/// completion. The function makes sure to poll the provided stream,
68/// which is assumed to be associated with the `Subscription` that the
69/// future belongs to, so that control messages can be received.
70#[inline]
71pub async fn drive<F, S, B, Q, T>(
72  future: F,
73  stream: &mut S,
74) -> Result<F::Output, UserMessage<B, Q, T>>
75where
76  F: Future + Unpin,
77  S: FusedStream<Item = UserMessage<B, Q, T>> + Unpin,
78{
79  subscribe::drive::<ParsedMessage<B, Q, T>, _, _>(future, stream).await
80}
81
82
83mod private {
84  pub trait Sealed {}
85}
86
87
88#[doc(hidden)]
89#[derive(Clone, Debug)]
90pub enum SourceVariant {
91  /// The source provided is a path component to be appended to an
92  /// already present base URL.
93  PathComponent(&'static str),
94  /// The source provided is a complete URL.
95  Url(String),
96}
97
98
99/// A trait representing the source from which to stream real time data.
100// TODO: Once we can use enumerations as const generic parameters we
101//       should probably switch over to repurposing `data::v2::Feed`
102//       here instead.
103pub trait Source: private::Sealed {
104  /// Return the source.
105  #[doc(hidden)]
106  fn source() -> SourceVariant;
107}
108
109
110/// Use the Investors Exchange (IEX) as the data source.
111///
112/// This source is available unconditionally, i.e., with the free and
113/// unlimited plans.
114#[derive(Clone, Copy, Debug)]
115pub enum IEX {}
116
117impl Source for IEX {
118  #[inline]
119  fn source() -> SourceVariant {
120    SourceVariant::PathComponent("iex")
121  }
122}
123
124impl private::Sealed for IEX {}
125
126
127/// Use CTA (administered by NYSE) and UTP (administered by Nasdaq) SIPs
128/// as the data source.
129///
130/// This source is only usable with the unlimited market data plan.
131#[derive(Clone, Copy, Debug)]
132pub enum SIP {}
133
134impl Source for SIP {
135  #[inline]
136  fn source() -> SourceVariant {
137    SourceVariant::PathComponent("sip")
138  }
139}
140
141impl private::Sealed for SIP {}
142
143
144/// A realtime data source that uses a custom URL.
145///
146/// This type provides a way to stream realtime data from a custom URL.
147/// The endpoint at said URL has to follow the `v2` handshake and
148/// message protocol. Provided that is the case, usage could be as
149/// follows:
150/// ```no_run
151/// # use apca::ApiInfo;
152/// # use apca::Client;
153/// # use apca::data::v2::stream::CustomUrl;
154/// # use apca::data::v2::stream::RealtimeData;
155/// // The v1beta3 crypto API happens to be using the same handshake as
156/// // the v2 stock APIs.
157/// #[derive(Default)]
158/// struct Crypto;
159///
160/// impl ToString for Crypto {
161///   fn to_string(&self) -> String {
162///     "wss://stream.data.alpaca.markets/v1beta3/crypto/us".into()
163///   }
164/// }
165///
166/// let api_info = ApiInfo::from_env().unwrap();
167/// let client = Client::new(api_info);
168/// # tokio::runtime::Runtime::new().unwrap().block_on(async move {
169/// let (mut stream, mut subscription) = client
170///   .subscribe::<RealtimeData<CustomUrl<Crypto>>>()
171///   .await
172///   .unwrap();
173/// # })
174///
175/// // Use `subscription` to subscribe to quotes, trades, or bars, then
176/// // handle `stream` as usual.
177/// ```
178#[derive(Clone, Copy, Debug)]
179pub struct CustomUrl<URL> {
180  _phantom: PhantomData<URL>,
181}
182
183impl<URL> Source for CustomUrl<URL>
184where
185  URL: Default + ToString,
186{
187  #[inline]
188  fn source() -> SourceVariant {
189    let url = URL::default();
190    SourceVariant::Url(url.to_string())
191  }
192}
193
194impl<URL> private::Sealed for CustomUrl<URL> {}
195
196
197/// A symbol.
198pub type Symbol = Str;
199
200
201/// Check whether a slice of symbols is normalized.
202///
203/// Such a slice is normalized if it is sorted lexically and all
204/// duplicates are removed.
205fn is_normalized(symbols: &[Symbol]) -> bool {
206  // The body here is effectively a copy of `Iterator::is_sorted_by`. We
207  // should use that once it's stable.
208
209  #[inline]
210  fn check<'a>(last: &'a mut &'a Symbol) -> impl FnMut(&'a Symbol) -> bool + 'a {
211    move |curr| {
212      if let Some(Ordering::Greater) | None = PartialOrd::partial_cmp(last, &curr) {
213        return false
214      }
215      *last = curr;
216      true
217    }
218  }
219
220  let mut it = symbols.iter();
221  let mut last = match it.next() {
222    Some(e) => e,
223    None => return true,
224  };
225
226  it.all(check(&mut last))
227}
228
229
230/// Normalize a list of symbols.
231fn normalize(symbols: Cow<'static, [Symbol]>) -> Cow<'static, [Symbol]> {
232  fn normalize_now(symbols: Cow<'static, [Symbol]>) -> Cow<'static, [Symbol]> {
233    let mut symbols = symbols.into_owned();
234    // Unwrapping here is fine, as we know that there is no
235    // `Symbol::All` variant in the list and so we cannot encounter
236    // variants that are not comparable.
237    symbols.sort_by(|x, y| x.partial_cmp(y).unwrap());
238    symbols.dedup();
239    Cow::from(symbols)
240  }
241
242  if !is_normalized(&symbols) {
243    let symbols = normalize_now(symbols);
244    debug_assert!(is_normalized(&symbols));
245    symbols
246  } else {
247    symbols
248  }
249}
250
251
252/// Aggregate data for an equity.
253#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
254pub struct Bar {
255  /// The bar's symbol.
256  #[serde(rename = "S")]
257  pub symbol: String,
258  /// The bar's open price.
259  #[serde(rename = "o")]
260  pub open_price: Num,
261  /// The bar's high price.
262  #[serde(rename = "h")]
263  pub high_price: Num,
264  /// The bar's low price.
265  #[serde(rename = "l")]
266  pub low_price: Num,
267  /// The bar's close price.
268  #[serde(rename = "c")]
269  pub close_price: Num,
270  /// The bar's volume.
271  #[serde(rename = "v")]
272  pub volume: Num,
273  /// The bar's time stamp.
274  #[serde(rename = "t")]
275  pub timestamp: DateTime<Utc>,
276}
277
278
279/// A quote for an equity.
280#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
281pub struct Quote {
282  /// The quote's symbol.
283  #[serde(rename = "S")]
284  pub symbol: String,
285  /// The bid's price.
286  #[serde(rename = "bp")]
287  pub bid_price: Num,
288  /// The bid's size.
289  #[serde(rename = "bs")]
290  pub bid_size: Num,
291  /// The ask's price.
292  #[serde(rename = "ap")]
293  pub ask_price: Num,
294  /// The ask's size.
295  #[serde(rename = "as")]
296  pub ask_size: Num,
297  /// The quote's time stamp.
298  #[serde(rename = "t")]
299  pub timestamp: DateTime<Utc>,
300}
301
302
303/// A trade for an equity.
304#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
305pub struct Trade {
306  /// The trade's symbol.
307  #[serde(rename = "S")]
308  pub symbol: String,
309  /// The trade's ID.
310  #[serde(rename = "i")]
311  pub trade_id: u64,
312  /// The trade's price.
313  #[serde(rename = "p")]
314  pub trade_price: Num,
315  /// The trade's size.
316  #[serde(rename = "s")]
317  pub trade_size: Num,
318  /// The trade's time stamp.
319  #[serde(rename = "t")]
320  pub timestamp: DateTime<Utc>,
321}
322
323
324/// An error as reported by the Alpaca Stream API.
325#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize, ThisError)]
326#[error("{message} ({code})")]
327pub struct StreamApiError {
328  /// The error code being reported.
329  #[serde(rename = "code")]
330  pub code: u64,
331  /// A message providing more details about the error.
332  #[serde(rename = "msg")]
333  pub message: String,
334}
335
336
337/// An enum representing the different messages we may receive over our
338/// websocket channel.
339#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
340#[doc(hidden)]
341#[serde(tag = "T")]
342#[allow(clippy::large_enum_variant)]
343pub enum DataMessage<B = Bar, Q = Quote, T = Trade> {
344  /// A variant representing aggregate data for a given symbol.
345  #[serde(rename = "b")]
346  Bar(B),
347  /// A variant representing a quote for a given symbol.
348  #[serde(rename = "q")]
349  Quote(Q),
350  /// A variant representing a trade for a given symbol.
351  #[serde(rename = "t")]
352  Trade(T),
353  /// A control message describing the current list of subscriptions.
354  #[serde(rename = "subscription")]
355  Subscription(MarketData),
356  /// A control message indicating that the last operation was
357  /// successful.
358  #[serde(rename = "success")]
359  Success,
360  /// An error reported by the Alpaca Stream API.
361  #[serde(rename = "error")]
362  Error(StreamApiError),
363}
364
365
366/// A data item as received over our websocket channel.
367#[derive(Debug)]
368#[non_exhaustive]
369pub enum Data<B = Bar, Q = Quote, T = Trade> {
370  /// A variant representing aggregate data for a given symbol.
371  Bar(B),
372  /// A variant representing quote data for a given symbol.
373  Quote(Q),
374  /// A variant representing trade data for a given symbol.
375  Trade(T),
376}
377
378impl<B, Q, T> Data<B, Q, T> {
379  /// Check whether this object is of the `Bar` variant.
380  #[inline]
381  pub fn is_bar(&self) -> bool {
382    matches!(self, Self::Bar(..))
383  }
384
385  /// Check whether this object is of the `Quote` variant.
386  #[inline]
387  pub fn is_quote(&self) -> bool {
388    matches!(self, Self::Quote(..))
389  }
390
391  /// Check whether this object is of the `Trade` variant.
392  #[inline]
393  pub fn is_trade(&self) -> bool {
394    matches!(self, Self::Trade(..))
395  }
396}
397
398
399/// An enumeration of the supported control messages.
400#[derive(Debug)]
401#[doc(hidden)]
402pub enum ControlMessage {
403  /// A control message describing the current list of subscriptions.
404  Subscription(MarketData),
405  /// A control message indicating that the last operation was
406  /// successful.
407  Success,
408  /// An error reported by the Alpaca Stream API.
409  Error(StreamApiError),
410}
411
412
413/// A websocket message that we tried to parse.
414type ParsedMessage<B, Q, T> =
415  MessageResult<Result<DataMessage<B, Q, T>, JsonError>, WebSocketError>;
416
417impl<B, Q, T> subscribe::Message for ParsedMessage<B, Q, T> {
418  type UserMessage = Result<Result<Data<B, Q, T>, JsonError>, WebSocketError>;
419  type ControlMessage = ControlMessage;
420
421  fn classify(self) -> subscribe::Classification<Self::UserMessage, Self::ControlMessage> {
422    match self {
423      MessageResult::Ok(Ok(message)) => match message {
424        DataMessage::Bar(bar) => subscribe::Classification::UserMessage(Ok(Ok(Data::Bar(bar)))),
425        DataMessage::Quote(quote) => {
426          subscribe::Classification::UserMessage(Ok(Ok(Data::Quote(quote))))
427        },
428        DataMessage::Trade(trade) => {
429          subscribe::Classification::UserMessage(Ok(Ok(Data::Trade(trade))))
430        },
431        DataMessage::Subscription(data) => {
432          subscribe::Classification::ControlMessage(ControlMessage::Subscription(data))
433        },
434        DataMessage::Success => subscribe::Classification::ControlMessage(ControlMessage::Success),
435        DataMessage::Error(error) => {
436          subscribe::Classification::ControlMessage(ControlMessage::Error(error))
437        },
438      },
439      // JSON errors are directly passed through.
440      MessageResult::Ok(Err(err)) => subscribe::Classification::UserMessage(Ok(Err(err))),
441      // WebSocket errors are also directly pushed through.
442      MessageResult::Err(err) => subscribe::Classification::UserMessage(Err(err)),
443    }
444  }
445
446  #[inline]
447  fn is_error(user_message: &Self::UserMessage) -> bool {
448    // Both outer `WebSocketError` and inner `JsonError` errors
449    // constitute errors in our sense. Note, however, that an API error
450    // does not. It's just a regular control message from our
451    // perspective.
452    user_message
453      .as_ref()
454      .map(|result| result.is_err())
455      .unwrap_or(true)
456  }
457}
458
459
460/// Deserialize a normalized list of symbols from a string.
461#[inline]
462fn normalized_from_str<'de, D>(deserializer: D) -> Result<Cow<'static, [Symbol]>, D::Error>
463where
464  D: Deserializer<'de>,
465{
466  Cow::<'static, [Symbol]>::deserialize(deserializer).map(normalize)
467}
468
469
470/// A type representing a normalized list of symbols.
471#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
472pub struct SymbolList(#[serde(deserialize_with = "normalized_from_str")] Cow<'static, [Symbol]>);
473
474impl Deref for SymbolList {
475  type Target = [Symbol];
476
477  fn deref(&self) -> &Self::Target {
478    self.0.borrow()
479  }
480}
481
482impl From<Cow<'static, [Symbol]>> for SymbolList {
483  #[inline]
484  fn from(symbols: Cow<'static, [Symbol]>) -> Self {
485    Self(normalize(symbols))
486  }
487}
488
489impl From<Vec<String>> for SymbolList {
490  #[inline]
491  fn from(symbols: Vec<String>) -> Self {
492    Self(normalize(Cow::from(
493      IntoIterator::into_iter(symbols)
494        .map(Symbol::from)
495        .collect::<Vec<_>>(),
496    )))
497  }
498}
499
500impl<const N: usize> From<[&'static str; N]> for SymbolList {
501  #[inline]
502  fn from(symbols: [&'static str; N]) -> Self {
503    Self(normalize(Cow::from(
504      IntoIterator::into_iter(symbols)
505        .map(Symbol::from)
506        .collect::<Vec<_>>(),
507    )))
508  }
509}
510
511
512mod symbols_all {
513  use super::*;
514
515  use serde::de::Error;
516  use serde::de::Unexpected;
517
518  /// Deserialize the [`Symbols::All`] variant.
519  pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<(), D::Error>
520  where
521    D: Deserializer<'de>,
522  {
523    let string = <[&str; 1]>::deserialize(deserializer)?;
524    if string == ["*"] {
525      Ok(())
526    } else {
527      Err(Error::invalid_value(
528        Unexpected::Str(string[0]),
529        &"the string \"*\"",
530      ))
531    }
532  }
533
534  /// Serialize the [`Symbols::All`] variant.
535  pub(crate) fn serialize<S>(serializer: S) -> Result<S::Ok, S::Error>
536  where
537    S: Serializer,
538  {
539    let mut seq = serializer.serialize_seq(Some(1))?;
540    seq.serialize_element("*")?;
541    seq.end()
542  }
543}
544
545
546/// An enumeration of symbols to subscribe to.
547// Please note that the order of variants is important for
548// deserialization purposes: we first need to check whether we are
549// dealing with the `All` variant.
550#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
551#[serde(untagged)]
552pub enum Symbols {
553  /// A variant representing all symbols.
554  #[serde(with = "symbols_all")]
555  All,
556  /// A list of symbols to work with.
557  List(SymbolList),
558}
559
560impl Symbols {
561  /// Check whether the object represents no symbols.
562  #[inline]
563  pub fn is_empty(&self) -> bool {
564    match self {
565      Self::List(list) => list.is_empty(),
566      Self::All => false,
567    }
568  }
569}
570
571impl Default for Symbols {
572  fn default() -> Self {
573    Self::List(SymbolList::from([]))
574  }
575}
576
577
578/// A type defining the market data a client intends to subscribe to.
579#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
580pub struct MarketData {
581  /// The aggregate bars to subscribe to.
582  #[serde(default)]
583  pub bars: Symbols,
584  /// The quotes to subscribe to.
585  #[serde(default)]
586  pub quotes: Symbols,
587  /// The trades to subscribe to.
588  #[serde(default)]
589  pub trades: Symbols,
590}
591
592impl MarketData {
593  /// A convenience function for setting the [`bars`][MarketData::bars]
594  /// member.
595  #[inline]
596  pub fn set_bars<S>(&mut self, symbols: S)
597  where
598    S: Into<SymbolList>,
599  {
600    self.bars = Symbols::List(symbols.into());
601  }
602
603  /// A convenience function for setting the [`quotes`][MarketData::quotes]
604  /// member.
605  #[inline]
606  pub fn set_quotes<S>(&mut self, symbols: S)
607  where
608    S: Into<SymbolList>,
609  {
610    self.quotes = Symbols::List(symbols.into());
611  }
612
613  /// A convenience function for setting the [`trades`][MarketData::trades]
614  /// member.
615  #[inline]
616  pub fn set_trades<S>(&mut self, symbols: S)
617  where
618    S: Into<SymbolList>,
619  {
620    self.trades = Symbols::List(symbols.into());
621  }
622}
623
624
625/// A control message "request" sent over a websocket channel.
626#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)]
627#[doc(hidden)]
628#[serde(tag = "action")]
629pub enum Request<'d> {
630  /// A control message indicating whether or not we were authenticated
631  /// successfully.
632  #[serde(rename = "auth")]
633  Authenticate {
634    #[serde(rename = "key")]
635    key_id: Cow<'d, str>,
636    #[serde(rename = "secret")]
637    secret: Cow<'d, str>,
638  },
639  /// A control message subscribing the client to receive updates for
640  /// the provided symbols.
641  #[serde(rename = "subscribe")]
642  Subscribe(Cow<'d, MarketData>),
643  /// A control message unsubscribing the client from receiving updates
644  /// for the provided symbols.
645  #[serde(rename = "unsubscribe")]
646  Unsubscribe(Cow<'d, MarketData>),
647}
648
649
650/// A subscription allowing certain control operations pertaining
651/// a real time market data stream.
652///
653/// # Notes
654/// - in order for any [`subscribe`][Subscription::subscribe] or
655///   [`unsubscribe`][Subscription::unsubscribe] operation to resolve,
656///   the associated [`MessageStream`] stream needs to be polled;
657///   consider using the [`drive`] function for that purpose
658#[derive(Debug)]
659pub struct Subscription<S, B, Q, T> {
660  /// Our internally used subscription object for sending control
661  /// messages.
662  subscription: subscribe::Subscription<S, ParsedMessage<B, Q, T>, wrap::Message>,
663  /// The currently active individual market data subscriptions.
664  subscriptions: MarketData,
665}
666
667impl<S, B, Q, T> Subscription<S, B, Q, T> {
668  /// Create a `Subscription` object wrapping the `websocket_util` based one.
669  #[inline]
670  fn new(subscription: subscribe::Subscription<S, ParsedMessage<B, Q, T>, wrap::Message>) -> Self {
671    Self {
672      subscription,
673      subscriptions: MarketData::default(),
674    }
675  }
676}
677
678impl<S, B, Q, T> Subscription<S, B, Q, T>
679where
680  S: Sink<wrap::Message> + Unpin,
681{
682  /// Authenticate the connection using Alpaca credentials.
683  async fn authenticate(
684    &mut self,
685    key_id: &str,
686    secret: &str,
687  ) -> Result<Result<(), Error>, S::Error> {
688    let request = Request::Authenticate {
689      key_id: key_id.into(),
690      secret: secret.into(),
691    };
692    let json = match to_json(&request) {
693      Ok(json) => json,
694      Err(err) => return Ok(Err(Error::Json(err))),
695    };
696    let message = wrap::Message::Text(json);
697    let response = self.subscription.send(message).await?;
698
699    match response {
700      Some(response) => match response {
701        Ok(ControlMessage::Success) => Ok(Ok(())),
702        Ok(ControlMessage::Subscription(..)) => Ok(Err(Error::Str(
703          "server responded with unexpected subscription message".into(),
704        ))),
705        Ok(ControlMessage::Error(error)) => Ok(Err(Error::Str(
706          format!(
707            "failed to authenticate with server: {} ({})",
708            error.message, error.code
709          )
710          .into(),
711        ))),
712        Err(()) => Ok(Err(Error::Str("failed to authenticate with server".into()))),
713      },
714      None => Ok(Err(Error::Str(
715        "stream was closed before authorization message was received".into(),
716      ))),
717    }
718  }
719
720  /// Handle sending of a subscribe or unsubscribe request.
721  async fn subscribe_unsubscribe(
722    &mut self,
723    request: &Request<'_>,
724  ) -> Result<Result<(), Error>, S::Error> {
725    let json = match to_json(request) {
726      Ok(json) => json,
727      Err(err) => return Ok(Err(Error::Json(err))),
728    };
729    let message = wrap::Message::Text(json);
730    let response = self.subscription.send(message).await?;
731
732    match response {
733      Some(response) => match response {
734        Ok(ControlMessage::Subscription(data)) => {
735          self.subscriptions = data;
736          Ok(Ok(()))
737        },
738        Ok(ControlMessage::Error(error)) => Ok(Err(Error::Str(
739          format!("failed to subscribe: {error}").into(),
740        ))),
741        Ok(_) => Ok(Err(Error::Str(
742          "server responded with unexpected message".into(),
743        ))),
744        Err(()) => Ok(Err(Error::Str("failed to adjust subscription".into()))),
745      },
746      None => Ok(Err(Error::Str(
747        "stream was closed before subscription confirmation message was received".into(),
748      ))),
749    }
750  }
751
752  /// Subscribe to the provided market data.
753  ///
754  /// Contained in `subscribe` are the *additional* symbols to subscribe
755  /// to. Use the [`unsubscribe`][Self::unsubscribe] method to
756  /// unsubscribe from receiving data for certain symbols.
757  #[inline]
758  pub async fn subscribe(&mut self, subscribe: &MarketData) -> Result<Result<(), Error>, S::Error> {
759    let request = Request::Subscribe(Cow::Borrowed(subscribe));
760    self.subscribe_unsubscribe(&request).await
761  }
762
763  /// Unsubscribe from receiving market data for the provided symbols.
764  ///
765  /// Subscriptions of market data for symbols other than the ones
766  /// provided to this function are left untouched.
767  #[inline]
768  pub async fn unsubscribe(
769    &mut self,
770    unsubscribe: &MarketData,
771  ) -> Result<Result<(), Error>, S::Error> {
772    let request = Request::Unsubscribe(Cow::Borrowed(unsubscribe));
773    self.subscribe_unsubscribe(&request).await
774  }
775
776  /// Inquire the currently active individual market data subscriptions.
777  #[inline]
778  pub fn subscriptions(&self) -> &MarketData {
779    &self.subscriptions
780  }
781}
782
783
784type ParseFn<B, Q, T> = fn(
785  Result<wrap::Message, WebSocketError>,
786) -> Result<Result<Vec<DataMessage<B, Q, T>>, JsonError>, WebSocketError>;
787type MapFn<B, Q, T> =
788  fn(Result<Result<DataMessage<B, Q, T>, JsonError>, WebSocketError>) -> ParsedMessage<B, Q, T>;
789type Stream<B, Q, T> = Map<
790  Unfold<
791    Map<Wrapper<WebSocketStream<MaybeTlsStream<TcpStream>>>, ParseFn<B, Q, T>>,
792    DataMessage<B, Q, T>,
793    JsonError,
794  >,
795  MapFn<B, Q, T>,
796>;
797
798
799/// A type used for requesting a subscription to real time market
800/// data.
801///
802/// The bar (`B`), quote (`Q`), and trade (`T`) types used can be
803/// overwritten to extend/customize the default types ([`Bar`],
804/// [`Quote`], and [`Trade`], respectively) that are provided by the
805/// library.
806#[derive(Debug)]
807pub struct RealtimeData<S, B = Bar, Q = Quote, T = Trade> {
808  /// Phantom data to make sure that we "use" `S`.
809  _phantom: PhantomData<(S, B, Q, T)>,
810}
811
812#[async_trait]
813impl<S, B, Q, T> Subscribable for RealtimeData<S, B, Q, T>
814where
815  S: Source,
816  B: Send + Unpin + Debug + DeserializeOwned,
817  Q: Send + Unpin + Debug + DeserializeOwned,
818  T: Send + Unpin + Debug + DeserializeOwned,
819{
820  type Input = ApiInfo;
821  type Subscription = Subscription<SplitSink<Stream<B, Q, T>, wrap::Message>, B, Q, T>;
822  type Stream = Fuse<MessageStream<SplitStream<Stream<B, Q, T>>, ParsedMessage<B, Q, T>>>;
823
824  async fn connect(api_info: &Self::Input) -> Result<(Self::Stream, Self::Subscription), Error> {
825    fn parse<B, Q, T>(
826      result: Result<wrap::Message, WebSocketError>,
827    ) -> Result<Result<Vec<DataMessage<B, Q, T>>, JsonError>, WebSocketError>
828    where
829      B: DeserializeOwned,
830      Q: DeserializeOwned,
831      T: DeserializeOwned,
832    {
833      result.map(|message| match message {
834        wrap::Message::Text(string) => json_from_str::<Vec<DataMessage<B, Q, T>>>(&string),
835        wrap::Message::Binary(data) => json_from_slice::<Vec<DataMessage<B, Q, T>>>(&data),
836      })
837    }
838
839    let ApiInfo {
840      data_stream_base_url: url,
841      key_id,
842      secret,
843      ..
844    } = api_info;
845
846    let url = match S::source() {
847      SourceVariant::PathComponent(component) => {
848        let mut url = url.clone();
849        url.set_path(&format!("v2/{}", component));
850        url
851      },
852      SourceVariant::Url(url) => Url::parse(&url)?,
853    };
854
855    let stream = Unfold::new(
856      connect(&url)
857        .await?
858        .map(parse::<B, Q, T> as ParseFn<_, _, _>),
859    )
860    .map(MessageResult::from as MapFn<B, Q, T>);
861    let (send, recv) = stream.split();
862    let (stream, subscription) = subscribe::subscribe(recv, send);
863    let mut stream = stream.fuse();
864    let mut subscription = Subscription::new(subscription);
865
866    let connect = subscription.subscription.read().boxed();
867    let message = drive(connect, &mut stream).await.map_err(|result| {
868      result
869        .map(|result| Error::Json(result.unwrap_err()))
870        .map_err(Error::WebSocket)
871        .unwrap_or_else(|err| err)
872    })?;
873
874    match message {
875      Some(Ok(ControlMessage::Success)) => (),
876      Some(Ok(_)) => {
877        return Err(Error::Str(
878          "server responded with unexpected initial message".into(),
879        ))
880      },
881      Some(Err(())) => return Err(Error::Str("failed to read connected message".into())),
882      None => {
883        return Err(Error::Str(
884          "stream was closed before connected message was received".into(),
885        ))
886      },
887    }
888
889    let authenticate = subscription.authenticate(key_id, secret).boxed();
890    let () = drive(authenticate, &mut stream).await.map_err(|result| {
891      result
892        .map(|result| Error::Json(result.unwrap_err()))
893        .map_err(Error::WebSocket)
894        .unwrap_or_else(|err| err)
895    })???;
896
897    Ok((stream, subscription))
898  }
899}
900
901
902#[allow(clippy::to_string_trait_impl)]
903#[cfg(test)]
904mod tests {
905  use super::*;
906
907  use std::str::FromStr;
908  use std::time::Duration;
909
910  use chrono::DateTime;
911
912  use futures::SinkExt as _;
913  use futures::TryStreamExt as _;
914
915  use serial_test::serial;
916
917  use serde_json::from_str as json_from_str;
918
919  use test_log::test;
920
921  use tokio::time::timeout;
922
923  use tungstenite::tungstenite::Utf8Bytes;
924
925  use websocket_util::test::WebSocketStream;
926  use websocket_util::tungstenite::Message;
927
928  use crate::api::API_BASE_URL;
929  use crate::websocket::test::mock_stream;
930  use crate::Client;
931
932
933  const CONN_RESP: &str = r#"[{"T":"success","msg":"connected"}]"#;
934  // TODO: Until we can interpolate more complex expressions using
935  //       `std::format` in a const context we have to hard code the
936  //       values of `crate::websocket::test::KEY_ID` and
937  //       `crate::websocket::test::SECRET` here.
938  const AUTH_REQ: &str = r#"{"action":"auth","key":"USER12345678","secret":"justletmein"}"#;
939  const AUTH_RESP: &str = r#"[{"T":"success","msg":"authenticated"}]"#;
940  const SUB_REQ: &str = r#"{"action":"subscribe","bars":["AAPL","VOO"],"quotes":[],"trades":[]}"#;
941  const SUB_RESP: &str = r#"[{"T":"subscription","bars":["AAPL","VOO"]}]"#;
942  const SUB_ERR_REQ: &str = r#"{"action":"subscribe","bars":[],"quotes":[],"trades":[]}"#;
943  const SUB_ERR_RESP: &str = r#"[{"T":"error","code":400,"msg":"invalid syntax"}]"#;
944
945
946  /// Exercise the `Sip::source` method.
947  #[test]
948  fn sip_source() {
949    assert_ne!(format!("{:?}", SIP::source()), "");
950  }
951
952  /// Exercise the various `is_*` methods of the `Data` enum.
953  #[test]
954  fn data_classification() {
955    assert!(Data::<(), Quote, Trade>::Bar(()).is_bar());
956    assert!(Data::<Bar, (), Trade>::Quote(()).is_quote());
957    assert!(Data::<Bar, Quote, ()>::Trade(()).is_trade());
958  }
959
960  /// Test that the [`Symbols::is_empty`] method works as expected.
961  #[test]
962  fn symbols_is_empty() {
963    assert!(!Symbols::All.is_empty());
964    assert!(!Symbols::List(SymbolList::from(["SPY"])).is_empty());
965    assert!(Symbols::List(SymbolList::from([])).is_empty());
966  }
967
968  /// Check that we can deserialize and serialize the
969  /// [`DataMessage::Bar`] variant.
970  #[test]
971  fn serialize_deserialize_bar() {
972    let json = r#"{
973  "T": "b",
974  "S": "SPY",
975  "o": 388.985,
976  "h": 389.13,
977  "l": 388.975,
978  "c": 389.12,
979  "v": 49378,
980  "t": "2021-02-22T19:15:00Z"
981}"#;
982
983    let message = json_from_str::<DataMessage>(json).unwrap();
984    let bar = match &message {
985      DataMessage::Bar(bar) => bar,
986      _ => panic!("Decoded unexpected message variant: {message:?}"),
987    };
988    assert_eq!(bar.symbol, "SPY");
989    assert_eq!(bar.open_price, Num::new(388985, 1000));
990    assert_eq!(bar.high_price, Num::new(38913, 100));
991    assert_eq!(bar.low_price, Num::new(388975, 1000));
992    assert_eq!(bar.close_price, Num::new(38912, 100));
993    assert_eq!(bar.volume, Num::from(49378));
994    assert_eq!(
995      bar.timestamp,
996      DateTime::<Utc>::from_str("2021-02-22T19:15:00Z").unwrap()
997    );
998
999    assert_eq!(
1000      json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1001      message
1002    );
1003  }
1004
1005  /// Check that we can serialize and deserialize the
1006  /// [`DataMessage::Quote`] variant.
1007  #[test]
1008  fn serialize_deserialize_quote() {
1009    let json: &str = r#"{
1010  "T": "q",
1011  "S": "NVDA",
1012  "bx": "P",
1013  "bp": 258.8,
1014  "bs": 2,
1015  "ax": "A",
1016  "ap": 259.99,
1017  "as": 5,
1018  "c": [
1019      "R"
1020  ],
1021  "z": "C",
1022  "t": "2022-01-18T23:09:42.151875584Z"
1023}"#;
1024
1025    let message = json_from_str::<DataMessage>(json).unwrap();
1026    let quote = match &message {
1027      DataMessage::Quote(quote) => quote,
1028      _ => panic!("Decoded unexpected message variant: {message:?}"),
1029    };
1030    assert_eq!(quote.symbol, "NVDA");
1031    assert_eq!(quote.bid_price, Num::new(2588, 10));
1032    assert_eq!(quote.bid_size, Num::from(2));
1033    assert_eq!(quote.ask_price, Num::new(25999, 100));
1034    assert_eq!(quote.ask_size, Num::from(5));
1035
1036    assert_eq!(
1037      quote.timestamp,
1038      DateTime::<Utc>::from_str("2022-01-18T23:09:42.151875584Z").unwrap()
1039    );
1040
1041    assert_eq!(
1042      json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1043      message
1044    );
1045  }
1046
1047
1048  /// A quote for an equity.
1049  #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
1050  struct DetailedQuote {
1051    /// The quote's symbol.
1052    #[serde(rename = "S")]
1053    symbol: String,
1054    /// The bid's price.
1055    #[serde(rename = "bp")]
1056    bid_price: Num,
1057    /// The bid's size.
1058    #[serde(rename = "bs")]
1059    bid_size: Num,
1060    /// The bid's exchange code.
1061    #[serde(rename = "bx")]
1062    bid_exchange_code: String,
1063    /// The ask's price.
1064    #[serde(rename = "ap")]
1065    ask_price: Num,
1066    /// The ask's size.
1067    #[serde(rename = "as")]
1068    ask_size: Num,
1069    /// The ask's exchange code.
1070    #[serde(rename = "ax")]
1071    ask_exchange_code: String,
1072    /// The trade's size.
1073    #[serde(rename = "s")]
1074    trade_size: Num,
1075    /// The quote's time stamp.
1076    #[serde(rename = "t")]
1077    timestamp: DateTime<Utc>,
1078    /// The quote's conditions.
1079    #[serde(rename = "c")]
1080    quote_conditions: Vec<String>,
1081    /// The tape.
1082    #[serde(rename = "z")]
1083    tape: String,
1084  }
1085
1086  /// Check that we can serialize and deserialize the
1087  /// [`DataMessage::Quote`] variant with a custom quote type.
1088  #[test]
1089  fn serialize_deserialize_custom_quote() {
1090    let json: &str = r#"{
1091  "T": "q",
1092  "S": "NVDA",
1093  "bx": "P",
1094  "bp": 258.8,
1095  "bs": 2,
1096  "s": 3,
1097  "ax": "A",
1098  "ap": 259.99,
1099  "as": 5,
1100  "c": [
1101      "R"
1102  ],
1103  "z": "C",
1104  "t": "2022-01-18T23:09:42.151875584Z"
1105}"#;
1106
1107    let message = json_from_str::<DataMessage<Bar, DetailedQuote, Trade>>(json).unwrap();
1108    let quote = match &message {
1109      DataMessage::Quote(quote) => quote,
1110      _ => panic!("Decoded unexpected message variant: {message:?}"),
1111    };
1112    assert_eq!(quote.symbol, "NVDA");
1113    assert_eq!(quote.bid_price, Num::new(2588, 10));
1114    assert_eq!(quote.bid_size, Num::from(2));
1115    assert_eq!(quote.bid_exchange_code, "P");
1116    assert_eq!(quote.ask_price, Num::new(25999, 100));
1117    assert_eq!(quote.ask_size, Num::from(5));
1118    assert_eq!(quote.ask_exchange_code, "A");
1119    assert_eq!(quote.trade_size, Num::from(3));
1120
1121    assert_eq!(
1122      quote.timestamp,
1123      DateTime::<Utc>::from_str("2022-01-18T23:09:42.151875584Z").unwrap()
1124    );
1125
1126    assert_eq!(
1127      json_from_str::<DataMessage<Bar, DetailedQuote, Trade>>(&to_json(&message).unwrap()).unwrap(),
1128      message
1129    );
1130  }
1131
1132  /// Check that we can serialize and deserialize the
1133  /// [`DataMessage::Trade`] variant.
1134  #[test]
1135  fn serialize_deserialize_trade() {
1136    let json: &str = r#"{
1137  "T": "t",
1138  "i": 96921,
1139  "S": "AAPL",
1140  "x": "D",
1141  "p": 126.55,
1142  "s": 1,
1143  "t": "2021-02-22T15:51:44.208Z",
1144  "c": ["@", "I"],
1145  "z": "C"
1146}"#;
1147
1148    let message = json_from_str::<DataMessage>(json).unwrap();
1149    let trade = match &message {
1150      DataMessage::Trade(trade) => trade,
1151      _ => panic!("Decoded unexpected message variant: {message:?}"),
1152    };
1153    assert_eq!(trade.symbol, "AAPL");
1154    assert_eq!(trade.trade_id, 96921);
1155    assert_eq!(trade.trade_price, Num::new(12655, 100));
1156    assert_eq!(trade.trade_size, Num::from(1));
1157
1158    assert_eq!(
1159      trade.timestamp,
1160      DateTime::<Utc>::from_str("2021-02-22T15:51:44.208Z").unwrap()
1161    );
1162
1163    assert_eq!(
1164      json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1165      message
1166    );
1167  }
1168
1169
1170  /// A trade for an equity.
1171  #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
1172  struct DetailedTrade {
1173    /// The trade's symbol.
1174    #[serde(rename = "S")]
1175    symbol: String,
1176    /// The trade's ID.
1177    #[serde(rename = "i")]
1178    trade_id: u64,
1179    /// The trade's price.
1180    #[serde(rename = "p")]
1181    trade_price: Num,
1182    /// The trade's size.
1183    #[serde(rename = "s")]
1184    trade_size: u64,
1185    /// The trade's conditions.
1186    #[serde(rename = "c")]
1187    conditions: Vec<String>,
1188    /// The trade's time stamp.
1189    #[serde(rename = "t")]
1190    timestamp: DateTime<Utc>,
1191    /// The trade's exchange.
1192    #[serde(rename = "x")]
1193    exchange: String,
1194    /// The trade's tape.
1195    #[serde(rename = "z")]
1196    tape: String,
1197    /// The trade's update, may be "canceled", "corrected", or
1198    /// "incorrect".
1199    #[serde(rename = "u", default)]
1200    update: Option<String>,
1201  }
1202
1203  /// Check that we can serialize and deserialize the
1204  /// [`DataMessage::Trade`] variant with a custom trade type.
1205  #[test]
1206  fn serialize_deserialize_custom_trade() {
1207    let json: &str = r#"{
1208  "T": "t",
1209  "i": 96921,
1210  "S": "AAPL",
1211  "x": "D",
1212  "p": 126.55,
1213  "s": 1,
1214  "t": "2021-02-22T15:51:44.208Z",
1215  "c": ["@", "I"],
1216  "z": "C",
1217  "u": "corrected"
1218}"#;
1219
1220    let message = json_from_str::<DataMessage<Bar, Quote, DetailedTrade>>(json).unwrap();
1221    let trade = match &message {
1222      DataMessage::Trade(trade) => trade,
1223      _ => panic!("Decoded unexpected message variant: {message:?}"),
1224    };
1225    assert_eq!(trade.symbol, "AAPL");
1226    assert_eq!(trade.trade_id, 96921);
1227    assert_eq!(trade.trade_price, Num::new(12655, 100));
1228    assert_eq!(trade.trade_size, 1);
1229
1230    assert_eq!(
1231      trade.timestamp,
1232      DateTime::<Utc>::from_str("2021-02-22T15:51:44.208Z").unwrap()
1233    );
1234
1235    assert_eq!(trade.conditions, vec!["@", "I"]);
1236    assert_eq!(trade.tape, "C");
1237    assert_eq!(trade.update, Some("corrected".to_string()));
1238
1239    assert_eq!(
1240      json_from_str::<DataMessage<Bar, Quote, DetailedTrade>>(&to_json(&message).unwrap()).unwrap(),
1241      message
1242    );
1243  }
1244
1245  /// Check that we can serialize and deserialize the
1246  /// [`DataMessage::Success`] variant.
1247  #[test]
1248  fn serialize_deserialize_success() {
1249    let json = r#"{"T":"success","msg":"authenticated"}"#;
1250    let message = json_from_str::<DataMessage>(json).unwrap();
1251    let () = match message {
1252      DataMessage::Success => (),
1253      _ => panic!("Decoded unexpected message variant: {message:?}"),
1254    };
1255
1256    assert_eq!(
1257      json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1258      message
1259    );
1260  }
1261
1262  /// Check that we can serialize and deserialize the
1263  /// [`DataMessage::Error`] variant.
1264  #[test]
1265  fn serialize_deserialize_error() {
1266    let json = r#"{"T":"error","code":400,"msg":"invalid syntax"}"#;
1267    let message = json_from_str::<DataMessage>(json).unwrap();
1268    let error = match &message {
1269      DataMessage::Error(error) => error,
1270      _ => panic!("Decoded unexpected message variant: {message:?}"),
1271    };
1272
1273    assert_eq!(error.code, 400);
1274    assert_eq!(error.message, "invalid syntax");
1275
1276    assert_eq!(
1277      json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1278      message
1279    );
1280
1281    let json = r#"{"T":"error","code":500,"msg":"internal error"}"#;
1282    let message = json_from_str::<DataMessage>(json).unwrap();
1283    let error = match &message {
1284      DataMessage::Error(error) => error,
1285      _ => panic!("Decoded unexpected message variant: {message:?}"),
1286    };
1287
1288    assert_eq!(error.code, 500);
1289    assert_eq!(error.message, "internal error");
1290
1291    assert_eq!(
1292      json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1293      message
1294    );
1295  }
1296
1297  /// Check that we can serialize and deserialize the
1298  /// [`Request::Authenticate`] variant properly.
1299  #[test]
1300  fn serialize_deserialize_authentication_request() {
1301    let request = Request::Authenticate {
1302      key_id: "KEY-ID".into(),
1303      secret: "SECRET-KEY".into(),
1304    };
1305    let json = to_json(&request).unwrap();
1306    let expected = r#"{"action":"auth","key":"KEY-ID","secret":"SECRET-KEY"}"#;
1307    assert_eq!(json, expected);
1308    assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
1309  }
1310
1311  /// Check that we can serialize and deserialize the
1312  /// [`Request::Subscribe`] variant properly.
1313  #[test]
1314  fn serialize_deserialize_subscribe_request() {
1315    let mut data = MarketData::default();
1316    data.set_bars(["AAPL", "VOO"]);
1317    let request = Request::Subscribe(Cow::Borrowed(&data));
1318
1319    let json = to_json(&request).unwrap();
1320    let expected = r#"{"action":"subscribe","bars":["AAPL","VOO"],"quotes":[],"trades":[]}"#;
1321    assert_eq!(json, expected);
1322    assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
1323  }
1324
1325  /// Check that we can serialize and deserialize the
1326  /// [`Request::Subscribe`] variant properly.
1327  #[test]
1328  fn serialize_deserialize_unsubscribe_request() {
1329    let mut data = MarketData::default();
1330    data.set_bars(["VOO"]);
1331    let request = Request::Unsubscribe(Cow::Borrowed(&data));
1332
1333    let json = to_json(&request).unwrap();
1334    let expected = r#"{"action":"unsubscribe","bars":["VOO"],"quotes":[],"trades":[]}"#;
1335    assert_eq!(json, expected);
1336    assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
1337  }
1338
1339  /// Check that we can correctly deserialize a `SymbolList` object.
1340  #[test]
1341  fn deserialize_symbol_list() {
1342    let json = r#"["AAPL","XLK","SPY"]"#;
1343    let list = json_from_str::<SymbolList>(json).unwrap();
1344    let expected = SymbolList::from(["AAPL", "SPY", "XLK"]);
1345    assert_eq!(list, expected);
1346  }
1347
1348  /// Check that we can normalize `Symbol` slices.
1349  #[test]
1350  fn normalize_subscriptions() {
1351    let subscriptions = [];
1352    assert!(is_normalized(&subscriptions));
1353
1354    let subscriptions = ["MSFT".into(), "SPY".into()];
1355    assert!(is_normalized(&subscriptions));
1356
1357    let mut subscriptions = Cow::from(vec!["SPY".into(), "MSFT".into()]);
1358    assert!(!is_normalized(&subscriptions));
1359    subscriptions = normalize(subscriptions);
1360    assert!(is_normalized(&subscriptions));
1361
1362    let expected = [Cow::from("MSFT"), "SPY".into()];
1363    assert_eq!(subscriptions.borrow(), expected);
1364
1365    let mut subscriptions = Cow::from(vec!["SPY".into(), "MSFT".into(), "MSFT".into()]);
1366    assert!(!is_normalized(&subscriptions));
1367    subscriptions = normalize(subscriptions);
1368    assert!(is_normalized(&subscriptions));
1369
1370    assert_eq!(subscriptions.borrow(), expected);
1371  }
1372
1373  /// Check that we can correctly handle a successful subscription
1374  /// without pushing actual data.
1375  #[test(tokio::test)]
1376  async fn authenticate_and_subscribe() {
1377    async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
1378      stream
1379        .send(Message::Text(Utf8Bytes::from_static(CONN_RESP)))
1380        .await?;
1381      // Authentication.
1382      assert_eq!(
1383        stream.next().await.unwrap()?,
1384        Message::Text(Utf8Bytes::from_static(AUTH_REQ)),
1385      );
1386      stream
1387        .send(Message::Text(Utf8Bytes::from_static(AUTH_RESP)))
1388        .await?;
1389
1390      // Subscription.
1391      assert_eq!(
1392        stream.next().await.unwrap()?,
1393        Message::Text(Utf8Bytes::from_static(SUB_REQ)),
1394      );
1395      stream
1396        .send(Message::Text(Utf8Bytes::from_static(SUB_RESP)))
1397        .await?;
1398      stream.send(Message::Close(None)).await?;
1399      Ok(())
1400    }
1401
1402    let (mut stream, mut subscription) =
1403      mock_stream::<RealtimeData<IEX>, _, _>(test).await.unwrap();
1404
1405    let mut data = MarketData::default();
1406    data.set_bars(["AAPL", "VOO"]);
1407
1408    let subscribe = subscription.subscribe(&data).boxed_local();
1409    let () = drive(subscribe, &mut stream)
1410      .await
1411      .unwrap()
1412      .unwrap()
1413      .unwrap();
1414
1415    stream
1416      .map_err(Error::WebSocket)
1417      .try_for_each(|result| async { result.map(|_data| ()).map_err(Error::Json) })
1418      .await
1419      .unwrap();
1420  }
1421
1422  /// Check that we correctly handle errors reported as part of
1423  /// subscription.
1424  #[test(tokio::test)]
1425  async fn subscribe_error() {
1426    async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
1427      stream
1428        .send(Message::Text(Utf8Bytes::from_static(CONN_RESP)))
1429        .await?;
1430      // Authentication.
1431      assert_eq!(
1432        stream.next().await.unwrap()?,
1433        Message::Text(Utf8Bytes::from_static(AUTH_REQ)),
1434      );
1435      stream
1436        .send(Message::Text(Utf8Bytes::from_static(AUTH_RESP)))
1437        .await?;
1438
1439      // Subscription.
1440      assert_eq!(
1441        stream.next().await.unwrap()?,
1442        Message::Text(Utf8Bytes::from_static(SUB_ERR_REQ)),
1443      );
1444      stream
1445        .send(Message::Text(Utf8Bytes::from_static(SUB_ERR_RESP)))
1446        .await?;
1447      stream.send(Message::Close(None)).await?;
1448      Ok(())
1449    }
1450
1451    let (mut stream, mut subscription) =
1452      mock_stream::<RealtimeData<IEX>, _, _>(test).await.unwrap();
1453
1454    let data = MarketData::default();
1455
1456    let subscribe = subscription.subscribe(&data).boxed_local();
1457    let error = drive(subscribe, &mut stream)
1458      .await
1459      .unwrap()
1460      .unwrap()
1461      .unwrap_err();
1462
1463    match error {
1464      Error::Str(ref e) if e == "failed to subscribe: invalid syntax (400)" => {},
1465      e => panic!("received unexpected error: {e}"),
1466    }
1467  }
1468
1469  /// Check that we can adjust the current market data subscription on
1470  /// the fly.
1471  #[test(tokio::test)]
1472  #[serial(realtime_data)]
1473  async fn subscribe_resubscribe() {
1474    let api_info = ApiInfo::from_env().unwrap();
1475    let client = Client::new(api_info);
1476    let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
1477
1478    let mut data = MarketData::default();
1479    data.set_bars(["AAPL", "SPY"]);
1480
1481    let subscribe = subscription.subscribe(&data).boxed_local();
1482    let () = drive(subscribe, &mut stream)
1483      .await
1484      .unwrap()
1485      .unwrap()
1486      .unwrap();
1487
1488    assert_eq!(subscription.subscriptions(), &data);
1489
1490    let mut data = MarketData::default();
1491    data.set_bars(["XLK"]);
1492    let subscribe = subscription.subscribe(&data).boxed_local();
1493    let () = drive(subscribe, &mut stream)
1494      .await
1495      .unwrap()
1496      .unwrap()
1497      .unwrap();
1498
1499    let mut expected = MarketData::default();
1500    expected.set_bars(["AAPL", "SPY", "XLK"]);
1501    assert_eq!(subscription.subscriptions(), &expected);
1502  }
1503
1504  /// Check that we can stream realtime market data updates.
1505  ///
1506  /// Note that we do not have any control over whether the market is
1507  /// open or not and as such we can only try on a best-effort basis to
1508  /// receive and decode updates.
1509  #[test(tokio::test)]
1510  #[serial(realtime_data)]
1511  async fn stream_market_data_updates() {
1512    let api_info = ApiInfo::from_env().unwrap();
1513    let client = Client::new(api_info);
1514    let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
1515
1516    let data = MarketData {
1517      bars: Symbols::All,
1518      ..Default::default()
1519    };
1520
1521    let subscribe = subscription.subscribe(&data).boxed_local();
1522    let () = drive(subscribe, &mut stream)
1523      .await
1524      .unwrap()
1525      .unwrap()
1526      .unwrap();
1527
1528    assert_eq!(subscription.subscriptions(), &data);
1529
1530    let read = stream
1531      .map_err(Error::WebSocket)
1532      .try_for_each(|result| async {
1533        result
1534          .map(|data| {
1535            assert!(data.is_bar());
1536          })
1537          .map_err(Error::Json)
1538      });
1539
1540    if timeout(Duration::from_millis(100), read).await.is_ok() {
1541      panic!("realtime data stream got exhausted unexpectedly")
1542    }
1543  }
1544
1545  /// Check that we can stream realtime stock quotes.
1546  ///
1547  /// Note that we do not have any control over whether the market is
1548  /// open or not and as such we can only try on a best-effort basis to
1549  /// receive and decode updates.
1550  #[test(tokio::test)]
1551  #[serial(realtime_data)]
1552  async fn stream_quotes() {
1553    async fn test<S>()
1554    where
1555      S: Source,
1556    {
1557      let api_info = ApiInfo::from_env().unwrap();
1558      let client = Client::new(api_info);
1559      let (mut stream, mut subscription) = client.subscribe::<RealtimeData<S>>().await.unwrap();
1560
1561      let mut data = MarketData::default();
1562      data.set_quotes(["SPY"]);
1563
1564      let subscribe = subscription.subscribe(&data).boxed_local();
1565      let () = drive(subscribe, &mut stream)
1566        .await
1567        .unwrap()
1568        .unwrap()
1569        .unwrap();
1570
1571      let read = stream
1572        .map_err(Error::WebSocket)
1573        .try_for_each(|result| async {
1574          result
1575            .map(|data| {
1576              assert!(data.is_quote());
1577            })
1578            .map_err(Error::Json)
1579        });
1580
1581      if timeout(Duration::from_millis(100), read).await.is_ok() {
1582        panic!("realtime data stream got exhausted unexpectedly")
1583      }
1584    }
1585
1586    test::<IEX>().await;
1587
1588    #[derive(Default)]
1589    struct IexWithUrl;
1590
1591    impl ToString for IexWithUrl {
1592      fn to_string(&self) -> String {
1593        "wss://stream.data.alpaca.markets/v2/iex".into()
1594      }
1595    }
1596
1597    test::<CustomUrl<IexWithUrl>>().await;
1598  }
1599
1600  /// Check that we can stream realtime stock trades.
1601  #[test(tokio::test)]
1602  #[serial(realtime_data)]
1603  async fn stream_trades() {
1604    let api_info = ApiInfo::from_env().unwrap();
1605    let client = Client::new(api_info);
1606    let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
1607
1608    let mut data = MarketData::default();
1609    data.set_trades(["SPY"]);
1610
1611    let subscribe = subscription.subscribe(&data).boxed_local();
1612    let () = drive(subscribe, &mut stream)
1613      .await
1614      .unwrap()
1615      .unwrap()
1616      .unwrap();
1617
1618    let read = stream
1619      .map_err(Error::WebSocket)
1620      .try_for_each(|result| async {
1621        result
1622          .map(|data| {
1623            assert!(data.is_trade());
1624          })
1625          .map_err(Error::Json)
1626      });
1627
1628    if timeout(Duration::from_millis(100), read).await.is_ok() {
1629      panic!("realtime data stream got exhausted unexpectedly")
1630    }
1631  }
1632
1633  /// Check that the Alpaca API reports no error when unsubscribing
1634  /// from a symbol not currently subscribed to.
1635  #[test(tokio::test)]
1636  #[serial(realtime_data)]
1637  async fn unsubscribe_not_subscribed_symbol() {
1638    let api_info = ApiInfo::from_env().unwrap();
1639    let client = Client::new(api_info);
1640    let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
1641
1642    let mut data = MarketData::default();
1643    data.set_bars(["AAPL"]);
1644
1645    let unsubscribe = subscription.unsubscribe(&data).boxed_local();
1646    let () = drive(unsubscribe, &mut stream)
1647      .await
1648      .unwrap()
1649      .unwrap()
1650      .unwrap();
1651
1652    assert_eq!(subscription.subscriptions(), &MarketData::default());
1653  }
1654
1655  /// Test that we fail as expected when attempting to authenticate for
1656  /// real time market updates using invalid credentials.
1657  #[test(tokio::test)]
1658  #[serial(realtime_data)]
1659  async fn stream_with_invalid_credentials() {
1660    let api_info = ApiInfo::from_parts(API_BASE_URL, "invalid", "invalid-too").unwrap();
1661    let client = Client::new(api_info);
1662    let err = client.subscribe::<RealtimeData<IEX>>().await.unwrap_err();
1663
1664    match err {
1665      Error::Str(ref e) if e.starts_with("failed to authenticate with server") => (),
1666      e => panic!("received unexpected error: {e}"),
1667    }
1668  }
1669
1670  /// Check that we fail connection as expected on an invalid URL.
1671  #[test(tokio::test)]
1672  #[serial(realtime_data)]
1673  async fn stream_with_invalid_url() {
1674    #[derive(Default)]
1675    struct Invalid;
1676
1677    impl ToString for Invalid {
1678      fn to_string(&self) -> String {
1679        "<invalid-url-is-invalid>".into()
1680      }
1681    }
1682
1683    let api_info = ApiInfo::from_env().unwrap();
1684    let client = Client::new(api_info);
1685    let err = client
1686      .subscribe::<RealtimeData<CustomUrl<Invalid>>>()
1687      .await
1688      .unwrap_err();
1689
1690    match err {
1691      Error::Url(..) => (),
1692      _ => panic!("Received unexpected error: {err:?}"),
1693    };
1694  }
1695}