use std::sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
};
use ahash::AHashMap;
use nautilus_network::{
RECONNECTED,
websocket::{AuthTracker, WebSocketClient},
};
use serde_json::Value;
use tokio_tungstenite::tungstenite::Message;
use super::{
error::DeriveWsError,
messages::{DeriveWsChannel, DeriveWsFrame, WsSubscribeParams, WsSubscriptionPayload},
};
use crate::http::models::JsonRpcRequest;
#[derive(Debug)]
pub(super) enum HandlerCommand {
SetClient(WebSocketClient),
Request {
method: &'static str,
params: Value,
response_tx: tokio::sync::oneshot::Sender<Result<Value, DeriveWsError>>,
},
Disconnect,
}
#[derive(Debug, Clone)]
pub enum DeriveWsMessage {
Authenticated,
Reconnected,
Subscription(WsSubscriptionPayload),
}
pub(super) struct FeedHandler {
signal: Arc<AtomicBool>,
client: Option<WebSocketClient>,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
next_id: Arc<AtomicU64>,
pending: AHashMap<u64, tokio::sync::oneshot::Sender<Result<Value, DeriveWsError>>>,
auth_tracker: AuthTracker,
}
impl FeedHandler {
pub(super) fn new(
signal: Arc<AtomicBool>,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
next_id: Arc<AtomicU64>,
auth_tracker: AuthTracker,
) -> Self {
Self {
signal,
client: None,
cmd_rx,
raw_rx,
next_id,
pending: AHashMap::new(),
auth_tracker,
}
}
pub(super) async fn next(&mut self) -> Option<DeriveWsMessage> {
loop {
tokio::select! {
Some(cmd) = self.cmd_rx.recv() => {
match cmd {
HandlerCommand::SetClient(client) => {
log::debug!("Setting WebSocket client in Derive handler");
self.client = Some(client);
}
HandlerCommand::Request { method, params, response_tx } => {
self.dispatch_request(method, params, response_tx).await;
}
HandlerCommand::Disconnect => {
log::debug!("Derive handler received disconnect command");
if let Some(ref client) = self.client {
client.disconnect().await;
}
self.signal.store(true, Ordering::SeqCst);
return None;
}
}
}
Some(raw) = self.raw_rx.recv() => {
match raw {
Message::Text(text) => {
if text.as_str() == RECONNECTED {
log::info!("Derive WebSocket reconnected sentinel received");
self.auth_tracker.invalidate();
self.fail_pending("WebSocket reconnected before response was received");
return Some(DeriveWsMessage::Reconnected);
}
match DeriveWsFrame::parse(&text) {
Ok(DeriveWsFrame::Response { id, result, error }) => {
if let Some(sender) = self.pending.remove(&id) {
let outcome = match (result, error) {
(_, Some(err)) => Err(DeriveWsError::JsonRpc {
code: err.code,
message: err.message,
data: err.data,
}),
(Some(value), None) => Ok(value),
(None, None) => Ok(Value::Null),
};
let _ = sender.send(outcome);
} else {
log::debug!(
"Derive WebSocket response with unknown id={id} dropped",
);
}
}
Ok(DeriveWsFrame::Subscription(payload)) => {
return Some(DeriveWsMessage::Subscription(payload));
}
Ok(DeriveWsFrame::Unknown(value)) => {
log::debug!("Derive WebSocket unknown frame: {value}");
}
Err(e) => {
log::error!(
"Derive WebSocket frame parse error: {e}, text: {text}",
);
}
}
}
Message::Ping(data) => {
if let Some(ref client) = self.client
&& let Err(e) = client.send_pong(data.to_vec()).await {
log::error!("Derive WebSocket send_pong failed: {e}");
}
}
Message::Close(_) => {
log::info!("Derive WebSocket close frame received");
return None;
}
_ => {}
}
}
else => {
log::debug!("Derive handler shutting down: channels closed");
return None;
}
}
}
}
async fn dispatch_request(
&mut self,
method: &'static str,
params: Value,
response_tx: tokio::sync::oneshot::Sender<Result<Value, DeriveWsError>>,
) {
let Some(ref client) = self.client else {
let _ = response_tx.send(Err(DeriveWsError::NotConnected));
return;
};
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let request = JsonRpcRequest::new(id, method, params);
let payload = match serde_json::to_string(&request) {
Ok(p) => p,
Err(e) => {
let _ = response_tx.send(Err(DeriveWsError::Serde(e)));
return;
}
};
self.pending.insert(id, response_tx);
log::debug!("Derive WebSocket sending `{method}` id={id}");
if let Err(e) = client.send_text(payload, None).await
&& let Some(sender) = self.pending.remove(&id)
{
let _ = sender.send(Err(DeriveWsError::transport(e.to_string())));
}
}
fn fail_pending(&mut self, reason: &str) {
if self.pending.is_empty() {
return;
}
log::debug!(
"Failing {} pending Derive WebSocket request(s): {reason}",
self.pending.len(),
);
for (_, sender) in self.pending.drain() {
let _ = sender.send(Err(DeriveWsError::transport(reason.to_string())));
}
}
}
#[must_use]
pub(super) fn subscribe_params(channel: DeriveWsChannel) -> WsSubscribeParams {
WsSubscribeParams {
channels: vec![channel],
}
}
#[must_use]
pub(super) fn ticker_subscribe_params(instrument_name: &str, interval: &str) -> WsSubscribeParams {
subscribe_params(DeriveWsChannel::ticker_slim(instrument_name, interval))
}
#[must_use]
pub(super) fn orderbook_subscribe_params(
instrument_name: &str,
group: &str,
depth: &str,
) -> WsSubscribeParams {
subscribe_params(DeriveWsChannel::orderbook(instrument_name, group, depth))
}
#[must_use]
pub(super) fn trades_subscribe_params(instrument_type: &str, currency: &str) -> WsSubscribeParams {
subscribe_params(DeriveWsChannel::trades(instrument_type, currency))
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
fn test_subscribe_params_carries_single_channel() {
let params = subscribe_params(DeriveWsChannel::ticker_slim("ETH-PERP", "1000"));
assert_eq!(
params.channels,
vec![DeriveWsChannel::ticker_slim("ETH-PERP", "1000")],
);
}
#[rstest]
fn test_ticker_subscribe_params_formats_topic() {
let params = ticker_subscribe_params("ETH-PERP", "1000");
assert_eq!(
params.channels,
vec![DeriveWsChannel::ticker_slim("ETH-PERP", "1000")],
);
}
#[rstest]
fn test_orderbook_subscribe_params_formats_topic() {
let params = orderbook_subscribe_params("ETH-PERP", "1", "10");
assert_eq!(
params.channels,
vec![DeriveWsChannel::orderbook("ETH-PERP", "1", "10")],
);
}
#[rstest]
fn test_trades_subscribe_params_formats_topic() {
let params = trades_subscribe_params("perp", "ETH");
assert_eq!(
params.channels,
vec![DeriveWsChannel::trades("perp", "ETH")],
);
}
#[rstest]
#[tokio::test]
async fn test_dispatch_request_without_client_returns_not_connected() {
let signal = Arc::new(AtomicBool::new(false));
let (_cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
let (_raw_tx, raw_rx) = tokio::sync::mpsc::unbounded_channel();
let next_id = Arc::new(AtomicU64::new(1));
let auth_tracker = AuthTracker::new();
let mut handler = FeedHandler::new(signal, cmd_rx, raw_rx, next_id, auth_tracker);
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let params = serde_json::to_value(WsSubscribeParams { channels: vec![] }).unwrap();
handler
.dispatch_request("public/login", params, response_tx)
.await;
let outcome = response_rx.await.expect("oneshot resolved");
match outcome {
Err(DeriveWsError::NotConnected) => {}
other => panic!("expected NotConnected, was {other:?}"),
}
}
}