use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use whisky::WError;
use crate::responses::stream::StreamMessage;
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub max_retries: Option<u32>,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub connect_timeout_ms: u64,
pub jitter_factor: f64,
}
impl Default for ReconnectConfig {
fn default() -> Self {
ReconnectConfig {
max_retries: Some(10),
initial_delay_ms: 1000, max_delay_ms: 60000, backoff_multiplier: 2.0,
connect_timeout_ms: 30000, jitter_factor: 0.25, }
}
}
impl ReconnectConfig {
pub fn no_reconnect() -> Self {
ReconnectConfig {
max_retries: Some(0),
..Default::default()
}
}
pub fn infinite() -> Self {
ReconnectConfig {
max_retries: None,
..Default::default()
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = Some(max_retries);
self
}
pub fn with_initial_delay_ms(mut self, delay: u64) -> Self {
self.initial_delay_ms = delay;
self
}
pub fn with_max_delay_ms(mut self, delay: u64) -> Self {
self.max_delay_ms = delay;
self
}
pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
pub fn with_connect_timeout_ms(mut self, timeout: u64) -> Self {
self.connect_timeout_ms = timeout;
self
}
pub fn with_jitter_factor(mut self, factor: f64) -> Self {
self.jitter_factor = factor.clamp(0.0, 1.0);
self
}
pub fn delay_for_attempt(&self, attempt: u32) -> u64 {
use rand::Rng;
let base_delay = self.initial_delay_ms as f64 * self.backoff_multiplier.powi(attempt as i32);
let capped = base_delay.min(self.max_delay_ms as f64);
if self.jitter_factor > 0.0 {
let jitter_range = capped * self.jitter_factor;
let jitter = rand::thread_rng().gen_range(-jitter_range..=jitter_range);
((capped + jitter) as u64).max(1) } else {
capped as u64
}
}
pub fn base_delay_for_attempt(&self, attempt: u32) -> u64 {
let delay = self.initial_delay_ms as f64 * self.backoff_multiplier.powi(attempt as i32);
(delay as u64).min(self.max_delay_ms)
}
pub fn should_retry(&self, attempt: u32) -> bool {
match self.max_retries {
None => true,
Some(max) => attempt < max,
}
}
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
Message(StreamMessage),
Connected,
Reconnecting {
attempt: u32,
delay_ms: u64,
},
Disconnected {
reason: String,
},
MaxRetriesExceeded,
}
#[derive(Debug)]
pub enum StreamError {
ConnectionFailed(String),
ConnectionClosed,
ChannelSendError,
StreamClosed,
ConnectionTimeout,
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StreamError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
StreamError::ConnectionClosed => write!(f, "Connection closed"),
StreamError::ChannelSendError => write!(f, "Failed to send message through channel"),
StreamError::StreamClosed => write!(f, "Stream was already closed"),
StreamError::ConnectionTimeout => write!(f, "Connection timeout exceeded"),
}
}
}
impl std::error::Error for StreamError {}
impl From<StreamError> for WError {
fn from(err: StreamError) -> Self {
WError::new("StreamError", &err.to_string())
}
}
#[derive(Debug)]
pub struct StreamHandle {
close_tx: Option<mpsc::Sender<()>>,
}
impl StreamHandle {
pub async fn close(&mut self) {
if let Some(tx) = self.close_tx.take() {
let _ = tx.send(()).await;
}
}
pub fn is_active(&self) -> bool {
self.close_tx.is_some()
}
}
impl Drop for StreamHandle {
fn drop(&mut self) {
if let Some(tx) = self.close_tx.take() {
let _ = tx.try_send(());
}
}
}
enum ConnectionResult {
UserClosed,
Error(String, bool),
ReceiverDropped,
}
pub struct AccountStream {
ws_url: String,
api_key: String,
}
impl AccountStream {
pub fn new(ws_url: String, api_key: String) -> Self {
AccountStream { ws_url, api_key }
}
pub async fn subscribe(
&self,
buffer_size: Option<usize>,
) -> Result<(StreamHandle, mpsc::Receiver<StreamMessage>), WError> {
let buffer = buffer_size.unwrap_or(100);
let (message_tx, message_rx) = mpsc::channel::<StreamMessage>(buffer);
let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
let ws_endpoint = format!("{}/accounts/stream?api_key={}", self.ws_url, self.api_key);
let connect_timeout = Duration::from_secs(30);
let (ws_stream, _response) = timeout(connect_timeout, connect_async(&ws_endpoint))
.await
.map_err(|_| WError::new("AccountStream", "Connection timeout"))?
.map_err(|e| WError::new("AccountStream", &format!("Connection failed: {}", e)))?;
let (mut write, mut read) = ws_stream.split();
tokio::spawn(async move {
loop {
tokio::select! {
_ = close_rx.recv() => {
let _ = write.send(Message::Close(None)).await;
break;
}
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let stream_msg = StreamMessage::from_json(&text);
if message_tx.send(stream_msg).await.is_err() {
let _ = write.send(Message::Close(None)).await;
break;
}
}
Some(Ok(Message::Ping(data))) => {
if write.send(Message::Pong(data)).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) => {
break;
}
Some(Err(_)) => {
break;
}
None => {
break;
}
_ => {
}
}
}
}
}
});
let handle = StreamHandle {
close_tx: Some(close_tx),
};
Ok((handle, message_rx))
}
pub async fn subscribe_with_reconnect(
&self,
buffer_size: Option<usize>,
reconnect_config: Option<ReconnectConfig>,
) -> Result<(StreamHandle, mpsc::Receiver<StreamEvent>), WError> {
let buffer = buffer_size.unwrap_or(100);
let config = reconnect_config.unwrap_or_default();
let (event_tx, event_rx) = mpsc::channel::<StreamEvent>(buffer);
let (close_tx, close_rx) = mpsc::channel::<()>(1);
let ws_endpoint = format!("{}/accounts/stream?api_key={}", self.ws_url, self.api_key);
tokio::spawn(Self::run_reconnecting_stream(
ws_endpoint,
config,
event_tx,
close_rx,
));
let handle = StreamHandle {
close_tx: Some(close_tx),
};
Ok((handle, event_rx))
}
async fn run_single_connection(
ws_endpoint: &str,
connect_timeout: Duration,
event_tx: &mpsc::Sender<StreamEvent>,
close_rx: &mut mpsc::Receiver<()>,
) -> ConnectionResult {
let connect_result = timeout(connect_timeout, connect_async(ws_endpoint)).await;
let ws_stream = match connect_result {
Ok(Ok((stream, _response))) => stream,
Ok(Err(e)) => {
return ConnectionResult::Error(format!("Connection failed: {}", e), false);
}
Err(_) => {
return ConnectionResult::Error("Connection timeout".to_string(), false);
}
};
if event_tx.send(StreamEvent::Connected).await.is_err() {
return ConnectionResult::ReceiverDropped;
}
let (mut write, mut read) = ws_stream.split();
loop {
tokio::select! {
_ = close_rx.recv() => {
let _ = write.send(Message::Close(None)).await;
return ConnectionResult::UserClosed;
}
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let stream_msg = StreamMessage::from_json(&text);
let event = StreamEvent::Message(stream_msg);
if event_tx.send(event).await.is_err() {
let _ = write.send(Message::Close(None)).await;
return ConnectionResult::ReceiverDropped;
}
}
Some(Ok(Message::Ping(data))) => {
if write.send(Message::Pong(data)).await.is_err() {
return ConnectionResult::Error("Failed to send pong".to_string(), true);
}
}
Some(Ok(Message::Close(frame))) => {
let reason = frame
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "Server closed connection".to_string());
return ConnectionResult::Error(reason, true);
}
Some(Err(e)) => {
return ConnectionResult::Error(format!("WebSocket error: {}", e), true);
}
None => {
return ConnectionResult::Error("Stream ended unexpectedly".to_string(), true);
}
_ => {
}
}
}
}
}
}
async fn run_reconnecting_stream(
ws_endpoint: String,
config: ReconnectConfig,
event_tx: mpsc::Sender<StreamEvent>,
mut close_rx: mpsc::Receiver<()>,
) {
let connect_timeout = Duration::from_millis(config.connect_timeout_ms);
let mut attempt: u32 = 0;
loop {
if close_rx.try_recv().is_ok() {
break;
}
let result = Self::run_single_connection(
&ws_endpoint,
connect_timeout,
&event_tx,
&mut close_rx,
)
.await;
match result {
ConnectionResult::UserClosed => {
break;
}
ConnectionResult::ReceiverDropped => {
break;
}
ConnectionResult::Error(reason, was_connected) => {
if was_connected {
attempt = 0;
}
let _ = event_tx
.send(StreamEvent::Disconnected {
reason: reason.clone(),
})
.await;
if !config.should_retry(attempt) {
let _ = event_tx.send(StreamEvent::MaxRetriesExceeded).await;
break;
}
let delay_ms = config.delay_for_attempt(attempt);
attempt += 1;
let _ = event_tx
.send(StreamEvent::Reconnecting { attempt, delay_ms })
.await;
tokio::select! {
_ = close_rx.recv() => {
break;
}
_ = tokio::time::sleep(Duration::from_millis(delay_ms)) => {
}
}
}
}
}
}
pub async fn subscribe_with_callback<F, Fut>(&self, mut callback: F) -> Result<(), WError>
where
F: FnMut(StreamMessage) -> Fut,
Fut: std::future::Future<Output = bool>,
{
let (mut handle, mut receiver) = self.subscribe(None).await?;
while let Some(message) = receiver.recv().await {
if !callback(message).await {
break;
}
}
handle.close().await;
Ok(())
}
pub async fn subscribe_with_reconnect_callback<F, Fut>(
&self,
reconnect_config: Option<ReconnectConfig>,
mut callback: F,
) -> Result<(), WError>
where
F: FnMut(StreamEvent) -> Fut,
Fut: std::future::Future<Output = bool>,
{
let (mut handle, mut receiver) = self.subscribe_with_reconnect(None, reconnect_config).await?;
while let Some(event) = receiver.recv().await {
if !callback(event).await {
break;
}
}
handle.close().await;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_message_parsing() {
let balance_json = r#"{
"type": "Account",
"sub_type": "balance",
"balance": [
{"asset": "ADA", "asset_unit": "lovelace", "free": "1000000", "locked": "0"}
]
}"#;
let msg = StreamMessage::from_json(balance_json);
assert!(msg.is_balance());
assert!(!msg.is_order_info());
if let StreamMessage::Balance(balance_msg) = msg {
assert_eq!(balance_msg.balance.len(), 1);
assert_eq!(balance_msg.balance[0].asset, "ADA");
}
let order_json = r#"{
"type": "Account",
"sub_type": "order_info",
"order": {
"id": "order-123",
"account_id": "acc-456",
"status": "open",
"symbol": "ADAUSDM",
"base_qty": "100",
"quote_qty": "50",
"side": "buy",
"price": "0.5",
"type": "limit",
"locked_base_qty": "0",
"locked_quote_qty": "50",
"executed_base_qty": "0",
"executed_quote_qty": "0",
"ob_open_order_base_qty": "100",
"commission_unit": "USDM",
"commission": "0",
"commission_rate_bp": 30,
"executed_price": "0",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
}"#;
let msg = StreamMessage::from_json(order_json);
assert!(msg.is_order_info());
if let StreamMessage::OrderInfo(order_msg) = msg {
assert_eq!(order_msg.order.id, "order-123");
assert_eq!(order_msg.order.symbol, "ADAUSDM");
}
let points_json = r#"{
"type": "Account",
"sub_type": "dlta_points",
"dlta_points": {
"delta": "100",
"new_total": "1000",
"season_points": "500",
"source_type": "trade",
"source_ref": "order-123",
"league": "gold"
}
}"#;
let msg = StreamMessage::from_json(points_json);
assert!(msg.is_dlta_points());
if let StreamMessage::DltaPoints(points_msg) = msg {
assert_eq!(points_msg.dlta_points.delta, "100");
assert_eq!(points_msg.dlta_points.league, "gold");
}
let unknown_json = r#"{"type": "Unknown", "data": "test"}"#;
let msg = StreamMessage::from_json(unknown_json);
assert!(msg.is_unknown());
}
#[test]
fn test_reconnect_config_defaults() {
let config = ReconnectConfig::default();
assert_eq!(config.max_retries, Some(10));
assert_eq!(config.initial_delay_ms, 1000);
assert_eq!(config.max_delay_ms, 60000);
assert_eq!(config.backoff_multiplier, 2.0);
assert_eq!(config.connect_timeout_ms, 30000);
assert_eq!(config.jitter_factor, 0.25);
}
#[test]
fn test_reconnect_config_no_reconnect() {
let config = ReconnectConfig::no_reconnect();
assert_eq!(config.max_retries, Some(0));
assert!(!config.should_retry(0));
}
#[test]
fn test_reconnect_config_infinite() {
let config = ReconnectConfig::infinite();
assert_eq!(config.max_retries, None);
assert!(config.should_retry(0));
assert!(config.should_retry(100));
assert!(config.should_retry(1000));
}
#[test]
fn test_reconnect_config_builder() {
let config = ReconnectConfig::default()
.with_max_retries(5)
.with_initial_delay_ms(500)
.with_max_delay_ms(30000)
.with_backoff_multiplier(1.5)
.with_connect_timeout_ms(10000)
.with_jitter_factor(0.1);
assert_eq!(config.max_retries, Some(5));
assert_eq!(config.initial_delay_ms, 500);
assert_eq!(config.max_delay_ms, 30000);
assert_eq!(config.backoff_multiplier, 1.5);
assert_eq!(config.connect_timeout_ms, 10000);
assert_eq!(config.jitter_factor, 0.1);
let config_clamped = ReconnectConfig::default().with_jitter_factor(1.5);
assert_eq!(config_clamped.jitter_factor, 1.0);
let config_clamped_neg = ReconnectConfig::default().with_jitter_factor(-0.5);
assert_eq!(config_clamped_neg.jitter_factor, 0.0);
}
#[test]
fn test_reconnect_config_delay_calculation() {
let config = ReconnectConfig::default()
.with_initial_delay_ms(1000)
.with_max_delay_ms(60000)
.with_backoff_multiplier(2.0);
assert_eq!(config.base_delay_for_attempt(0), 1000); assert_eq!(config.base_delay_for_attempt(1), 2000); assert_eq!(config.base_delay_for_attempt(2), 4000); assert_eq!(config.base_delay_for_attempt(3), 8000); assert_eq!(config.base_delay_for_attempt(4), 16000); assert_eq!(config.base_delay_for_attempt(5), 32000); assert_eq!(config.base_delay_for_attempt(6), 60000); assert_eq!(config.base_delay_for_attempt(10), 60000); }
#[test]
fn test_reconnect_config_jitter_range() {
let config = ReconnectConfig::default()
.with_initial_delay_ms(1000)
.with_max_delay_ms(60000)
.with_backoff_multiplier(2.0)
.with_jitter_factor(0.25);
for _ in 0..100 {
let delay = config.delay_for_attempt(0);
assert!(delay >= 750 && delay <= 1250, "delay {} out of expected range [750, 1250]", delay);
}
let config_no_jitter = ReconnectConfig::default()
.with_initial_delay_ms(1000)
.with_jitter_factor(0.0);
assert_eq!(config_no_jitter.delay_for_attempt(0), 1000);
assert_eq!(config_no_jitter.delay_for_attempt(1), 2000);
}
#[test]
fn test_reconnect_config_should_retry() {
let config = ReconnectConfig::default().with_max_retries(3);
assert!(config.should_retry(0));
assert!(config.should_retry(1));
assert!(config.should_retry(2));
assert!(!config.should_retry(3));
assert!(!config.should_retry(4));
}
#[test]
fn test_stream_event_variants() {
let _connected = StreamEvent::Connected;
let _disconnected = StreamEvent::Disconnected {
reason: "test".to_string(),
};
let _reconnecting = StreamEvent::Reconnecting {
attempt: 1,
delay_ms: 1000,
};
let _max_retries = StreamEvent::MaxRetriesExceeded;
let balance_json = r#"{
"type": "Account",
"sub_type": "balance",
"balance": []
}"#;
let msg = StreamMessage::from_json(balance_json);
let _message_event = StreamEvent::Message(msg);
}
}