use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use log::{debug, warn};
use tokio::sync::mpsc;
use super::common::{process_decode_result, DecoderContext, ProcessingResult};
use super::StreamDecoder;
use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage};
use crate::transport::{AsyncInternalSubscription, AsyncMessageBus};
use crate::Error;
type CancelFn = Box<dyn Fn(i32, Option<i32>, Option<&DecoderContext>) -> Result<RequestMessage, Error> + Send + Sync>;
type DecoderFn<T> = Arc<dyn Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync>;
pub struct Subscription<T> {
inner: SubscriptionInner<T>,
request_id: Option<i32>,
order_id: Option<i32>,
_message_type: Option<OutgoingMessages>,
context: DecoderContext,
cancelled: Arc<AtomicBool>,
stream_ended: Arc<AtomicBool>,
message_bus: Option<Arc<dyn AsyncMessageBus>>,
cancel_fn: Option<Arc<CancelFn>>,
}
enum SubscriptionInner<T> {
WithDecoder {
subscription: AsyncInternalSubscription,
decoder: DecoderFn<T>,
context: DecoderContext,
},
PreDecoded { receiver: mpsc::UnboundedReceiver<Result<T, Error>> },
}
impl<T> Clone for SubscriptionInner<T> {
fn clone(&self) -> Self {
match self {
SubscriptionInner::WithDecoder {
subscription,
decoder,
context,
} => SubscriptionInner::WithDecoder {
subscription: subscription.clone(),
decoder: decoder.clone(),
context: context.clone(),
},
SubscriptionInner::PreDecoded { .. } => {
panic!("Cannot clone pre-decoded subscriptions");
}
}
}
}
impl<T> Clone for Subscription<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
request_id: self.request_id,
order_id: self.order_id,
_message_type: self._message_type,
context: self.context.clone(),
cancelled: self.cancelled.clone(),
stream_ended: self.stream_ended.clone(),
message_bus: self.message_bus.clone(),
cancel_fn: self.cancel_fn.clone(),
}
}
}
impl<T> Subscription<T> {
#[allow(clippy::too_many_arguments)]
pub fn with_decoder<D>(
internal: AsyncInternalSubscription,
message_bus: Arc<dyn AsyncMessageBus>,
decoder: D,
request_id: Option<i32>,
order_id: Option<i32>,
message_type: Option<OutgoingMessages>,
context: DecoderContext,
) -> Self
where
D: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
{
Self {
inner: SubscriptionInner::WithDecoder {
subscription: internal,
decoder: Arc::new(decoder),
context: context.clone(),
},
request_id,
order_id,
_message_type: message_type,
context,
cancelled: Arc::new(AtomicBool::new(false)),
stream_ended: Arc::new(AtomicBool::new(false)),
message_bus: Some(message_bus),
cancel_fn: None,
}
}
#[allow(clippy::too_many_arguments)]
pub fn new_with_decoder<F>(
internal: AsyncInternalSubscription,
message_bus: Arc<dyn AsyncMessageBus>,
decoder: F,
request_id: Option<i32>,
order_id: Option<i32>,
message_type: Option<OutgoingMessages>,
context: DecoderContext,
) -> Self
where
F: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
{
Self::with_decoder(internal, message_bus, decoder, request_id, order_id, message_type, context)
}
#[allow(clippy::too_many_arguments)]
pub fn with_decoder_components<D>(
internal: AsyncInternalSubscription,
message_bus: Arc<dyn AsyncMessageBus>,
decoder: D,
request_id: Option<i32>,
order_id: Option<i32>,
message_type: Option<OutgoingMessages>,
context: DecoderContext,
) -> Self
where
D: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
{
Self::with_decoder(internal, message_bus, decoder, request_id, order_id, message_type, context)
}
pub(crate) fn new_from_internal<D>(
internal: AsyncInternalSubscription,
message_bus: Arc<dyn AsyncMessageBus>,
request_id: Option<i32>,
order_id: Option<i32>,
message_type: Option<OutgoingMessages>,
context: DecoderContext,
) -> Self
where
D: StreamDecoder<T> + 'static,
T: 'static,
{
let mut sub = Self::with_decoder_components(internal, message_bus, D::decode, request_id, order_id, message_type, context);
sub.cancel_fn = Some(Arc::new(Box::new(D::cancel_message)));
sub
}
pub(crate) fn new_from_internal_simple<D>(
internal: AsyncInternalSubscription,
context: DecoderContext,
message_bus: Arc<dyn AsyncMessageBus>,
) -> Self
where
D: StreamDecoder<T> + 'static,
T: 'static,
{
Self::new_from_internal::<D>(internal, message_bus, None, None, None, context)
}
pub fn new(receiver: mpsc::UnboundedReceiver<Result<T, Error>>) -> Self {
Self {
inner: SubscriptionInner::PreDecoded { receiver },
request_id: None,
order_id: None,
_message_type: None,
context: DecoderContext::default(),
cancelled: Arc::new(AtomicBool::new(false)),
stream_ended: Arc::new(AtomicBool::new(false)),
message_bus: None,
cancel_fn: None,
}
}
pub async fn next(&mut self) -> Option<Result<T, Error>>
where
T: 'static,
{
if self.stream_ended.load(Ordering::Relaxed) {
return None;
}
match &mut self.inner {
SubscriptionInner::WithDecoder {
subscription,
decoder,
context,
} => loop {
match subscription.next().await {
Some(Ok(mut message)) => {
let result = decoder(context, &mut message);
match process_decode_result(result) {
ProcessingResult::Success(val) => return Some(Ok(val)),
ProcessingResult::EndOfStream => {
self.stream_ended.store(true, Ordering::Relaxed);
return None;
}
ProcessingResult::Skip => {
log::trace!("skipping unexpected message on shared channel");
continue;
}
ProcessingResult::Error(err) => return Some(Err(err)),
}
}
Some(Err(e)) => return Some(Err(e)),
None => return None,
}
},
SubscriptionInner::PreDecoded { receiver } => receiver.recv().await,
}
}
pub fn request_id(&self) -> Option<i32> {
self.request_id
}
}
impl<T> Subscription<T> {
pub async fn cancel(&self) {
if self.cancelled.load(Ordering::Relaxed) {
return;
}
self.cancelled.store(true, Ordering::Relaxed);
if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) {
let id = self.request_id.or(self.order_id);
if let Ok(message) = cancel_fn(self.context.server_version, id, Some(&self.context)) {
if let Err(e) = message_bus.send_message(message).await {
warn!("error sending cancel message: {e}")
}
}
}
}
}
impl<T> Drop for Subscription<T> {
fn drop(&mut self) {
debug!("dropping async subscription");
if self.cancelled.load(Ordering::Relaxed) {
return;
}
self.cancelled.store(true, Ordering::Relaxed);
if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) {
let message_bus = message_bus.clone();
let id = self.request_id.or(self.order_id);
let context = self.context.clone();
if let Ok(message) = cancel_fn(context.server_version, id, Some(&context)) {
tokio::spawn(async move {
if let Err(e) = message_bus.send_message(message).await {
warn!("error sending cancel message in drop: {e}");
}
});
}
}
}
}
#[cfg(all(test, feature = "async"))]
mod tests {
use super::*;
use crate::market_data::realtime::Bar;
use crate::messages::OutgoingMessages;
use crate::stubs::MessageBusStub;
use std::sync::RwLock;
use time::OffsetDateTime;
use tokio::sync::{broadcast, mpsc};
#[tokio::test]
async fn test_subscription_with_decoder() {
let message_bus = Arc::new(MessageBusStub {
request_messages: RwLock::new(vec![]),
response_messages: vec!["1|9000|20241231 12:00:00|100.5|101.0|100.0|100.25|1000|100.2|5|0".to_string()],
});
let (tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx.resubscribe());
let subscription: Subscription<Bar> = Subscription::with_decoder(
internal,
message_bus,
|_context, _msg| {
let bar = Bar {
date: OffsetDateTime::now_utc(),
open: 100.5,
high: 101.0,
low: 100.0,
close: 100.25,
volume: 1000.0,
wap: 100.2,
count: 5,
};
Ok(bar)
},
Some(9000),
None,
Some(OutgoingMessages::RequestRealTimeBars),
DecoderContext::default(),
);
let msg = ResponseMessage::from("1\09000\020241231 12:00:00\0100.5\0101.0\0100.0\0100.25\01000\0100.2\05\00");
tx.send(msg).unwrap();
let mut sub = subscription;
let result = sub.next().await;
assert!(result.is_some());
let bar = result.unwrap().unwrap();
assert_eq!(bar.open, 100.5);
assert_eq!(bar.high, 101.0);
}
#[tokio::test]
async fn test_subscription_new_with_decoder() {
let message_bus = Arc::new(MessageBusStub::default());
let (_tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let subscription: Subscription<String> = Subscription::new_with_decoder(
internal,
message_bus,
|_context, _msg| Ok("decoded".to_string()),
Some(1),
None,
Some(OutgoingMessages::RequestMarketData),
DecoderContext::default(),
);
assert_eq!(subscription.request_id, Some(1));
assert_eq!(subscription._message_type, Some(OutgoingMessages::RequestMarketData));
}
#[tokio::test]
async fn test_subscription_with_decoder_components() {
let message_bus = Arc::new(MessageBusStub::default());
let (_tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let subscription: Subscription<i32> = Subscription::with_decoder_components(
internal,
message_bus,
|_context, _msg| Ok(42),
Some(100),
Some(200),
Some(OutgoingMessages::RequestPositions),
DecoderContext::default(),
);
assert_eq!(subscription.request_id, Some(100));
assert_eq!(subscription.order_id, Some(200));
}
#[tokio::test]
async fn test_subscription_new_from_receiver() {
let (tx, rx) = mpsc::unbounded_channel();
let mut subscription = Subscription::new(rx);
tx.send(Ok("test".to_string())).unwrap();
let result = subscription.next().await;
assert!(result.is_some());
assert_eq!(result.unwrap().unwrap(), "test");
}
#[tokio::test]
async fn test_subscription_next_with_error() {
let message_bus = Arc::new(MessageBusStub::default());
let (tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let mut subscription: Subscription<String> = Subscription::with_decoder(
internal,
message_bus,
|_context, _msg| Err(Error::Simple("decode error".into())),
None,
None,
None,
DecoderContext::default(),
);
let msg = ResponseMessage::from("test\0");
tx.send(msg).unwrap();
let result = subscription.next().await;
assert!(result.is_some());
assert!(result.unwrap().is_err());
}
#[tokio::test]
async fn test_subscription_next_end_of_stream() {
let message_bus = Arc::new(MessageBusStub::default());
let (tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let mut subscription: Subscription<String> = Subscription::with_decoder(
internal,
message_bus,
|_context, _msg| Err(Error::EndOfStream),
None,
None,
None,
DecoderContext::default(),
);
let msg = ResponseMessage::from("test\0");
tx.send(msg).unwrap();
let result = subscription.next().await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_subscription_no_retries_after_end_of_stream() {
let message_bus = Arc::new(MessageBusStub::default());
let (tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let mut subscription: Subscription<String> = Subscription::with_decoder(
internal,
message_bus,
move |_context, _msg| {
let n = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n == 0 {
Err(Error::EndOfStream)
} else {
Err(Error::UnexpectedResponse(ResponseMessage::from("stray\0")))
}
},
None,
None,
None,
DecoderContext::default(),
);
tx.send(ResponseMessage::from("end\0")).unwrap();
let result = subscription.next().await;
assert!(result.is_none());
tx.send(ResponseMessage::from("stray1\0")).unwrap();
tx.send(ResponseMessage::from("stray2\0")).unwrap();
let result = subscription.next().await;
assert!(result.is_none());
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_subscription_skips_unexpected_messages_without_retry_limit() {
let message_bus = Arc::new(MessageBusStub::default());
let (tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let mut subscription: Subscription<String> = Subscription::with_decoder(
internal,
message_bus,
move |_context, _msg| {
let n = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 20 {
Err(Error::UnexpectedResponse(ResponseMessage::from("stray\0")))
} else {
Ok("success".to_string())
}
},
None,
None,
None,
DecoderContext::default(),
);
for _ in 0..21 {
tx.send(ResponseMessage::from("msg\0")).unwrap();
}
let result = subscription.next().await;
assert!(
result.is_some(),
"subscription should not have stopped after skipping unexpected messages"
);
assert_eq!(result.unwrap().unwrap(), "success");
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 21);
}
#[tokio::test]
async fn test_subscription_cancel() {
let message_bus = Arc::new(MessageBusStub::default());
let (_tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let cancel_fn: CancelFn = Box::new(|_version, _id, _ctx| {
let mut msg = RequestMessage::new();
msg.push_field(&OutgoingMessages::CancelMarketData);
Ok(msg)
});
let mut subscription: Subscription<String> = Subscription::with_decoder(
internal,
message_bus.clone(),
|_context, _msg| Ok("test".to_string()),
Some(123),
None,
Some(OutgoingMessages::RequestMarketData),
DecoderContext::default(),
);
subscription.cancel_fn = Some(Arc::new(cancel_fn));
subscription.cancel().await;
assert!(subscription.cancelled.load(Ordering::Relaxed));
subscription.cancel().await;
}
#[tokio::test]
async fn test_subscription_clone() {
let message_bus = Arc::new(MessageBusStub::default());
let (_tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let subscription: Subscription<String> = Subscription::with_decoder(
internal,
message_bus,
|_context, _msg| Ok("test".to_string()),
Some(456),
Some(789),
Some(OutgoingMessages::RequestPositions),
DecoderContext::default()
.with_smart_depth(true)
.with_request_type(OutgoingMessages::RequestPositions),
);
let cloned = subscription.clone();
assert_eq!(cloned.request_id, Some(456));
assert_eq!(cloned.order_id, Some(789));
assert_eq!(cloned._message_type, Some(OutgoingMessages::RequestPositions));
assert!(cloned.context.is_smart_depth);
}
#[tokio::test]
async fn test_subscription_drop_with_cancel() {
let message_bus = Arc::new(MessageBusStub::default());
let (_tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let cancel_fn: CancelFn = Box::new(|_version, _id, _ctx| {
let mut msg = RequestMessage::new();
msg.push_field(&OutgoingMessages::CancelMarketData);
Ok(msg)
});
{
let mut subscription: Subscription<String> = Subscription::with_decoder(
internal,
message_bus.clone(),
|_context, _msg| Ok("test".to_string()),
Some(999),
None,
Some(OutgoingMessages::RequestMarketData),
DecoderContext::default(),
);
subscription.cancel_fn = Some(Arc::new(cancel_fn));
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
#[tokio::test]
#[should_panic(expected = "Cannot clone pre-decoded subscriptions")]
async fn test_subscription_inner_clone_panic() {
let (_tx, rx) = mpsc::unbounded_channel::<Result<String, Error>>();
let subscription = Subscription::new(rx);
let _ = subscription.inner.clone();
}
#[tokio::test]
async fn test_subscription_with_context() {
let message_bus = Arc::new(MessageBusStub::default());
let (_tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let context = DecoderContext::default()
.with_smart_depth(true)
.with_request_type(OutgoingMessages::RequestMarketDepth);
let subscription: Subscription<String> = Subscription::with_decoder(
internal,
message_bus,
|_context, _msg| Ok("test".to_string()),
None,
None,
None,
context.clone(),
);
assert_eq!(subscription.context, context);
}
#[tokio::test]
async fn test_subscription_new_from_internal_simple() {
struct TestDecoder;
impl StreamDecoder<String> for TestDecoder {
fn decode(_context: &DecoderContext, _msg: &mut ResponseMessage) -> Result<String, Error> {
Ok("decoded".to_string())
}
fn cancel_message(_server_version: i32, _id: Option<i32>, _context: Option<&DecoderContext>) -> Result<RequestMessage, Error> {
let mut msg = RequestMessage::new();
msg.push_field(&OutgoingMessages::CancelMarketData);
Ok(msg)
}
}
let message_bus = Arc::new(MessageBusStub::default());
let (_tx, rx) = broadcast::channel(100);
let internal = AsyncInternalSubscription::new(rx);
let subscription: Subscription<String> =
Subscription::new_from_internal_simple::<TestDecoder>(internal, DecoderContext::default(), message_bus);
assert!(subscription.cancel_fn.is_some());
}
}