1use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5
6use log::{debug, warn};
7use tokio::sync::mpsc;
8
9use super::common::{process_decode_result, DecoderContext, ProcessingResult};
10use super::StreamDecoder;
11use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage};
12use crate::transport::{AsyncInternalSubscription, AsyncMessageBus};
13use crate::Error;
14
15type CancelFn = Box<dyn Fn(i32, Option<i32>, Option<&DecoderContext>) -> Result<RequestMessage, Error> + Send + Sync>;
17type DecoderFn<T> = Arc<dyn Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync>;
18
19pub struct Subscription<T> {
21 inner: SubscriptionInner<T>,
22 request_id: Option<i32>,
24 order_id: Option<i32>,
25 _message_type: Option<OutgoingMessages>,
26 context: DecoderContext,
27 cancelled: Arc<AtomicBool>,
28 stream_ended: Arc<AtomicBool>,
29 message_bus: Option<Arc<dyn AsyncMessageBus>>,
30 cancel_fn: Option<Arc<CancelFn>>,
32}
33
34enum SubscriptionInner<T> {
35 WithDecoder {
37 subscription: AsyncInternalSubscription,
38 decoder: DecoderFn<T>,
39 context: DecoderContext,
40 },
41 PreDecoded { receiver: mpsc::UnboundedReceiver<Result<T, Error>> },
43}
44
45impl<T> Clone for SubscriptionInner<T> {
46 fn clone(&self) -> Self {
47 match self {
48 SubscriptionInner::WithDecoder {
49 subscription,
50 decoder,
51 context,
52 } => SubscriptionInner::WithDecoder {
53 subscription: subscription.clone(),
54 decoder: decoder.clone(),
55 context: context.clone(),
56 },
57 SubscriptionInner::PreDecoded { .. } => {
58 panic!("Cannot clone pre-decoded subscriptions");
60 }
61 }
62 }
63}
64
65impl<T> Clone for Subscription<T> {
66 fn clone(&self) -> Self {
67 Self {
68 inner: self.inner.clone(),
69 request_id: self.request_id,
70 order_id: self.order_id,
71 _message_type: self._message_type,
72 context: self.context.clone(),
73 cancelled: self.cancelled.clone(),
74 stream_ended: self.stream_ended.clone(),
75 message_bus: self.message_bus.clone(),
76 cancel_fn: self.cancel_fn.clone(),
77 }
78 }
79}
80
81impl<T> Subscription<T> {
82 #[allow(clippy::too_many_arguments)]
84 pub fn with_decoder<D>(
85 internal: AsyncInternalSubscription,
86 message_bus: Arc<dyn AsyncMessageBus>,
87 decoder: D,
88 request_id: Option<i32>,
89 order_id: Option<i32>,
90 message_type: Option<OutgoingMessages>,
91 context: DecoderContext,
92 ) -> Self
93 where
94 D: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
95 {
96 Self {
97 inner: SubscriptionInner::WithDecoder {
98 subscription: internal,
99 decoder: Arc::new(decoder),
100 context: context.clone(),
101 },
102 request_id,
103 order_id,
104 _message_type: message_type,
105 context,
106 cancelled: Arc::new(AtomicBool::new(false)),
107 stream_ended: Arc::new(AtomicBool::new(false)),
108 message_bus: Some(message_bus),
109 cancel_fn: None,
110 }
111 }
112
113 #[allow(clippy::too_many_arguments)]
115 pub fn new_with_decoder<F>(
116 internal: AsyncInternalSubscription,
117 message_bus: Arc<dyn AsyncMessageBus>,
118 decoder: F,
119 request_id: Option<i32>,
120 order_id: Option<i32>,
121 message_type: Option<OutgoingMessages>,
122 context: DecoderContext,
123 ) -> Self
124 where
125 F: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
126 {
127 Self::with_decoder(internal, message_bus, decoder, request_id, order_id, message_type, context)
128 }
129
130 #[allow(clippy::too_many_arguments)]
132 pub fn with_decoder_components<D>(
133 internal: AsyncInternalSubscription,
134 message_bus: Arc<dyn AsyncMessageBus>,
135 decoder: D,
136 request_id: Option<i32>,
137 order_id: Option<i32>,
138 message_type: Option<OutgoingMessages>,
139 context: DecoderContext,
140 ) -> Self
141 where
142 D: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
143 {
144 Self::with_decoder(internal, message_bus, decoder, request_id, order_id, message_type, context)
145 }
146
147 pub(crate) fn new_from_internal<D>(
149 internal: AsyncInternalSubscription,
150 message_bus: Arc<dyn AsyncMessageBus>,
151 request_id: Option<i32>,
152 order_id: Option<i32>,
153 message_type: Option<OutgoingMessages>,
154 context: DecoderContext,
155 ) -> Self
156 where
157 D: StreamDecoder<T> + 'static,
158 T: 'static,
159 {
160 let mut sub = Self::with_decoder_components(internal, message_bus, D::decode, request_id, order_id, message_type, context);
161 sub.cancel_fn = Some(Arc::new(Box::new(D::cancel_message)));
163 sub
164 }
165
166 pub(crate) fn new_from_internal_simple<D>(
168 internal: AsyncInternalSubscription,
169 context: DecoderContext,
170 message_bus: Arc<dyn AsyncMessageBus>,
171 ) -> Self
172 where
173 D: StreamDecoder<T> + 'static,
174 T: 'static,
175 {
176 Self::new_from_internal::<D>(internal, message_bus, None, None, None, context)
178 }
179
180 pub fn new(receiver: mpsc::UnboundedReceiver<Result<T, Error>>) -> Self {
182 Self {
185 inner: SubscriptionInner::PreDecoded { receiver },
186 request_id: None,
187 order_id: None,
188 _message_type: None,
189 context: DecoderContext::default(),
190 cancelled: Arc::new(AtomicBool::new(false)),
191 stream_ended: Arc::new(AtomicBool::new(false)),
192 message_bus: None,
193 cancel_fn: None,
194 }
195 }
196
197 pub async fn next(&mut self) -> Option<Result<T, Error>>
199 where
200 T: 'static,
201 {
202 if self.stream_ended.load(Ordering::Relaxed) {
203 return None;
204 }
205
206 match &mut self.inner {
207 SubscriptionInner::WithDecoder {
208 subscription,
209 decoder,
210 context,
211 } => loop {
212 match subscription.next().await {
213 Some(Ok(mut message)) => {
214 let result = decoder(context, &mut message);
215 match process_decode_result(result) {
216 ProcessingResult::Success(val) => return Some(Ok(val)),
217 ProcessingResult::EndOfStream => {
218 self.stream_ended.store(true, Ordering::Relaxed);
219 return None;
220 }
221 ProcessingResult::Skip => {
222 log::trace!("skipping unexpected message on shared channel");
223 continue;
224 }
225 ProcessingResult::Error(err) => return Some(Err(err)),
226 }
227 }
228 Some(Err(e)) => return Some(Err(e)),
229 None => return None,
230 }
231 },
232 SubscriptionInner::PreDecoded { receiver } => receiver.recv().await,
233 }
234 }
235
236 pub fn request_id(&self) -> Option<i32> {
238 self.request_id
239 }
240}
241
242impl<T> Subscription<T> {
243 pub async fn cancel(&self) {
245 if self.cancelled.load(Ordering::Relaxed) {
246 return;
247 }
248
249 self.cancelled.store(true, Ordering::Relaxed);
250
251 if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) {
252 let id = self.request_id.or(self.order_id);
253 if let Ok(message) = cancel_fn(self.context.server_version, id, Some(&self.context)) {
254 if let Err(e) = message_bus.send_message(message).await {
255 warn!("error sending cancel message: {e}")
256 }
257 }
258 }
259
260 }
262}
263
264impl<T> Drop for Subscription<T> {
265 fn drop(&mut self) {
266 debug!("dropping async subscription");
267
268 if self.cancelled.load(Ordering::Relaxed) {
270 return;
271 }
272
273 self.cancelled.store(true, Ordering::Relaxed);
274
275 if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) {
277 let message_bus = message_bus.clone();
278 let id = self.request_id.or(self.order_id);
279 let context = self.context.clone();
280
281 if let Ok(message) = cancel_fn(context.server_version, id, Some(&context)) {
283 tokio::spawn(async move {
285 if let Err(e) = message_bus.send_message(message).await {
286 warn!("error sending cancel message in drop: {e}");
287 }
288 });
289 }
290 }
291
292 }
294}
295
296#[cfg(all(test, feature = "async"))]
301mod tests {
302 use super::*;
303 use crate::market_data::realtime::Bar;
304 use crate::messages::OutgoingMessages;
305 use crate::stubs::MessageBusStub;
306 use std::sync::RwLock;
307 use time::OffsetDateTime;
308 use tokio::sync::{broadcast, mpsc};
309
310 #[tokio::test]
311 async fn test_subscription_with_decoder() {
312 let message_bus = Arc::new(MessageBusStub {
313 request_messages: RwLock::new(vec![]),
314 response_messages: vec!["1|9000|20241231 12:00:00|100.5|101.0|100.0|100.25|1000|100.2|5|0".to_string()],
315 });
316
317 let (tx, rx) = broadcast::channel(100);
318 let internal = AsyncInternalSubscription::new(rx.resubscribe());
319
320 let subscription: Subscription<Bar> = Subscription::with_decoder(
321 internal,
322 message_bus,
323 |_context, _msg| {
324 let bar = Bar {
325 date: OffsetDateTime::now_utc(),
326 open: 100.5,
327 high: 101.0,
328 low: 100.0,
329 close: 100.25,
330 volume: 1000.0,
331 wap: 100.2,
332 count: 5,
333 };
334 Ok(bar)
335 },
336 Some(9000),
337 None,
338 Some(OutgoingMessages::RequestRealTimeBars),
339 DecoderContext::default(),
340 );
341
342 let msg = ResponseMessage::from("1\09000\020241231 12:00:00\0100.5\0101.0\0100.0\0100.25\01000\0100.2\05\00");
344 tx.send(msg).unwrap();
345
346 let mut sub = subscription;
348 let result = sub.next().await;
349 assert!(result.is_some());
350 let bar = result.unwrap().unwrap();
351 assert_eq!(bar.open, 100.5);
352 assert_eq!(bar.high, 101.0);
353 }
354
355 #[tokio::test]
356 async fn test_subscription_new_with_decoder() {
357 let message_bus = Arc::new(MessageBusStub::default());
358 let (_tx, rx) = broadcast::channel(100);
359 let internal = AsyncInternalSubscription::new(rx);
360
361 let subscription: Subscription<String> = Subscription::new_with_decoder(
362 internal,
363 message_bus,
364 |_context, _msg| Ok("decoded".to_string()),
365 Some(1),
366 None,
367 Some(OutgoingMessages::RequestMarketData),
368 DecoderContext::default(),
369 );
370
371 assert_eq!(subscription.request_id, Some(1));
372 assert_eq!(subscription._message_type, Some(OutgoingMessages::RequestMarketData));
373 }
374
375 #[tokio::test]
376 async fn test_subscription_with_decoder_components() {
377 let message_bus = Arc::new(MessageBusStub::default());
378 let (_tx, rx) = broadcast::channel(100);
379 let internal = AsyncInternalSubscription::new(rx);
380
381 let subscription: Subscription<i32> = Subscription::with_decoder_components(
382 internal,
383 message_bus,
384 |_context, _msg| Ok(42),
385 Some(100),
386 Some(200),
387 Some(OutgoingMessages::RequestPositions),
388 DecoderContext::default(),
389 );
390
391 assert_eq!(subscription.request_id, Some(100));
392 assert_eq!(subscription.order_id, Some(200));
393 }
394
395 #[tokio::test]
396 async fn test_subscription_new_from_receiver() {
397 let (tx, rx) = mpsc::unbounded_channel();
398
399 let mut subscription = Subscription::new(rx);
400
401 tx.send(Ok("test".to_string())).unwrap();
403
404 let result = subscription.next().await;
405 assert!(result.is_some());
406 assert_eq!(result.unwrap().unwrap(), "test");
407 }
408
409 #[tokio::test]
410 async fn test_subscription_next_with_error() {
411 let message_bus = Arc::new(MessageBusStub::default());
412 let (tx, rx) = broadcast::channel(100);
413 let internal = AsyncInternalSubscription::new(rx);
414
415 let mut subscription: Subscription<String> = Subscription::with_decoder(
416 internal,
417 message_bus,
418 |_context, _msg| Err(Error::Simple("decode error".into())),
419 None,
420 None,
421 None,
422 DecoderContext::default(),
423 );
424
425 let msg = ResponseMessage::from("test\0");
427 tx.send(msg).unwrap();
428
429 let result = subscription.next().await;
430 assert!(result.is_some());
431 assert!(result.unwrap().is_err());
432 }
433
434 #[tokio::test]
435 async fn test_subscription_next_end_of_stream() {
436 let message_bus = Arc::new(MessageBusStub::default());
437 let (tx, rx) = broadcast::channel(100);
438 let internal = AsyncInternalSubscription::new(rx);
439
440 let mut subscription: Subscription<String> = Subscription::with_decoder(
441 internal,
442 message_bus,
443 |_context, _msg| Err(Error::EndOfStream),
444 None,
445 None,
446 None,
447 DecoderContext::default(),
448 );
449
450 let msg = ResponseMessage::from("test\0");
452 tx.send(msg).unwrap();
453
454 let result = subscription.next().await;
455 assert!(result.is_none());
456 }
457
458 #[tokio::test]
459 async fn test_subscription_no_retries_after_end_of_stream() {
460 let message_bus = Arc::new(MessageBusStub::default());
461 let (tx, rx) = broadcast::channel(100);
462 let internal = AsyncInternalSubscription::new(rx);
463
464 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
465 let call_count_clone = call_count.clone();
466
467 let mut subscription: Subscription<String> = Subscription::with_decoder(
468 internal,
469 message_bus,
470 move |_context, _msg| {
471 let n = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
472 if n == 0 {
473 Err(Error::EndOfStream)
474 } else {
475 Err(Error::UnexpectedResponse(ResponseMessage::from("stray\0")))
476 }
477 },
478 None,
479 None,
480 None,
481 DecoderContext::default(),
482 );
483
484 tx.send(ResponseMessage::from("end\0")).unwrap();
486 let result = subscription.next().await;
487 assert!(result.is_none());
488
489 tx.send(ResponseMessage::from("stray1\0")).unwrap();
491 tx.send(ResponseMessage::from("stray2\0")).unwrap();
492
493 let result = subscription.next().await;
495 assert!(result.is_none());
496
497 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
499 }
500
501 #[tokio::test]
502 async fn test_subscription_skips_unexpected_messages_without_retry_limit() {
503 let message_bus = Arc::new(MessageBusStub::default());
504 let (tx, rx) = broadcast::channel(100);
505 let internal = AsyncInternalSubscription::new(rx);
506
507 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
508 let call_count_clone = call_count.clone();
509
510 let mut subscription: Subscription<String> = Subscription::with_decoder(
514 internal,
515 message_bus,
516 move |_context, _msg| {
517 let n = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
518 if n < 20 {
519 Err(Error::UnexpectedResponse(ResponseMessage::from("stray\0")))
520 } else {
521 Ok("success".to_string())
522 }
523 },
524 None,
525 None,
526 None,
527 DecoderContext::default(),
528 );
529
530 for _ in 0..21 {
532 tx.send(ResponseMessage::from("msg\0")).unwrap();
533 }
534
535 let result = subscription.next().await;
536 assert!(
537 result.is_some(),
538 "subscription should not have stopped after skipping unexpected messages"
539 );
540 assert_eq!(result.unwrap().unwrap(), "success");
541 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 21);
543 }
544
545 #[tokio::test]
546 async fn test_subscription_cancel() {
547 let message_bus = Arc::new(MessageBusStub::default());
548 let (_tx, rx) = broadcast::channel(100);
549 let internal = AsyncInternalSubscription::new(rx);
550
551 let cancel_fn: CancelFn = Box::new(|_version, _id, _ctx| {
553 let mut msg = RequestMessage::new();
554 msg.push_field(&OutgoingMessages::CancelMarketData);
555 Ok(msg)
556 });
557
558 let mut subscription: Subscription<String> = Subscription::with_decoder(
559 internal,
560 message_bus.clone(),
561 |_context, _msg| Ok("test".to_string()),
562 Some(123),
563 None,
564 Some(OutgoingMessages::RequestMarketData),
565 DecoderContext::default(),
566 );
567 subscription.cancel_fn = Some(Arc::new(cancel_fn));
568
569 subscription.cancel().await;
571
572 assert!(subscription.cancelled.load(Ordering::Relaxed));
574
575 subscription.cancel().await;
577 }
578
579 #[tokio::test]
580 async fn test_subscription_clone() {
581 let message_bus = Arc::new(MessageBusStub::default());
582 let (_tx, rx) = broadcast::channel(100);
583 let internal = AsyncInternalSubscription::new(rx);
584
585 let subscription: Subscription<String> = Subscription::with_decoder(
586 internal,
587 message_bus,
588 |_context, _msg| Ok("test".to_string()),
589 Some(456),
590 Some(789),
591 Some(OutgoingMessages::RequestPositions),
592 DecoderContext::default()
593 .with_smart_depth(true)
594 .with_request_type(OutgoingMessages::RequestPositions),
595 );
596
597 let cloned = subscription.clone();
598 assert_eq!(cloned.request_id, Some(456));
599 assert_eq!(cloned.order_id, Some(789));
600 assert_eq!(cloned._message_type, Some(OutgoingMessages::RequestPositions));
601 assert!(cloned.context.is_smart_depth);
602 }
603
604 #[tokio::test]
605 async fn test_subscription_drop_with_cancel() {
606 let message_bus = Arc::new(MessageBusStub::default());
607 let (_tx, rx) = broadcast::channel(100);
608 let internal = AsyncInternalSubscription::new(rx);
609
610 let cancel_fn: CancelFn = Box::new(|_version, _id, _ctx| {
612 let mut msg = RequestMessage::new();
613 msg.push_field(&OutgoingMessages::CancelMarketData);
614 Ok(msg)
615 });
616
617 {
618 let mut subscription: Subscription<String> = Subscription::with_decoder(
619 internal,
620 message_bus.clone(),
621 |_context, _msg| Ok("test".to_string()),
622 Some(999),
623 None,
624 Some(OutgoingMessages::RequestMarketData),
625 DecoderContext::default(),
626 );
627 subscription.cancel_fn = Some(Arc::new(cancel_fn));
628 }
630
631 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
633 }
634
635 #[tokio::test]
636 #[should_panic(expected = "Cannot clone pre-decoded subscriptions")]
637 async fn test_subscription_inner_clone_panic() {
638 let (_tx, rx) = mpsc::unbounded_channel::<Result<String, Error>>();
639 let subscription = Subscription::new(rx);
640
641 let _ = subscription.inner.clone();
643 }
644
645 #[tokio::test]
646 async fn test_subscription_with_context() {
647 let message_bus = Arc::new(MessageBusStub::default());
648 let (_tx, rx) = broadcast::channel(100);
649 let internal = AsyncInternalSubscription::new(rx);
650
651 let context = DecoderContext::default()
652 .with_smart_depth(true)
653 .with_request_type(OutgoingMessages::RequestMarketDepth);
654
655 let subscription: Subscription<String> = Subscription::with_decoder(
656 internal,
657 message_bus,
658 |_context, _msg| Ok("test".to_string()),
659 None,
660 None,
661 None,
662 context.clone(),
663 );
664
665 assert_eq!(subscription.context, context);
666 }
667
668 #[tokio::test]
669 async fn test_subscription_new_from_internal_simple() {
670 struct TestDecoder;
672
673 impl StreamDecoder<String> for TestDecoder {
674 fn decode(_context: &DecoderContext, _msg: &mut ResponseMessage) -> Result<String, Error> {
675 Ok("decoded".to_string())
676 }
677
678 fn cancel_message(_server_version: i32, _id: Option<i32>, _context: Option<&DecoderContext>) -> Result<RequestMessage, Error> {
679 let mut msg = RequestMessage::new();
680 msg.push_field(&OutgoingMessages::CancelMarketData);
681 Ok(msg)
682 }
683 }
684
685 let message_bus = Arc::new(MessageBusStub::default());
686 let (_tx, rx) = broadcast::channel(100);
687 let internal = AsyncInternalSubscription::new(rx);
688
689 let subscription: Subscription<String> =
690 Subscription::new_from_internal_simple::<TestDecoder>(internal, DecoderContext::default(), message_bus);
691
692 assert!(subscription.cancel_fn.is_some());
693 }
694}