use chrono::serde::ts_milliseconds::deserialize as datetime_from_timestamp;
use chrono::DateTime;
use chrono::Utc;
use futures::stream::unfold;
use futures::Stream;
use futures::StreamExt;
use num_decimal::Num;
use serde::Deserialize;
use serde_json::from_slice as from_json_slice;
use serde_json::from_str as from_json_str;
use serde_json::Error as JsonError;
use tracing::debug;
use tracing::trace;
use tungstenite::connect_async;
use websocket_util::tungstenite::Error as WebSocketError;
use websocket_util::wrap::Message as WebSocketMessage;
use websocket_util::wrap::Wrapper;
use crate::api_info::ApiInfo;
use crate::error::Error;
use crate::events::handshake::handshake;
use crate::events::subscription::Subscription;
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct Trade {
#[serde(rename = "sym")]
pub symbol: String,
#[serde(rename = "x")]
pub exchange: u64,
#[serde(rename = "p")]
pub price: Num,
#[serde(rename = "s")]
pub quantity: u64,
#[serde(rename = "t", deserialize_with = "datetime_from_timestamp")]
pub timestamp: DateTime<Utc>,
}
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct Quote {
#[serde(rename = "sym")]
pub symbol: String,
#[serde(rename = "bx")]
pub bid_exchange: u64,
#[serde(rename = "bp")]
pub bid_price: Num,
#[serde(rename = "bs")]
pub bid_quantity: u64,
#[serde(rename = "ax")]
pub ask_exchange: u64,
#[serde(rename = "ap")]
pub ask_price: Num,
#[serde(rename = "as")]
pub ask_quantity: u64,
#[serde(rename = "t", deserialize_with = "datetime_from_timestamp")]
pub timestamp: DateTime<Utc>,
}
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct Aggregate {
#[serde(rename = "sym")]
pub symbol: String,
#[serde(rename = "v")]
pub volume: u64,
#[serde(rename = "vw")]
pub volume_weighted_average_price: Num,
#[serde(rename = "o")]
pub open_price: Num,
#[serde(rename = "c")]
pub close_price: Num,
#[serde(rename = "h")]
pub high_price: Num,
#[serde(rename = "l")]
pub low_price: Num,
#[serde(rename = "s", deserialize_with = "datetime_from_timestamp")]
pub start_timestamp: DateTime<Utc>,
#[serde(rename = "e", deserialize_with = "datetime_from_timestamp")]
pub end_timestamp: DateTime<Utc>,
}
#[derive(Copy, Clone, Debug, Deserialize, PartialEq)]
pub(crate) enum Code {
#[serde(rename = "connected")]
Connected,
#[serde(rename = "disconnected")]
Disconnected,
#[serde(rename = "auth_success")]
AuthSuccess,
#[serde(rename = "auth_failed")]
AuthFailure,
#[serde(rename = "success")]
Success,
}
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub(crate) struct Status {
#[serde(rename = "status")]
pub code: Code,
#[serde(rename = "message")]
pub message: String,
}
#[derive(Clone, Debug, Deserialize, PartialEq)]
#[allow(clippy::large_enum_variant)]
#[serde(tag = "ev")]
pub(crate) enum Message {
#[serde(rename = "status")]
Status(Status),
#[serde(rename = "A")]
SecondAggregate(Aggregate),
#[serde(rename = "AM")]
MinuteAggregate(Aggregate),
#[serde(rename = "T")]
Trade(Trade),
#[serde(rename = "Q")]
Quote(Quote),
}
#[cfg(test)]
impl Message {
pub fn into_status(self) -> Option<Status> {
match self {
Message::Status(status) => Some(status),
_ => None,
}
}
}
pub(crate) type Messages = Vec<Message>;
#[derive(Clone, Debug, Deserialize, PartialEq)]
#[allow(clippy::large_enum_variant)]
#[serde(tag = "ev")]
pub enum Event {
#[serde(rename = "A")]
SecondAggregate(Aggregate),
#[serde(rename = "AM")]
MinuteAggregate(Aggregate),
#[serde(rename = "T")]
Trade(Trade),
#[serde(rename = "Q")]
Quote(Quote),
}
impl Event {
pub fn symbol(&self) -> &str {
match self {
Event::SecondAggregate(aggregate) | Event::MinuteAggregate(aggregate) => &aggregate.symbol,
Event::Trade(trade) => &trade.symbol,
Event::Quote(quote) => "e.symbol,
}
}
#[cfg(test)]
fn to_trade(&self) -> Option<&Trade> {
match self {
Event::Trade(trade) => Some(trade),
_ => None,
}
}
#[cfg(test)]
fn to_quote(&self) -> Option<&Quote> {
match self {
Event::Quote(quote) => Some(quote),
_ => None,
}
}
}
fn process_message(message: Message) -> Option<Result<Event, WebSocketError>> {
let event = match message {
Message::Status(status) => {
if status.code == Code::Disconnected {
return Some(Err(WebSocketError::AlreadyClosed))
} else {
return None
}
},
Message::SecondAggregate(aggregate) => Event::SecondAggregate(aggregate),
Message::MinuteAggregate(aggregate) => Event::MinuteAggregate(aggregate),
Message::Trade(trade) => Event::Trade(trade),
Message::Quote(quote) => Event::Quote(quote),
};
Some(Ok(event))
}
async fn handle_msg<S>(
stop: &mut bool,
stream: &mut S,
messages: &mut Vec<Message>,
) -> Option<Result<Result<Event, JsonError>, WebSocketError>>
where
S: Stream<Item = Result<Result<Vec<Message>, JsonError>, WebSocketError>> + Unpin,
{
if *stop {
None
} else {
let result = loop {
match messages.pop() {
Some(message) => {
let result = process_message(message);
match result {
Some(result) => {
if result.is_err() {
*stop = true;
}
break result.map(Ok)
},
None => continue,
}
},
None => {
let next_msg = StreamExt::next(stream).await;
if let Some(result) = next_msg {
match result {
Ok(result) => match result {
Ok(new) => {
*messages = new;
continue
},
Err(err) => break Ok(Err(err)),
},
Err(err) => break Err(err),
}
} else {
return None
}
},
};
};
Some(result)
}
}
#[allow(clippy::cognitive_complexity)]
pub async fn stream<S>(
api_info: ApiInfo,
subscriptions: S,
) -> Result<impl Stream<Item = Result<Result<Event, JsonError>, WebSocketError>>, Error>
where
S: IntoIterator<Item = Subscription>,
{
let ApiInfo {
stream_url: url,
api_key,
..
} = api_info;
debug!(message = "connecting", url = display(&url));
let (mut stream, response) = connect_async(url).await?;
debug!("connection successful");
trace!(response = debug(&response));
handshake(&mut stream, api_key, subscriptions).await?;
debug!("subscription successful");
let stream = Wrapper::builder().build(stream).map(|result| {
result.map(|message| match message {
WebSocketMessage::Text(string) => from_json_str::<Messages>(&string),
WebSocketMessage::Binary(data) => from_json_slice::<Messages>(&data),
})
});
let stream = Box::pin(stream);
let stream = unfold(
(false, (stream, Vec::new())),
|(mut stop, (mut stream, mut messages))| async move {
let result = handle_msg(&mut stop, &mut stream, &mut messages).await;
result.map(|result| (result, (stop, (stream, messages))))
},
);
Ok(stream)
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
use futures::future::ready;
use futures::SinkExt;
use futures::StreamExt;
use futures::TryStreamExt;
use serde_json::from_str as from_json;
use test_log::test;
use tungstenite::tungstenite::Message as WebSocketMessage;
use url::Url;
use websocket_util::test::mock_server;
use websocket_util::test::WebSocketStream;
use crate::events::subscription::Stock;
const API_KEY: &str = "USER12345678";
const CONNECTED_MSG: &str =
r#"[{"ev":"status","status":"connected","message":"Connected Successfully"}]"#;
const DISCONNECTED_MSG: &str =
r#"[{"ev":"status","status":"disconnected","message":"Reason: Max connections reached"}]"#;
const AUTH_REQ: &str = r#"{"action":"auth","params":"USER12345678"}"#;
const AUTH_RESP: &str = r#"[{"ev":"status","status":"auth_success","message":"authenticated"}]"#;
const SUB_REQ: &str = r#"{"action":"subscribe","params":"T.MSFT,Q.*"}"#;
const SUB_RESP: &str = {
r#"[
{"ev":"status","status":"success","message":"subscribed to: T.MSFT"},
{"ev":"status","status":"success","message":"subscribed to: Q.*"}]
"#
};
const MSFT_TRADE_MSG: &str = {
r#"[{"ev":"T","sym":"MSFT","i":8310,"x":4,"p":156.9799,"s":3,"c":[37],"t":1577818283019,"z":3}]"#
};
const UFO_QUOTE_MSG: &str = {
r#"[
{"ev":"Q","sym":"UFO","c":1,"bx":8,"ax":12,"bp":26.4,"ap":26.47,"bs":1,"as":3,"t":1577818659363,"z":3},
{"ev":"Q","sym":"UFO","c":1,"bx":8,"ax":12,"bp":26.4,"ap":26.47,"bs":1,"as":11,"t":1577818659365,"z":3}
]"#
};
async fn mock_stream<F, R, S>(
f: F,
subscriptions: S,
) -> Result<impl Stream<Item = Result<Result<Event, JsonError>, WebSocketError>>, Error>
where
F: FnOnce(WebSocketStream) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), WebSocketError>> + Send + Sync + 'static,
S: IntoIterator<Item = Subscription>,
{
let addr = mock_server(f).await;
let api_info = ApiInfo {
api_url: Url::parse("http://example.com").unwrap(),
stream_url: Url::parse(&format!("ws://{}", addr)).unwrap(),
api_key: API_KEY.to_string(),
};
stream(api_info, subscriptions).await
}
#[test]
fn deserialize_trade() {
let response = r#"{
"ev": "T",
"sym": "SPY",
"i": 436698869,
"x": 19,
"p": 293.67,
"s": 100,
"c": [],
"t": 1583527402638,
"z": 2
}"#;
let trade = from_json::<Trade>(response).unwrap();
assert_eq!(trade.symbol, "SPY");
assert_eq!(trade.exchange, 19);
assert_eq!(trade.price, Num::new(29367, 100));
assert_eq!(trade.quantity, 100);
assert_eq!(
trade.timestamp,
DateTime::parse_from_rfc3339("2020-03-06T15:43:22.638-05:00").unwrap()
);
}
#[test]
fn deserialize_quote() {
let response = r#"{
"ev": "Q",
"sym": "SPY",
"c": 0,
"bx": 12,
"ax": 11,
"bp": 294.31,
"ap": 294.33,
"bs": 1,
"as": 2,
"t": 1583527004684,
"z": 2
}"#;
let quote = from_json::<Quote>(response).unwrap();
assert_eq!(quote.symbol, "SPY");
assert_eq!(quote.bid_exchange, 12);
assert_eq!(quote.bid_price, Num::new(29431, 100));
assert_eq!(quote.bid_quantity, 1);
assert_eq!(quote.ask_exchange, 11);
assert_eq!(quote.ask_price, Num::new(29433, 100));
assert_eq!(quote.ask_quantity, 2);
assert_eq!(
quote.timestamp,
DateTime::parse_from_rfc3339("2020-03-06T15:36:44.684-05:00").unwrap()
);
}
#[test]
fn deserialize_aggregate() {
let response = r#"{
"ev": "A",
"sym": "SPY",
"v": 2287,
"av": 163569633,
"op": 298.71,
"vw": 294.6301,
"o": 293.79,
"c": 293.68,
"h": 293.8,
"l": 293.68,
"a": 293.7442,
"s": 1583527401000,
"e": 1583527402000
}"#;
let aggregate = from_json::<Aggregate>(response).unwrap();
assert_eq!(aggregate.symbol, "SPY");
assert_eq!(aggregate.volume, 2287);
assert_eq!(
aggregate.volume_weighted_average_price,
Num::new(2_946_301, 10000),
);
assert_eq!(aggregate.open_price, Num::new(29379, 100));
assert_eq!(aggregate.close_price, Num::new(29368, 100));
assert_eq!(aggregate.high_price, Num::new(2938, 10));
assert_eq!(aggregate.low_price, Num::new(29368, 100));
assert_eq!(
aggregate.start_timestamp,
DateTime::parse_from_rfc3339("2020-03-06T15:43:21-05:00").unwrap()
);
assert_eq!(
aggregate.end_timestamp,
DateTime::parse_from_rfc3339("2020-03-06T15:43:22-05:00").unwrap()
);
}
#[test]
fn parse_event() {
let response = r#"{
"ev": "AM",
"sym": "MSFT",
"v": 10204,
"av": 200304,
"op": 114.04,
"vw": 114.4040,
"o": 114.11,
"c": 114.14,
"h": 114.19,
"l": 114.09,
"a": 114.1314,
"s": 1536036818784,
"e": 1536036818784
}"#;
let event = from_json::<Event>(response).unwrap();
match event {
Event::MinuteAggregate(aggregate) => {
assert_eq!(aggregate.symbol, "MSFT");
},
_ => panic!("unexpected event: {:?}", event),
}
}
#[test]
fn parse_events() {
let response = r#"[
{"ev":"Q","sym":"XLE","c":0,"bx":11,"ax":12,"bp":59.88,
"ap":59.89,"bs":28,"as":67,"t":1577724127207,"z":2},
{"ev":"Q","sym":"AAPL","c":0,"bx":11,"ax":12,"bp":59.88,
"ap":59.89,"bs":28,"as":65,"t":1577724127207,"z":2}
]"#;
let messages = from_json::<Messages>(response).unwrap();
assert_eq!(messages.len(), 2);
match &messages[0] {
Message::Quote(Quote { symbol, .. }) if symbol == "XLE" => (),
e => panic!("unexpected event: {:?}", e),
}
match &messages[1] {
Message::Quote(Quote { symbol, .. }) if symbol == "AAPL" => (),
e => panic!("unexpected event: {:?}", e),
}
}
#[test(tokio::test)]
async fn stream_msft() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream
.send(WebSocketMessage::Text(CONNECTED_MSG.to_string()))
.await?;
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Text(AUTH_REQ.to_string()),
);
stream
.send(WebSocketMessage::Text(AUTH_RESP.to_string()))
.await?;
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Text(SUB_REQ.to_string()),
);
stream
.send(WebSocketMessage::Text(SUB_RESP.to_string()))
.await?;
stream
.send(WebSocketMessage::Text(MSFT_TRADE_MSG.to_string()))
.await?;
stream
.send(WebSocketMessage::Text(UFO_QUOTE_MSG.to_string()))
.await?;
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
let subscriptions = vec![
Subscription::Trades(Stock::Symbol("MSFT".into())),
Subscription::Quotes(Stock::All),
];
let mut stream = Box::pin(mock_stream(test, subscriptions).await.unwrap());
let trade = stream.next().await.unwrap().unwrap().unwrap();
assert_eq!(trade.to_trade().unwrap().symbol, "MSFT");
let quote = stream.next().await.unwrap().unwrap().unwrap();
let quote0 = quote.to_quote().unwrap();
assert_eq!(quote0.symbol, "UFO");
assert_eq!(quote0.ask_quantity, 11);
let quote = stream.next().await.unwrap().unwrap().unwrap();
let quote1 = quote.to_quote().unwrap();
assert_eq!(quote1.symbol, "UFO");
assert_eq!(quote1.ask_quantity, 3);
assert!(stream.next().await.is_none());
}
#[test(tokio::test)]
async fn interleaved_trade() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream
.send(WebSocketMessage::Text(CONNECTED_MSG.to_string()))
.await?;
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Text(AUTH_REQ.to_string()),
);
stream
.send(WebSocketMessage::Text(AUTH_RESP.to_string()))
.await?;
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Text(SUB_REQ.to_string()),
);
stream
.send(WebSocketMessage::Text(MSFT_TRADE_MSG.to_string()))
.await?;
stream
.send(WebSocketMessage::Text(SUB_RESP.to_string()))
.await?;
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
let subscriptions = vec![
Subscription::Trades(Stock::Symbol("MSFT".into())),
Subscription::Quotes(Stock::All),
];
let _ = mock_stream(test, subscriptions)
.await
.unwrap()
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
#[test(tokio::test)]
async fn disconnect() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream
.send(WebSocketMessage::Text(CONNECTED_MSG.to_string()))
.await?;
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Text(AUTH_REQ.to_string()),
);
stream
.send(WebSocketMessage::Text(AUTH_RESP.to_string()))
.await?;
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Text(SUB_REQ.to_string()),
);
stream
.send(WebSocketMessage::Text(SUB_RESP.to_string()))
.await?;
stream
.send(WebSocketMessage::Text(MSFT_TRADE_MSG.to_string()))
.await?;
stream
.send(WebSocketMessage::Text(DISCONNECTED_MSG.to_string()))
.await?;
stream
.send(WebSocketMessage::Text(UFO_QUOTE_MSG.to_string()))
.await?;
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
let subscriptions = vec![
Subscription::Trades(Stock::Symbol("MSFT".into())),
Subscription::Quotes(Stock::All),
];
let mut stream = Box::pin(mock_stream(test, subscriptions).await.unwrap());
assert!(stream.next().await.unwrap().is_ok());
assert!(stream.next().await.unwrap().is_err());
assert!(stream.next().await.is_none());
}
}