use crate::client::Client;
use crate::websocket::{Channel, SubscribeMessage, UnsubscribeMessage, WebSocketMessage};
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
const MARKET_DATA_WS_URL: &str = "wss://advanced-trade-ws.coinbase.com";
const USER_DATA_WS_URL: &str = "wss://advanced-trade-ws-user.coinbase.com";
pub struct WebSocketClient<'a> {
client: Arc<Client<'a>>,
market_data_ws: Option<tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>>,
user_data_ws: Option<tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>>,
message_sender: mpsc::UnboundedSender<WebSocketMessage>,
message_receiver: Option<mpsc::UnboundedReceiver<WebSocketMessage>>,
}
impl<'a> WebSocketClient<'a> {
pub fn new(client: Arc<Client<'a>>) -> Self {
let (sender, receiver) = mpsc::unbounded_channel();
WebSocketClient {
client,
market_data_ws: None,
user_data_ws: None,
message_sender: sender,
message_receiver: Some(receiver),
}
}
pub async fn connect_market_data(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let (ws_stream, _) = tokio_tungstenite::connect_async(MARKET_DATA_WS_URL).await?;
self.market_data_ws = Some(ws_stream);
Ok(())
}
pub async fn connect_user_data(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let (ws_stream, _) = tokio_tungstenite::connect_async(USER_DATA_WS_URL).await?;
self.user_data_ws = Some(ws_stream);
Ok(())
}
pub async fn subscribe(
&mut self,
channel: Channel,
product_ids: Option<Vec<String>>,
use_user_endpoint: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let jwt = if channel.requires_authentication() {
Some(self.client.create_ws_jwt())
} else {
None
};
let subscribe_msg = SubscribeMessage {
message_type: "subscribe".to_string(),
channel: channel.as_str().to_string(),
product_ids,
jwt,
};
let message_json = serde_json::to_string(&subscribe_msg)?;
if use_user_endpoint {
if let Some(ref mut ws) = self.user_data_ws {
ws.send(Message::text(message_json)).await?;
} else {
return Err("User data WebSocket not connected".into());
}
} else {
if let Some(ref mut ws) = self.market_data_ws {
ws.send(Message::text(message_json)).await?;
} else {
return Err("Market data WebSocket not connected".into());
}
}
Ok(())
}
pub async fn unsubscribe(
&mut self,
channel: Channel,
product_ids: Option<Vec<String>>,
use_user_endpoint: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let jwt = if channel.requires_authentication() {
Some(self.client.create_ws_jwt())
} else {
None
};
let unsubscribe_msg = UnsubscribeMessage {
message_type: "unsubscribe".to_string(),
channel: channel.as_str().to_string(),
product_ids,
jwt,
};
let message_json = serde_json::to_string(&unsubscribe_msg)?;
if use_user_endpoint {
if let Some(ref mut ws) = self.user_data_ws {
ws.send(Message::text(message_json)).await?;
} else {
return Err("User data WebSocket not connected".into());
}
} else {
if let Some(ref mut ws) = self.market_data_ws {
ws.send(Message::text(message_json)).await?;
} else {
return Err("Market data WebSocket not connected".into());
}
}
Ok(())
}
pub fn get_message_receiver(&mut self) -> Option<mpsc::UnboundedReceiver<WebSocketMessage>> {
self.message_receiver.take()
}
pub async fn listen_market_data(&mut self) -> Result<(), Box<dyn std::error::Error>> {
if let Some(ref mut ws) = self.market_data_ws {
while let Some(message) = ws.next().await {
match message {
Ok(Message::Text(text)) => {
if let Ok(ws_message) = serde_json::from_str::<WebSocketMessage>(&text) {
let _ = self.message_sender.send(ws_message);
}
}
Ok(Message::Close(_)) => {
break;
}
Err(e) => {
eprintln!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
}
Ok(())
}
pub async fn listen_user_data(&mut self) -> Result<(), Box<dyn std::error::Error>> {
if let Some(ref mut ws) = self.user_data_ws {
while let Some(message) = ws.next().await {
match message {
Ok(Message::Text(text)) => {
if let Ok(ws_message) = serde_json::from_str::<WebSocketMessage>(&text) {
let _ = self.message_sender.send(ws_message);
}
}
Ok(Message::Close(_)) => {
break;
}
Err(e) => {
eprintln!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
}
Ok(())
}
pub async fn close_market_data(&mut self) -> Result<(), Box<dyn std::error::Error>> {
if let Some(mut ws) = self.market_data_ws.take() {
ws.close(None).await?;
}
Ok(())
}
pub async fn close_user_data(&mut self) -> Result<(), Box<dyn std::error::Error>> {
if let Some(mut ws) = self.user_data_ws.take() {
ws.close(None).await?;
}
Ok(())
}
}