use super::utils::connect_with_retry;
use std::{
collections::HashSet,
io::prelude::*,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
time::{Duration, Instant},
};
use flate2::read::{DeflateDecoder, GzDecoder};
use log::*;
use tungstenite::{client::AutoStream, Error, Message, WebSocket};
const PING_LATENCY: u64 = 2;
pub(super) enum MiscMessage {
WebSocket(Message), Reconnect, Misc, Normal, }
pub(super) struct WSClientInternal<'a> {
exchange: &'static str, pub(super) url: String, ws_stream: Mutex<WebSocket<AutoStream>>,
channels: Mutex<HashSet<String>>, on_msg: Arc<Mutex<dyn FnMut(String) + 'a + Send>>, on_misc_msg: fn(&str) -> MiscMessage, channels_to_commands: fn(&[String], bool) -> Vec<String>,
should_stop: AtomicBool, ping_interval_and_msg: Option<(u64, &'static str)>,
}
impl<'a> WSClientInternal<'a> {
pub fn new(
exchange: &'static str,
url: &str,
on_msg: Arc<Mutex<dyn FnMut(String) + 'a + Send>>,
on_misc_msg: fn(&str) -> MiscMessage,
channels_to_commands: fn(&[String], bool) -> Vec<String>,
ping_interval_and_msg: Option<(u64, &'static str)>,
) -> Self {
let stream = connect_with_retry(
url,
if let Some(interval_and_msg) = ping_interval_and_msg {
Some(interval_and_msg.0 - PING_LATENCY)
} else {
None
},
);
WSClientInternal {
exchange,
url: url.to_string(),
ws_stream: Mutex::new(stream),
on_msg,
on_misc_msg,
channels: Mutex::new(HashSet::new()),
channels_to_commands,
should_stop: AtomicBool::new(false),
ping_interval_and_msg,
}
}
pub fn subscribe(&self, channels: &[String]) {
self.subscribe_or_unsubscribe(channels, true);
}
pub fn unsubscribe(&self, channels: &[String]) {
self.subscribe_or_unsubscribe(channels, false);
}
fn subscribe_or_unsubscribe(&self, channels: &[String], subscribe: bool) {
let mut diff = Vec::<String>::new();
{
let mut guard = self.channels.lock().unwrap();
for ch in channels.iter() {
if guard.insert(ch.clone()) {
diff.push(ch.clone());
}
}
}
if !diff.is_empty() {
let commands = (self.channels_to_commands)(&diff, subscribe);
let mut ws_stream = self.ws_stream.lock().unwrap();
commands.into_iter().for_each(|command| {
let ret = ws_stream.write_message(Message::Text(command));
if let Err(err) = ret {
error!("{}", err);
}
});
}
}
fn reconnect(&self) {
warn!("Reconnecting to {}", &self.url);
{
let mut guard = self.ws_stream.lock().unwrap();
*guard = connect_with_retry(
self.url.as_str(),
if let Some(interval_and_msg) = self.ping_interval_and_msg {
Some(interval_and_msg.0 - PING_LATENCY)
} else {
None
},
);
}
let channels = self
.channels
.lock()
.unwrap()
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>();
if !channels.is_empty() {
let commands = (self.channels_to_commands)(&channels, true);
let mut ws_stream = self.ws_stream.lock().unwrap();
commands.into_iter().for_each(|command| {
let ret = ws_stream.write_message(Message::Text(command));
if let Err(err) = ret {
error!("{}", err);
}
});
}
}
fn handle_msg(&self, txt: &str) -> bool {
match (self.on_misc_msg)(txt) {
MiscMessage::Misc => false,
MiscMessage::Reconnect => {
self.reconnect();
false
}
MiscMessage::WebSocket(ws_msg) => {
let ret = self.ws_stream.lock().unwrap().write_message(ws_msg);
if let Err(err) = ret {
error!("{}", err);
}
false
}
MiscMessage::Normal => {
if self.exchange == super::mxc::EXCHANGE_NAME
&& self.url.as_str() == super::mxc::SPOT_WEBSOCKET_URL
{
match txt.strip_prefix("42") {
Some(msg) => (self.on_msg.lock().unwrap())(msg.to_string()),
None => error!(
"{}, Not possible, should be handled by {}.on_misc_msg() previously",
txt, self.exchange
),
}
} else {
(self.on_msg.lock().unwrap())(txt.to_string());
}
true
}
}
}
pub fn run(&self, duration: Option<u64>) {
let now = Instant::now();
while !self.should_stop.load(Ordering::Acquire) {
let resp = self.ws_stream.lock().unwrap().read_message();
let normal = match resp {
Ok(msg) => match msg {
Message::Text(txt) => self.handle_msg(&txt),
Message::Binary(binary) => {
let mut txt = String::new();
let resp = if self.exchange == super::huobi::EXCHANGE_NAME
|| self.exchange == super::binance::EXCHANGE_NAME
{
let mut decoder = GzDecoder::new(&binary[..]);
decoder.read_to_string(&mut txt)
} else if self.exchange == super::okex::EXCHANGE_NAME {
let mut decoder = DeflateDecoder::new(&binary[..]);
decoder.read_to_string(&mut txt)
} else {
error!("Unknown binary format from {}", self.url);
panic!("Unknown binary format from {}", self.url);
};
match resp {
Ok(_) => self.handle_msg(&txt),
Err(err) => {
error!("Decompression failed, {}", err);
false
}
}
}
Message::Ping(resp) => {
info!(
"Received a ping frame: {}",
std::str::from_utf8(&resp).unwrap()
);
let ret = self
.ws_stream
.lock()
.unwrap()
.write_message(Message::Pong(resp));
if let Err(err) = ret {
error!("{}", err);
}
false
}
Message::Pong(resp) => {
let tmp = std::str::from_utf8(&resp);
warn!("Received a pong frame: {}", tmp.unwrap());
false
}
Message::Close(resp) => {
match resp {
Some(frame) => warn!("Received a Message::Close message with a CloseFrame: code: {}, reason: {}", frame.code, frame.reason),
None => warn!("Received a close message without CloseFrame"),
}
false
}
},
Err(err) => {
match err {
Error::ConnectionClosed => {
error!("ConnectionClosed");
self.reconnect();
}
Error::AlreadyClosed => {
error!("Impossible to happen, fix the bug in the code");
panic!("Impossible to happen, fix the bug in the code");
}
Error::Io(io_err) => {
if io_err.kind() == std::io::ErrorKind::WouldBlock {
info!("Sending ping");
let ping_msg = Message::Text(
self.ping_interval_and_msg.unwrap().1.to_string(),
);
if let Err(err) =
self.ws_stream.lock().unwrap().write_message(ping_msg)
{
error!("{}", err);
}
} else {
let err_msg = io_err.to_string();
if err_msg.contains("connection closed via error") {
warn!(
"I/O error thrown from read_message(): {}, {:?}",
io_err,
io_err.kind()
);
self.reconnect();
} else {
error!(
"I/O error thrown from read_message(): {}, {:?}",
io_err,
io_err.kind()
);
}
}
}
Error::Protocol(protocol_err) => {
if protocol_err.contains("Connection reset without closing handshake") {
error!("ResetWithoutClosingHandshake");
self.reconnect();
} else {
error!(
"Protocol error thrown from read_message(): {}",
protocol_err
);
}
}
_ => {
error!("Error thrown from read_message(): {}", err);
panic!("Error thrown from read_message(): {}", err);
}
}
false
}
};
if let Some(seconds) = duration {
if now.elapsed() > Duration::from_secs(seconds) && normal {
break;
}
}
}
}
pub fn close(&self) {
self.should_stop.store(true, Ordering::Release);
let ret = self.ws_stream.lock().unwrap().close(None);
if let Err(err) = ret {
error!("{}", err);
}
}
}
macro_rules! define_client {
($struct_name:ident, $exchange:ident, $default_url:ident, $channels_to_commands:ident, $on_misc_msg:ident, $ping_interval:expr) => {
impl<'a> WSClient<'a> for $struct_name<'a> {
fn new(
on_msg: Arc<Mutex<dyn FnMut(String) + 'a + Send>>,
url: Option<&str>,
) -> $struct_name<'a> {
let real_url = match url {
Some(endpoint) => endpoint,
None => $default_url,
};
$struct_name {
client: WSClientInternal::new(
$exchange,
real_url,
on_msg,
$on_misc_msg,
$channels_to_commands,
$ping_interval,
),
}
}
fn subscribe_trade(&self, channels: &[String]) {
<$struct_name as Trade>::subscribe_trade(self, channels);
}
fn subscribe_orderbook(&self, channels: &[String]) {
<$struct_name as OrderBook>::subscribe_orderbook(self, channels);
}
fn subscribe_orderbook_snapshot(&self, channels: &[String]) {
<$struct_name as OrderBookSnapshot>::subscribe_orderbook_snapshot(self, channels);
}
fn subscribe_ticker(&self, channels: &[String]) {
<$struct_name as Ticker>::subscribe_ticker(self, channels);
}
fn subscribe_bbo(&self, channels: &[String]) {
<$struct_name as BBO>::subscribe_bbo(self, channels);
}
fn subscribe_candlestick(&self, pairs: &[String], interval: u32) {
<$struct_name as Candlestick>::subscribe_candlestick(self, pairs, interval);
}
fn subscribe(&self, channels: &[String]) {
self.client.subscribe(channels);
}
fn unsubscribe(&self, channels: &[String]) {
self.client.unsubscribe(channels);
}
fn run(&self, duration: Option<u64>) {
self.client.run(duration);
}
fn close(&self) {
self.client.close();
}
}
};
}