use std::borrow::Cow;
use async_trait::async_trait;
use futures::stream::Fuse;
use futures::stream::Map;
use futures::stream::SplitSink;
use futures::stream::SplitStream;
use futures::FutureExt as _;
use futures::Sink;
use futures::StreamExt as _;
use serde::Deserialize;
use serde::Serialize;
use serde_json::from_slice as json_from_slice;
use serde_json::from_str as json_from_str;
use serde_json::to_string as to_json;
use serde_json::Error as JsonError;
use tokio::net::TcpStream;
use tungstenite::MaybeTlsStream;
use tungstenite::WebSocketStream;
use websocket_util::subscribe;
use websocket_util::subscribe::MessageStream;
use websocket_util::tungstenite::Error as WebSocketError;
use websocket_util::wrap;
use websocket_util::wrap::Wrapper;
use crate::api::v2::order;
use crate::api_info::ApiInfo;
use crate::subscribable::Subscribable;
use crate::websocket::connect;
use crate::websocket::MessageResult;
use crate::Error;
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub enum OrderStatus {
#[serde(rename = "new")]
New,
#[serde(rename = "replaced")]
Replaced,
#[serde(rename = "order_replace_rejected")]
ReplaceRejected,
#[serde(rename = "partial_fill")]
PartialFill,
#[serde(rename = "fill")]
Filled,
#[serde(rename = "done_for_day")]
DoneForDay,
#[serde(rename = "canceled")]
Canceled,
#[serde(rename = "order_cancel_rejected")]
CancelRejected,
#[serde(rename = "expired")]
Expired,
#[serde(rename = "pending_cancel")]
PendingCancel,
#[serde(rename = "stopped")]
Stopped,
#[serde(rename = "rejected")]
Rejected,
#[serde(rename = "suspended")]
Suspended,
#[serde(rename = "pending_new")]
PendingNew,
#[serde(rename = "pending_replace")]
PendingReplace,
#[serde(rename = "calculated")]
Calculated,
#[serde(other, rename(serialize = "unknown"))]
Unknown,
}
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[doc(hidden)]
pub enum StreamType {
#[serde(rename = "trade_updates")]
OrderUpdates,
}
#[derive(Debug, Deserialize, Serialize)]
#[doc(hidden)]
pub struct Streams<'d> {
pub streams: Cow<'d, [StreamType]>,
}
impl<'d> From<&'d [StreamType]> for Streams<'d> {
#[inline]
fn from(src: &'d [StreamType]) -> Self {
Self {
streams: Cow::from(src),
}
}
}
#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)]
#[doc(hidden)]
#[allow(missing_copy_implementations)]
pub enum AuthenticationStatus {
#[serde(rename = "authorized")]
Authorized,
#[serde(rename = "unauthorized")]
Unauthorized,
}
#[derive(Debug, Deserialize, Serialize)]
#[doc(hidden)]
pub struct Authentication {
#[serde(rename = "status")]
pub status: AuthenticationStatus,
}
#[derive(Debug, Deserialize, Serialize)]
#[doc(hidden)]
#[serde(tag = "action", content = "data")]
pub enum Request<'d> {
#[serde(rename = "authenticate")]
Authenticate {
#[serde(rename = "key_id")]
key_id: Cow<'d, str>,
#[serde(rename = "secret_key")]
secret: Cow<'d, str>,
},
#[serde(rename = "listen")]
Listen(Streams<'d>),
}
#[derive(Debug)]
#[doc(hidden)]
pub enum ControlMessage {
AuthenticationMessage(Authentication),
ListeningMessage(Streams<'static>),
}
#[derive(Debug, Deserialize, Serialize)]
#[doc(hidden)]
#[serde(tag = "stream", content = "data")]
#[allow(clippy::large_enum_variant)]
pub enum OrderMessage {
#[serde(rename = "trade_updates")]
OrderUpdate(OrderUpdate),
#[serde(rename = "authorization")]
AuthenticationMessage(Authentication),
#[serde(rename = "listening")]
ListeningMessage(Streams<'static>),
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct OrderUpdate {
#[serde(rename = "event")]
pub event: OrderStatus,
#[serde(rename = "order")]
pub order: order::Order,
}
type ParsedMessage = MessageResult<Result<OrderMessage, JsonError>, WebSocketError>;
impl subscribe::Message for ParsedMessage {
type UserMessage = Result<Result<OrderUpdate, JsonError>, WebSocketError>;
type ControlMessage = ControlMessage;
fn classify(self) -> subscribe::Classification<Self::UserMessage, Self::ControlMessage> {
match self {
MessageResult::Ok(Ok(message)) => match message {
OrderMessage::OrderUpdate(update) => subscribe::Classification::UserMessage(Ok(Ok(update))),
OrderMessage::AuthenticationMessage(authentication) => {
subscribe::Classification::ControlMessage(ControlMessage::AuthenticationMessage(
authentication,
))
},
OrderMessage::ListeningMessage(streams) => {
subscribe::Classification::ControlMessage(ControlMessage::ListeningMessage(streams))
},
},
MessageResult::Ok(Err(err)) => subscribe::Classification::UserMessage(Ok(Err(err))),
MessageResult::Err(err) => subscribe::Classification::UserMessage(Err(err)),
}
}
#[inline]
fn is_error(user_message: &Self::UserMessage) -> bool {
user_message
.as_ref()
.map(|result| result.is_err())
.unwrap_or(true)
}
}
#[derive(Debug)]
pub struct Subscription<S>(subscribe::Subscription<S, ParsedMessage, wrap::Message>);
impl<S> Subscription<S>
where
S: Sink<wrap::Message> + Unpin,
{
async fn authenticate(
&mut self,
key_id: &str,
secret: &str,
) -> Result<Result<(), Error>, S::Error> {
let request = Request::Authenticate {
key_id: key_id.into(),
secret: secret.into(),
};
let json = match to_json(&request) {
Ok(json) => json,
Err(err) => return Ok(Err(Error::Json(err))),
};
let message = wrap::Message::Text(json);
let response = self.0.send(message).await?;
match response {
Some(response) => match response {
Ok(ControlMessage::AuthenticationMessage(authentication)) => {
if authentication.status != AuthenticationStatus::Authorized {
return Ok(Err(Error::Str("authentication not successful".into())))
}
Ok(Ok(()))
},
Ok(_) => Ok(Err(Error::Str(
"server responded with an unexpected message".into(),
))),
Err(()) => Ok(Err(Error::Str("failed to authenticate with server".into()))),
},
None => Ok(Err(Error::Str(
"stream was closed before authorization message was received".into(),
))),
}
}
async fn listen(&mut self) -> Result<Result<(), Error>, S::Error> {
let streams = Streams::from([StreamType::OrderUpdates].as_ref());
let request = Request::Listen(streams);
let json = match to_json(&request) {
Ok(json) => json,
Err(err) => return Ok(Err(Error::Json(err))),
};
let message = wrap::Message::Text(json);
let response = self.0.send(message).await?;
match response {
Some(response) => match response {
Ok(ControlMessage::ListeningMessage(streams)) => {
if !streams.streams.contains(&StreamType::OrderUpdates) {
return Ok(Err(Error::Str(
"server did not subscribe us to order update stream".into(),
)))
}
Ok(Ok(()))
},
Ok(_) => Ok(Err(Error::Str(
"server responded with an unexpected message".into(),
))),
Err(()) => Ok(Err(Error::Str(
"failed to listen to order update stream".into(),
))),
},
None => Ok(Err(Error::Str(
"stream was closed before listen message was received".into(),
))),
}
}
}
type Stream = Map<Wrapper<WebSocketStream<MaybeTlsStream<TcpStream>>>, MapFn>;
type MapFn = fn(Result<wrap::Message, WebSocketError>) -> ParsedMessage;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum OrderUpdates {}
#[async_trait]
impl Subscribable for OrderUpdates {
type Input = ApiInfo;
type Subscription = Subscription<SplitSink<Stream, wrap::Message>>;
type Stream = Fuse<MessageStream<SplitStream<Stream>, ParsedMessage>>;
async fn connect(api_info: &Self::Input) -> Result<(Self::Stream, Self::Subscription), Error> {
fn map(result: Result<wrap::Message, WebSocketError>) -> ParsedMessage {
MessageResult::from(result.map(|message| match message {
wrap::Message::Text(string) => json_from_str::<OrderMessage>(&string),
wrap::Message::Binary(data) => json_from_slice::<OrderMessage>(&data),
}))
}
let ApiInfo {
api_stream_url: url,
key_id,
secret,
..
} = api_info;
let stream = connect(url).await?.map(map as MapFn);
let (send, recv) = stream.split();
let (stream, subscription) = subscribe::subscribe(recv, send);
let mut stream = stream.fuse();
let mut subscription = Subscription(subscription);
let authenticate = subscription.authenticate(key_id, secret).boxed().fuse();
let () = subscribe::drive::<ParsedMessage, _, _>(authenticate, &mut stream)
.await
.map_err(|result| {
result
.map(|result| Error::Json(result.unwrap_err()))
.map_err(Error::WebSocket)
.unwrap_or_else(|err| err)
})???;
let listen = subscription.listen().boxed().fuse();
let () = subscribe::drive::<ParsedMessage, _, _>(listen, &mut stream)
.await
.map_err(|result| {
result
.map(|result| Error::Json(result.unwrap_err()))
.map_err(Error::WebSocket)
.unwrap_or_else(|err| err)
})???;
Ok((stream, subscription))
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::channel::oneshot::channel;
use futures::future::ok;
use futures::future::ready;
use futures::SinkExt;
use futures::TryStreamExt;
use serde_json::from_str as json_from_str;
use test_log::test;
use websocket_util::test::WebSocketStream;
use websocket_util::tungstenite::error::ProtocolError;
use websocket_util::tungstenite::Message;
use crate::api::v2::order;
use crate::api::v2::order_util::order_aapl;
use crate::api::API_BASE_URL;
use crate::websocket::test::mock_stream;
use crate::Client;
use crate::Error;
const AUTH_REQ: &str =
r#"{"action":"authenticate","data":{"key_id":"USER12345678","secret_key":"justletmein"}}"#;
const AUTH_RESP: &str =
r#"{"stream":"authorization","data":{"action":"authenticate","status":"authorized"}}"#;
const STREAM_REQ: &str = r#"{"action":"listen","data":{"streams":["trade_updates"]}}"#;
const STREAM_RESP: &str = r#"{"stream":"listening","data":{"streams":["trade_updates"]}}"#;
#[test]
fn encode_authentication_request() {
let key_id = "some-key".into();
let secret = "super-secret-secret".into();
let expected = {
r#"{"action":"authenticate","data":{"key_id":"some-key","secret_key":"super-secret-secret"}}"#
};
let request = Request::Authenticate { key_id, secret };
let json = to_json(&request).unwrap();
assert_eq!(json, expected)
}
#[test]
fn encode_listen_request() {
let expected = r#"{"action":"listen","data":{"streams":["trade_updates"]}}"#;
let streams = Streams::from([StreamType::OrderUpdates].as_ref());
let request = Request::Listen(streams);
let json = to_json(&request).unwrap();
assert_eq!(json, expected)
}
#[test]
fn decode_order_update() {
let json = r#"{
"stream":"trade_updates","data":{
"event":"new","execution_id":"11111111-2222-3333-4444-555555555555","order":{
"asset_class":"us_equity","asset_id":"11111111-2222-3333-4444-555555555555",
"canceled_at":null,"client_order_id":"11111111-2222-3333-4444-555555555555",
"created_at":"2021-12-09T19:48:46.176628398Z","expired_at":null,
"extended_hours":false,"failed_at":null,"filled_at":null,
"filled_avg_price":null,"filled_qty":"0","hwm":null,
"id":"11111111-2222-3333-4444-555555555555","legs":null,"limit_price":"1",
"notional":null,"order_class":"simple","order_type":"limit","qty":"1",
"replaced_at":null,"replaced_by":null,"replaces":null,"side":"buy",
"status":"new","stop_price":null,"submitted_at":"2021-12-09T19:48:46.175261379Z",
"symbol":"AAPL","time_in_force":"day","trail_percent":null,"trail_price":null,
"type":"limit","updated_at":"2021-12-09T19:48:46.185346448Z"
},"timestamp":"2021-12-09T19:48:46.182987144Z"
}
}"#;
let message = json_from_str::<OrderMessage>(json).unwrap();
match message {
OrderMessage::OrderUpdate(update) => {
assert_eq!(update.event, OrderStatus::New);
assert_eq!(update.order.side, order::Side::Buy);
},
_ => panic!("Decoded unexpected message variant: {message:?}"),
}
}
#[test]
fn decode_authentication() {
let json =
{ r#"{"stream":"authorization","data":{"status":"authorized","action":"authenticate"}}"# };
let message = json_from_str::<OrderMessage>(json).unwrap();
match message {
OrderMessage::AuthenticationMessage(authentication) => {
assert_eq!(authentication.status, AuthenticationStatus::Authorized);
},
_ => panic!("Decoded unexpected message variant: {message:?}"),
}
}
#[test]
fn decode_unauthorized_authentication() {
let json =
{ r#"{"stream":"authorization","data":{"status":"unauthorized","action":"listen"}}"# };
let message = json_from_str::<OrderMessage>(json).unwrap();
match message {
OrderMessage::AuthenticationMessage(authentication) => {
assert_eq!(authentication.status, AuthenticationStatus::Unauthorized);
},
_ => panic!("Decoded unexpected message variant: {message:?}"),
}
}
#[test]
fn decode_listening() {
let json = r#"{"stream":"listening","data":{"streams":["trade_updates"]}}"#;
let message = json_from_str::<OrderMessage>(json).unwrap();
match message {
OrderMessage::ListeningMessage(streams) => {
assert_eq!(streams.streams, vec![StreamType::OrderUpdates]);
},
_ => panic!("Decoded unexpected message variant: {message:?}"),
}
}
#[test(tokio::test)]
async fn broken_stream() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
let msg = stream.next().await.unwrap()?;
assert_eq!(msg, Message::Text(AUTH_REQ.to_string()));
Ok(())
}
let result = mock_stream::<OrderUpdates, _, _>(test).await;
match result {
Ok(..) => panic!("authentication succeeded unexpectedly"),
Err(Error::WebSocket(WebSocketError::Protocol(e)))
if e == ProtocolError::ResetWithoutClosingHandshake => {},
Err(e) => panic!("received unexpected error: {e}"),
}
}
#[test(tokio::test)]
async fn early_close() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(AUTH_REQ.to_string()),
);
stream.send(Message::Text(AUTH_RESP.to_string())).await?;
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(STREAM_REQ.to_string()),
);
stream.send(Message::Close(None)).await?;
Ok(())
}
let result = mock_stream::<OrderUpdates, _, _>(test).await;
match result {
Ok(..) => panic!("operation succeeded unexpectedly"),
Err(Error::Str(ref e)) if e.starts_with("stream was closed before listen") => (),
Err(e) => panic!("received unexpected error: {e}"),
}
}
#[test(tokio::test)]
async fn no_messages() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(AUTH_REQ.to_string()),
);
stream.send(Message::Text(AUTH_RESP.to_string())).await?;
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(STREAM_REQ.to_string()),
);
stream.send(Message::Text(STREAM_RESP.to_string())).await?;
Ok(())
}
let err = mock_stream::<OrderUpdates, _, _>(test).await.unwrap_err();
match err {
Error::WebSocket(WebSocketError::Protocol(e))
if e == ProtocolError::ResetWithoutClosingHandshake => {},
e => panic!("received unexpected error: {e}"),
}
}
#[test(tokio::test)]
async fn decode_error_during_handshake() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(AUTH_REQ.to_string()),
);
stream.send(Message::Text(AUTH_RESP.to_string())).await?;
stream
.send(Message::Text("{ foobarbaz }".to_string()))
.await?;
Ok(())
}
let result = mock_stream::<OrderUpdates, _, _>(test).await.unwrap_err();
match result {
Error::Json(_) => (),
e => panic!("received unexpected error: {e}"),
}
}
#[test(tokio::test)]
async fn decode_error_errors_do_not_terminate() {
let (sender, receiver) = channel();
let test = |mut stream: WebSocketStream| {
async move {
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(AUTH_REQ.to_string()),
);
stream.send(Message::Text(AUTH_RESP.to_string())).await?;
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(STREAM_REQ.to_string()),
);
stream.send(Message::Text(STREAM_RESP.to_string())).await?;
let () = receiver.await.unwrap();
stream
.send(Message::Text("{ foobarbaz }".to_string()))
.await?;
stream.send(Message::Close(None)).await?;
Ok(())
}
};
let (stream, _subscription) = mock_stream::<OrderUpdates, _, _>(test).await.unwrap();
let () = sender.send(()).unwrap();
stream
.map_err(Error::from)
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
#[test(tokio::test)]
async fn ping_pong() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(AUTH_REQ.to_string()),
);
stream.send(Message::Text(AUTH_RESP.to_string())).await?;
assert_eq!(
stream.next().await.unwrap()?,
Message::Text(STREAM_REQ.to_string()),
);
stream.send(Message::Text(STREAM_RESP.to_string())).await?;
stream.send(Message::Ping(Vec::new())).await?;
assert_eq!(stream.next().await.unwrap()?, Message::Pong(Vec::new()),);
stream.send(Message::Close(None)).await?;
Ok(())
}
let (stream, _subscription) = mock_stream::<OrderUpdates, _, _>(test).await.unwrap();
stream
.map_err(Error::from)
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
#[test(tokio::test)]
async fn stream_order_events() {
let api_info = ApiInfo::from_env().unwrap();
let client = Client::new(api_info);
let (stream, _subscription) = client.subscribe::<OrderUpdates>().await.unwrap();
let order = order_aapl(&client).await.unwrap();
client.issue::<order::Delete>(&order.id).await.unwrap();
let update = stream
.try_filter_map(|result| {
let update = result.unwrap();
ok(Some(update))
})
.try_skip_while(|update| ok(update.order.id != order.id))
.next()
.await
.unwrap()
.unwrap();
assert_eq!(order.id, update.order.id);
assert_eq!(order.asset_id, update.order.asset_id);
assert_eq!(order.symbol, update.order.symbol);
assert_eq!(order.asset_class, update.order.asset_class);
assert_eq!(order.type_, update.order.type_);
assert_eq!(order.side, update.order.side);
assert_eq!(order.time_in_force, update.order.time_in_force);
}
#[test(tokio::test)]
async fn stream_with_invalid_credentials() {
let api_info = ApiInfo::from_parts(API_BASE_URL, "invalid", "invalid-too").unwrap();
let client = Client::new(api_info);
let err = client.subscribe::<OrderUpdates>().await.unwrap_err();
match err {
Error::Str(ref e) if e == "authentication not successful" => (),
e => panic!("received unexpected error: {e}"),
}
}
}