use std::ops::Deref;
use bon::{Builder, bon};
use futures::{FutureExt, SinkExt, StreamExt, select};
use futures_timer::Delay;
use tracing::{debug, warn};
use web_time::Duration;
use super::models::{ServerMessage, TaggedMessage};
use super::topics::Topic;
use crate::errors::WSErrors;
use crate::types::{ClientMessage, OrderParams, RequestId};
use crate::{Client, SDKResult};
const DEFAULT_CONNECTION_TIMEOUT_SECS: u64 = 10;
pub struct WebsocketHandle {
socket: reqwest_websocket::WebSocket,
}
impl WebsocketHandle {
pub(crate) async fn connect(
ws_client: &reqwest::Client,
ws_url: &str,
timeout: web_time::Duration,
) -> SDKResult<Self, WSErrors> {
use reqwest_websocket::Upgrade;
let response: reqwest_websocket::UpgradeResponse =
ws_client.clone().get(ws_url).upgrade().send().await?;
let websocket = response.into_websocket().await?;
let mut handle = Self { socket: websocket };
handle.wait_for_connected(timeout).await?;
Ok(handle)
}
}
#[derive(Builder, Clone, Debug)]
pub struct WebsocketConfig {
#[builder(default = Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT_SECS))]
pub connection_timeout: Duration,
}
impl Default for WebsocketConfig {
fn default() -> Self {
Self::builder().build()
}
}
impl Deref for WebsocketHandle {
type Target = reqwest_websocket::WebSocket;
fn deref(&self) -> &Self::Target {
&self.socket
}
}
#[bon]
impl Client {
#[builder]
pub async fn connect_ws(
&self,
config: Option<WebsocketConfig>,
) -> SDKResult<WebsocketHandle, WSErrors> {
let config = config.unwrap_or_default();
WebsocketHandle::connect(&self.ws_client, self.ws_url(), config.connection_timeout).await
}
}
impl WebsocketHandle {
async fn wait_for_connected(&mut self, timeout: Duration) -> SDKResult<(), WSErrors> {
#[allow(clippy::useless_conversion)]
let timeout = Delay::new(
timeout
.try_into()
.unwrap_or(std::time::Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT_SECS)),
);
debug!("Waiting for connected message from websocket.");
select! {
result = self.recv().fuse() => {
match result? {
ServerMessage::Tagged(TaggedMessage::Status(status))
if status.status == "connected" =>
{
debug!("Successfully got connected message, continuing");
Ok(())
}
other => Err(WSErrors::WsHandshakeFailed(format!("{other:?}"))),
}
}
_ = timeout.fuse() => {
Err(WSErrors::WsConnectionTimeout)
}
}
}
pub async fn send(&mut self, msg: ClientMessage) -> SDKResult<(), WSErrors> {
let string_msg = serde_json::to_string(&msg)?;
self.socket.send(reqwest_websocket::Message::Text(string_msg)).await?;
Ok(())
}
pub async fn recv(&mut self) -> SDKResult<ServerMessage, WSErrors> {
while let Some(msg) = self.socket.next().await {
let msg = msg?;
match msg {
reqwest_websocket::Message::Text(text) => {
let server_msg = match serde_json::from_str::<ServerMessage>(&text) {
Ok(v) => v,
Err(e) => {
warn!(?e, "Failed to parse ServerMessage, returning Unknown");
ServerMessage::Unknown(e.to_string(), text)
}
};
return Ok(server_msg);
}
reqwest_websocket::Message::Binary(data) => {
let text = String::from_utf8_lossy(&data).to_string();
let server_msg = match serde_json::from_slice::<ServerMessage>(&data) {
Ok(v) => v,
Err(e) => {
warn!(?e, "Failed to parse ServerMessage, returning Unknown");
ServerMessage::Unknown(e.to_string(), text)
}
};
return Ok(server_msg);
}
reqwest_websocket::Message::Close { code, reason } => {
return Err(WSErrors::WsClosed { code, reason });
}
_ => continue,
}
}
Err(WSErrors::WsStreamEnded)
}
pub async fn subscribe(
&mut self,
topics: impl IntoIterator<Item = Topic>,
id: Option<RequestId>,
) -> SDKResult<(), WSErrors> {
self.send(ClientMessage::Subscribe {
id,
params: topics.into_iter().map(|t| t.to_string()).collect(),
})
.await
}
pub async fn list_subscriptions(&mut self, id: Option<RequestId>) -> SDKResult<(), WSErrors> {
self.send(ClientMessage::ListSubscriptions { id }).await
}
pub async fn order_place(
&mut self,
tx: impl Into<String>,
id: Option<RequestId>,
) -> SDKResult<(), WSErrors> {
self.send(ClientMessage::OrderPlace { id, params: OrderParams { tx: tx.into() } }).await
}
pub async fn order_cancel(
&mut self,
tx: impl Into<String>,
id: Option<RequestId>,
) -> SDKResult<(), WSErrors> {
self.send(ClientMessage::OrderCancel { id, params: OrderParams { tx: tx.into() } }).await
}
pub async fn place_order(
&mut self,
signed: &bullet_exchange_interface::transaction::Transaction,
id: Option<RequestId>,
) -> SDKResult<(), WSErrors> {
let base64 =
crate::Transaction::to_base64(signed).map_err(|e| WSErrors::WsError(e.to_string()))?;
self.order_place(base64, id).await
}
pub async fn cancel_order(
&mut self,
signed: &bullet_exchange_interface::transaction::Transaction,
id: Option<RequestId>,
) -> SDKResult<(), WSErrors> {
let base64 =
crate::Transaction::to_base64(signed).map_err(|e| WSErrors::WsError(e.to_string()))?;
self.order_cancel(base64, id).await
}
}