use async_trait::async_trait;
use std::{collections::HashMap, pin::Pin, task::Poll};
use futures::{
stream::{SplitStream, Stream},
SinkExt, StreamExt,
};
use serde::{Deserialize, Serialize};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use crate::model::websocket::{Channel, CoinbaseSubscription, CoinbaseWebsocketMessage, Subscribe, SubscribeCmd};
use openlimits_exchange::errors::OpenLimitsError;
use crate::model::websocket::ChannelType;
use crate::CoinbaseParameters;
use openlimits_exchange::traits::stream::{ExchangeStream, Subscriptions};
use futures::stream::BoxStream;
use std::sync::Mutex;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use super::shared::Result;
use openlimits_exchange::exchange::Environment;
const WS_URL_PROD: &str = "wss://ws-feed.exchange.coinbase.com";
const WS_URL_SANDBOX: &str = "wss://ws-feed-public.sandbox.exchange.coinbase.com";
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
enum Either<L, R> {
Left(L),
Right(R),
}
type WSStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub struct CoinbaseWebsocket {
pub subscriptions: HashMap<CoinbaseSubscription, SplitStream<WSStream>>,
pub parameters: CoinbaseParameters,
disconnection_senders: Mutex<Vec<UnboundedSender<()>>>,
}
impl CoinbaseWebsocket {
pub async fn subscribe_(&mut self, subscription: CoinbaseSubscription) -> Result<()> {
let (channels, product_ids) = match &subscription {
CoinbaseSubscription::Level2(product_id) => (
vec![Channel::Name(ChannelType::Level2)],
vec![product_id.clone()],
),
CoinbaseSubscription::Heartbeat(product_id) => (
vec![Channel::Name(ChannelType::Heartbeat)],
vec![product_id.clone()],
),
CoinbaseSubscription::Matches(product_id) => (
vec![Channel::Name(ChannelType::Matches)],
vec![product_id.clone()]
)
};
let subscribe = Subscribe {
_type: SubscribeCmd::Subscribe,
auth: None,
channels,
product_ids,
};
let stream = self.connect(subscribe).await?;
self.subscriptions.insert(subscription, stream);
Ok(())
}
pub async fn connect(&self, subscribe: Subscribe) -> Result<SplitStream<WSStream>> {
let ws_url = if self.parameters.environment == Environment::Sandbox {
WS_URL_SANDBOX
} else {
WS_URL_PROD
};
let url = url::Url::parse(ws_url).expect("Couldn't parse url.");
let (ws_stream, _) = connect_async(&url).await?;
let (mut sink, stream) = ws_stream.split();
let subscribe = serde_json::to_string(&subscribe)?;
sink.send(Message::Text(subscribe)).await?;
Ok(stream)
}
}
impl Stream for CoinbaseWebsocket {
type Item = Result<CoinbaseWebsocketMessage>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
for (_sub, stream) in &mut self.subscriptions.iter_mut() {
if let Poll::Ready(Some(message)) = Pin::new(stream).poll_next(cx) {
let message = parse_message(message?);
return Poll::Ready(Some(message));
}
}
std::task::Poll::Pending
}
}
fn parse_message(ws_message: Message) -> Result<CoinbaseWebsocketMessage> {
let msg = match ws_message {
Message::Text(m) => m,
_ => return Err(OpenLimitsError::SocketError()),
};
Ok(serde_json::from_str(&msg)?)
}
#[async_trait]
impl ExchangeStream for CoinbaseWebsocket {
type InitParams = CoinbaseParameters;
type Subscription = CoinbaseSubscription;
type Response = CoinbaseWebsocketMessage;
async fn new(parameters: Self::InitParams) -> Result<Self> {
Ok(Self {
subscriptions: Default::default(),
parameters,
disconnection_senders: Default::default(),
})
}
async fn disconnect(&self) {
if let Ok(mut senders) = self.disconnection_senders.lock() {
for sender in senders.iter() {
sender.send(()).ok();
}
senders.clear();
}
}
async fn create_stream_specific(
&self,
subscription: Subscriptions<Self::Subscription>,
) -> Result<BoxStream<'static, Result<Self::Response>>> {
let ws_url = if self.parameters.environment == Environment::Sandbox {
WS_URL_SANDBOX
} else {
WS_URL_PROD
};
let endpoint = url::Url::parse(ws_url).expect("Couldn't parse url.");
let (ws_stream, _) = connect_async(endpoint).await?;
let (channel_name, product_ids) = match &subscription.as_slice()[0] {
CoinbaseSubscription::Level2(product_id) => (
ChannelType::Level2,
vec![product_id.clone()],
),
CoinbaseSubscription::Heartbeat(product_id) => (
ChannelType::Heartbeat,
vec![product_id.clone()],
),
CoinbaseSubscription::Matches(product_id) => (
ChannelType::Matches,
vec![product_id.clone()]
)
};
let channels = vec![Channel::Name(channel_name.clone())];
let subscribe = Subscribe {
_type: SubscribeCmd::Subscribe,
auth: None,
channels,
product_ids: product_ids.clone(),
};
let subscribe = serde_json::to_string(&subscribe)?;
let (mut sink, stream) = ws_stream.split();
let (disconnection_sender, mut disconnection_receiver) = unbounded_channel();
sink.send(Message::Text(subscribe)).await?;
tokio::spawn(async move {
if disconnection_receiver.recv().await.is_some() {
sink.close().await.ok();
}
});
if let Ok(mut senders) = self.disconnection_senders.lock() {
senders.push(disconnection_sender);
}
let mut s = stream.map(|message| parse_message(message?));
let name = channel_name;
let product = Channel::WithProduct { name, product_ids };
let channels = vec![product];
let expected_response = CoinbaseWebsocketMessage::Subscriptions { channels };
let response = s.next().await;
if let Some(Ok(response)) = response {
if response == expected_response {
Ok(s.boxed())
} else {
Err(OpenLimitsError::UnkownResponse(format!("Response: {:#?}, expected response: {:#?}", response, expected_response)))
}
} else {
Err(OpenLimitsError::UnkownResponse(format!("No response")))
}
}
}