use std::borrow::Borrow as _;
use std::borrow::Cow;
use std::cmp::Ordering;
use std::marker::PhantomData;
use std::ops::Deref;
use async_trait::async_trait;
use chrono::DateTime;
use chrono::Utc;
use futures::stream::Fuse;
use futures::stream::FusedStream;
use futures::stream::Map;
use futures::stream::SplitSink;
use futures::stream::SplitStream;
use futures::Future;
use futures::FutureExt as _;
use futures::Sink;
use futures::StreamExt as _;
use num_decimal::Num;
use serde::de::Deserializer;
use serde::ser::SerializeSeq as _;
use serde::ser::Serializer;
use serde::Deserialize;
use serde::Serialize;
use serde_json::from_slice as json_from_slice;
use serde_json::from_str as json_from_str;
use serde_json::to_string as to_json;
use serde_json::Error as JsonError;
use thiserror::Error as ThisError;
use tokio::net::TcpStream;
use tungstenite::MaybeTlsStream;
use tungstenite::WebSocketStream;
use websocket_util::subscribe;
use websocket_util::subscribe::MessageStream;
use websocket_util::tungstenite::Error as WebSocketError;
use websocket_util::wrap;
use websocket_util::wrap::Wrapper;
use super::unfold::Unfold;
use crate::subscribable::Subscribable;
use crate::websocket::connect;
use crate::websocket::MessageResult;
use crate::ApiInfo;
use crate::Error;
use crate::Str;
type UserMessage = <ParsedMessage as subscribe::Message>::UserMessage;
#[inline]
pub async fn drive<F, S>(future: F, stream: &mut S) -> Result<F::Output, UserMessage>
where
F: Future + Unpin,
S: FusedStream<Item = UserMessage> + Unpin,
{
subscribe::drive::<ParsedMessage, _, _>(future, stream).await
}
mod private {
pub trait Sealed {}
}
pub trait Source: private::Sealed {
#[doc(hidden)]
fn as_str() -> &'static str;
}
#[derive(Clone, Copy, Debug)]
pub enum IEX {}
impl Source for IEX {
#[inline]
fn as_str() -> &'static str {
"iex"
}
}
impl private::Sealed for IEX {}
#[derive(Clone, Copy, Debug)]
pub enum SIP {}
impl Source for SIP {
#[inline]
fn as_str() -> &'static str {
"sip"
}
}
impl private::Sealed for SIP {}
pub type Symbol = Str;
fn is_normalized(symbols: &[Symbol]) -> bool {
#[inline]
fn check<'a>(last: &'a mut &'a Symbol) -> impl FnMut(&'a Symbol) -> bool + 'a {
move |curr| {
if let Some(Ordering::Greater) | None = PartialOrd::partial_cmp(last, &curr) {
return false
}
*last = curr;
true
}
}
let mut it = symbols.iter();
let mut last = match it.next() {
Some(e) => e,
None => return true,
};
it.all(check(&mut last))
}
fn normalize(symbols: Cow<'static, [Symbol]>) -> Cow<'static, [Symbol]> {
fn normalize_now(symbols: Cow<'static, [Symbol]>) -> Cow<'static, [Symbol]> {
let mut symbols = symbols.into_owned();
symbols.sort_by(|x, y| x.partial_cmp(y).unwrap());
symbols.dedup();
Cow::from(symbols)
}
if !is_normalized((*symbols).borrow()) {
let symbols = normalize_now(symbols);
debug_assert!(is_normalized(&symbols));
symbols
} else {
symbols
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Bar {
#[serde(rename = "S")]
pub symbol: String,
#[serde(rename = "o")]
pub open_price: Num,
#[serde(rename = "h")]
pub high_price: Num,
#[serde(rename = "l")]
pub low_price: Num,
#[serde(rename = "c")]
pub close_price: Num,
#[serde(rename = "v")]
pub volume: u64,
#[serde(rename = "t")]
pub timestamp: DateTime<Utc>,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Quote {
#[serde(rename = "S")]
pub symbol: String,
#[serde(rename = "bp")]
pub bid_price: Num,
#[serde(rename = "bs")]
pub bid_size: u64,
#[serde(rename = "ap")]
pub ask_price: Num,
#[serde(rename = "as")]
pub ask_size: u64,
#[serde(rename = "t")]
pub timestamp: DateTime<Utc>,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Trade {
#[serde(rename = "S")]
pub symbol: String,
#[serde(rename = "i")]
pub trade_id: u64,
#[serde(rename = "p")]
pub trade_price: Num,
#[serde(rename = "s")]
pub trade_size: u64,
#[serde(rename = "t")]
pub timestamp: DateTime<Utc>,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize, ThisError)]
#[error("{message} ({code})")]
pub struct StreamApiError {
#[serde(rename = "code")]
pub code: u64,
#[serde(rename = "msg")]
pub message: String,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[doc(hidden)]
#[serde(tag = "T")]
#[allow(clippy::large_enum_variant)]
pub enum DataMessage {
#[serde(rename = "b")]
Bar(Bar),
#[serde(rename = "q")]
Quote(Quote),
#[serde(rename = "t")]
Trade(Trade),
#[serde(rename = "subscription")]
Subscription(MarketData),
#[serde(rename = "success")]
Success,
#[serde(rename = "error")]
Error(StreamApiError),
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Data {
Bar(Bar),
Quote(Quote),
Trade(Trade),
}
impl Data {
#[inline]
pub fn is_bar(&self) -> bool {
matches!(self, Self::Bar(..))
}
#[inline]
pub fn is_quote(&self) -> bool {
matches!(self, Self::Quote(..))
}
#[inline]
pub fn is_trade(&self) -> bool {
matches!(self, Self::Trade(..))
}
}
#[derive(Debug)]
#[doc(hidden)]
pub enum ControlMessage {
Subscription(MarketData),
Success,
Error(StreamApiError),
}
type ParsedMessage = MessageResult<Result<DataMessage, JsonError>, WebSocketError>;
impl subscribe::Message for ParsedMessage {
type UserMessage = Result<Result<Data, JsonError>, WebSocketError>;
type ControlMessage = ControlMessage;
fn classify(self) -> subscribe::Classification<Self::UserMessage, Self::ControlMessage> {
match self {
MessageResult::Ok(Ok(message)) => match message {
DataMessage::Bar(bar) => subscribe::Classification::UserMessage(Ok(Ok(Data::Bar(bar)))),
DataMessage::Quote(quote) => {
subscribe::Classification::UserMessage(Ok(Ok(Data::Quote(quote))))
},
DataMessage::Trade(trade) => {
subscribe::Classification::UserMessage(Ok(Ok(Data::Trade(trade))))
},
DataMessage::Subscription(data) => {
subscribe::Classification::ControlMessage(ControlMessage::Subscription(data))
},
DataMessage::Success => subscribe::Classification::ControlMessage(ControlMessage::Success),
DataMessage::Error(error) => {
subscribe::Classification::ControlMessage(ControlMessage::Error(error))
},
},
MessageResult::Ok(Err(err)) => subscribe::Classification::UserMessage(Ok(Err(err))),
MessageResult::Err(err) => subscribe::Classification::UserMessage(Err(err)),
}
}
#[inline]
fn is_error(user_message: &Self::UserMessage) -> bool {
user_message
.as_ref()
.map(|result| result.is_err())
.unwrap_or(true)
}
}
#[inline]
fn normalized_from_str<'de, D>(deserializer: D) -> Result<Cow<'static, [Symbol]>, D::Error>
where
D: Deserializer<'de>,
{
Cow::<'static, [Symbol]>::deserialize(deserializer).map(normalize)
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct SymbolList(#[serde(deserialize_with = "normalized_from_str")] Cow<'static, [Symbol]>);
impl Deref for SymbolList {
type Target = [Symbol];
fn deref(&self) -> &Self::Target {
self.0.borrow()
}
}
impl From<Cow<'static, [Symbol]>> for SymbolList {
#[inline]
fn from(symbols: Cow<'static, [Symbol]>) -> Self {
Self(normalize(symbols))
}
}
impl From<Vec<String>> for SymbolList {
#[inline]
fn from(symbols: Vec<String>) -> Self {
Self(normalize(Cow::from(
IntoIterator::into_iter(symbols)
.map(Symbol::from)
.collect::<Vec<_>>(),
)))
}
}
impl<const N: usize> From<[&'static str; N]> for SymbolList {
#[inline]
fn from(symbols: [&'static str; N]) -> Self {
Self(normalize(Cow::from(
IntoIterator::into_iter(symbols)
.map(Symbol::from)
.collect::<Vec<_>>(),
)))
}
}
mod symbols_all {
use super::*;
use serde::de::Error;
use serde::de::Unexpected;
pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<(), D::Error>
where
D: Deserializer<'de>,
{
let string = <[&str; 1]>::deserialize(deserializer)?;
if string == ["*"] {
Ok(())
} else {
Err(Error::invalid_value(
Unexpected::Str(string[0]),
&"the string \"*\"",
))
}
}
pub(crate) fn serialize<S>(serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(1))?;
seq.serialize_element("*")?;
seq.end()
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(untagged)]
pub enum Symbols {
#[serde(with = "symbols_all")]
All,
List(SymbolList),
}
impl Symbols {
#[inline]
pub fn is_empty(&self) -> bool {
match self {
Self::List(list) => list.is_empty(),
Self::All => false,
}
}
}
impl Default for Symbols {
fn default() -> Self {
Self::List(SymbolList::from([]))
}
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct MarketData {
#[serde(default)]
pub bars: Symbols,
#[serde(default)]
pub quotes: Symbols,
#[serde(default)]
pub trades: Symbols,
}
impl MarketData {
#[inline]
pub fn set_bars<S>(&mut self, symbols: S)
where
S: Into<SymbolList>,
{
self.bars = Symbols::List(symbols.into());
}
#[inline]
pub fn set_quotes<S>(&mut self, symbols: S)
where
S: Into<SymbolList>,
{
self.quotes = Symbols::List(symbols.into());
}
#[inline]
pub fn set_trades<S>(&mut self, symbols: S)
where
S: Into<SymbolList>,
{
self.trades = Symbols::List(symbols.into());
}
}
#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)]
#[doc(hidden)]
#[serde(tag = "action")]
pub enum Request<'d> {
#[serde(rename = "auth")]
Authenticate {
#[serde(rename = "key")]
key_id: Cow<'d, str>,
#[serde(rename = "secret")]
secret: Cow<'d, str>,
},
#[serde(rename = "subscribe")]
Subscribe(Cow<'d, MarketData>),
#[serde(rename = "unsubscribe")]
Unsubscribe(Cow<'d, MarketData>),
}
#[derive(Debug)]
pub struct Subscription<S> {
subscription: subscribe::Subscription<S, ParsedMessage, wrap::Message>,
subscriptions: MarketData,
}
impl<S> Subscription<S> {
#[inline]
fn new(subscription: subscribe::Subscription<S, ParsedMessage, wrap::Message>) -> Self {
Self {
subscription,
subscriptions: MarketData::default(),
}
}
}
impl<S> Subscription<S>
where
S: Sink<wrap::Message> + Unpin,
{
async fn authenticate(
&mut self,
key_id: &str,
secret: &str,
) -> Result<Result<(), Error>, S::Error> {
let request = Request::Authenticate {
key_id: key_id.into(),
secret: secret.into(),
};
let json = match to_json(&request) {
Ok(json) => json,
Err(err) => return Ok(Err(Error::Json(err))),
};
let message = wrap::Message::Text(json);
let response = self.subscription.send(message).await?;
match response {
Some(response) => match response {
Ok(ControlMessage::Success) => Ok(Ok(())),
Ok(ControlMessage::Subscription(..)) => Ok(Err(Error::Str(
"server responded with unexpected subscription message".into(),
))),
Ok(ControlMessage::Error(error)) => Ok(Err(Error::Str(
format!(
"failed to authenticate with server: {} ({})",
error.message, error.code
)
.into(),
))),
Err(()) => Ok(Err(Error::Str("failed to authenticate with server".into()))),
},
None => Ok(Err(Error::Str(
"stream was closed before authorization message was received".into(),
))),
}
}
async fn subscribe_unsubscribe(
&mut self,
request: &Request<'_>,
) -> Result<Result<(), Error>, S::Error> {
let json = match to_json(request) {
Ok(json) => json,
Err(err) => return Ok(Err(Error::Json(err))),
};
let message = wrap::Message::Text(json);
let response = self.subscription.send(message).await?;
match response {
Some(response) => match response {
Ok(ControlMessage::Subscription(data)) => {
self.subscriptions = data;
Ok(Ok(()))
},
Ok(ControlMessage::Error(error)) => Ok(Err(Error::Str(
format!("failed to subscribe: {error}").into(),
))),
Ok(_) => Ok(Err(Error::Str(
"server responded with unexpected message".into(),
))),
Err(()) => Ok(Err(Error::Str("failed to adjust subscription".into()))),
},
None => Ok(Err(Error::Str(
"stream was closed before subscription confirmation message was received".into(),
))),
}
}
#[inline]
pub async fn subscribe(&mut self, subscribe: &MarketData) -> Result<Result<(), Error>, S::Error> {
let request = Request::Subscribe(Cow::Borrowed(subscribe));
self.subscribe_unsubscribe(&request).await
}
#[inline]
pub async fn unsubscribe(
&mut self,
unsubscribe: &MarketData,
) -> Result<Result<(), Error>, S::Error> {
let request = Request::Unsubscribe(Cow::Borrowed(unsubscribe));
self.subscribe_unsubscribe(&request).await
}
#[inline]
pub fn subscriptions(&self) -> &MarketData {
&self.subscriptions
}
}
type ParseFn = fn(
Result<wrap::Message, WebSocketError>,
) -> Result<Result<Vec<DataMessage>, JsonError>, WebSocketError>;
type MapFn = fn(Result<Result<DataMessage, JsonError>, WebSocketError>) -> ParsedMessage;
type Stream = Map<
Unfold<Map<Wrapper<WebSocketStream<MaybeTlsStream<TcpStream>>>, ParseFn>, DataMessage, JsonError>,
MapFn,
>;
#[derive(Debug)]
pub struct RealtimeData<S> {
_phantom: PhantomData<S>,
}
#[async_trait]
impl<S> Subscribable for RealtimeData<S>
where
S: Source,
{
type Input = ApiInfo;
type Subscription = Subscription<SplitSink<Stream, wrap::Message>>;
type Stream = Fuse<MessageStream<SplitStream<Stream>, ParsedMessage>>;
async fn connect(api_info: &Self::Input) -> Result<(Self::Stream, Self::Subscription), Error> {
fn parse(
result: Result<wrap::Message, WebSocketError>,
) -> Result<Result<Vec<DataMessage>, JsonError>, WebSocketError> {
result.map(|message| match message {
wrap::Message::Text(string) => json_from_str::<Vec<DataMessage>>(&string),
wrap::Message::Binary(data) => json_from_slice::<Vec<DataMessage>>(&data),
})
}
let ApiInfo {
data_stream_base_url: url,
key_id,
secret,
..
} = api_info;
let mut url = url.clone();
url.set_path(&format!("v2/{}", S::as_str()));
let stream =
Unfold::new(connect(&url).await?.map(parse as ParseFn)).map(MessageResult::from as MapFn);
let (send, recv) = stream.split();
let (stream, subscription) = subscribe::subscribe(recv, send);
let mut stream = stream.fuse();
let mut subscription = Subscription::new(subscription);
let connect = subscription.subscription.read().boxed().fuse();
let message = drive(connect, &mut stream).await.map_err(|result| {
result
.map(|result| Error::Json(result.unwrap_err()))
.map_err(Error::WebSocket)
.unwrap_or_else(|err| err)
})?;
match message {
Some(Ok(ControlMessage::Success)) => (),
Some(Ok(_)) => {
return Err(Error::Str(
"server responded with unexpected initial message".into(),
))
},
Some(Err(())) => return Err(Error::Str("failed to read connected message".into())),
None => {
return Err(Error::Str(
"stream was closed before connected message was received".into(),
))
},
}
let authenticate = subscription.authenticate(key_id, secret).boxed().fuse();
let () = drive(authenticate, &mut stream).await.map_err(|result| {
result
.map(|result| Error::Json(result.unwrap_err()))
.map_err(Error::WebSocket)
.unwrap_or_else(|err| err)
})???;
Ok((stream, subscription))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
use std::time::Duration;
use chrono::DateTime;
use futures::SinkExt as _;
use futures::TryStreamExt as _;
use serial_test::serial;
use serde_json::from_str as json_from_str;
use test_log::test;
use tokio::time::timeout;
use websocket_util::test::WebSocketStream;
use websocket_util::tungstenite::Message;
use crate::api::API_BASE_URL;
use crate::websocket::test::mock_stream;
use crate::Client;
const CONN_RESP: &str = r#"[{"T":"success","msg":"connected"}]"#;
const AUTH_REQ: &str = r#"{"action":"auth","key":"USER12345678","secret":"justletmein"}"#;
const AUTH_RESP: &str = r#"[{"T":"success","msg":"authenticated"}]"#;
const SUB_REQ: &str = r#"{"action":"subscribe","bars":["AAPL","VOO"],"quotes":[],"trades":[]}"#;
const SUB_RESP: &str = r#"[{"T":"subscription","bars":["AAPL","VOO"]}]"#;
const SUB_ERR_REQ: &str = r#"{"action":"subscribe","bars":[],"quotes":[],"trades":[]}"#;
const SUB_ERR_RESP: &str = r#"[{"T":"error","code":400,"msg":"invalid syntax"}]"#;
#[test]
fn symbols_is_empty() {
assert!(!Symbols::All.is_empty());
assert!(!Symbols::List(SymbolList::from(["SPY"])).is_empty());
assert!(Symbols::List(SymbolList::from([])).is_empty());
}
#[test]
fn serialize_deserialize_bar() {
let json = r#"{
"T": "b",
"S": "SPY",
"o": 388.985,
"h": 389.13,
"l": 388.975,
"c": 389.12,
"v": 49378,
"t": "2021-02-22T19:15:00Z"
}"#;
let message = json_from_str::<DataMessage>(json).unwrap();
let bar = match &message {
DataMessage::Bar(bar) => bar,
_ => panic!("Decoded unexpected message variant: {message:?}"),
};
assert_eq!(bar.symbol, "SPY");
assert_eq!(bar.open_price, Num::new(388985, 1000));
assert_eq!(bar.high_price, Num::new(38913, 100));
assert_eq!(bar.low_price, Num::new(388975, 1000));
assert_eq!(bar.close_price, Num::new(38912, 100));
assert_eq!(bar.volume, 49378);
assert_eq!(
bar.timestamp,
DateTime::<Utc>::from_str("2021-02-22T19:15:00Z").unwrap()
);
assert_eq!(
json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
message
);
}
#[test]
fn serialize_deserialize_quote() {
let json: &str = r#"{
"T": "q",
"S": "NVDA",
"bx": "P",
"bp": 258.8,
"bs": 2,
"ax": "A",
"ap": 259.99,
"as": 5,
"c": [
"R"
],
"z": "C",
"t": "2022-01-18T23:09:42.151875584Z"
}"#;
let message = json_from_str::<DataMessage>(json).unwrap();
let quote = match &message {
DataMessage::Quote(qoute) => qoute,
_ => panic!("Decoded unexpected message variant: {message:?}"),
};
assert_eq!(quote.symbol, "NVDA");
assert_eq!(quote.bid_price, Num::new(2588, 10));
assert_eq!(quote.bid_size, 2);
assert_eq!(quote.ask_price, Num::new(25999, 100));
assert_eq!(quote.ask_size, 5);
assert_eq!(
quote.timestamp,
DateTime::<Utc>::from_str("2022-01-18T23:09:42.151875584Z").unwrap()
);
assert_eq!(
json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
message
);
}
#[test]
fn serialize_deserialize_trade() {
let json: &str = r#"{
"T": "t",
"i": 96921,
"S": "AAPL",
"x": "D",
"p": 126.55,
"s": 1,
"t": "2021-02-22T15:51:44.208Z",
"c": ["@", "I"],
"z": "C"
}"#;
let message = json_from_str::<DataMessage>(json).unwrap();
let trade = match &message {
DataMessage::Trade(trade) => trade,
_ => panic!("Decoded unexpected message variant: {message:?}"),
};
assert_eq!(trade.symbol, "AAPL");
assert_eq!(trade.trade_id, 96921);
assert_eq!(trade.trade_price, Num::new(12655, 100));
assert_eq!(trade.trade_size, 1);
assert_eq!(
trade.timestamp,
DateTime::<Utc>::from_str("2021-02-22T15:51:44.208Z").unwrap()
);
assert_eq!(
json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
message
);
}
#[test]
fn serialize_deserialize_success() {
let json = r#"{"T":"success","msg":"authenticated"}"#;
let message = json_from_str::<DataMessage>(json).unwrap();
let () = match message {
DataMessage::Success => (),
_ => panic!("Decoded unexpected message variant: {message:?}"),
};
assert_eq!(
json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
message
);
}
#[test]
fn serialize_deserialize_error() {
let json = r#"{"T":"error","code":400,"msg":"invalid syntax"}"#;
let message = json_from_str::<DataMessage>(json).unwrap();
let error = match &message {
DataMessage::Error(error) => error,
_ => panic!("Decoded unexpected message variant: {message:?}"),
};
assert_eq!(error.code, 400);
assert_eq!(error.message, "invalid syntax");
assert_eq!(
json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
message
);
let json = r#"{"T":"error","code":500,"msg":"internal error"}"#;
let message = json_from_str::<DataMessage>(json).unwrap();
let error = match &message {
DataMessage::Error(error) => error,
_ => panic!("Decoded unexpected message variant: {message:?}"),
};
assert_eq!(error.code, 500);
assert_eq!(error.message, "internal error");
assert_eq!(
json_from_str::<DataMessage>(&to_json(&message).unwrap()).unwrap(),
message
);
}
#[test]
fn serialize_deserialize_authentication_request() {
let request = Request::Authenticate {
key_id: "KEY-ID".into(),
secret: "SECRET-KEY".into(),
};
let json = to_json(&request).unwrap();
let expected = r#"{"action":"auth","key":"KEY-ID","secret":"SECRET-KEY"}"#;
assert_eq!(json, expected);
assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
}
#[test]
fn serialize_deserialize_subscribe_request() {
let mut data = MarketData::default();
data.set_bars(["AAPL", "VOO"]);
let request = Request::Subscribe(Cow::Borrowed(&data));
let json = to_json(&request).unwrap();
let expected = r#"{"action":"subscribe","bars":["AAPL","VOO"],"quotes":[],"trades":[]}"#;
assert_eq!(json, expected);
assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
}
#[test]
fn serialize_deserialize_unsubscribe_request() {
let mut data = MarketData::default();
data.set_bars(["VOO"]);
let request = Request::Unsubscribe(Cow::Borrowed(&data));
let json = to_json(&request).unwrap();
let expected = r#"{"action":"unsubscribe","bars":["VOO"],"quotes":[],"trades":[]}"#;
assert_eq!(json, expected);
assert_eq!(json_from_str::<Request<'_>>(&json).unwrap(), request);
}
#[test]
fn deserialize_symbol_list() {
let json = r#"["AAPL","XLK","SPY"]"#;
let list = json_from_str::<SymbolList>(json).unwrap();
let expected = SymbolList::from(["AAPL", "SPY", "XLK"]);
assert_eq!(list, expected);
}
#[test]
fn normalize_subscriptions() {
let subscriptions = [];
assert!(is_normalized(&subscriptions));
let subscriptions = ["MSFT".into(), "SPY".into()];
assert!(is_normalized(&subscriptions));
let mut subscriptions = Cow::from(vec!["SPY".into(), "MSFT".into()]);
assert!(!is_normalized(&subscriptions));
subscriptions = normalize(subscriptions);
assert!(is_normalized(&subscriptions));
let expected = [Cow::from("MSFT"), "SPY".into()];
assert_eq!(subscriptions.borrow(), expected);
let mut subscriptions = Cow::from(vec!["SPY".into(), "MSFT".into(), "MSFT".into()]);
assert!(!is_normalized(&subscriptions));
subscriptions = normalize(subscriptions);
assert!(is_normalized(&subscriptions));
assert_eq!(subscriptions.borrow(), expected);
}
#[test(tokio::test)]
async fn authenticate_and_subscribe() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream.send(Message::Text(CONN_RESP.to_string())).await?;
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(AUTH_REQ.to_string()),
);
stream.send(Message::Text(AUTH_RESP.to_string())).await?;
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(SUB_REQ.to_string()),
);
stream.send(Message::Text(SUB_RESP.to_string())).await?;
stream.send(Message::Close(None)).await?;
Ok(())
}
let (mut stream, mut subscription) =
mock_stream::<RealtimeData<IEX>, _, _>(test).await.unwrap();
let mut data = MarketData::default();
data.set_bars(["AAPL", "VOO"]);
let subscribe = subscription.subscribe(&data).boxed_local().fuse();
let () = drive(subscribe, &mut stream)
.await
.unwrap()
.unwrap()
.unwrap();
stream
.map_err(Error::WebSocket)
.try_for_each(|result| async { result.map(|_data| ()).map_err(Error::Json) })
.await
.unwrap();
}
#[test(tokio::test)]
async fn subscribe_error() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream.send(Message::Text(CONN_RESP.to_string())).await?;
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(AUTH_REQ.to_string()),
);
stream.send(Message::Text(AUTH_RESP.to_string())).await?;
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(SUB_ERR_REQ.to_string()),
);
stream.send(Message::Text(SUB_ERR_RESP.to_string())).await?;
stream.send(Message::Close(None)).await?;
Ok(())
}
let (mut stream, mut subscription) =
mock_stream::<RealtimeData<IEX>, _, _>(test).await.unwrap();
let data = MarketData::default();
let subscribe = subscription.subscribe(&data).boxed_local().fuse();
let error = drive(subscribe, &mut stream)
.await
.unwrap()
.unwrap()
.unwrap_err();
match error {
Error::Str(ref e) if e == "failed to subscribe: invalid syntax (400)" => {},
e => panic!("received unexpected error: {e}"),
}
}
#[test(tokio::test)]
#[serial(realtime_data)]
async fn subscribe_resubscribe() {
let api_info = ApiInfo::from_env().unwrap();
let client = Client::new(api_info);
let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
let mut data = MarketData::default();
data.set_bars(["AAPL", "SPY"]);
let subscribe = subscription.subscribe(&data).boxed_local().fuse();
let () = drive(subscribe, &mut stream)
.await
.unwrap()
.unwrap()
.unwrap();
assert_eq!(subscription.subscriptions(), &data);
let mut data = MarketData::default();
data.set_bars(["XLK"]);
let subscribe = subscription.subscribe(&data).boxed_local().fuse();
let () = drive(subscribe, &mut stream)
.await
.unwrap()
.unwrap()
.unwrap();
let mut expected = MarketData::default();
expected.set_bars(["AAPL", "SPY", "XLK"]);
assert_eq!(subscription.subscriptions(), &expected);
}
#[test(tokio::test)]
#[serial(realtime_data)]
async fn stream_market_data_updates() {
let api_info = ApiInfo::from_env().unwrap();
let client = Client::new(api_info);
let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
let data = MarketData {
bars: Symbols::All,
..Default::default()
};
let subscribe = subscription.subscribe(&data).boxed_local().fuse();
let () = drive(subscribe, &mut stream)
.await
.unwrap()
.unwrap()
.unwrap();
assert_eq!(subscription.subscriptions(), &data);
let read = stream
.map_err(Error::WebSocket)
.try_for_each(|result| async {
result
.map(|data| {
assert!(data.is_bar());
})
.map_err(Error::Json)
});
if timeout(Duration::from_millis(100), read).await.is_ok() {
panic!("realtime data stream got exhausted unexpectedly")
}
}
#[test(tokio::test)]
#[serial(realtime_data)]
async fn stream_quotes() {
let api_info = ApiInfo::from_env().unwrap();
let client = Client::new(api_info);
let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
let mut data = MarketData::default();
data.set_quotes(["SPY"]);
let subscribe = subscription.subscribe(&data).boxed_local().fuse();
let () = drive(subscribe, &mut stream)
.await
.unwrap()
.unwrap()
.unwrap();
let read = stream
.map_err(Error::WebSocket)
.try_for_each(|result| async {
result
.map(|data| {
assert!(data.is_quote());
})
.map_err(Error::Json)
});
if timeout(Duration::from_millis(100), read).await.is_ok() {
panic!("realtime data stream got exhausted unexpectedly")
}
}
#[test(tokio::test)]
#[serial(realtime_data)]
async fn stream_trades() {
let api_info = ApiInfo::from_env().unwrap();
let client = Client::new(api_info);
let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
let mut data = MarketData::default();
data.set_trades(["SPY"]);
let subscribe = subscription.subscribe(&data).boxed_local().fuse();
let () = drive(subscribe, &mut stream)
.await
.unwrap()
.unwrap()
.unwrap();
let read = stream
.map_err(Error::WebSocket)
.try_for_each(|result| async {
result
.map(|data| {
assert!(data.is_trade());
})
.map_err(Error::Json)
});
if timeout(Duration::from_millis(100), read).await.is_ok() {
panic!("realtime data stream got exhausted unexpectedly")
}
}
#[test(tokio::test)]
#[serial(realtime_data)]
async fn unsubscribe_not_subscribed_symbol() {
let api_info = ApiInfo::from_env().unwrap();
let client = Client::new(api_info);
let (mut stream, mut subscription) = client.subscribe::<RealtimeData<IEX>>().await.unwrap();
let mut data = MarketData::default();
data.set_bars(["AAPL"]);
let unsubscribe = subscription.unsubscribe(&data).boxed_local().fuse();
let () = drive(unsubscribe, &mut stream)
.await
.unwrap()
.unwrap()
.unwrap();
assert_eq!(subscription.subscriptions(), &MarketData::default());
}
#[test(tokio::test)]
#[serial(realtime_data)]
async fn stream_with_invalid_credentials() {
let api_info = ApiInfo::from_parts(API_BASE_URL, "invalid", "invalid-too").unwrap();
let client = Client::new(api_info);
let err = client.subscribe::<RealtimeData<IEX>>().await.unwrap_err();
match err {
Error::Str(ref e) if e.starts_with("failed to authenticate with server") => (),
e => panic!("received unexpected error: {e}"),
}
}
}