betfair_stream_api/connection/
mod.rs

1/// Module for managing the connection to the Betfair streaming API.
2pub(crate) mod builder;
3pub(crate) mod cron;
4pub(crate) mod handshake;
5
6use core::convert::Infallible as Never;
7use core::task::Poll;
8
9use betfair_stream_types::request::RequestMessage;
10use betfair_stream_types::response::connection_message::ConnectionMessage;
11use betfair_stream_types::response::status_message::StatusMessage;
12use betfair_stream_types::response::ResponseMessage;
13use futures::Stream;
14use tokio::task::JoinSet;
15
16use self::builder::wrap_with_cache_layer;
17use self::cron::FatalError;
18use crate::cache::primitives::{MarketBookCache, OrderBookCache};
19
20/// Represents the streaming API connection.
21#[derive(Debug)]
22#[pin_project::pin_project]
23pub struct StreamApi<T> {
24    join_set: JoinSet<Result<Never, FatalError>>,
25    rt_handle: tokio::runtime::Handle,
26    is_shutting_down: bool,
27    data_feed: tokio::sync::mpsc::Receiver<ExternalUpdates<T>>,
28    command_sender: tokio::sync::broadcast::Sender<RequestMessage>,
29}
30
31impl<T> StreamApi<T> {
32    /// Creates a new instance of `StreamApi`.
33    #[must_use] pub const fn new(
34        join_set: JoinSet<Result<Never, FatalError>>,
35        data_feed: tokio::sync::mpsc::Receiver<ExternalUpdates<T>>,
36        command_sender: tokio::sync::broadcast::Sender<RequestMessage>,
37        rt_handle: tokio::runtime::Handle,
38    ) -> Self {
39        Self {
40            is_shutting_down: false,
41            join_set,
42            rt_handle,
43            data_feed,
44            command_sender,
45        }
46    }
47
48    /// Returns a reference to the command sender.
49    #[must_use]
50    pub const fn command_sender(&self) -> &tokio::sync::broadcast::Sender<RequestMessage> {
51        &self.command_sender
52    }
53}
54
55impl StreamApi<ResponseMessage> {
56    /// Enables caching for the stream API.
57    #[must_use]
58    pub fn enable_cache(mut self) -> StreamApi<CacheEnabledMessages> {
59        let output_queue_reader_post_cache =
60            wrap_with_cache_layer(&mut self.join_set, self.data_feed, &self.rt_handle);
61        StreamApi {
62            join_set: self.join_set,
63            rt_handle: self.rt_handle,
64            is_shutting_down: self.is_shutting_down,
65            data_feed: output_queue_reader_post_cache,
66            command_sender: self.command_sender,
67        }
68    }
69}
70
71/// Represents external updates received from the data feed.
72#[derive(Debug, Clone)]
73pub enum ExternalUpdates<T> {
74    /// Represents a layer of data.
75    Layer(T),
76    /// Represents metadata updates.
77    Metadata(MetadataUpdates),
78}
79
80/// Represents messages that have caching enabled.
81#[derive(Debug, Clone)]
82pub enum CacheEnabledMessages {
83    /// Represents market changes.
84    MarketChange(Vec<MarketBookCache>),
85    /// Represents order changes.
86    OrderChange(Vec<OrderBookCache>),
87    /// Represents connection messages.
88    Connection(ConnectionMessage),
89    /// Represents status messages.
90    Status(StatusMessage),
91}
92
93/// Represents various metadata updates related to the connection state.
94#[derive(Debug, Clone)]
95pub enum MetadataUpdates {
96    /// Indicates disconnection.
97    Disconnected,
98    /// Indicates TCP connection established.
99    TcpConnected,
100    /// Indicates failure to connect.
101    FailedToConnect,
102    /// Indicates authentication message sent.
103    AuthenticationMessageSent,
104    /// Indicates successful authentication.
105    Authenticated {
106        /// Number of available connections.
107        connections_available: i32,
108        /// Optional connection ID.
109        connection_id: Option<String>,
110    },
111    /// Indicates failure to authenticate.
112    FailedToAuthenticate,
113}
114
115impl<T> Stream for StreamApi<T> {
116    type Item = ExternalUpdates<T>;
117
118    fn poll_next(
119        mut self: core::pin::Pin<&mut Self>,
120        cx: &mut core::task::Context<'_>,
121    ) -> Poll<Option<Self::Item>> {
122        // only return None if we are shutting down and there are no tasks left
123        if self.join_set.is_empty() && self.is_shutting_down {
124            tracing::warn!("StreamApiConnection: No tasks remaining, shutting down.");
125            return Poll::Ready(None);
126        }
127
128        // Poll the join set to check if any child tasks have completed
129        match self.join_set.poll_join_next(cx) {
130            Poll::Ready(Some(Ok(Err(err)))) => {
131                tracing::error!(?err, "Error returned by a task");
132                self.join_set.abort_all();
133                self.is_shutting_down = true;
134                cx.waker().wake_by_ref();
135            }
136            Poll::Ready(Some(Ok(Ok(_e)))) => {
137                cx.waker().wake_by_ref();
138                return Poll::Pending;
139            }
140            Poll::Ready(Some(Err(err))) => {
141                tracing::error!(?err, "Error in join_set");
142                self.join_set.abort_all();
143                self.is_shutting_down = true;
144                cx.waker().wake_by_ref();
145            }
146            Poll::Ready(None) => {
147                // All tasks have completed; commence shutdown
148                self.is_shutting_down = true;
149            }
150            Poll::Pending => {}
151        }
152
153        // Poll the data feed for new items
154        match self.data_feed.poll_recv(cx) {
155            Poll::Ready(Some(update)) => Poll::Ready(Some(update)),
156            Poll::Ready(None) => {
157                // No more data, initiate shutdown
158                tracing::warn!("StreamApiConnection: Data feed closed.");
159                self.join_set.abort_all();
160                self.is_shutting_down = true;
161                cx.waker().wake_by_ref();
162                Poll::Ready(None)
163            }
164            Poll::Pending if self.is_shutting_down => {
165                // If shutting down and no data available, end the stream
166                Poll::Ready(None)
167            }
168            Poll::Pending => Poll::Pending,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use core::time::Duration;
176
177    use futures::stream::StreamExt;
178    use tokio::sync::{broadcast, mpsc};
179    use tokio::task::JoinSet;
180    use tokio::time::timeout;
181
182    use super::*;
183
184    #[tokio::test]
185    async fn stream_api_connection_poll_next_shuts_down_on_empty_join_set_when_shutting_down() {
186        let (_data_sender, data_receiver) = mpsc::channel(10);
187        let (command_sender, _) = broadcast::channel(10);
188        let join_set = JoinSet::new();
189        let handle = tokio::runtime::Handle::current();
190
191        let mut connection = StreamApi::<()>::new(join_set, data_receiver, command_sender, handle);
192        connection.is_shutting_down = true;
193
194        assert!(
195            connection.next().await.is_none(),
196            "Stream should shut down immediately when no tasks are left"
197        );
198    }
199
200    #[tokio::test]
201    async fn stream_api_connection_poll_next_shuts_down_on_empty_join_set() {
202        let (_data_sender, data_receiver) = mpsc::channel(10);
203        let (command_sender, _) = broadcast::channel(10);
204        let join_set = JoinSet::new();
205        let handle = tokio::runtime::Handle::current();
206
207        let mut connection = StreamApi::<()>::new(join_set, data_receiver, command_sender, handle);
208
209        assert!(
210            connection.next().await.is_none(),
211            "Stream should shut down immediately when no tasks are left"
212        );
213    }
214
215    #[tokio::test]
216    async fn stream_api_connection_receives_updates() {
217        let (data_sender, data_receiver) = mpsc::channel(10);
218        let (command_sender, _) = broadcast::channel(10);
219        let mut join_set = JoinSet::new();
220        join_set.spawn(futures::future::pending());
221        let handle = tokio::runtime::Handle::current();
222
223        let mut connection = StreamApi::new(join_set, data_receiver, command_sender, handle);
224
225        let expected_update = ExternalUpdates::Layer("Test".to_owned());
226        data_sender.send(expected_update.clone()).await.unwrap();
227
228        match connection.next().await {
229            Some(update) => match update {
230                ExternalUpdates::Layer(content) => assert_eq!(content, "Test"),
231                _ => panic!("Unexpected update type"),
232            },
233            _ => panic!("Expected to receive an update"),
234        }
235
236        assert!(
237            timeout(Duration::from_millis(100), connection.next())
238                .await
239                .is_err(),
240            "Stream should remain pending after receiving an update"
241        );
242    }
243
244    #[tokio::test]
245    async fn stream_api_connection_receives_updates_then_closes_empty_join_set() {
246        let (data_sender, data_receiver) = mpsc::channel(10);
247        let (command_sender, _) = broadcast::channel(10);
248        let join_set = JoinSet::new();
249        let handle = tokio::runtime::Handle::current();
250
251        let mut connection = StreamApi::new(join_set, data_receiver, command_sender, handle);
252
253        let expected_update = ExternalUpdates::Layer("Test".to_owned());
254        data_sender.send(expected_update.clone()).await.unwrap();
255
256        match connection.next().await {
257            Some(update) => match update {
258                ExternalUpdates::Layer(content) => assert_eq!(content, "Test"),
259                _ => panic!("Unexpected update type"),
260            },
261            _ => panic!("Expected to receive an update"),
262        }
263
264        assert!(
265            connection.next().await.is_none(),
266            "Stream should return None after receiving an update and closing"
267        );
268    }
269
270    #[tokio::test]
271    async fn stream_api_connection_closes_after_join_set_returns() {
272        let (data_sender, data_receiver) = mpsc::channel(10);
273        let (command_sender, _) = broadcast::channel(10);
274        let mut join_set = JoinSet::new();
275        join_set.spawn(futures::future::ready(Err(FatalError)));
276        let handle = tokio::runtime::Handle::current();
277
278        let mut connection = StreamApi::new(join_set, data_receiver, command_sender, handle);
279
280        let expected_update = ExternalUpdates::Layer("Test".to_owned());
281        data_sender.send(expected_update.clone()).await.unwrap();
282
283        match connection.next().await {
284            Some(update) => match update {
285                ExternalUpdates::Layer(content) => assert_eq!(content, "Test"),
286                _ => panic!("Unexpected update type"),
287            },
288            _ => panic!("Expected to receive an update"),
289        }
290
291        assert!(
292            connection.next().await.is_none(),
293            "Stream should return None after receiving an update and closing"
294        );
295    }
296
297    #[tokio::test]
298    async fn stream_api_connection_shuts_down_on_data_feed_close() {
299        let (data_sender, data_receiver) = mpsc::channel::<ExternalUpdates<String>>(1);
300        let (command_sender, _) = broadcast::channel(10);
301        let join_set = JoinSet::new();
302        let handle = tokio::runtime::Handle::current();
303
304        let mut connection = StreamApi::new(join_set, data_receiver, command_sender, handle);
305
306        drop(data_sender); // This closes the data_feed channel
307
308        assert!(
309            connection.next().await.is_none(),
310            "Stream should have ended due to data feed close"
311        );
312    }
313}