use std::{
fmt::Debug,
sync::{
Arc,
atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
},
time::Duration,
};
use alloy::signers::local::PrivateKeySigner;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use nautilus_common::live::get_runtime;
use nautilus_core::UUID4;
use nautilus_network::{
mode::ConnectionMode,
ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
websocket::{
AuthTracker, TransportBackend, WebSocketClient, WebSocketConfig, channel_message_handler,
},
};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use ustr::Ustr;
use super::{
error::{DeriveWsError, Result},
handler::{
DeriveWsMessage, FeedHandler, HandlerCommand, orderbook_subscribe_params,
ticker_subscribe_params, trades_subscribe_params,
},
messages::{
DeriveWsChannel, WsLoginParams, WsLoginResult, WsSubscribeParams, WsSubscribeResult,
WsUnsubscribeParams, WsUnsubscribeResult, methods, orderbook_channel, rate_limit_key_for,
ticker_channel, trades_channel,
},
};
use crate::{
common::{
consts::{
RECONNECT_BACKOFF_FACTOR, RECONNECT_BASE_BACKOFF, RECONNECT_JITTER_MS,
RECONNECT_MAX_BACKOFF, RECONNECT_TIMEOUT, WS_HEARTBEAT_SECS, WS_REQUEST_TIMEOUT,
},
enums::DeriveEnvironment,
rate_limit::{self, DERIVE_MATCHING_RATE_KEY},
urls,
},
http::{
models::{
DeriveEmptyResult, DeriveOpenOrdersResult, DeriveOrder, DeriveOrderResult,
DeriveReplaceResult,
},
query::{
DeriveCancelAllParams, DeriveCancelParams, DeriveCancelTriggerOrderParams,
DeriveGetTriggerOrdersParams, DeriveOrderParams, DeriveReplaceParams,
DeriveTriggerOrderParams,
},
},
signing::auth::build_ws_login,
};
#[derive(Clone)]
pub struct DeriveWsCredentials {
pub wallet_address: String,
pub signer: PrivateKeySigner,
}
impl DeriveWsCredentials {
pub fn new(wallet_address: impl Into<String>, session_key_hex: &str) -> Result<Self> {
let signer: PrivateKeySigner = session_key_hex
.parse()
.map_err(|e| DeriveWsError::transport(format!("invalid session key: {e}")))?;
Ok(Self {
wallet_address: wallet_address.into(),
signer,
})
}
}
impl Debug for DeriveWsCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(DeriveWsCredentials))
.field("wallet_address", &self.wallet_address)
.field("signer", &"***redacted***")
.finish()
}
}
type WsRateLimiter = RateLimiter<Ustr, MonotonicClock>;
#[derive(Debug)]
pub struct DeriveWebSocketClient {
url: String,
transport_backend: TransportBackend,
proxy_url: Option<String>,
connection_mode: Arc<ArcSwap<AtomicU8>>,
signal: Arc<AtomicBool>,
auth_tracker: AuthTracker,
credentials: Option<DeriveWsCredentials>,
next_id: Arc<AtomicU64>,
cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
out_rx: Option<tokio::sync::mpsc::UnboundedReceiver<DeriveWsMessage>>,
subscriptions: Arc<DashMap<String, ()>>,
task_handle: Option<tokio::task::JoinHandle<()>>,
request_timeout: Duration,
conn_id: Arc<ArcSwap<String>>,
rate_limiter: Arc<WsRateLimiter>,
}
#[derive(Debug, Clone)]
pub struct DeriveWebSocketSubscriptionHandle {
cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
subscriptions: Arc<DashMap<String, ()>>,
request_timeout: Duration,
rate_limiter: Arc<WsRateLimiter>,
}
#[derive(Debug, Clone)]
pub struct DeriveWsExecutionHandle {
cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
request_timeout: Duration,
conn_id: Arc<ArcSwap<String>>,
rate_limiter: Arc<WsRateLimiter>,
}
impl DeriveWebSocketClient {
#[must_use]
pub fn new(
url: Option<String>,
environment: DeriveEnvironment,
transport_backend: TransportBackend,
proxy_url: Option<String>,
) -> Self {
let url = url.unwrap_or_else(|| urls::ws_url(environment).to_string());
Self::build(url, transport_backend, proxy_url, None, None)
}
#[must_use]
pub fn with_credentials(
url: Option<String>,
environment: DeriveEnvironment,
transport_backend: TransportBackend,
proxy_url: Option<String>,
credentials: DeriveWsCredentials,
max_matching_requests_per_second: Option<u32>,
) -> Self {
let url = url.unwrap_or_else(|| urls::ws_url(environment).to_string());
let matching_quota = rate_limit::matching_quota(max_matching_requests_per_second);
Self::build(
url,
transport_backend,
proxy_url,
Some(credentials),
Some(matching_quota),
)
}
fn build(
url: String,
transport_backend: TransportBackend,
proxy_url: Option<String>,
credentials: Option<DeriveWsCredentials>,
matching_quota: Option<Quota>,
) -> Self {
let connection_mode = Arc::new(ArcSwap::new(Arc::new(AtomicU8::new(
ConnectionMode::Closed as u8,
))));
let (placeholder_tx, _) = tokio::sync::mpsc::unbounded_channel();
let mut keyed_quotas: Vec<(Ustr, Quota)> = Vec::new();
if let Some(quota) = matching_quota {
keyed_quotas.push((Ustr::from(DERIVE_MATCHING_RATE_KEY), quota));
}
let rate_limiter = Arc::new(RateLimiter::new_with_quota(
Some(rate_limit::non_matching_quota()),
keyed_quotas,
));
Self {
url,
transport_backend,
proxy_url,
connection_mode,
signal: Arc::new(AtomicBool::new(false)),
auth_tracker: AuthTracker::new(),
credentials,
next_id: Arc::new(AtomicU64::new(1)),
cmd_tx: Arc::new(tokio::sync::RwLock::new(placeholder_tx)),
out_rx: None,
subscriptions: Arc::new(DashMap::new()),
task_handle: None,
request_timeout: WS_REQUEST_TIMEOUT,
conn_id: Arc::new(ArcSwap::from_pointee(UUID4::new().to_string())),
rate_limiter,
}
}
#[must_use]
pub fn url(&self) -> &str {
&self.url
}
#[must_use]
pub fn is_authenticated(&self) -> bool {
self.auth_tracker.is_authenticated()
}
#[must_use]
pub fn is_active(&self) -> bool {
self.connection_mode.load().load(Ordering::Relaxed) == ConnectionMode::Active as u8
}
pub async fn connect(&mut self) -> Result<()> {
let auth_ok = self.credentials.is_none() || self.is_authenticated();
if self.is_active() && auth_ok && self.task_handle.is_some() {
log::warn!("Derive WebSocket already connected");
return Ok(());
}
if self.task_handle.is_some() {
log::debug!("Tearing down stale Derive WebSocket state before connect");
self.teardown().await;
}
let (message_handler, raw_rx) = channel_message_handler();
let cfg = WebSocketConfig {
url: self.url.clone(),
headers: vec![],
heartbeat: Some(WS_HEARTBEAT_SECS),
heartbeat_msg: None,
reconnect_timeout_ms: Some(RECONNECT_TIMEOUT.as_millis() as u64),
reconnect_delay_initial_ms: Some(RECONNECT_BASE_BACKOFF.as_millis() as u64),
reconnect_delay_max_ms: Some(RECONNECT_MAX_BACKOFF.as_millis() as u64),
reconnect_backoff_factor: Some(RECONNECT_BACKOFF_FACTOR),
reconnect_jitter_ms: Some(RECONNECT_JITTER_MS),
reconnect_max_attempts: None,
idle_timeout_ms: None,
backend: self.transport_backend,
proxy_url: self.proxy_url.clone(),
};
let client = WebSocketClient::connect(cfg, Some(message_handler), None, None, vec![], None)
.await
.map_err(|e| DeriveWsError::transport(e.to_string()))?;
client.set_auth_tracker(self.auth_tracker.clone(), false);
let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<DeriveWsMessage>();
*self.cmd_tx.write().await = cmd_tx.clone();
self.out_rx = Some(out_rx);
self.conn_id.store(Arc::new(UUID4::new().to_string()));
self.connection_mode.store(client.connection_mode_atomic());
log::info!("Derive WebSocket connected: {}", self.url);
if let Err(e) = cmd_tx.send(HandlerCommand::SetClient(client)) {
return Err(DeriveWsError::transport(format!(
"failed to send SetClient command: {e}",
)));
}
let signal = Arc::clone(&self.signal);
let auth_tracker = self.auth_tracker.clone();
let next_id = Arc::clone(&self.next_id);
let credentials = self.credentials.clone();
let subscriptions = Arc::clone(&self.subscriptions);
let conn_id = Arc::clone(&self.conn_id);
let cmd_tx_for_loop = cmd_tx.clone();
let rate_limiter = Arc::clone(&self.rate_limiter);
let request_timeout = self.request_timeout;
let stream_handle = get_runtime().spawn(async move {
let mut handler =
FeedHandler::new(signal, cmd_rx, raw_rx, next_id, auth_tracker.clone());
loop {
match handler.next().await {
Some(DeriveWsMessage::Reconnected) => {
log::info!("Derive WebSocket re-establishing session after reconnect");
conn_id.store(Arc::new(UUID4::new().to_string()));
if out_tx.send(DeriveWsMessage::Reconnected).is_err() {
log::debug!("Derive outer receiver dropped, exiting stream loop");
break;
}
let cmd_tx_async = cmd_tx_for_loop.clone();
let auth_tracker_async = auth_tracker.clone();
let creds_async = credentials.clone();
let subs_async = Arc::clone(&subscriptions);
let rate_limiter_async = Arc::clone(&rate_limiter);
get_runtime().spawn(async move {
if let Some(creds) = creds_async
&& let Err(e) = login_via_handler(
&rate_limiter_async,
&cmd_tx_async,
&auth_tracker_async,
&creds,
request_timeout,
)
.await
{
log::error!("Derive WebSocket re-login failed: {e}");
}
let channels: Vec<String> =
subs_async.iter().map(|e| e.key().clone()).collect();
for channel in channels {
if let Err(e) = subscribe_via_handler(
&rate_limiter_async,
&cmd_tx_async,
vec![channel.clone()],
request_timeout,
)
.await
{
log::error!(
"Derive WebSocket resubscribe failed for {channel}: {e}",
);
}
}
});
}
Some(msg) => {
if out_tx.send(msg).is_err() {
log::debug!("Derive outer receiver dropped, exiting stream loop");
break;
}
}
None => {
log::debug!("Derive handler task ended");
break;
}
}
}
});
self.task_handle = Some(stream_handle);
if let Some(creds) = self.credentials.clone()
&& let Err(e) = login_via_handler(
&self.rate_limiter,
&cmd_tx,
&self.auth_tracker,
&creds,
self.request_timeout,
)
.await
{
log::warn!("Derive WebSocket login failed; tearing down transport: {e}");
self.teardown().await;
return Err(e);
}
Ok(())
}
async fn teardown(&mut self) {
self.signal.store(true, Ordering::Relaxed);
if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
log::debug!(
"Failed to enqueue Disconnect command (handler may already be shut down): {e}",
);
}
if let Some(handle) = self.task_handle.take() {
let abort_handle = handle.abort_handle();
tokio::select! {
result = handle => match result {
Ok(()) => log::debug!("Derive WebSocket task completed"),
Err(e) if e.is_cancelled() => log::debug!("Derive WebSocket task cancelled"),
Err(e) => log::error!("Derive WebSocket task error: {e:?}"),
},
() = tokio::time::sleep(Duration::from_secs(2)) => {
log::warn!("Timeout waiting for Derive WebSocket task, aborting");
abort_handle.abort();
}
}
}
let (placeholder_tx, _) = tokio::sync::mpsc::unbounded_channel();
*self.cmd_tx.write().await = placeholder_tx;
self.out_rx = None;
self.connection_mode
.store(Arc::new(AtomicU8::new(ConnectionMode::Closed as u8)));
self.auth_tracker.invalidate();
self.subscriptions.clear();
self.signal.store(false, Ordering::Relaxed);
}
pub async fn disconnect(&mut self) -> Result<()> {
log::info!("Disconnecting Derive WebSocket");
self.teardown().await;
Ok(())
}
pub async fn subscribe_ticker(&self, instrument_name: &str, interval: &str) -> Result<()> {
self.subscription_handle()
.subscribe_ticker(instrument_name, interval)
.await
}
pub async fn unsubscribe_ticker(&self, instrument_name: &str, interval: &str) -> Result<()> {
self.subscription_handle()
.unsubscribe_ticker(instrument_name, interval)
.await
}
pub async fn subscribe_orderbook(
&self,
instrument_name: &str,
group: &str,
depth: &str,
) -> Result<()> {
self.subscription_handle()
.subscribe_orderbook(instrument_name, group, depth)
.await
}
pub async fn unsubscribe_orderbook(
&self,
instrument_name: &str,
group: &str,
depth: &str,
) -> Result<()> {
self.subscription_handle()
.unsubscribe_orderbook(instrument_name, group, depth)
.await
}
pub async fn subscribe_trades(&self, instrument_type: &str, currency: &str) -> Result<()> {
self.subscription_handle()
.subscribe_trades(instrument_type, currency)
.await
}
pub async fn unsubscribe_trades(&self, instrument_type: &str, currency: &str) -> Result<()> {
self.subscription_handle()
.unsubscribe_trades(instrument_type, currency)
.await
}
pub async fn subscribe_channels<C>(&self, channels: Vec<C>) -> Result<()>
where
C: Into<DeriveWsChannel>,
{
self.subscription_handle()
.subscribe_channels(channels)
.await
}
pub async fn unsubscribe_channels<C>(&self, channels: Vec<C>) -> Result<()>
where
C: Into<DeriveWsChannel>,
{
self.subscription_handle()
.unsubscribe_channels(channels)
.await
}
pub async fn next_event(&mut self) -> Option<DeriveWsMessage> {
if let Some(rx) = self.out_rx.as_mut() {
rx.recv().await
} else {
None
}
}
#[must_use]
pub fn subscription_count(&self) -> usize {
self.subscriptions.len()
}
#[must_use]
pub fn subscription_handle(&self) -> DeriveWebSocketSubscriptionHandle {
DeriveWebSocketSubscriptionHandle {
cmd_tx: Arc::clone(&self.cmd_tx),
subscriptions: Arc::clone(&self.subscriptions),
request_timeout: self.request_timeout,
rate_limiter: Arc::clone(&self.rate_limiter),
}
}
#[must_use]
pub fn execution_handle(&self) -> DeriveWsExecutionHandle {
DeriveWsExecutionHandle {
cmd_tx: Arc::clone(&self.cmd_tx),
request_timeout: self.request_timeout,
conn_id: Arc::clone(&self.conn_id),
rate_limiter: Arc::clone(&self.rate_limiter),
}
}
pub fn take_event_receiver(
&mut self,
) -> Option<tokio::sync::mpsc::UnboundedReceiver<DeriveWsMessage>> {
self.out_rx.take()
}
}
impl DeriveWebSocketSubscriptionHandle {
pub async fn subscribe_ticker(&self, instrument_name: &str, interval: &str) -> Result<()> {
let channel = ticker_channel(instrument_name, interval);
let params = ticker_subscribe_params(instrument_name, interval);
self.send_subscribe(channel, ¶ms).await
}
pub async fn unsubscribe_ticker(&self, instrument_name: &str, interval: &str) -> Result<()> {
let channel = ticker_channel(instrument_name, interval);
self.send_unsubscribe(channel).await
}
pub async fn subscribe_orderbook(
&self,
instrument_name: &str,
group: &str,
depth: &str,
) -> Result<()> {
let channel = orderbook_channel(instrument_name, group, depth);
let params = orderbook_subscribe_params(instrument_name, group, depth);
self.send_subscribe(channel, ¶ms).await
}
pub async fn unsubscribe_orderbook(
&self,
instrument_name: &str,
group: &str,
depth: &str,
) -> Result<()> {
let channel = orderbook_channel(instrument_name, group, depth);
self.send_unsubscribe(channel).await
}
pub async fn subscribe_trades(&self, instrument_type: &str, currency: &str) -> Result<()> {
let channel = trades_channel(instrument_type, currency);
let params = trades_subscribe_params(instrument_type, currency);
self.send_subscribe(channel, ¶ms).await
}
pub async fn unsubscribe_trades(&self, instrument_type: &str, currency: &str) -> Result<()> {
let channel = trades_channel(instrument_type, currency);
self.send_unsubscribe(channel).await
}
pub async fn subscribe_channels<C>(&self, channels: Vec<C>) -> Result<()>
where
C: Into<DeriveWsChannel>,
{
let channels = channels.into_iter().map(Into::into).collect::<Vec<_>>();
if channels.is_empty() {
return Ok(());
}
let topics = channel_topics(&channels);
let params = WsSubscribeParams { channels };
let cmd_tx = self.cmd_tx.read().await.clone();
let _: WsSubscribeResult = send_request(
&self.rate_limiter,
&cmd_tx,
methods::PUBLIC_SUBSCRIBE,
¶ms,
self.request_timeout,
)
.await?;
for channel in topics {
self.subscriptions.insert(channel, ());
}
Ok(())
}
pub async fn unsubscribe_channels<C>(&self, channels: Vec<C>) -> Result<()>
where
C: Into<DeriveWsChannel>,
{
let channels = channels.into_iter().map(Into::into).collect::<Vec<_>>();
if channels.is_empty() {
return Ok(());
}
let topics = channel_topics(&channels);
let params = WsUnsubscribeParams { channels };
let cmd_tx = self.cmd_tx.read().await.clone();
let _: WsUnsubscribeResult = send_request(
&self.rate_limiter,
&cmd_tx,
methods::PUBLIC_UNSUBSCRIBE,
¶ms,
self.request_timeout,
)
.await?;
for channel in topics {
self.subscriptions.remove(&channel);
}
Ok(())
}
async fn send_subscribe(&self, channel: String, params: &WsSubscribeParams) -> Result<()> {
let cmd_tx = self.cmd_tx.read().await.clone();
let _: WsSubscribeResult = send_request(
&self.rate_limiter,
&cmd_tx,
methods::PUBLIC_SUBSCRIBE,
params,
self.request_timeout,
)
.await?;
self.subscriptions.insert(channel, ());
Ok(())
}
async fn send_unsubscribe(&self, channel: String) -> Result<()> {
let params = WsUnsubscribeParams {
channels: vec![DeriveWsChannel::from(channel.clone())],
};
let cmd_tx = self.cmd_tx.read().await.clone();
let _: WsUnsubscribeResult = send_request(
&self.rate_limiter,
&cmd_tx,
methods::PUBLIC_UNSUBSCRIBE,
¶ms,
self.request_timeout,
)
.await?;
self.subscriptions.remove(&channel);
Ok(())
}
}
impl DeriveWsExecutionHandle {
#[must_use]
pub fn conn_id(&self) -> String {
self.conn_id.load_full().as_ref().clone()
}
pub async fn submit_order(&self, params: &DeriveOrderParams) -> Result<DeriveOrder> {
let cmd_tx = self.cmd_tx.read().await.clone();
let result: DeriveOrderResult = send_request_typed(
&self.rate_limiter,
&cmd_tx,
methods::PRIVATE_ORDER,
params,
self.request_timeout,
)
.await?;
Ok(result.order)
}
pub async fn submit_trigger_order(
&self,
params: &DeriveTriggerOrderParams,
) -> Result<DeriveOrder> {
let cmd_tx = self.cmd_tx.read().await.clone();
let result: DeriveOrderResult = send_request_typed(
&self.rate_limiter,
&cmd_tx,
methods::PRIVATE_TRIGGER_ORDER,
params,
self.request_timeout,
)
.await?;
Ok(result.order)
}
pub async fn modify_order(&self, params: &DeriveReplaceParams) -> Result<DeriveOrder> {
let cmd_tx = self.cmd_tx.read().await.clone();
let result: DeriveReplaceResult = send_request_typed(
&self.rate_limiter,
&cmd_tx,
methods::PRIVATE_REPLACE,
params,
self.request_timeout,
)
.await?;
Ok(result.order)
}
pub async fn cancel_order(&self, params: &DeriveCancelParams) -> Result<()> {
let cmd_tx = self.cmd_tx.read().await.clone();
let _: DeriveEmptyResult = send_request(
&self.rate_limiter,
&cmd_tx,
methods::PRIVATE_CANCEL,
params,
self.request_timeout,
)
.await?;
Ok(())
}
pub async fn cancel_trigger_order(
&self,
params: &DeriveCancelTriggerOrderParams,
) -> Result<DeriveOrder> {
let cmd_tx = self.cmd_tx.read().await.clone();
send_request_typed(
&self.rate_limiter,
&cmd_tx,
methods::PRIVATE_CANCEL_TRIGGER_ORDER,
params,
self.request_timeout,
)
.await
}
pub async fn get_trigger_orders(
&self,
params: &DeriveGetTriggerOrdersParams,
) -> Result<DeriveOpenOrdersResult> {
let cmd_tx = self.cmd_tx.read().await.clone();
send_request_typed(
&self.rate_limiter,
&cmd_tx,
methods::PRIVATE_GET_TRIGGER_ORDERS,
params,
self.request_timeout,
)
.await
}
pub async fn cancel_all_orders(&self, params: &DeriveCancelAllParams) -> Result<()> {
let cmd_tx = self.cmd_tx.read().await.clone();
let _: DeriveEmptyResult = send_request(
&self.rate_limiter,
&cmd_tx,
methods::PRIVATE_CANCEL_ALL,
params,
self.request_timeout,
)
.await?;
Ok(())
}
}
async fn send_raw<P>(
rate_limiter: &WsRateLimiter,
cmd_tx: &tokio::sync::mpsc::UnboundedSender<HandlerCommand>,
method: &'static str,
params: &P,
timeout: Duration,
) -> Result<Value>
where
P: Serialize + ?Sized,
{
let params = serde_json::to_value(params)?;
let rate_keys = [rate_limit_key_for(method)];
rate_limiter.await_keys_ready(Some(&rate_keys)).await;
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
cmd_tx
.send(HandlerCommand::Request {
method,
params,
response_tx,
})
.map_err(|e| DeriveWsError::transport(format!("failed to enqueue `{method}`: {e}")))?;
match tokio::time::timeout(timeout, response_rx).await {
Ok(Ok(outcome)) => outcome,
Ok(Err(_)) => Err(DeriveWsError::RequestCancelled {
method: method.to_owned(),
}),
Err(_) => Err(DeriveWsError::Timeout {
method: method.to_owned(),
}),
}
}
async fn send_request<P, R>(
rate_limiter: &WsRateLimiter,
cmd_tx: &tokio::sync::mpsc::UnboundedSender<HandlerCommand>,
method: &'static str,
params: &P,
timeout: Duration,
) -> Result<R>
where
P: Serialize + ?Sized,
R: Default + DeserializeOwned,
{
let value = send_raw(rate_limiter, cmd_tx, method, params, timeout).await?;
let typed = if value.is_null() {
R::default()
} else {
serde_json::from_value(value)?
};
Ok(typed)
}
async fn send_request_typed<P, R>(
rate_limiter: &WsRateLimiter,
cmd_tx: &tokio::sync::mpsc::UnboundedSender<HandlerCommand>,
method: &'static str,
params: &P,
timeout: Duration,
) -> Result<R>
where
P: Serialize + ?Sized,
R: DeserializeOwned,
{
let value = send_raw(rate_limiter, cmd_tx, method, params, timeout).await?;
Ok(serde_json::from_value(value)?)
}
fn channel_topics(channels: &[DeriveWsChannel]) -> Vec<String> {
channels.iter().map(ToString::to_string).collect()
}
async fn login_via_handler(
rate_limiter: &WsRateLimiter,
cmd_tx: &tokio::sync::mpsc::UnboundedSender<HandlerCommand>,
auth_tracker: &AuthTracker,
creds: &DeriveWsCredentials,
timeout: Duration,
) -> Result<()> {
let login = build_ws_login(&creds.wallet_address, &creds.signer)?;
let params = WsLoginParams {
wallet: login.wallet,
timestamp: login.timestamp,
signature: login.signature,
};
let _receiver = auth_tracker.begin();
match send_request::<_, WsLoginResult>(
rate_limiter,
cmd_tx,
methods::PUBLIC_LOGIN,
¶ms,
timeout,
)
.await
{
Ok(_) => {
auth_tracker.succeed();
log::info!("Derive WebSocket authenticated");
Ok(())
}
Err(e) => {
auth_tracker.fail(e.to_string());
Err(e)
}
}
}
async fn subscribe_via_handler(
rate_limiter: &WsRateLimiter,
cmd_tx: &tokio::sync::mpsc::UnboundedSender<HandlerCommand>,
channels: Vec<String>,
timeout: Duration,
) -> Result<()> {
let params = WsSubscribeParams {
channels: channels.into_iter().map(DeriveWsChannel::from).collect(),
};
let _: WsSubscribeResult = send_request(
rate_limiter,
cmd_tx,
methods::PUBLIC_SUBSCRIBE,
¶ms,
timeout,
)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
fn test_public_client_defaults_to_environment_url() {
let client = DeriveWebSocketClient::new(
None,
DeriveEnvironment::Mainnet,
TransportBackend::default(),
None,
);
assert!(client.url().starts_with("wss://"));
assert!(client.url().contains("api.lyra.finance"));
assert!(!client.is_authenticated());
assert!(!client.is_active());
assert_eq!(client.subscription_count(), 0);
}
#[rstest]
fn test_testnet_client_routes_to_demo_url() {
let client = DeriveWebSocketClient::new(
None,
DeriveEnvironment::Testnet,
TransportBackend::default(),
None,
);
assert!(client.url().contains("demo"));
}
#[rstest]
fn test_credentials_constructor_parses_session_key() {
let creds = DeriveWsCredentials::new(
"0x000000000000000000000000000000000000aaaa",
"0x2ae8be44db8a590d20bffbe3b6872df9b569147d3bf6801a35a28281a4816bbd",
)
.unwrap();
assert!(creds.wallet_address.starts_with("0x"));
let client = DeriveWebSocketClient::with_credentials(
None,
DeriveEnvironment::Testnet,
TransportBackend::default(),
None,
creds,
None,
);
assert!(client.url().contains("demo"));
assert!(!client.is_authenticated());
}
#[rstest]
fn test_credentials_debug_redacts_signer() {
let creds = DeriveWsCredentials::new(
"0xWALLET",
"0x2ae8be44db8a590d20bffbe3b6872df9b569147d3bf6801a35a28281a4816bbd",
)
.unwrap();
let debug = format!("{creds:?}");
assert!(debug.contains("redacted"));
assert!(debug.contains("0xWALLET"));
assert!(!debug.contains("2ae8be44"));
}
#[rstest]
fn test_credentials_constructor_rejects_invalid_session_key() {
let err = DeriveWsCredentials::new("0xWALLET", "not-a-hex-key").unwrap_err();
assert!(err.to_string().contains("invalid session key"));
}
#[rstest]
#[tokio::test]
async fn test_send_raw_times_out_when_no_response_arrives() {
let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
let rate_limiter: WsRateLimiter = RateLimiter::new_with_quota(None, Vec::new());
let err = send_raw(
&rate_limiter,
&cmd_tx,
methods::PRIVATE_ORDER,
&serde_json::json!({}),
Duration::from_millis(50),
)
.await
.expect_err("must time out");
match err {
DeriveWsError::Timeout { method } => assert_eq!(method, methods::PRIVATE_ORDER),
other => panic!("expected Timeout, was {other:?}"),
}
}
#[rstest]
#[tokio::test]
async fn test_send_request_typed_rejects_null_result() {
let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
tokio::spawn(async move {
if let Some(HandlerCommand::Request { response_tx, .. }) = cmd_rx.recv().await {
let _ = response_tx.send(Ok(Value::Null));
}
});
let rate_limiter: WsRateLimiter = RateLimiter::new_with_quota(None, Vec::new());
let result: Result<DeriveOrderResult> = send_request_typed(
&rate_limiter,
&cmd_tx,
methods::PRIVATE_ORDER,
&serde_json::json!({}),
Duration::from_secs(1),
)
.await;
assert!(matches!(result, Err(DeriveWsError::Serde(_))));
}
}