use crate::error::{KrakyError, Result};
use futures_util::Stream;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
pub const DEFAULT_BUFFER_SIZE: usize = 1000;
#[derive(Debug, Clone)]
pub struct BackpressureConfig {
pub buffer_size: usize,
}
impl Default for BackpressureConfig {
fn default() -> Self {
Self {
buffer_size: DEFAULT_BUFFER_SIZE,
}
}
}
impl BackpressureConfig {
pub fn with_buffer_size(buffer_size: usize) -> Self {
Self { buffer_size }
}
}
#[derive(Debug, Default)]
pub struct SubscriptionStats {
pub delivered: AtomicU64,
pub dropped: AtomicU64,
}
impl SubscriptionStats {
pub fn delivered(&self) -> u64 {
self.delivered.load(Ordering::Relaxed)
}
pub fn dropped(&self) -> u64 {
self.dropped.load(Ordering::Relaxed)
}
pub fn drop_rate(&self) -> f64 {
let delivered = self.delivered() as f64;
let dropped = self.dropped() as f64;
let total = delivered + dropped;
if total == 0.0 {
0.0
} else {
(dropped / total) * 100.0
}
}
}
pub struct Subscription<T> {
receiver: mpsc::Receiver<T>,
id: String,
stats: Arc<SubscriptionStats>,
}
impl<T> Subscription<T> {
pub(crate) fn new(
receiver: mpsc::Receiver<T>,
id: String,
stats: Arc<SubscriptionStats>,
) -> Self {
Self {
receiver,
id,
stats,
}
}
pub async fn next(&mut self) -> Option<T> {
self.receiver.recv().await
}
pub fn id(&self) -> &str {
&self.id
}
pub fn stats(&self) -> &SubscriptionStats {
&self.stats
}
}
impl<T> Stream for Subscription<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.receiver).poll_recv(cx)
}
}
pub(crate) struct SubscriptionSender<T> {
sender: mpsc::Sender<T>,
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
channel: String,
pub(crate) symbol: String,
stats: Arc<SubscriptionStats>,
}
impl<T> SubscriptionSender<T> {
pub fn new(channel: String, symbol: String) -> (Self, Subscription<T>) {
Self::with_config(channel, symbol, BackpressureConfig::default())
}
pub fn with_config(
channel: String,
symbol: String,
config: BackpressureConfig,
) -> (Self, Subscription<T>) {
let (sender, receiver) = mpsc::channel(config.buffer_size);
let id = format!("{}-{}-{}", channel, symbol, uuid::Uuid::new_v4());
let stats = Arc::new(SubscriptionStats::default());
let subscription = Subscription::new(receiver, id.clone(), Arc::clone(&stats));
let sender = Self {
sender,
id,
channel,
symbol,
stats,
};
(sender, subscription)
}
pub fn send(&self, data: T) -> Result<()> {
match self.sender.try_send(data) {
Ok(()) => {
self.stats.delivered.fetch_add(1, Ordering::Relaxed);
Ok(())
}
Err(mpsc::error::TrySendError::Full(_)) => {
self.stats.dropped.fetch_add(1, Ordering::Relaxed);
Ok(()) }
Err(mpsc::error::TrySendError::Closed(_)) => {
Err(KrakyError::ChannelSend("subscription closed".to_string()))
}
}
}
#[allow(dead_code)]
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
}
pub(crate) struct SubscriptionManager {
#[cfg(feature = "orderbook")]
pub orderbook: Vec<SubscriptionSender<crate::models::OrderbookUpdate>>,
#[cfg(feature = "trades")]
pub trades: Vec<SubscriptionSender<crate::models::Trade>>,
#[cfg(feature = "ticker")]
pub ticker: Vec<SubscriptionSender<crate::models::Ticker>>,
#[cfg(feature = "ohlc")]
pub ohlc: Vec<SubscriptionSender<crate::models::OHLC>>,
}
impl Default for SubscriptionManager {
fn default() -> Self {
Self::new()
}
}
impl SubscriptionManager {
pub fn new() -> Self {
Self {
#[cfg(feature = "orderbook")]
orderbook: Vec::new(),
#[cfg(feature = "trades")]
trades: Vec::new(),
#[cfg(feature = "ticker")]
ticker: Vec::new(),
#[cfg(feature = "ohlc")]
ohlc: Vec::new(),
}
}
#[allow(dead_code)]
pub fn cleanup(&mut self) {
#[cfg(feature = "orderbook")]
self.orderbook.retain(|s| !s.is_closed());
#[cfg(feature = "trades")]
self.trades.retain(|s| !s.is_closed());
#[cfg(feature = "ticker")]
self.ticker.retain(|s| !s.is_closed());
#[cfg(feature = "ohlc")]
self.ohlc.retain(|s| !s.is_closed());
}
#[cfg(feature = "orderbook")]
pub fn dispatch_orderbook(&self, update: &crate::models::OrderbookUpdate) {
for data in &update.data {
for sub in &self.orderbook {
if sub.symbol == data.symbol || sub.symbol == "*" {
let _ = sub.send(update.clone());
}
}
}
}
#[cfg(feature = "trades")]
pub fn dispatch_trade(&self, update: &crate::models::TradeUpdate) {
for data in &update.data {
let trade = data.to_trade();
for sub in &self.trades {
if sub.symbol == trade.symbol || sub.symbol == "*" {
let _ = sub.send(trade.clone());
}
}
}
}
#[cfg(feature = "ticker")]
pub fn dispatch_ticker(&self, update: &crate::models::TickerUpdate) {
for data in &update.data {
let ticker = data.to_ticker();
for sub in &self.ticker {
if sub.symbol == ticker.symbol || sub.symbol == "*" {
let _ = sub.send(ticker.clone());
}
}
}
}
#[cfg(feature = "ohlc")]
pub fn dispatch_ohlc(&self, update: &crate::models::OHLCUpdate) {
for data in &update.data {
let ohlc = data.to_ohlc();
for sub in &self.ohlc {
if sub.symbol == ohlc.symbol || sub.symbol == "*" {
let _ = sub.send(ohlc.clone());
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_subscription_sender_receiver() {
let (sender, mut subscription) =
SubscriptionSender::<String>::new("test".to_string(), "BTC/USD".to_string());
sender.send("hello".to_string()).unwrap();
let msg = subscription.next().await;
assert_eq!(msg, Some("hello".to_string()));
}
#[test]
fn test_subscription_id_format() {
let (sender, subscription) =
SubscriptionSender::<String>::new("book".to_string(), "BTC/USD".to_string());
assert!(subscription.id().starts_with("book-BTC/USD-"));
assert!(sender.symbol == "BTC/USD");
}
#[tokio::test]
async fn test_backpressure_drops_messages() {
let config = BackpressureConfig::with_buffer_size(3);
let (sender, mut subscription) = SubscriptionSender::<String>::with_config(
"test".to_string(),
"BTC/USD".to_string(),
config,
);
sender.send("msg1".to_string()).unwrap();
sender.send("msg2".to_string()).unwrap();
sender.send("msg3".to_string()).unwrap();
sender.send("msg4".to_string()).unwrap();
sender.send("msg5".to_string()).unwrap();
assert_eq!(subscription.stats().delivered(), 3);
assert_eq!(subscription.stats().dropped(), 2);
assert_eq!(subscription.next().await, Some("msg1".to_string()));
assert_eq!(subscription.next().await, Some("msg2".to_string()));
assert_eq!(subscription.next().await, Some("msg3".to_string()));
}
#[test]
fn test_drop_rate_calculation() {
let stats = SubscriptionStats::default();
assert_eq!(stats.drop_rate(), 0.0);
stats.delivered.store(80, Ordering::Relaxed);
stats.dropped.store(20, Ordering::Relaxed);
assert!((stats.drop_rate() - 20.0).abs() < 0.001);
}
}