use std::time::Duration;
use futures_util::{SinkExt, Stream, StreamExt};
use serde::Serialize;
use serde_json::Value;
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::{client::IntoClientRequest, Message};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tracing::{debug, warn};
use url::Url;
use crate::client::Client;
use crate::error::{Error, Result};
pub type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
#[derive(Debug)]
pub struct WsClient {
stream: WsStream,
}
pub type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
pub type WsRecv = futures_util::stream::SplitStream<WsStream>;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum WsMessage {
Heartbeat,
System(Value),
MarketData {
conid: i64,
payload: Value,
},
Order(Value),
Pnl(Value),
Other(Value),
Malformed {
text: String,
error: String,
},
}
impl WsMessage {
#[must_use]
pub fn topic(&self) -> &'static str {
match self {
Self::Heartbeat => "heartbeat",
Self::System(_) => "system",
Self::MarketData { .. } => "market_data",
Self::Order(_) => "order",
Self::Pnl(_) => "pnl",
Self::Other(_) => "other",
Self::Malformed { .. } => "malformed",
}
}
#[must_use]
pub fn as_value(&self) -> Option<&Value> {
match self {
Self::System(v) | Self::Order(v) | Self::Pnl(v) | Self::Other(v) => Some(v),
Self::MarketData { payload, .. } => Some(payload),
Self::Heartbeat | Self::Malformed { .. } => None,
}
}
}
#[derive(Debug, Clone)]
pub struct Subscription {
cancel_payload: String,
pub name: String,
}
impl Subscription {
pub async fn cancel(self, ws: &mut WsClient) -> Result<()> {
ws.send_text(self.cancel_payload).await
}
#[must_use]
pub fn cancel_payload(&self) -> &str {
&self.cancel_payload
}
}
#[derive(Debug, Clone)]
pub struct MarketDataFields(Vec<String>);
impl MarketDataFields {
#[must_use]
pub fn default_l1() -> Self {
Self::from_codes(["31", "84", "86", "85", "88", "87"])
}
pub fn from_codes<I, S>(codes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self(codes.into_iter().map(Into::into).collect())
}
#[must_use]
pub fn as_slice(&self) -> &[String] {
&self.0
}
}
impl<S> FromIterator<S> for MarketDataFields
where
S: Into<String>,
{
fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
Self::from_codes(iter)
}
}
impl WsClient {
#[tracing::instrument(skip(client), level = "debug")]
pub async fn connect(client: &Client) -> Result<Self> {
let tickle = client.tickle().await?;
let session = tickle.session.ok_or(Error::NoSession)?;
let ws_url = ws_url_from_base(client.base_url())?;
let cookie = format!(r#"api={{"session":"{session}"}}"#);
debug!(%ws_url, "bezant: opening websocket");
let mut request = ws_url.as_str().into_client_request().map_err(|source| {
Error::WsHandshake {
url: ws_url.to_string(),
source,
}
})?;
request.headers_mut().insert(
"Cookie",
cookie.parse().map_err(|source| Error::Header {
name: "cookie",
source,
})?,
);
request.headers_mut().insert(
"User-Agent",
format!("bezant/{}", env!("CARGO_PKG_VERSION"))
.parse()
.map_err(|source| Error::Header {
name: "user-agent",
source,
})?,
);
let (stream, _) =
tokio_tungstenite::connect_async(request)
.await
.map_err(|source| Error::WsHandshake {
url: ws_url.to_string(),
source,
})?;
Ok(Self { stream })
}
#[tracing::instrument(skip(self, fields), fields(conid = conid), level = "debug")]
pub async fn subscribe_market_data(
&mut self,
conid: i64,
fields: &MarketDataFields,
) -> Result<Subscription> {
#[derive(Serialize)]
struct Body<'a> {
fields: &'a [String],
}
let body = Body {
fields: fields.as_slice(),
};
let payload = format!(
"smd+{conid}+{}",
serde_json::to_string(&body)
.map_err(|e| Error::WsProtocol(format!("serialise fields: {e}")))?
);
self.send_text(payload).await?;
Ok(Subscription {
cancel_payload: format!("umd+{conid}+{{}}"),
name: format!("market_data:{conid}"),
})
}
pub async fn unsubscribe_market_data(&mut self, conid: i64) -> Result<()> {
self.send_text(format!("umd+{conid}+{{}}")).await
}
pub async fn subscribe_orders(&mut self) -> Result<Subscription> {
self.send_text("sor+{}".into()).await?;
Ok(Subscription {
cancel_payload: "uor+{}".into(),
name: "orders".into(),
})
}
pub async fn subscribe_pnl(&mut self) -> Result<Subscription> {
self.send_text("spl+{}".into()).await?;
Ok(Subscription {
cancel_payload: "upl+{}".into(),
name: "pnl".into(),
})
}
pub async fn send_text(&mut self, payload: String) -> Result<()> {
self.stream
.send(Message::text(payload))
.await
.map_err(|source| Error::WsTransport { source })
}
pub async fn next_message(&mut self) -> Result<Option<WsMessage>> {
while let Some(raw) = self.stream.next().await {
let frame = raw.map_err(|source| Error::WsTransport { source })?;
match frame {
Message::Text(text) => return Ok(Some(classify(text.as_str()))),
Message::Binary(bytes) => {
let s = String::from_utf8_lossy(&bytes).to_string();
return Ok(Some(classify(&s)));
}
Message::Ping(data) => {
if let Err(e) = self.stream.send(Message::Pong(data)).await {
warn!(error = %e, "bezant: pong send failed");
}
}
Message::Pong(_) => {}
Message::Frame(_) => {}
Message::Close(_) => return Ok(None),
}
}
Ok(None)
}
pub fn raw_stream(self) -> impl Stream<Item = Result<WsMessage>> + Unpin {
Box::pin(futures_util::stream::unfold(
self.stream,
|mut s| async move {
loop {
match s.next().await {
None => return None,
Some(Err(source)) => {
return Some((Err(Error::WsTransport { source }), s))
}
Some(Ok(Message::Text(t))) => {
return Some((Ok(classify(t.as_str())), s));
}
Some(Ok(Message::Binary(b))) => {
let text = String::from_utf8_lossy(&b).to_string();
return Some((Ok(classify(&text)), s));
}
Some(Ok(Message::Ping(p))) => {
let _ = s.send(Message::Pong(p)).await;
}
Some(Ok(Message::Close(_))) => return None,
Some(Ok(_)) => {}
}
}
},
))
}
pub fn split(self) -> (WsSink, WsRecv) {
let (sink, stream) = self.stream.split();
(sink, stream)
}
#[must_use]
pub const fn recommended_keepalive() -> Duration {
Duration::from_secs(60)
}
}
fn ws_url_from_base(base: &Url) -> Result<Url> {
let mut ws = base.clone();
match ws.scheme() {
"https" => ws.set_scheme("wss").map_err(|()| Error::WsProtocol(
"failed to upgrade base URL scheme to wss".into(),
))?,
"http" => ws
.set_scheme("ws")
.map_err(|()| Error::WsProtocol("failed to upgrade base URL scheme to ws".into()))?,
s => {
return Err(Error::BadRequest(format!(
"unsupported WebSocket base scheme '{s}' (expected http/https)"
)))
}
}
{
let mut segs = ws.path_segments_mut().map_err(|()| Error::UrlNotABase {
url: base.to_string(),
})?;
segs.push("ws");
}
Ok(ws)
}
fn classify(text: &str) -> WsMessage {
if text == "{}" || text.is_empty() {
return WsMessage::Heartbeat;
}
let value: Value = match serde_json::from_str(text) {
Ok(v) => v,
Err(e) => {
return WsMessage::Malformed {
text: text.to_owned(),
error: e.to_string(),
}
}
};
let topic = value
.get("topic")
.and_then(Value::as_str)
.unwrap_or_default();
if let Some(rest) = topic.strip_prefix("smd+") {
if let Ok(conid) = rest.parse::<i64>() {
return WsMessage::MarketData {
conid,
payload: value,
};
}
}
match topic {
"system" => WsMessage::System(value),
"sor" | "ortd" | "ord" => WsMessage::Order(value),
"spl" | "pnl" | "ssd" | "ssl" => WsMessage::Pnl(value),
_ => WsMessage::Other(value),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_url_flips_https_to_wss_and_appends_ws() {
let base = Url::parse("https://localhost:5000/v1/api").unwrap();
let ws = ws_url_from_base(&base).unwrap();
assert_eq!(ws.as_str(), "wss://localhost:5000/v1/api/ws");
}
#[test]
fn ws_url_flips_http_to_ws() {
let base = Url::parse("http://localhost:8080/v1/api").unwrap();
let ws = ws_url_from_base(&base).unwrap();
assert_eq!(ws.as_str(), "ws://localhost:8080/v1/api/ws");
}
#[test]
fn classify_identifies_market_data_by_topic() {
let raw = r#"{"topic":"smd+265598","31":"150.25","_updated":1700000000}"#;
match classify(raw) {
WsMessage::MarketData { conid, .. } => assert_eq!(conid, 265_598),
other => panic!("expected MarketData, got {other:?}"),
}
}
#[test]
fn classify_empty_brace_is_heartbeat() {
assert!(matches!(classify("{}"), WsMessage::Heartbeat));
}
#[test]
fn classify_system_topic() {
let raw = r#"{"topic":"system","msg":"ready"}"#;
assert!(matches!(classify(raw), WsMessage::System(_)));
}
#[test]
fn classify_malformed_text() {
assert!(matches!(classify("not-json"), WsMessage::Malformed { .. }));
}
#[test]
fn ws_message_topic_is_static_label() {
assert_eq!(WsMessage::Heartbeat.topic(), "heartbeat");
assert_eq!(
WsMessage::MarketData {
conid: 1,
payload: serde_json::json!({})
}
.topic(),
"market_data"
);
assert_eq!(WsMessage::Order(serde_json::json!({})).topic(), "order");
assert_eq!(WsMessage::Pnl(serde_json::json!({})).topic(), "pnl");
assert_eq!(WsMessage::System(serde_json::json!({})).topic(), "system");
assert_eq!(WsMessage::Other(serde_json::json!({})).topic(), "other");
assert_eq!(
WsMessage::Malformed {
text: "x".into(),
error: "y".into()
}
.topic(),
"malformed"
);
}
#[test]
fn ws_message_as_value_returns_payload_for_data_variants() {
let v = serde_json::json!({"hello": "world"});
assert_eq!(WsMessage::Order(v.clone()).as_value(), Some(&v));
assert_eq!(
WsMessage::MarketData {
conid: 1,
payload: v.clone()
}
.as_value(),
Some(&v)
);
assert_eq!(WsMessage::Heartbeat.as_value(), None);
assert_eq!(
WsMessage::Malformed {
text: "x".into(),
error: "y".into()
}
.as_value(),
None
);
}
#[test]
fn subscription_cancel_payload_round_trips_topic() {
let sub = Subscription {
cancel_payload: "umd+265598+{}".into(),
name: "market_data:265598".into(),
};
assert_eq!(sub.cancel_payload(), "umd+265598+{}");
assert_eq!(sub.name, "market_data:265598");
}
}