1use std::borrow::Borrow as _;
5use std::borrow::Cow;
6use std::cmp::Ordering;
7use std::fmt::Debug;
8use std::marker::PhantomData;
9use std::ops::Deref;
10
11use async_trait::async_trait;
12
13use chrono::DateTime;
14use chrono::Utc;
15
16use futures::stream::Fuse;
17use futures::stream::FusedStream;
18use futures::stream::Map;
19use futures::stream::SplitSink;
20use futures::stream::SplitStream;
21use futures::Future;
22use futures::FutureExt as _;
23use futures::Sink;
24use futures::StreamExt as _;
25
26use num_decimal::Num;
27
28use serde::de::DeserializeOwned;
29use serde::de::Deserializer;
30use serde::ser::SerializeSeq as _;
31use serde::ser::Serializer;
32use serde::Deserialize;
33use serde::Serialize;
34use serde_json::from_slice as json_from_slice;
35use serde_json::from_str as json_from_str;
36use serde_json::to_string as to_json;
37use serde_json::Error as JsonError;
38
39use thiserror::Error as ThisError;
40
41use tokio::net::TcpStream;
42
43use tungstenite::MaybeTlsStream;
44use tungstenite::WebSocketStream;
45
46use url::Url;
47
48use websocket_util::subscribe;
49use websocket_util::subscribe::MessageStream;
50use websocket_util::tungstenite::Error as WebSocketError;
51use websocket_util::wrap;
52use websocket_util::wrap::Wrapper;
53
54use super::unfold::Unfold;
55
56use crate::subscribable::Subscribable;
57use crate::websocket::connect;
58use crate::websocket::MessageResult;
59use crate::ApiInfo;
60use crate::Error;
61use crate::Str;
62
63
64type UserMessage<B, Q, T> = <ParsedMessage<B, Q, T> as subscribe::Message>::UserMessage;
65
66#[inline]
71pub async fn drive<F, S, B, Q, T>(
72 future: F,
73 stream: &mut S,
74) -> Result<F::Output, UserMessage<B, Q, T>>
75where
76 F: Future + Unpin,
77 S: FusedStream<Item = UserMessage<B, Q, T>> + Unpin,
78{
79 subscribe::drive::<ParsedMessage<B, Q, T>, _, _>(future, stream).await
80}
81
82
83mod private {
84 pub trait Sealed {}
85}
86
87
88#[doc(hidden)]
89#[derive(Clone, Debug)]
90pub enum SourceVariant {
91 PathComponent(&'static str),
94 Url(String),
96}
97
98
99pub trait Source: private::Sealed {
104 #[doc(hidden)]
106 fn source() -> SourceVariant;
107}
108
109
110#[derive(Clone, Copy, Debug)]
115pub enum IEX {}
116
117impl Source for IEX {
118 #[inline]
119 fn source() -> SourceVariant {
120 SourceVariant::PathComponent("iex")
121 }
122}
123
124impl private::Sealed for IEX {}
125
126
127#[derive(Clone, Copy, Debug)]
132pub enum SIP {}
133
134impl Source for SIP {
135 #[inline]
136 fn source() -> SourceVariant {
137 SourceVariant::PathComponent("sip")
138 }
139}
140
141impl private::Sealed for SIP {}
142
143
144#[derive(Clone, Copy, Debug)]
179pub struct CustomUrl<URL> {
180 _phantom: PhantomData<URL>,
181}
182
183impl<URL> Source for CustomUrl<URL>
184where
185 URL: Default + ToString,
186{
187 #[inline]
188 fn source() -> SourceVariant {
189 let url = URL::default();
190 SourceVariant::Url(url.to_string())
191 }
192}
193
194impl<URL> private::Sealed for CustomUrl<URL> {}
195
196
197pub type Symbol = Str;
199
200
201fn is_normalized(symbols: &[Symbol]) -> bool {
206 #[inline]
210 fn check<'a>(last: &'a mut &'a Symbol) -> impl FnMut(&'a Symbol) -> bool + 'a {
211 move |curr| {
212 if let Some(Ordering::Greater) | None = PartialOrd::partial_cmp(last, &curr) {
213 return false
214 }
215 *last = curr;
216 true
217 }
218 }
219
220 let mut it = symbols.iter();
221 let mut last = match it.next() {
222 Some(e) => e,
223 None => return true,
224 };
225
226 it.all(check(&mut last))
227}
228
229
230fn normalize(symbols: Cow<'static, [Symbol]>) -> Cow<'static, [Symbol]> {
232 fn normalize_now(symbols: Cow<'static, [Symbol]>) -> Cow<'static, [Symbol]> {
233 let mut symbols = symbols.into_owned();
234 symbols.sort_by(|x, y| x.partial_cmp(y).unwrap());
238 symbols.dedup();
239 Cow::from(symbols)
240 }
241
242 if !is_normalized(&symbols) {
243 let symbols = normalize_now(symbols);
244 debug_assert!(is_normalized(&symbols));
245 symbols
246 } else {
247 symbols
248 }
249}
250
251
252#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
254pub struct Bar {
255 #[serde(rename = "S")]
257 pub symbol: String,
258 #[serde(rename = "o")]
260 pub open_price: Num,
261 #[serde(rename = "h")]
263 pub high_price: Num,
264 #[serde(rename = "l")]
266 pub low_price: Num,
267 #[serde(rename = "c")]
269 pub close_price: Num,
270 #[serde(rename = "v")]
272 pub volume: Num,
273 #[serde(rename = "t")]
275 pub timestamp: DateTime<Utc>,
276}
277
278
279#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
281pub struct Quote {
282 #[serde(rename = "S")]
284 pub symbol: String,
285 #[serde(rename = "bp")]
287 pub bid_price: Num,
288 #[serde(rename = "bs")]
290 pub bid_size: Num,
291 #[serde(rename = "ap")]
293 pub ask_price: Num,
294 #[serde(rename = "as")]
296 pub ask_size: Num,
297 #[serde(rename = "t")]
299 pub timestamp: DateTime<Utc>,
300}
301
302
303#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
305pub struct Trade {
306 #[serde(rename = "S")]
308 pub symbol: String,
309 #[serde(rename = "i")]
311 pub trade_id: u64,
312 #[serde(rename = "p")]
314 pub trade_price: Num,
315 #[serde(rename = "s")]
317 pub trade_size: Num,
318 #[serde(rename = "t")]
320 pub timestamp: DateTime<Utc>,
321}
322
323
324#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize, ThisError)]
326#[error("{message} ({code})")]
327pub struct StreamApiError {
328 #[serde(rename = "code")]
330 pub code: u64,
331 #[serde(rename = "msg")]
333 pub message: String,
334}
335
336
337#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
340#[doc(hidden)]
341#[serde(tag = "T")]
342#[allow(clippy::large_enum_variant)]
343pub enum DataMessage<B = Bar, Q = Quote, T = Trade> {
344 #[serde(rename = "b")]
346 Bar(B),
347 #[serde(rename = "q")]
349 Quote(Q),
350 #[serde(rename = "t")]
352 Trade(T),
353 #[serde(rename = "subscription")]
355 Subscription(MarketData),
356 #[serde(rename = "success")]
359 Success,
360 #[serde(rename = "error")]
362 Error(StreamApiError),
363}
364
365
366#[derive(Debug)]
368#[non_exhaustive]
369pub enum Data<B = Bar, Q = Quote, T = Trade> {
370 Bar(B),
372 Quote(Q),
374 Trade(T),
376}
377
378impl<B, Q, T> Data<B, Q, T> {
379 #[inline]
381 pub fn is_bar(&self) -> bool {
382 matches!(self, Self::Bar(..))
383 }
384
385 #[inline]
387 pub fn is_quote(&self) -> bool {
388 matches!(self, Self::Quote(..))
389 }
390
391 #[inline]
393 pub fn is_trade(&self) -> bool {
394 matches!(self, Self::Trade(..))
395 }
396}
397
398
399#[derive(Debug)]
401#[doc(hidden)]
402pub enum ControlMessage {
403 Subscription(MarketData),
405 Success,
408 Error(StreamApiError),
410}
411
412
413type ParsedMessage<B, Q, T> =
415 MessageResult<Result<DataMessage<B, Q, T>, JsonError>, WebSocketError>;
416
417impl<B, Q, T> subscribe::Message for ParsedMessage<B, Q, T> {
418 type UserMessage = Result<Result<Data<B, Q, T>, JsonError>, WebSocketError>;
419 type ControlMessage = ControlMessage;
420
421 fn classify(self) -> subscribe::Classification<Self::UserMessage, Self::ControlMessage> {
422 match self {
423 MessageResult::Ok(Ok(message)) => match message {
424 DataMessage::Bar(bar) => subscribe::Classification::UserMessage(Ok(Ok(Data::Bar(bar)))),
425 DataMessage::Quote(quote) => {
426 subscribe::Classification::UserMessage(Ok(Ok(Data::Quote(quote))))
427 },
428 DataMessage::Trade(trade) => {
429 subscribe::Classification::UserMessage(Ok(Ok(Data::Trade(trade))))
430 },
431 DataMessage::Subscription(data) => {
432 subscribe::Classification::ControlMessage(ControlMessage::Subscription(data))
433 },
434 DataMessage::Success => subscribe::Classification::ControlMessage(ControlMessage::Success),
435 DataMessage::Error(error) => {
436 subscribe::Classification::ControlMessage(ControlMessage::Error(error))
437 },
438 },
439 MessageResult::Ok(Err(err)) => subscribe::Classification::UserMessage(Ok(Err(err))),
441 MessageResult::Err(err) => subscribe::Classification::UserMessage(Err(err)),
443 }
444 }
445
446 #[inline]
447 fn is_error(user_message: &Self::UserMessage) -> bool {
448 user_message
453 .as_ref()
454 .map(|result| result.is_err())
455 .unwrap_or(true)
456 }
457}
458
459
460#[inline]
462fn normalized_from_str<'de, D>(deserializer: D) -> Result<Cow<'static, [Symbol]>, D::Error>
463where
464 D: Deserializer<'de>,
465{
466 Cow::<'static, [Symbol]>::deserialize(deserializer).map(normalize)
467}
468
469
470#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
472pub struct SymbolList(#[serde(deserialize_with = "normalized_from_str")] Cow<'static, [Symbol]>);
473
474impl Deref for SymbolList {
475 type Target = [Symbol];
476
477 fn deref(&self) -> &Self::Target {
478 self.0.borrow()
479 }
480}
481
482impl From<Cow<'static, [Symbol]>> for SymbolList {
483 #[inline]
484 fn from(symbols: Cow<'static, [Symbol]>) -> Self {
485 Self(normalize(symbols))
486 }
487}
488
489impl From<Vec<String>> for SymbolList {
490 #[inline]
491 fn from(symbols: Vec<String>) -> Self {
492 Self(normalize(Cow::from(
493 IntoIterator::into_iter(symbols)
494 .map(Symbol::from)
495 .collect::<Vec<_>>(),
496 )))
497 }
498}
499
500impl<const N: usize> From<[&'static str; N]> for SymbolList {
501 #[inline]
502 fn from(symbols: [&'static str; N]) -> Self {
503 Self(normalize(Cow::from(
504 IntoIterator::into_iter(symbols)
505 .map(Symbol::from)
506 .collect::<Vec<_>>(),
507 )))
508 }
509}
510
511
512mod symbols_all {
513 use super::*;
514
515 use serde::de::Error;
516 use serde::de::Unexpected;
517
518 pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<(), D::Error>
520 where
521 D: Deserializer<'de>,
522 {
523 let string = <[&str; 1]>::deserialize(deserializer)?;
524 if string == ["*"] {
525 Ok(())
526 } else {
527 Err(Error::invalid_value(
528 Unexpected::Str(string[0]),
529 &"the string \"*\"",
530 ))
531 }
532 }
533
534 pub(crate) fn serialize<S>(serializer: S) -> Result<S::Ok, S::Error>
536 where
537 S: Serializer,
538 {
539 let mut seq = serializer.serialize_seq(Some(1))?;
540 seq.serialize_element("*")?;
541 seq.end()
542 }
543}
544
545
546#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
551#[serde(untagged)]
552pub enum Symbols {
553 #[serde(with = "symbols_all")]
555 All,
556 List(SymbolList),
558}
559
560impl Symbols {
561 #[inline]
563 pub fn is_empty(&self) -> bool {
564 match self {
565 Self::List(list) => list.is_empty(),
566 Self::All => false,
567 }
568 }
569}
570
571impl Default for Symbols {
572 fn default() -> Self {
573 Self::List(SymbolList::from([]))
574 }
575}
576
577
578#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
580pub struct MarketData {
581 #[serde(default)]
583 pub bars: Symbols,
584 #[serde(default)]
586 pub quotes: Symbols,
587 #[serde(default)]
589 pub trades: Symbols,
590}
591
592impl MarketData {
593 #[inline]
596 pub fn set_bars<S>(&mut self, symbols: S)
597 where
598 S: Into<SymbolList>,
599 {
600 self.bars = Symbols::List(symbols.into());
601 }
602
603 #[inline]
606 pub fn set_quotes<S>(&mut self, symbols: S)
607 where
608 S: Into<SymbolList>,
609 {
610 self.quotes = Symbols::List(symbols.into());
611 }
612
613 #[inline]
616 pub fn set_trades<S>(&mut self, symbols: S)
617 where
618 S: Into<SymbolList>,
619 {
620 self.trades = Symbols::List(symbols.into());
621 }
622}
623
624
625#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)]
627#[doc(hidden)]
628#[serde(tag = "action")]
629pub enum Request<'d> {
630 #[serde(rename = "auth")]
633 Authenticate {
634 #[serde(rename = "key")]
635 key_id: Cow<'d, str>,
636 #[serde(rename = "secret")]
637 secret: Cow<'d, str>,
638 },
639 #[serde(rename = "subscribe")]
642 Subscribe(Cow<'d, MarketData>),
643 #[serde(rename = "unsubscribe")]
646 Unsubscribe(Cow<'d, MarketData>),
647}
648
649
650#[derive(Debug)]
659pub struct Subscription<S, B, Q, T> {
660 subscription: subscribe::Subscription<S, ParsedMessage<B, Q, T>, wrap::Message>,
663 subscriptions: MarketData,
665}
666
667impl<S, B, Q, T> Subscription<S, B, Q, T> {
668 #[inline]
670 fn new(subscription: subscribe::Subscription<S, ParsedMessage<B, Q, T>, wrap::Message>) -> Self {
671 Self {
672 subscription,
673 subscriptions: MarketData::default(),
674 }
675 }
676}
677
678impl<S, B, Q, T> Subscription<S, B, Q, T>
679where
680 S: Sink<wrap::Message> + Unpin,
681{
682 async fn authenticate(
684 &mut self,
685 key_id: &str,
686 secret: &str,
687 ) -> Result<Result<(), Error>, S::Error> {
688 let request = Request::Authenticate {
689 key_id: key_id.into(),
690 secret: secret.into(),
691 };
692 let json = match to_json(&request) {
693 Ok(json) => json,
694 Err(err) => return Ok(Err(Error::Json(err))),
695 };
696 let message = wrap::Message::Text(json);
697 let response = self.subscription.send(message).await?;
698
699 match response {
700 Some(response) => match response {
701 Ok(ControlMessage::Success) => Ok(Ok(())),
702 Ok(ControlMessage::Subscription(..)) => Ok(Err(Error::Str(
703 "server responded with unexpected subscription message".into(),
704 ))),
705 Ok(ControlMessage::Error(error)) => Ok(Err(Error::Str(
706 format!(
707 "failed to authenticate with server: {} ({})",
708 error.message, error.code
709 )
710 .into(),
711 ))),
712 Err(()) => Ok(Err(Error::Str("failed to authenticate with server".into()))),
713 },
714 None => Ok(Err(Error::Str(
715 "stream was closed before authorization message was received".into(),
716 ))),
717 }
718 }
719
720 async fn subscribe_unsubscribe(
722 &mut self,
723 request: &Request<'_>,
724 ) -> Result<Result<(), Error>, S::Error> {
725 let json = match to_json(request) {
726 Ok(json) => json,
727 Err(err) => return Ok(Err(Error::Json(err))),
728 };
729 let message = wrap::Message::Text(json);
730 let response = self.subscription.send(message).await?;
731
732 match response {
733 Some(response) => match response {
734 Ok(ControlMessage::Subscription(data)) => {
735 self.subscriptions = data;
736 Ok(Ok(()))
737 },
738 Ok(ControlMessage::Error(error)) => Ok(Err(Error::Str(
739 format!("failed to subscribe: {error}").into(),
740 ))),
741 Ok(_) => Ok(Err(Error::Str(
742 "server responded with unexpected message".into(),
743 ))),
744 Err(()) => Ok(Err(Error::Str("failed to adjust subscription".into()))),
745 },
746 None => Ok(Err(Error::Str(
747 "stream was closed before subscription confirmation message was received".into(),
748 ))),
749 }
750 }
751
752 #[inline]
758 pub async fn subscribe(&mut self, subscribe: &MarketData) -> Result<Result<(), Error>, S::Error> {
759 let request = Request::Subscribe(Cow::Borrowed(subscribe));
760 self.subscribe_unsubscribe(&request).await
761 }
762
763 #[inline]
768 pub async fn unsubscribe(
769 &mut self,
770 unsubscribe: &MarketData,
771 ) -> Result<Result<(), Error>, S::Error> {
772 let request = Request::Unsubscribe(Cow::Borrowed(unsubscribe));
773 self.subscribe_unsubscribe(&request).await
774 }
775
776 #[inline]
778 pub fn subscriptions(&self) -> &MarketData {
779 &self.subscriptions
780 }
781}
782
783
784type ParseFn<B, Q, T> = fn(
785 Result<wrap::Message, WebSocketError>,
786) -> Result<Result<Vec<DataMessage<B, Q, T>>, JsonError>, WebSocketError>;
787type MapFn<B, Q, T> =
788 fn(Result<Result<DataMessage<B, Q, T>, JsonError>, WebSocketError>) -> ParsedMessage<B, Q, T>;
789type Stream<B, Q, T> = Map<
790 Unfold<
791 Map<Wrapper<WebSocketStream<MaybeTlsStream<TcpStream>>>, ParseFn<B, Q, T>>,
792 DataMessage<B, Q, T>,
793 JsonError,
794 >,
795 MapFn<B, Q, T>,
796>;
797
798
799#[derive(Debug)]
807pub struct RealtimeData<S, B = Bar, Q = Quote, T = Trade> {
808 _phantom: PhantomData<(S, B, Q, T)>,
810}
811
812#[async_trait]
813impl<S, B, Q, T> Subscribable for RealtimeData<S, B, Q, T>
814where
815 S: Source,
816 B: Send + Unpin + Debug + DeserializeOwned,
817 Q: Send + Unpin + Debug + DeserializeOwned,
818 T: Send + Unpin + Debug + DeserializeOwned,
819{
820 type Input = ApiInfo;
821 type Subscription = Subscription<SplitSink<Stream<B, Q, T>, wrap::Message>, B, Q, T>;
822 type Stream = Fuse<MessageStream<SplitStream<Stream<B, Q, T>>, ParsedMessage<B, Q, T>>>;
823
824 async fn connect(api_info: &Self::Input) -> Result<(Self::Stream, Self::Subscription), Error> {
825 fn parse<B, Q, T>(
826 result: Result<wrap::Message, WebSocketError>,
827 ) -> Result<Result<Vec<DataMessage<B, Q, T>>, JsonError>, WebSocketError>
828 where
829 B: DeserializeOwned,
830 Q: DeserializeOwned,
831 T: DeserializeOwned,
832 {
833 result.map(|message| match message {
834 wrap::Message::Text(string) => json_from_str::<Vec<DataMessage<B, Q, T>>>(&string),
835 wrap::Message::Binary(data) => json_from_slice::<Vec<DataMessage<B, Q, T>>>(&data),
836 })
837 }
838
839 let ApiInfo {
840 data_stream_base_url: url,
841 key_id,
842 secret,
843 ..
844 } = api_info;
845
846 let url = match S::source() {
847 SourceVariant::PathComponent(component) => {
848 let mut url = url.clone();
849 url.set_path(&format!("v2/{}", component));
850 url
851 },
852 SourceVariant::Url(url) => Url::parse(&url)?,
853 };
854
855 let stream = Unfold::new(
856 connect(&url)
857 .await?
858 .map(parse::<B, Q, T> as ParseFn<_, _, _>),
859 )
860 .map(MessageResult::from as MapFn<B, Q, T>);
861 let (send, recv) = stream.split();
862 let (stream, subscription) = subscribe::subscribe(recv, send);
863 let mut stream = stream.fuse();
864 let mut subscription = Subscription::new(subscription);
865
866 let connect = subscription.subscription.read().boxed();
867 let message = drive(connect, &mut stream).await.map_err(|result| {
868 result
869 .map(|result| Error::Json(result.unwrap_err()))
870 .map_err(Error::WebSocket)
871 .unwrap_or_else(|err| err)
872 })?;
873
874 match message {
875 Some(Ok(ControlMessage::Success)) => (),
876 Some(Ok(_)) => {
877 return Err(Error::Str(
878 "server responded with unexpected initial message".into(),
879 ))
880 },
881 Some(Err(())) => return Err(Error::Str("failed to read connected message".into())),
882 None => {
883 return Err(Error::Str(
884 "stream was closed before connected message was received".into(),
885 ))
886 },
887 }
888
889 let authenticate = subscription.authenticate(key_id, secret).boxed();
890 let () = drive(authenticate, &mut stream).await.map_err(|result| {
891 result
892 .map(|result| Error::Json(result.unwrap_err()))
893 .map_err(Error::WebSocket)
894 .unwrap_or_else(|err| err)
895 })???;
896
897 Ok((stream, subscription))
898 }
899}
900
901
902#[allow(clippy::to_string_trait_impl)]
903#[cfg(test)]
904mod tests {
905 use super::*;
906
907 use std::str::FromStr;
908 use std::time::Duration;
909
910 use chrono::DateTime;
911
912 use futures::SinkExt as _;
913 use futures::TryStreamExt as _;
914
915 use serial_test::serial;
916
917 use serde_json::from_str as json_from_str;
918
919 use test_log::test;
920
921 use tokio::time::timeout;
922
923 use tungstenite::tungstenite::Utf8Bytes;
924
925 use websocket_util::test::WebSocketStream;
926 use websocket_util::tungstenite::Message;
927
928 use crate::api::API_BASE_URL;
929 use crate::websocket::test::mock_stream;
930 use crate::Client;
931
932
933 const CONN_RESP: &str = r#"[{"T":"success","msg":"connected"}]"#;
934 const AUTH_REQ: &str = r#"{"action":"auth","key":"USER12345678","secret":"justletmein"}"#;
939 const AUTH_RESP: &str = r#"[{"T":"success","msg":"authenticated"}]"#;
940 const SUB_REQ: &str = r#"{"action":"subscribe","bars":["AAPL","VOO"],"quotes":[],"trades":[]}"#;
941 const SUB_RESP: &str = r#"[{"T":"subscription","bars":["AAPL","VOO"]}]"#;
942 const SUB_ERR_REQ: &str = r#"{"action":"subscribe","bars":[],"quotes":[],"trades":[]}"#;
943 const SUB_ERR_RESP: &str = r#"[{"T":"error","code":400,"msg":"invalid syntax"}]"#;
944
945
946 #[test]
948 fn sip_source() {
949 assert_ne!(format!("{:?}", SIP::source()), "");
950 }
951
952 #[test]
954 fn data_classification() {
955 assert!(Data::<(), Quote, Trade>::Bar(()).is_bar());
956 assert!(Data::<Bar, (), Trade>::Quote(()).is_quote());
957 assert!(Data::<Bar, Quote, ()>::Trade(()).is_trade());
958 }
959
960 #[test]
962 fn symbols_is_empty() {
963 assert!(!Symbols::All.is_empty());
964 assert!(!Symbols::List(SymbolList::from(["SPY"])).is_empty());
965 assert!(Symbols::List(SymbolList::from([])).is_empty());
966 }
967
968 #[test]
971 fn serialize_deserialize_bar() {
972 let json = r#"{
973 "T": "b",
974 "S": "SPY",
975 "o": 388.985,
976 "h": 389.13,
977 "l": 388.975,
978 "c": 389.12,
979 "v": 49378,
980 "t": "2021-02-22T19:15:00Z"
981}"#;
982
983 let message = json_from_str::<DataMessage>(json).unwrap();
984 let bar = match &message {
985 DataMessage::Bar(bar) => bar,
986 _ => panic!("Decoded unexpected message variant: {message:?}"),
987 };
988 assert_eq!(bar.symbol, "SPY");
989 assert_eq!(bar.open_price, Num::new(388985, 1000));
990 assert_eq!(bar.high_price, Num::new(38913, 100));
991 assert_eq!(bar.low_price, Num::new(388975, 1000));
992 assert_eq!(bar.close_price, Num::new(38912, 100));
993 assert_eq!(bar.volume, Num::from(49378));
994 assert_eq!(
995 bar.timestamp,
996 DateTime::<Utc>::from_str("2021-02-22T19:15:00Z").unwrap()
997 );
998
999 assert_eq!(
1000 json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1001 message
1002 );
1003 }
1004
1005 #[test]
1008 fn serialize_deserialize_quote() {
1009 let json: &str = r#"{
1010 "T": "q",
1011 "S": "NVDA",
1012 "bx": "P",
1013 "bp": 258.8,
1014 "bs": 2,
1015 "ax": "A",
1016 "ap": 259.99,
1017 "as": 5,
1018 "c": [
1019 "R"
1020 ],
1021 "z": "C",
1022 "t": "2022-01-18T23:09:42.151875584Z"
1023}"#;
1024
1025 let message = json_from_str::<DataMessage>(json).unwrap();
1026 let quote = match &message {
1027 DataMessage::Quote(quote) => quote,
1028 _ => panic!("Decoded unexpected message variant: {message:?}"),
1029 };
1030 assert_eq!(quote.symbol, "NVDA");
1031 assert_eq!(quote.bid_price, Num::new(2588, 10));
1032 assert_eq!(quote.bid_size, Num::from(2));
1033 assert_eq!(quote.ask_price, Num::new(25999, 100));
1034 assert_eq!(quote.ask_size, Num::from(5));
1035
1036 assert_eq!(
1037 quote.timestamp,
1038 DateTime::<Utc>::from_str("2022-01-18T23:09:42.151875584Z").unwrap()
1039 );
1040
1041 assert_eq!(
1042 json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1043 message
1044 );
1045 }
1046
1047
1048 #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
1050 struct DetailedQuote {
1051 #[serde(rename = "S")]
1053 symbol: String,
1054 #[serde(rename = "bp")]
1056 bid_price: Num,
1057 #[serde(rename = "bs")]
1059 bid_size: Num,
1060 #[serde(rename = "bx")]
1062 bid_exchange_code: String,
1063 #[serde(rename = "ap")]
1065 ask_price: Num,
1066 #[serde(rename = "as")]
1068 ask_size: Num,
1069 #[serde(rename = "ax")]
1071 ask_exchange_code: String,
1072 #[serde(rename = "s")]
1074 trade_size: Num,
1075 #[serde(rename = "t")]
1077 timestamp: DateTime<Utc>,
1078 #[serde(rename = "c")]
1080 quote_conditions: Vec<String>,
1081 #[serde(rename = "z")]
1083 tape: String,
1084 }
1085
1086 #[test]
1089 fn serialize_deserialize_custom_quote() {
1090 let json: &str = r#"{
1091 "T": "q",
1092 "S": "NVDA",
1093 "bx": "P",
1094 "bp": 258.8,
1095 "bs": 2,
1096 "s": 3,
1097 "ax": "A",
1098 "ap": 259.99,
1099 "as": 5,
1100 "c": [
1101 "R"
1102 ],
1103 "z": "C",
1104 "t": "2022-01-18T23:09:42.151875584Z"
1105}"#;
1106
1107 let message = json_from_str::<DataMessage<Bar, DetailedQuote, Trade>>(json).unwrap();
1108 let quote = match &message {
1109 DataMessage::Quote(quote) => quote,
1110 _ => panic!("Decoded unexpected message variant: {message:?}"),
1111 };
1112 assert_eq!(quote.symbol, "NVDA");
1113 assert_eq!(quote.bid_price, Num::new(2588, 10));
1114 assert_eq!(quote.bid_size, Num::from(2));
1115 assert_eq!(quote.bid_exchange_code, "P");
1116 assert_eq!(quote.ask_price, Num::new(25999, 100));
1117 assert_eq!(quote.ask_size, Num::from(5));
1118 assert_eq!(quote.ask_exchange_code, "A");
1119 assert_eq!(quote.trade_size, Num::from(3));
1120
1121 assert_eq!(
1122 quote.timestamp,
1123 DateTime::<Utc>::from_str("2022-01-18T23:09:42.151875584Z").unwrap()
1124 );
1125
1126 assert_eq!(
1127 json_from_str::<DataMessage<Bar, DetailedQuote, Trade>>(&to_json(&message).unwrap()).unwrap(),
1128 message
1129 );
1130 }
1131
1132 #[test]
1135 fn serialize_deserialize_trade() {
1136 let json: &str = r#"{
1137 "T": "t",
1138 "i": 96921,
1139 "S": "AAPL",
1140 "x": "D",
1141 "p": 126.55,
1142 "s": 1,
1143 "t": "2021-02-22T15:51:44.208Z",
1144 "c": ["@", "I"],
1145 "z": "C"
1146}"#;
1147
1148 let message = json_from_str::<DataMessage>(json).unwrap();
1149 let trade = match &message {
1150 DataMessage::Trade(trade) => trade,
1151 _ => panic!("Decoded unexpected message variant: {message:?}"),
1152 };
1153 assert_eq!(trade.symbol, "AAPL");
1154 assert_eq!(trade.trade_id, 96921);
1155 assert_eq!(trade.trade_price, Num::new(12655, 100));
1156 assert_eq!(trade.trade_size, Num::from(1));
1157
1158 assert_eq!(
1159 trade.timestamp,
1160 DateTime::<Utc>::from_str("2021-02-22T15:51:44.208Z").unwrap()
1161 );
1162
1163 assert_eq!(
1164 json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1165 message
1166 );
1167 }
1168
1169
1170 #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
1172 struct DetailedTrade {
1173 #[serde(rename = "S")]
1175 symbol: String,
1176 #[serde(rename = "i")]
1178 trade_id: u64,
1179 #[serde(rename = "p")]
1181 trade_price: Num,
1182 #[serde(rename = "s")]
1184 trade_size: u64,
1185 #[serde(rename = "c")]
1187 conditions: Vec<String>,
1188 #[serde(rename = "t")]
1190 timestamp: DateTime<Utc>,
1191 #[serde(rename = "x")]
1193 exchange: String,
1194 #[serde(rename = "z")]
1196 tape: String,
1197 #[serde(rename = "u", default)]
1200 update: Option<String>,
1201 }
1202
1203 #[test]
1206 fn serialize_deserialize_custom_trade() {
1207 let json: &str = r#"{
1208 "T": "t",
1209 "i": 96921,
1210 "S": "AAPL",
1211 "x": "D",
1212 "p": 126.55,
1213 "s": 1,
1214 "t": "2021-02-22T15:51:44.208Z",
1215 "c": ["@", "I"],
1216 "z": "C",
1217 "u": "corrected"
1218}"#;
1219
1220 let message = json_from_str::<DataMessage<Bar, Quote, DetailedTrade>>(json).unwrap();
1221 let trade = match &message {
1222 DataMessage::Trade(trade) => trade,
1223 _ => panic!("Decoded unexpected message variant: {message:?}"),
1224 };
1225 assert_eq!(trade.symbol, "AAPL");
1226 assert_eq!(trade.trade_id, 96921);
1227 assert_eq!(trade.trade_price, Num::new(12655, 100));
1228 assert_eq!(trade.trade_size, 1);
1229
1230 assert_eq!(
1231 trade.timestamp,
1232 DateTime::<Utc>::from_str("2021-02-22T15:51:44.208Z").unwrap()
1233 );
1234
1235 assert_eq!(trade.conditions, vec!["@", "I"]);
1236 assert_eq!(trade.tape, "C");
1237 assert_eq!(trade.update, Some("corrected".to_string()));
1238
1239 assert_eq!(
1240 json_from_str::<DataMessage<Bar, Quote, DetailedTrade>>(&to_json(&message).unwrap()).unwrap(),
1241 message
1242 );
1243 }
1244
1245 #[test]
1248 fn serialize_deserialize_success() {
1249 let json = r#"{"T":"success","msg":"authenticated"}"#;
1250 let message = json_from_str::<DataMessage>(json).unwrap();
1251 let () = match message {
1252 DataMessage::Success => (),
1253 _ => panic!("Decoded unexpected message variant: {message:?}"),
1254 };
1255
1256 assert_eq!(
1257 json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1258 message
1259 );
1260 }
1261
1262 #[test]
1265 fn serialize_deserialize_error() {
1266 let json = r#"{"T":"error","code":400,"msg":"invalid syntax"}"#;
1267 let message = json_from_str::<DataMessage>(json).unwrap();
1268 let error = match &message {
1269 DataMessage::Error(error) => error,
1270 _ => panic!("Decoded unexpected message variant: {message:?}"),
1271 };
1272
1273 assert_eq!(error.code, 400);
1274 assert_eq!(error.message, "invalid syntax");
1275
1276 assert_eq!(
1277 json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1278 message
1279 );
1280
1281 let json = r#"{"T":"error","code":500,"msg":"internal error"}"#;
1282 let message = json_from_str::<DataMessage>(json).unwrap();
1283 let error = match &message {
1284 DataMessage::Error(error) => error,
1285 _ => panic!("Decoded unexpected message variant: {message:?}"),
1286 };
1287
1288 assert_eq!(error.code, 500);
1289 assert_eq!(error.message, "internal error");
1290
1291 assert_eq!(
1292 json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
1293 message
1294 );
1295 }
1296
1297 #[test]
1300 fn serialize_deserialize_authentication_request() {
1301 let request = Request::Authenticate {
1302 key_id: "KEY-ID".into(),
1303 secret: "SECRET-KEY".into(),
1304 };
1305 let json = to_json(&request).unwrap();
1306 let expected = r#"{"action":"auth","key":"KEY-ID","secret":"SECRET-KEY"}"#;
1307 assert_eq!(json, expected);
1308 assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
1309 }
1310
1311 #[test]
1314 fn serialize_deserialize_subscribe_request() {
1315 let mut data = MarketData::default();
1316 data.set_bars(["AAPL", "VOO"]);
1317 let request = Request::Subscribe(Cow::Borrowed(&data));
1318
1319 let json = to_json(&request).unwrap();
1320 let expected = r#"{"action":"subscribe","bars":["AAPL","VOO"],"quotes":[],"trades":[]}"#;
1321 assert_eq!(json, expected);
1322 assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
1323 }
1324
1325 #[test]
1328 fn serialize_deserialize_unsubscribe_request() {
1329 let mut data = MarketData::default();
1330 data.set_bars(["VOO"]);
1331 let request = Request::Unsubscribe(Cow::Borrowed(&data));
1332
1333 let json = to_json(&request).unwrap();
1334 let expected = r#"{"action":"unsubscribe","bars":["VOO"],"quotes":[],"trades":[]}"#;
1335 assert_eq!(json, expected);
1336 assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
1337 }
1338
1339 #[test]
1341 fn deserialize_symbol_list() {
1342 let json = r#"["AAPL","XLK","SPY"]"#;
1343 let list = json_from_str::<SymbolList>(json).unwrap();
1344 let expected = SymbolList::from(["AAPL", "SPY", "XLK"]);
1345 assert_eq!(list, expected);
1346 }
1347
1348 #[test]
1350 fn normalize_subscriptions() {
1351 let subscriptions = [];
1352 assert!(is_normalized(&subscriptions));
1353
1354 let subscriptions = ["MSFT".into(), "SPY".into()];
1355 assert!(is_normalized(&subscriptions));
1356
1357 let mut subscriptions = Cow::from(vec!["SPY".into(), "MSFT".into()]);
1358 assert!(!is_normalized(&subscriptions));
1359 subscriptions = normalize(subscriptions);
1360 assert!(is_normalized(&subscriptions));
1361
1362 let expected = [Cow::from("MSFT"), "SPY".into()];
1363 assert_eq!(subscriptions.borrow(), expected);
1364
1365 let mut subscriptions = Cow::from(vec!["SPY".into(), "MSFT".into(), "MSFT".into()]);
1366 assert!(!is_normalized(&subscriptions));
1367 subscriptions = normalize(subscriptions);
1368 assert!(is_normalized(&subscriptions));
1369
1370 assert_eq!(subscriptions.borrow(), expected);
1371 }
1372
1373 #[test(tokio::test)]
1376 async fn authenticate_and_subscribe() {
1377 async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
1378 stream
1379 .send(Message::Text(Utf8Bytes::from_static(CONN_RESP)))
1380 .await?;
1381 assert_eq!(
1383 stream.next().await.unwrap()?,
1384 Message::Text(Utf8Bytes::from_static(AUTH_REQ)),
1385 );
1386 stream
1387 .send(Message::Text(Utf8Bytes::from_static(AUTH_RESP)))
1388 .await?;
1389
1390 assert_eq!(
1392 stream.next().await.unwrap()?,
1393 Message::Text(Utf8Bytes::from_static(SUB_REQ)),
1394 );
1395 stream
1396 .send(Message::Text(Utf8Bytes::from_static(SUB_RESP)))
1397 .await?;
1398 stream.send(Message::Close(None)).await?;
1399 Ok(())
1400 }
1401
1402 let (mut stream, mut subscription) =
1403 mock_stream::<RealtimeData<IEX>, _, _>(test).await.unwrap();
1404
1405 let mut data = MarketData::default();
1406 data.set_bars(["AAPL", "VOO"]);
1407
1408 let subscribe = subscription.subscribe(&data).boxed_local();
1409 let () = drive(subscribe, &mut stream)
1410 .await
1411 .unwrap()
1412 .unwrap()
1413 .unwrap();
1414
1415 stream
1416 .map_err(Error::WebSocket)
1417 .try_for_each(|result| async { result.map(|_data| ()).map_err(Error::Json) })
1418 .await
1419 .unwrap();
1420 }
1421
1422 #[test(tokio::test)]
1425 async fn subscribe_error() {
1426 async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
1427 stream
1428 .send(Message::Text(Utf8Bytes::from_static(CONN_RESP)))
1429 .await?;
1430 assert_eq!(
1432 stream.next().await.unwrap()?,
1433 Message::Text(Utf8Bytes::from_static(AUTH_REQ)),
1434 );
1435 stream
1436 .send(Message::Text(Utf8Bytes::from_static(AUTH_RESP)))
1437 .await?;
1438
1439 assert_eq!(
1441 stream.next().await.unwrap()?,
1442 Message::Text(Utf8Bytes::from_static(SUB_ERR_REQ)),
1443 );
1444 stream
1445 .send(Message::Text(Utf8Bytes::from_static(SUB_ERR_RESP)))
1446 .await?;
1447 stream.send(Message::Close(None)).await?;
1448 Ok(())
1449 }
1450
1451 let (mut stream, mut subscription) =
1452 mock_stream::<RealtimeData<IEX>, _, _>(test).await.unwrap();
1453
1454 let data = MarketData::default();
1455
1456 let subscribe = subscription.subscribe(&data).boxed_local();
1457 let error = drive(subscribe, &mut stream)
1458 .await
1459 .unwrap()
1460 .unwrap()
1461 .unwrap_err();
1462
1463 match error {
1464 Error::Str(ref e) if e == "failed to subscribe: invalid syntax (400)" => {},
1465 e => panic!("received unexpected error: {e}"),
1466 }
1467 }
1468
1469 #[test(tokio::test)]
1472 #[serial(realtime_data)]
1473 async fn subscribe_resubscribe() {
1474 let api_info = ApiInfo::from_env().unwrap();
1475 let client = Client::new(api_info);
1476 let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
1477
1478 let mut data = MarketData::default();
1479 data.set_bars(["AAPL", "SPY"]);
1480
1481 let subscribe = subscription.subscribe(&data).boxed_local();
1482 let () = drive(subscribe, &mut stream)
1483 .await
1484 .unwrap()
1485 .unwrap()
1486 .unwrap();
1487
1488 assert_eq!(subscription.subscriptions(), &data);
1489
1490 let mut data = MarketData::default();
1491 data.set_bars(["XLK"]);
1492 let subscribe = subscription.subscribe(&data).boxed_local();
1493 let () = drive(subscribe, &mut stream)
1494 .await
1495 .unwrap()
1496 .unwrap()
1497 .unwrap();
1498
1499 let mut expected = MarketData::default();
1500 expected.set_bars(["AAPL", "SPY", "XLK"]);
1501 assert_eq!(subscription.subscriptions(), &expected);
1502 }
1503
1504 #[test(tokio::test)]
1510 #[serial(realtime_data)]
1511 async fn stream_market_data_updates() {
1512 let api_info = ApiInfo::from_env().unwrap();
1513 let client = Client::new(api_info);
1514 let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
1515
1516 let data = MarketData {
1517 bars: Symbols::All,
1518 ..Default::default()
1519 };
1520
1521 let subscribe = subscription.subscribe(&data).boxed_local();
1522 let () = drive(subscribe, &mut stream)
1523 .await
1524 .unwrap()
1525 .unwrap()
1526 .unwrap();
1527
1528 assert_eq!(subscription.subscriptions(), &data);
1529
1530 let read = stream
1531 .map_err(Error::WebSocket)
1532 .try_for_each(|result| async {
1533 result
1534 .map(|data| {
1535 assert!(data.is_bar());
1536 })
1537 .map_err(Error::Json)
1538 });
1539
1540 if timeout(Duration::from_millis(100), read).await.is_ok() {
1541 panic!("realtime data stream got exhausted unexpectedly")
1542 }
1543 }
1544
1545 #[test(tokio::test)]
1551 #[serial(realtime_data)]
1552 async fn stream_quotes() {
1553 async fn test<S>()
1554 where
1555 S: Source,
1556 {
1557 let api_info = ApiInfo::from_env().unwrap();
1558 let client = Client::new(api_info);
1559 let (mut stream, mut subscription) = client.subscribe::<RealtimeData<S>>().await.unwrap();
1560
1561 let mut data = MarketData::default();
1562 data.set_quotes(["SPY"]);
1563
1564 let subscribe = subscription.subscribe(&data).boxed_local();
1565 let () = drive(subscribe, &mut stream)
1566 .await
1567 .unwrap()
1568 .unwrap()
1569 .unwrap();
1570
1571 let read = stream
1572 .map_err(Error::WebSocket)
1573 .try_for_each(|result| async {
1574 result
1575 .map(|data| {
1576 assert!(data.is_quote());
1577 })
1578 .map_err(Error::Json)
1579 });
1580
1581 if timeout(Duration::from_millis(100), read).await.is_ok() {
1582 panic!("realtime data stream got exhausted unexpectedly")
1583 }
1584 }
1585
1586 test::<IEX>().await;
1587
1588 #[derive(Default)]
1589 struct IexWithUrl;
1590
1591 impl ToString for IexWithUrl {
1592 fn to_string(&self) -> String {
1593 "wss://stream.data.alpaca.markets/v2/iex".into()
1594 }
1595 }
1596
1597 test::<CustomUrl<IexWithUrl>>().await;
1598 }
1599
1600 #[test(tokio::test)]
1602 #[serial(realtime_data)]
1603 async fn stream_trades() {
1604 let api_info = ApiInfo::from_env().unwrap();
1605 let client = Client::new(api_info);
1606 let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
1607
1608 let mut data = MarketData::default();
1609 data.set_trades(["SPY"]);
1610
1611 let subscribe = subscription.subscribe(&data).boxed_local();
1612 let () = drive(subscribe, &mut stream)
1613 .await
1614 .unwrap()
1615 .unwrap()
1616 .unwrap();
1617
1618 let read = stream
1619 .map_err(Error::WebSocket)
1620 .try_for_each(|result| async {
1621 result
1622 .map(|data| {
1623 assert!(data.is_trade());
1624 })
1625 .map_err(Error::Json)
1626 });
1627
1628 if timeout(Duration::from_millis(100), read).await.is_ok() {
1629 panic!("realtime data stream got exhausted unexpectedly")
1630 }
1631 }
1632
1633 #[test(tokio::test)]
1636 #[serial(realtime_data)]
1637 async fn unsubscribe_not_subscribed_symbol() {
1638 let api_info = ApiInfo::from_env().unwrap();
1639 let client = Client::new(api_info);
1640 let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
1641
1642 let mut data = MarketData::default();
1643 data.set_bars(["AAPL"]);
1644
1645 let unsubscribe = subscription.unsubscribe(&data).boxed_local();
1646 let () = drive(unsubscribe, &mut stream)
1647 .await
1648 .unwrap()
1649 .unwrap()
1650 .unwrap();
1651
1652 assert_eq!(subscription.subscriptions(), &MarketData::default());
1653 }
1654
1655 #[test(tokio::test)]
1658 #[serial(realtime_data)]
1659 async fn stream_with_invalid_credentials() {
1660 let api_info = ApiInfo::from_parts(API_BASE_URL, "invalid", "invalid-too").unwrap();
1661 let client = Client::new(api_info);
1662 let err = client.subscribe::<RealtimeData<IEX>>().await.unwrap_err();
1663
1664 match err {
1665 Error::Str(ref e) if e.starts_with("failed to authenticate with server") => (),
1666 e => panic!("received unexpected error: {e}"),
1667 }
1668 }
1669
1670 #[test(tokio::test)]
1672 #[serial(realtime_data)]
1673 async fn stream_with_invalid_url() {
1674 #[derive(Default)]
1675 struct Invalid;
1676
1677 impl ToString for Invalid {
1678 fn to_string(&self) -> String {
1679 "<invalid-url-is-invalid>".into()
1680 }
1681 }
1682
1683 let api_info = ApiInfo::from_env().unwrap();
1684 let client = Client::new(api_info);
1685 let err = client
1686 .subscribe::<RealtimeData<CustomUrl<Invalid>>>()
1687 .await
1688 .unwrap_err();
1689
1690 match err {
1691 Error::Url(..) => (),
1692 _ => panic!("Received unexpected error: {err:?}"),
1693 };
1694 }
1695}