Skip to main content

ibapi/subscriptions/
async.rs

1//! Asynchronous subscription implementation
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5
6use log::{debug, warn};
7use tokio::sync::mpsc;
8
9use super::common::{process_decode_result, DecoderContext, ProcessingResult};
10use super::StreamDecoder;
11use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage};
12use crate::transport::{AsyncInternalSubscription, AsyncMessageBus};
13use crate::Error;
14
15// Type aliases to reduce complexity
16type CancelFn = Box<dyn Fn(i32, Option<i32>, Option<&DecoderContext>) -> Result<RequestMessage, Error> + Send + Sync>;
17type DecoderFn<T> = Arc<dyn Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync>;
18
19/// Asynchronous subscription for streaming data
20pub struct Subscription<T> {
21    inner: SubscriptionInner<T>,
22    /// Metadata for cancellation
23    request_id: Option<i32>,
24    order_id: Option<i32>,
25    _message_type: Option<OutgoingMessages>,
26    context: DecoderContext,
27    cancelled: Arc<AtomicBool>,
28    stream_ended: Arc<AtomicBool>,
29    message_bus: Option<Arc<dyn AsyncMessageBus>>,
30    /// Cancel message generator
31    cancel_fn: Option<Arc<CancelFn>>,
32}
33
34enum SubscriptionInner<T> {
35    /// Subscription with decoder - receives ResponseMessage and decodes to T
36    WithDecoder {
37        subscription: AsyncInternalSubscription,
38        decoder: DecoderFn<T>,
39        context: DecoderContext,
40    },
41    /// Pre-decoded subscription - receives T directly
42    PreDecoded { receiver: mpsc::UnboundedReceiver<Result<T, Error>> },
43}
44
45impl<T> Clone for SubscriptionInner<T> {
46    fn clone(&self) -> Self {
47        match self {
48            SubscriptionInner::WithDecoder {
49                subscription,
50                decoder,
51                context,
52            } => SubscriptionInner::WithDecoder {
53                subscription: subscription.clone(),
54                decoder: decoder.clone(),
55                context: context.clone(),
56            },
57            SubscriptionInner::PreDecoded { .. } => {
58                // Can't clone mpsc receivers
59                panic!("Cannot clone pre-decoded subscriptions");
60            }
61        }
62    }
63}
64
65impl<T> Clone for Subscription<T> {
66    fn clone(&self) -> Self {
67        Self {
68            inner: self.inner.clone(),
69            request_id: self.request_id,
70            order_id: self.order_id,
71            _message_type: self._message_type,
72            context: self.context.clone(),
73            cancelled: self.cancelled.clone(),
74            stream_ended: self.stream_ended.clone(),
75            message_bus: self.message_bus.clone(),
76            cancel_fn: self.cancel_fn.clone(),
77        }
78    }
79}
80
81impl<T> Subscription<T> {
82    /// Create a subscription from an internal subscription and a decoder
83    #[allow(clippy::too_many_arguments)]
84    pub fn with_decoder<D>(
85        internal: AsyncInternalSubscription,
86        message_bus: Arc<dyn AsyncMessageBus>,
87        decoder: D,
88        request_id: Option<i32>,
89        order_id: Option<i32>,
90        message_type: Option<OutgoingMessages>,
91        context: DecoderContext,
92    ) -> Self
93    where
94        D: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
95    {
96        Self {
97            inner: SubscriptionInner::WithDecoder {
98                subscription: internal,
99                decoder: Arc::new(decoder),
100                context: context.clone(),
101            },
102            request_id,
103            order_id,
104            _message_type: message_type,
105            context,
106            cancelled: Arc::new(AtomicBool::new(false)),
107            stream_ended: Arc::new(AtomicBool::new(false)),
108            message_bus: Some(message_bus),
109            cancel_fn: None,
110        }
111    }
112
113    /// Create a subscription from an internal subscription with a decoder function
114    #[allow(clippy::too_many_arguments)]
115    pub fn new_with_decoder<F>(
116        internal: AsyncInternalSubscription,
117        message_bus: Arc<dyn AsyncMessageBus>,
118        decoder: F,
119        request_id: Option<i32>,
120        order_id: Option<i32>,
121        message_type: Option<OutgoingMessages>,
122        context: DecoderContext,
123    ) -> Self
124    where
125        F: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
126    {
127        Self::with_decoder(internal, message_bus, decoder, request_id, order_id, message_type, context)
128    }
129
130    /// Create a subscription from components and a decoder (alias for with_decoder)
131    #[allow(clippy::too_many_arguments)]
132    pub fn with_decoder_components<D>(
133        internal: AsyncInternalSubscription,
134        message_bus: Arc<dyn AsyncMessageBus>,
135        decoder: D,
136        request_id: Option<i32>,
137        order_id: Option<i32>,
138        message_type: Option<OutgoingMessages>,
139        context: DecoderContext,
140    ) -> Self
141    where
142        D: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
143    {
144        Self::with_decoder(internal, message_bus, decoder, request_id, order_id, message_type, context)
145    }
146
147    /// Create a subscription from an internal subscription using the DataStream decoder
148    pub(crate) fn new_from_internal<D>(
149        internal: AsyncInternalSubscription,
150        message_bus: Arc<dyn AsyncMessageBus>,
151        request_id: Option<i32>,
152        order_id: Option<i32>,
153        message_type: Option<OutgoingMessages>,
154        context: DecoderContext,
155    ) -> Self
156    where
157        D: StreamDecoder<T> + 'static,
158        T: 'static,
159    {
160        let mut sub = Self::with_decoder_components(internal, message_bus, D::decode, request_id, order_id, message_type, context);
161        // Store the cancel function
162        sub.cancel_fn = Some(Arc::new(Box::new(D::cancel_message)));
163        sub
164    }
165
166    /// Create a subscription from internal subscription without explicit metadata
167    pub(crate) fn new_from_internal_simple<D>(
168        internal: AsyncInternalSubscription,
169        context: DecoderContext,
170        message_bus: Arc<dyn AsyncMessageBus>,
171    ) -> Self
172    where
173        D: StreamDecoder<T> + 'static,
174        T: 'static,
175    {
176        // The AsyncInternalSubscription already has cleanup logic, so we don't need cancel metadata
177        Self::new_from_internal::<D>(internal, message_bus, None, None, None, context)
178    }
179
180    /// Create subscription from existing receiver (for backward compatibility)
181    pub fn new(receiver: mpsc::UnboundedReceiver<Result<T, Error>>) -> Self {
182        // This creates a subscription that expects pre-decoded messages
183        // Used for compatibility with existing code that manually decodes
184        Self {
185            inner: SubscriptionInner::PreDecoded { receiver },
186            request_id: None,
187            order_id: None,
188            _message_type: None,
189            context: DecoderContext::default(),
190            cancelled: Arc::new(AtomicBool::new(false)),
191            stream_ended: Arc::new(AtomicBool::new(false)),
192            message_bus: None,
193            cancel_fn: None,
194        }
195    }
196
197    /// Get the next value from the subscription
198    pub async fn next(&mut self) -> Option<Result<T, Error>>
199    where
200        T: 'static,
201    {
202        if self.stream_ended.load(Ordering::Relaxed) {
203            return None;
204        }
205
206        match &mut self.inner {
207            SubscriptionInner::WithDecoder {
208                subscription,
209                decoder,
210                context,
211            } => loop {
212                match subscription.next().await {
213                    Some(Ok(mut message)) => {
214                        let result = decoder(context, &mut message);
215                        match process_decode_result(result) {
216                            ProcessingResult::Success(val) => return Some(Ok(val)),
217                            ProcessingResult::EndOfStream => {
218                                self.stream_ended.store(true, Ordering::Relaxed);
219                                return None;
220                            }
221                            ProcessingResult::Skip => {
222                                log::trace!("skipping unexpected message on shared channel");
223                                continue;
224                            }
225                            ProcessingResult::Error(err) => return Some(Err(err)),
226                        }
227                    }
228                    Some(Err(e)) => return Some(Err(e)),
229                    None => return None,
230                }
231            },
232            SubscriptionInner::PreDecoded { receiver } => receiver.recv().await,
233        }
234    }
235
236    /// Get the request ID associated with this subscription
237    pub fn request_id(&self) -> Option<i32> {
238        self.request_id
239    }
240}
241
242impl<T> Subscription<T> {
243    /// Cancel the subscription
244    pub async fn cancel(&self) {
245        if self.cancelled.load(Ordering::Relaxed) {
246            return;
247        }
248
249        self.cancelled.store(true, Ordering::Relaxed);
250
251        if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) {
252            let id = self.request_id.or(self.order_id);
253            if let Ok(message) = cancel_fn(self.context.server_version, id, Some(&self.context)) {
254                if let Err(e) = message_bus.send_message(message).await {
255                    warn!("error sending cancel message: {e}")
256                }
257            }
258        }
259
260        // The AsyncInternalSubscription's Drop will handle cleanup
261    }
262}
263
264impl<T> Drop for Subscription<T> {
265    fn drop(&mut self) {
266        debug!("dropping async subscription");
267
268        // Check if already cancelled
269        if self.cancelled.load(Ordering::Relaxed) {
270            return;
271        }
272
273        self.cancelled.store(true, Ordering::Relaxed);
274
275        // Try to send cancel message if we have the necessary components
276        if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) {
277            let message_bus = message_bus.clone();
278            let id = self.request_id.or(self.order_id);
279            let context = self.context.clone();
280
281            // Clone the cancel function for use in the spawned task
282            if let Ok(message) = cancel_fn(context.server_version, id, Some(&context)) {
283                // Spawn a task to send the cancel message since drop can't be async
284                tokio::spawn(async move {
285                    if let Err(e) = message_bus.send_message(message).await {
286                        warn!("error sending cancel message in drop: {e}");
287                    }
288                });
289            }
290        }
291
292        // The AsyncInternalSubscription's Drop will handle channel cleanup
293    }
294}
295
296// Note: Stream trait implementation removed because tokio's broadcast::Receiver
297// doesn't provide poll_recv. Users should use the async next() method instead.
298// If Stream is needed, users can convert using futures::stream::unfold.
299
300#[cfg(all(test, feature = "async"))]
301mod tests {
302    use super::*;
303    use crate::market_data::realtime::Bar;
304    use crate::messages::OutgoingMessages;
305    use crate::stubs::MessageBusStub;
306    use std::sync::RwLock;
307    use time::OffsetDateTime;
308    use tokio::sync::{broadcast, mpsc};
309
310    #[tokio::test]
311    async fn test_subscription_with_decoder() {
312        let message_bus = Arc::new(MessageBusStub {
313            request_messages: RwLock::new(vec![]),
314            response_messages: vec!["1|9000|20241231 12:00:00|100.5|101.0|100.0|100.25|1000|100.2|5|0".to_string()],
315        });
316
317        let (tx, rx) = broadcast::channel(100);
318        let internal = AsyncInternalSubscription::new(rx.resubscribe());
319
320        let subscription: Subscription<Bar> = Subscription::with_decoder(
321            internal,
322            message_bus,
323            |_context, _msg| {
324                let bar = Bar {
325                    date: OffsetDateTime::now_utc(),
326                    open: 100.5,
327                    high: 101.0,
328                    low: 100.0,
329                    close: 100.25,
330                    volume: 1000.0,
331                    wap: 100.2,
332                    count: 5,
333                };
334                Ok(bar)
335            },
336            Some(9000),
337            None,
338            Some(OutgoingMessages::RequestRealTimeBars),
339            DecoderContext::default(),
340        );
341
342        // Send a test message
343        let msg = ResponseMessage::from("1\09000\020241231 12:00:00\0100.5\0101.0\0100.0\0100.25\01000\0100.2\05\00");
344        tx.send(msg).unwrap();
345
346        // Test that we can receive the decoded message
347        let mut sub = subscription;
348        let result = sub.next().await;
349        assert!(result.is_some());
350        let bar = result.unwrap().unwrap();
351        assert_eq!(bar.open, 100.5);
352        assert_eq!(bar.high, 101.0);
353    }
354
355    #[tokio::test]
356    async fn test_subscription_new_with_decoder() {
357        let message_bus = Arc::new(MessageBusStub::default());
358        let (_tx, rx) = broadcast::channel(100);
359        let internal = AsyncInternalSubscription::new(rx);
360
361        let subscription: Subscription<String> = Subscription::new_with_decoder(
362            internal,
363            message_bus,
364            |_context, _msg| Ok("decoded".to_string()),
365            Some(1),
366            None,
367            Some(OutgoingMessages::RequestMarketData),
368            DecoderContext::default(),
369        );
370
371        assert_eq!(subscription.request_id, Some(1));
372        assert_eq!(subscription._message_type, Some(OutgoingMessages::RequestMarketData));
373    }
374
375    #[tokio::test]
376    async fn test_subscription_with_decoder_components() {
377        let message_bus = Arc::new(MessageBusStub::default());
378        let (_tx, rx) = broadcast::channel(100);
379        let internal = AsyncInternalSubscription::new(rx);
380
381        let subscription: Subscription<i32> = Subscription::with_decoder_components(
382            internal,
383            message_bus,
384            |_context, _msg| Ok(42),
385            Some(100),
386            Some(200),
387            Some(OutgoingMessages::RequestPositions),
388            DecoderContext::default(),
389        );
390
391        assert_eq!(subscription.request_id, Some(100));
392        assert_eq!(subscription.order_id, Some(200));
393    }
394
395    #[tokio::test]
396    async fn test_subscription_new_from_receiver() {
397        let (tx, rx) = mpsc::unbounded_channel();
398
399        let mut subscription = Subscription::new(rx);
400
401        // Send test data
402        tx.send(Ok("test".to_string())).unwrap();
403
404        let result = subscription.next().await;
405        assert!(result.is_some());
406        assert_eq!(result.unwrap().unwrap(), "test");
407    }
408
409    #[tokio::test]
410    async fn test_subscription_next_with_error() {
411        let message_bus = Arc::new(MessageBusStub::default());
412        let (tx, rx) = broadcast::channel(100);
413        let internal = AsyncInternalSubscription::new(rx);
414
415        let mut subscription: Subscription<String> = Subscription::with_decoder(
416            internal,
417            message_bus,
418            |_context, _msg| Err(Error::Simple("decode error".into())),
419            None,
420            None,
421            None,
422            DecoderContext::default(),
423        );
424
425        // Send a message that will trigger the error
426        let msg = ResponseMessage::from("test\0");
427        tx.send(msg).unwrap();
428
429        let result = subscription.next().await;
430        assert!(result.is_some());
431        assert!(result.unwrap().is_err());
432    }
433
434    #[tokio::test]
435    async fn test_subscription_next_end_of_stream() {
436        let message_bus = Arc::new(MessageBusStub::default());
437        let (tx, rx) = broadcast::channel(100);
438        let internal = AsyncInternalSubscription::new(rx);
439
440        let mut subscription: Subscription<String> = Subscription::with_decoder(
441            internal,
442            message_bus,
443            |_context, _msg| Err(Error::EndOfStream),
444            None,
445            None,
446            None,
447            DecoderContext::default(),
448        );
449
450        // Send a message that will trigger end of stream
451        let msg = ResponseMessage::from("test\0");
452        tx.send(msg).unwrap();
453
454        let result = subscription.next().await;
455        assert!(result.is_none());
456    }
457
458    #[tokio::test]
459    async fn test_subscription_no_retries_after_end_of_stream() {
460        let message_bus = Arc::new(MessageBusStub::default());
461        let (tx, rx) = broadcast::channel(100);
462        let internal = AsyncInternalSubscription::new(rx);
463
464        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
465        let call_count_clone = call_count.clone();
466
467        let mut subscription: Subscription<String> = Subscription::with_decoder(
468            internal,
469            message_bus,
470            move |_context, _msg| {
471                let n = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
472                if n == 0 {
473                    Err(Error::EndOfStream)
474                } else {
475                    Err(Error::UnexpectedResponse(ResponseMessage::from("stray\0")))
476                }
477            },
478            None,
479            None,
480            None,
481            DecoderContext::default(),
482        );
483
484        // First message triggers EndOfStream
485        tx.send(ResponseMessage::from("end\0")).unwrap();
486        let result = subscription.next().await;
487        assert!(result.is_none());
488
489        // Send stray messages after stream ended
490        tx.send(ResponseMessage::from("stray1\0")).unwrap();
491        tx.send(ResponseMessage::from("stray2\0")).unwrap();
492
493        // Subsequent calls should return None immediately without invoking decoder
494        let result = subscription.next().await;
495        assert!(result.is_none());
496
497        // Decoder should have been called only once (for the EndOfStream message)
498        assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
499    }
500
501    #[tokio::test]
502    async fn test_subscription_skips_unexpected_messages_without_retry_limit() {
503        let message_bus = Arc::new(MessageBusStub::default());
504        let (tx, rx) = broadcast::channel(100);
505        let internal = AsyncInternalSubscription::new(rx);
506
507        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
508        let call_count_clone = call_count.clone();
509
510        // Decoder: returns UnexpectedResponse for the first 20 messages (more than
511        // MAX_DECODE_RETRIES=10), then returns a success value. If UnexpectedResponse
512        // counted toward the retry limit, the subscription would give up after 10.
513        let mut subscription: Subscription<String> = Subscription::with_decoder(
514            internal,
515            message_bus,
516            move |_context, _msg| {
517                let n = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
518                if n < 20 {
519                    Err(Error::UnexpectedResponse(ResponseMessage::from("stray\0")))
520                } else {
521                    Ok("success".to_string())
522                }
523            },
524            None,
525            None,
526            None,
527            DecoderContext::default(),
528        );
529
530        // Send 21 messages — 20 will be "unexpected" (skipped), 1 will succeed
531        for _ in 0..21 {
532            tx.send(ResponseMessage::from("msg\0")).unwrap();
533        }
534
535        let result = subscription.next().await;
536        assert!(
537            result.is_some(),
538            "subscription should not have stopped after skipping unexpected messages"
539        );
540        assert_eq!(result.unwrap().unwrap(), "success");
541        // All 21 messages should have been processed (20 skipped + 1 success)
542        assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 21);
543    }
544
545    #[tokio::test]
546    async fn test_subscription_cancel() {
547        let message_bus = Arc::new(MessageBusStub::default());
548        let (_tx, rx) = broadcast::channel(100);
549        let internal = AsyncInternalSubscription::new(rx);
550
551        // Mock cancel function
552        let cancel_fn: CancelFn = Box::new(|_version, _id, _ctx| {
553            let mut msg = RequestMessage::new();
554            msg.push_field(&OutgoingMessages::CancelMarketData);
555            Ok(msg)
556        });
557
558        let mut subscription: Subscription<String> = Subscription::with_decoder(
559            internal,
560            message_bus.clone(),
561            |_context, _msg| Ok("test".to_string()),
562            Some(123),
563            None,
564            Some(OutgoingMessages::RequestMarketData),
565            DecoderContext::default(),
566        );
567        subscription.cancel_fn = Some(Arc::new(cancel_fn));
568
569        // Cancel the subscription
570        subscription.cancel().await;
571
572        // Verify cancelled flag is set
573        assert!(subscription.cancelled.load(Ordering::Relaxed));
574
575        // Cancel again should be a no-op
576        subscription.cancel().await;
577    }
578
579    #[tokio::test]
580    async fn test_subscription_clone() {
581        let message_bus = Arc::new(MessageBusStub::default());
582        let (_tx, rx) = broadcast::channel(100);
583        let internal = AsyncInternalSubscription::new(rx);
584
585        let subscription: Subscription<String> = Subscription::with_decoder(
586            internal,
587            message_bus,
588            |_context, _msg| Ok("test".to_string()),
589            Some(456),
590            Some(789),
591            Some(OutgoingMessages::RequestPositions),
592            DecoderContext::default()
593                .with_smart_depth(true)
594                .with_request_type(OutgoingMessages::RequestPositions),
595        );
596
597        let cloned = subscription.clone();
598        assert_eq!(cloned.request_id, Some(456));
599        assert_eq!(cloned.order_id, Some(789));
600        assert_eq!(cloned._message_type, Some(OutgoingMessages::RequestPositions));
601        assert!(cloned.context.is_smart_depth);
602    }
603
604    #[tokio::test]
605    async fn test_subscription_drop_with_cancel() {
606        let message_bus = Arc::new(MessageBusStub::default());
607        let (_tx, rx) = broadcast::channel(100);
608        let internal = AsyncInternalSubscription::new(rx);
609
610        // Mock cancel function
611        let cancel_fn: CancelFn = Box::new(|_version, _id, _ctx| {
612            let mut msg = RequestMessage::new();
613            msg.push_field(&OutgoingMessages::CancelMarketData);
614            Ok(msg)
615        });
616
617        {
618            let mut subscription: Subscription<String> = Subscription::with_decoder(
619                internal,
620                message_bus.clone(),
621                |_context, _msg| Ok("test".to_string()),
622                Some(999),
623                None,
624                Some(OutgoingMessages::RequestMarketData),
625                DecoderContext::default(),
626            );
627            subscription.cancel_fn = Some(Arc::new(cancel_fn));
628            // Subscription will be dropped here and should send cancel message
629        }
630
631        // Give async task time to execute
632        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
633    }
634
635    #[tokio::test]
636    #[should_panic(expected = "Cannot clone pre-decoded subscriptions")]
637    async fn test_subscription_inner_clone_panic() {
638        let (_tx, rx) = mpsc::unbounded_channel::<Result<String, Error>>();
639        let subscription = Subscription::new(rx);
640
641        // This should panic because PreDecoded subscriptions can't be cloned
642        let _ = subscription.inner.clone();
643    }
644
645    #[tokio::test]
646    async fn test_subscription_with_context() {
647        let message_bus = Arc::new(MessageBusStub::default());
648        let (_tx, rx) = broadcast::channel(100);
649        let internal = AsyncInternalSubscription::new(rx);
650
651        let context = DecoderContext::default()
652            .with_smart_depth(true)
653            .with_request_type(OutgoingMessages::RequestMarketDepth);
654
655        let subscription: Subscription<String> = Subscription::with_decoder(
656            internal,
657            message_bus,
658            |_context, _msg| Ok("test".to_string()),
659            None,
660            None,
661            None,
662            context.clone(),
663        );
664
665        assert_eq!(subscription.context, context);
666    }
667
668    #[tokio::test]
669    async fn test_subscription_new_from_internal_simple() {
670        // Define a simple decoder type
671        struct TestDecoder;
672
673        impl StreamDecoder<String> for TestDecoder {
674            fn decode(_context: &DecoderContext, _msg: &mut ResponseMessage) -> Result<String, Error> {
675                Ok("decoded".to_string())
676            }
677
678            fn cancel_message(_server_version: i32, _id: Option<i32>, _context: Option<&DecoderContext>) -> Result<RequestMessage, Error> {
679                let mut msg = RequestMessage::new();
680                msg.push_field(&OutgoingMessages::CancelMarketData);
681                Ok(msg)
682            }
683        }
684
685        let message_bus = Arc::new(MessageBusStub::default());
686        let (_tx, rx) = broadcast::channel(100);
687        let internal = AsyncInternalSubscription::new(rx);
688
689        let subscription: Subscription<String> =
690            Subscription::new_from_internal_simple::<TestDecoder>(internal, DecoderContext::default(), message_bus);
691
692        assert!(subscription.cancel_fn.is_some());
693    }
694}