use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use log::{debug, error, warn};
use super::common::{process_decode_result, should_store_error, DecoderContext, ProcessingResult};
use super::StreamDecoder;
use crate::errors::Error;
use crate::messages::{OutgoingMessages, ResponseMessage};
use crate::transport::{InternalSubscription, MessageBus};
#[allow(private_bounds)]
pub struct Subscription<T: StreamDecoder<T>> {
context: DecoderContext,
message_bus: Arc<dyn MessageBus>,
request_id: Option<i32>,
order_id: Option<i32>,
message_type: Option<OutgoingMessages>,
phantom: PhantomData<T>,
cancelled: AtomicBool,
snapshot_ended: AtomicBool,
stream_ended: AtomicBool,
subscription: InternalSubscription,
error: Mutex<Option<Error>>,
}
enum NextAction<T> {
Return(Option<T>),
Skip,
}
#[allow(private_bounds)]
impl<T: StreamDecoder<T>> Subscription<T> {
pub(crate) fn new(message_bus: Arc<dyn MessageBus>, subscription: InternalSubscription, context: DecoderContext) -> Self {
let request_id = subscription.request_id;
let order_id = subscription.order_id;
let message_type = subscription.message_type;
Subscription {
context,
message_bus,
request_id,
order_id,
message_type,
subscription,
phantom: PhantomData,
cancelled: AtomicBool::new(false),
snapshot_ended: AtomicBool::new(false),
stream_ended: AtomicBool::new(false),
error: Mutex::new(None),
}
}
pub fn cancel(&self) {
if self.snapshot_ended.load(Ordering::Relaxed) {
return;
}
if self.cancelled.load(Ordering::Relaxed) {
return;
}
self.cancelled.store(true, Ordering::Relaxed);
if let Some(request_id) = self.request_id {
if let Ok(message) = T::cancel_message(self.context.server_version, self.request_id, Some(&self.context)) {
if let Err(e) = self.message_bus.cancel_subscription(request_id, &message) {
warn!("error cancelling subscription: {e}")
}
self.subscription.cancel();
}
} else if let Some(order_id) = self.order_id {
if let Ok(message) = T::cancel_message(self.context.server_version, self.request_id, Some(&self.context)) {
if let Err(e) = self.message_bus.cancel_order_subscription(order_id, &message) {
warn!("error cancelling order subscription: {e}")
}
self.subscription.cancel();
}
} else if let Some(message_type) = self.message_type {
if let Ok(message) = T::cancel_message(self.context.server_version, self.request_id, Some(&self.context)) {
if let Err(e) = self.message_bus.cancel_shared_subscription(message_type, &message) {
warn!("error cancelling shared subscription: {e}")
}
self.subscription.cancel();
}
} else {
debug!("Could not determine cancel method")
}
}
pub fn request_id(&self) -> Option<i32> {
self.request_id
}
pub fn next(&self) -> Option<T> {
if self.stream_ended.load(Ordering::Relaxed) {
return None;
}
loop {
match self.handle_response(self.subscription.next()) {
NextAction::Return(val) => return val,
NextAction::Skip => continue,
}
}
}
pub fn error(&self) -> Option<Error> {
let mut error = self.error.lock().unwrap();
error.take()
}
fn clear_error(&self) {
let mut error = self.error.lock().unwrap();
*error = None;
}
fn handle_response(&self, response: Option<Result<ResponseMessage, Error>>) -> NextAction<T> {
self.clear_error();
match response {
Some(Ok(mut message)) => match process_decode_result(T::decode(&self.context, &mut message)) {
ProcessingResult::Success(val) => {
if val.is_snapshot_end() {
self.snapshot_ended.store(true, Ordering::Relaxed);
}
NextAction::Return(Some(val))
}
ProcessingResult::Skip => {
log::trace!("skipping unexpected message on shared channel");
NextAction::Skip
}
ProcessingResult::EndOfStream => {
self.stream_ended.store(true, Ordering::Relaxed);
NextAction::Return(None)
}
ProcessingResult::Error(err) => {
error!("error decoding message: {err}");
let mut error = self.error.lock().unwrap();
*error = Some(err);
NextAction::Return(None)
}
},
Some(Err(e)) => {
if should_store_error(&e) {
let mut error = self.error.lock().unwrap();
*error = Some(e);
}
NextAction::Return(None)
}
None => NextAction::Return(None),
}
}
pub fn try_next(&self) -> Option<T> {
loop {
match self.handle_response(self.subscription.try_next()) {
NextAction::Return(val) => return val,
NextAction::Skip => continue,
}
}
}
pub fn next_timeout(&self, timeout: Duration) -> Option<T> {
let deadline = Instant::now() + timeout;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return None;
}
match self.handle_response(self.subscription.next_timeout(remaining)) {
NextAction::Return(val) => return val,
NextAction::Skip => continue,
}
}
}
pub fn iter(&self) -> SubscriptionIter<'_, T> {
SubscriptionIter { subscription: self }
}
pub fn try_iter(&self) -> SubscriptionTryIter<'_, T> {
SubscriptionTryIter { subscription: self }
}
pub fn timeout_iter(&self, timeout: Duration) -> SubscriptionTimeoutIter<'_, T> {
SubscriptionTimeoutIter { subscription: self, timeout }
}
}
impl<T: StreamDecoder<T>> Drop for Subscription<T> {
fn drop(&mut self) {
debug!("dropping subscription");
self.cancel();
}
}
#[allow(private_bounds)]
pub struct SubscriptionIter<'a, T: StreamDecoder<T>> {
subscription: &'a Subscription<T>,
}
impl<T: StreamDecoder<T>> Iterator for SubscriptionIter<'_, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.subscription.next()
}
}
impl<'a, T: StreamDecoder<T>> IntoIterator for &'a Subscription<T> {
type Item = T;
type IntoIter = SubscriptionIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[allow(private_bounds)]
pub struct SubscriptionOwnedIter<T: StreamDecoder<T>> {
subscription: Subscription<T>,
}
impl<T: StreamDecoder<T>> Iterator for SubscriptionOwnedIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.subscription.next()
}
}
impl<T: StreamDecoder<T>> IntoIterator for Subscription<T> {
type Item = T;
type IntoIter = SubscriptionOwnedIter<T>;
fn into_iter(self) -> Self::IntoIter {
SubscriptionOwnedIter { subscription: self }
}
}
#[allow(private_bounds)]
pub struct SubscriptionTryIter<'a, T: StreamDecoder<T>> {
subscription: &'a Subscription<T>,
}
impl<T: StreamDecoder<T>> Iterator for SubscriptionTryIter<'_, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.subscription.try_next()
}
}
#[allow(private_bounds)]
pub struct SubscriptionTimeoutIter<'a, T: StreamDecoder<T>> {
subscription: &'a Subscription<T>,
timeout: Duration,
}
impl<T: StreamDecoder<T>> Iterator for SubscriptionTimeoutIter<'_, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.subscription.next_timeout(self.timeout)
}
}
pub trait SharesChannel {}
#[cfg(all(test, feature = "sync"))]
mod tests {
use super::*;
use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage};
use crate::stubs::MessageBusStub;
use std::sync::Arc;
#[derive(Debug)]
struct EndOfStreamItem;
impl StreamDecoder<EndOfStreamItem> for EndOfStreamItem {
fn decode(_context: &DecoderContext, _msg: &mut ResponseMessage) -> Result<EndOfStreamItem, Error> {
Err(Error::EndOfStream)
}
fn cancel_message(_server_version: i32, _id: Option<i32>, _context: Option<&DecoderContext>) -> Result<RequestMessage, Error> {
let mut msg = RequestMessage::new();
msg.push_field(&OutgoingMessages::CancelMarketData);
Ok(msg)
}
}
#[test]
fn test_subscription_skips_unexpected_messages_without_limit() {
use std::sync::atomic::AtomicUsize;
static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
struct SkipThenSuccess;
impl StreamDecoder<SkipThenSuccess> for SkipThenSuccess {
fn decode(_context: &DecoderContext, _msg: &mut ResponseMessage) -> Result<SkipThenSuccess, Error> {
let n = CALL_COUNT.fetch_add(1, Ordering::Relaxed);
if n < 20 {
Err(Error::UnexpectedResponse(ResponseMessage::from("stray\0")))
} else {
Ok(SkipThenSuccess)
}
}
}
CALL_COUNT.store(0, Ordering::Relaxed);
let mut responses: Vec<String> = (0..21).map(|_| "1|msg".to_string()).collect();
responses.push("1|done".to_string());
let stub = MessageBusStub::with_responses(responses);
let message_bus = Arc::new(stub);
let sub: Subscription<SkipThenSuccess> = {
let internal = message_bus.send_request(1, &RequestMessage::new()).unwrap();
Subscription::new(message_bus.clone(), internal, DecoderContext::default())
};
let result = sub.next();
assert!(result.is_some(), "subscription should survive 20 skips and return valid message");
assert_eq!(CALL_COUNT.load(Ordering::Relaxed), 21);
}
#[test]
fn test_no_retries_after_end_of_stream() {
let stub = MessageBusStub::with_responses(vec![
"1|data".to_string(), "1|stray".to_string(), ]);
let message_bus = Arc::new(stub);
let sub: Subscription<EndOfStreamItem> = {
let internal = message_bus.send_request(1, &RequestMessage::new()).unwrap();
Subscription::new(message_bus.clone(), internal, DecoderContext::default())
};
assert!(sub.next().is_none());
assert!(sub.next().is_none());
assert!(sub.stream_ended.load(Ordering::Relaxed));
}
}