betfair_stream_api/connection/
mod.rs1pub(crate) mod builder;
3pub(crate) mod cron;
4pub(crate) mod handshake;
5
6use core::convert::Infallible as Never;
7use core::task::Poll;
8
9use betfair_stream_types::request::RequestMessage;
10use betfair_stream_types::response::connection_message::ConnectionMessage;
11use betfair_stream_types::response::status_message::StatusMessage;
12use betfair_stream_types::response::ResponseMessage;
13use futures::Stream;
14use tokio::task::JoinSet;
15
16use self::builder::wrap_with_cache_layer;
17use self::cron::FatalError;
18use crate::cache::primitives::{MarketBookCache, OrderBookCache};
19
20#[derive(Debug)]
22#[pin_project::pin_project]
23pub struct StreamApi<T> {
24 join_set: JoinSet<Result<Never, FatalError>>,
25 rt_handle: tokio::runtime::Handle,
26 is_shutting_down: bool,
27 data_feed: tokio::sync::mpsc::Receiver<ExternalUpdates<T>>,
28 command_sender: tokio::sync::broadcast::Sender<RequestMessage>,
29}
30
31impl<T> StreamApi<T> {
32 #[must_use] pub const fn new(
34 join_set: JoinSet<Result<Never, FatalError>>,
35 data_feed: tokio::sync::mpsc::Receiver<ExternalUpdates<T>>,
36 command_sender: tokio::sync::broadcast::Sender<RequestMessage>,
37 rt_handle: tokio::runtime::Handle,
38 ) -> Self {
39 Self {
40 is_shutting_down: false,
41 join_set,
42 rt_handle,
43 data_feed,
44 command_sender,
45 }
46 }
47
48 #[must_use]
50 pub const fn command_sender(&self) -> &tokio::sync::broadcast::Sender<RequestMessage> {
51 &self.command_sender
52 }
53}
54
55impl StreamApi<ResponseMessage> {
56 #[must_use]
58 pub fn enable_cache(mut self) -> StreamApi<CacheEnabledMessages> {
59 let output_queue_reader_post_cache =
60 wrap_with_cache_layer(&mut self.join_set, self.data_feed, &self.rt_handle);
61 StreamApi {
62 join_set: self.join_set,
63 rt_handle: self.rt_handle,
64 is_shutting_down: self.is_shutting_down,
65 data_feed: output_queue_reader_post_cache,
66 command_sender: self.command_sender,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73pub enum ExternalUpdates<T> {
74 Layer(T),
76 Metadata(MetadataUpdates),
78}
79
80#[derive(Debug, Clone)]
82pub enum CacheEnabledMessages {
83 MarketChange(Vec<MarketBookCache>),
85 OrderChange(Vec<OrderBookCache>),
87 Connection(ConnectionMessage),
89 Status(StatusMessage),
91}
92
93#[derive(Debug, Clone)]
95pub enum MetadataUpdates {
96 Disconnected,
98 TcpConnected,
100 FailedToConnect,
102 AuthenticationMessageSent,
104 Authenticated {
106 connections_available: i32,
108 connection_id: Option<String>,
110 },
111 FailedToAuthenticate,
113}
114
115impl<T> Stream for StreamApi<T> {
116 type Item = ExternalUpdates<T>;
117
118 fn poll_next(
119 mut self: core::pin::Pin<&mut Self>,
120 cx: &mut core::task::Context<'_>,
121 ) -> Poll<Option<Self::Item>> {
122 if self.join_set.is_empty() && self.is_shutting_down {
124 tracing::warn!("StreamApiConnection: No tasks remaining, shutting down.");
125 return Poll::Ready(None);
126 }
127
128 match self.join_set.poll_join_next(cx) {
130 Poll::Ready(Some(Ok(Err(err)))) => {
131 tracing::error!(?err, "Error returned by a task");
132 self.join_set.abort_all();
133 self.is_shutting_down = true;
134 cx.waker().wake_by_ref();
135 }
136 Poll::Ready(Some(Ok(Ok(_e)))) => {
137 cx.waker().wake_by_ref();
138 return Poll::Pending;
139 }
140 Poll::Ready(Some(Err(err))) => {
141 tracing::error!(?err, "Error in join_set");
142 self.join_set.abort_all();
143 self.is_shutting_down = true;
144 cx.waker().wake_by_ref();
145 }
146 Poll::Ready(None) => {
147 self.is_shutting_down = true;
149 }
150 Poll::Pending => {}
151 }
152
153 match self.data_feed.poll_recv(cx) {
155 Poll::Ready(Some(update)) => Poll::Ready(Some(update)),
156 Poll::Ready(None) => {
157 tracing::warn!("StreamApiConnection: Data feed closed.");
159 self.join_set.abort_all();
160 self.is_shutting_down = true;
161 cx.waker().wake_by_ref();
162 Poll::Ready(None)
163 }
164 Poll::Pending if self.is_shutting_down => {
165 Poll::Ready(None)
167 }
168 Poll::Pending => Poll::Pending,
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use core::time::Duration;
176
177 use futures::stream::StreamExt;
178 use tokio::sync::{broadcast, mpsc};
179 use tokio::task::JoinSet;
180 use tokio::time::timeout;
181
182 use super::*;
183
184 #[tokio::test]
185 async fn stream_api_connection_poll_next_shuts_down_on_empty_join_set_when_shutting_down() {
186 let (_data_sender, data_receiver) = mpsc::channel(10);
187 let (command_sender, _) = broadcast::channel(10);
188 let join_set = JoinSet::new();
189 let handle = tokio::runtime::Handle::current();
190
191 let mut connection = StreamApi::<()>::new(join_set, data_receiver, command_sender, handle);
192 connection.is_shutting_down = true;
193
194 assert!(
195 connection.next().await.is_none(),
196 "Stream should shut down immediately when no tasks are left"
197 );
198 }
199
200 #[tokio::test]
201 async fn stream_api_connection_poll_next_shuts_down_on_empty_join_set() {
202 let (_data_sender, data_receiver) = mpsc::channel(10);
203 let (command_sender, _) = broadcast::channel(10);
204 let join_set = JoinSet::new();
205 let handle = tokio::runtime::Handle::current();
206
207 let mut connection = StreamApi::<()>::new(join_set, data_receiver, command_sender, handle);
208
209 assert!(
210 connection.next().await.is_none(),
211 "Stream should shut down immediately when no tasks are left"
212 );
213 }
214
215 #[tokio::test]
216 async fn stream_api_connection_receives_updates() {
217 let (data_sender, data_receiver) = mpsc::channel(10);
218 let (command_sender, _) = broadcast::channel(10);
219 let mut join_set = JoinSet::new();
220 join_set.spawn(futures::future::pending());
221 let handle = tokio::runtime::Handle::current();
222
223 let mut connection = StreamApi::new(join_set, data_receiver, command_sender, handle);
224
225 let expected_update = ExternalUpdates::Layer("Test".to_owned());
226 data_sender.send(expected_update.clone()).await.unwrap();
227
228 match connection.next().await {
229 Some(update) => match update {
230 ExternalUpdates::Layer(content) => assert_eq!(content, "Test"),
231 _ => panic!("Unexpected update type"),
232 },
233 _ => panic!("Expected to receive an update"),
234 }
235
236 assert!(
237 timeout(Duration::from_millis(100), connection.next())
238 .await
239 .is_err(),
240 "Stream should remain pending after receiving an update"
241 );
242 }
243
244 #[tokio::test]
245 async fn stream_api_connection_receives_updates_then_closes_empty_join_set() {
246 let (data_sender, data_receiver) = mpsc::channel(10);
247 let (command_sender, _) = broadcast::channel(10);
248 let join_set = JoinSet::new();
249 let handle = tokio::runtime::Handle::current();
250
251 let mut connection = StreamApi::new(join_set, data_receiver, command_sender, handle);
252
253 let expected_update = ExternalUpdates::Layer("Test".to_owned());
254 data_sender.send(expected_update.clone()).await.unwrap();
255
256 match connection.next().await {
257 Some(update) => match update {
258 ExternalUpdates::Layer(content) => assert_eq!(content, "Test"),
259 _ => panic!("Unexpected update type"),
260 },
261 _ => panic!("Expected to receive an update"),
262 }
263
264 assert!(
265 connection.next().await.is_none(),
266 "Stream should return None after receiving an update and closing"
267 );
268 }
269
270 #[tokio::test]
271 async fn stream_api_connection_closes_after_join_set_returns() {
272 let (data_sender, data_receiver) = mpsc::channel(10);
273 let (command_sender, _) = broadcast::channel(10);
274 let mut join_set = JoinSet::new();
275 join_set.spawn(futures::future::ready(Err(FatalError)));
276 let handle = tokio::runtime::Handle::current();
277
278 let mut connection = StreamApi::new(join_set, data_receiver, command_sender, handle);
279
280 let expected_update = ExternalUpdates::Layer("Test".to_owned());
281 data_sender.send(expected_update.clone()).await.unwrap();
282
283 match connection.next().await {
284 Some(update) => match update {
285 ExternalUpdates::Layer(content) => assert_eq!(content, "Test"),
286 _ => panic!("Unexpected update type"),
287 },
288 _ => panic!("Expected to receive an update"),
289 }
290
291 assert!(
292 connection.next().await.is_none(),
293 "Stream should return None after receiving an update and closing"
294 );
295 }
296
297 #[tokio::test]
298 async fn stream_api_connection_shuts_down_on_data_feed_close() {
299 let (data_sender, data_receiver) = mpsc::channel::<ExternalUpdates<String>>(1);
300 let (command_sender, _) = broadcast::channel(10);
301 let join_set = JoinSet::new();
302 let handle = tokio::runtime::Handle::current();
303
304 let mut connection = StreamApi::new(join_set, data_receiver, command_sender, handle);
305
306 drop(data_sender); assert!(
309 connection.next().await.is_none(),
310 "Stream should have ended due to data feed close"
311 );
312 }
313}