1extern crate alloc;
8pub mod cache;
9use backon::{BackoffBuilder as _, ExponentialBuilder};
10use betfair_adapter::{Authenticated, BetfairRpcClient, Unauthenticated};
11pub use betfair_stream_types as types;
12use betfair_stream_types::{
13 request::{RequestMessage, authentication_message, heartbeat_message::HeartbeatMessage},
14 response::{
15 ResponseMessage,
16 connection_message::ConnectionMessage,
17 status_message::{ErrorCode, StatusMessage},
18 },
19};
20use cache::{
21 primitives::{MarketBookCache, OrderBookCache},
22 tracker::StreamState,
23};
24use core::fmt;
25use core::{pin::pin, time::Duration};
26use eyre::Context as _;
27use futures::{
28 SinkExt as _, StreamExt as _,
29 future::{self, select},
30};
31use std::sync::Arc;
32use tokio::{
33 net::TcpStream,
34 sync::mpsc::{self, Receiver, Sender},
35 task::JoinHandle,
36 time::sleep,
37};
38use tokio_stream::wrappers::{IntervalStream, ReceiverStream};
39use tokio_util::{
40 bytes,
41 codec::{Decoder, Encoder, Framed},
42};
43
44pub struct BetfairStreamBuilder<T: MessageProcessor> {
52 pub client: BetfairRpcClient<Unauthenticated>,
54 pub heartbeat_interval: Option<Duration>,
56 pub processor: T,
58}
59
60pub struct BetfairStreamClient<T: MessageProcessor> {
64 pub send_to_stream: Sender<RequestMessage>,
66 pub sink: Receiver<T::Output>,
68}
69
70pub struct Cache {
74 state: StreamState,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
82pub enum CachedMessage {
83 Connection(ConnectionMessage),
87
88 MarketChange(Vec<MarketBookCache>),
90
91 OrderChange(Vec<OrderBookCache>),
94
95 Status(StatusMessage),
98}
99
100impl MessageProcessor for Cache {
101 type Output = CachedMessage;
102
103 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
104 match message {
105 ResponseMessage::Connection(connection_message) => {
106 Some(CachedMessage::Connection(connection_message))
107 }
108 ResponseMessage::MarketChange(market_change_message) => {
109 let data = self
110 .state
111 .market_change_update(market_change_message)
112 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
113 .map(CachedMessage::MarketChange);
114
115 data
116 }
117 ResponseMessage::OrderChange(order_change_message) => {
118 let data = self
119 .state
120 .order_change_update(order_change_message)
121 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
122 .map(CachedMessage::OrderChange);
123
124 data
125 }
126 ResponseMessage::Status(status_message) => Some(CachedMessage::Status(status_message)),
127 }
128 }
129}
130
131pub struct Forwarder;
133impl MessageProcessor for Forwarder {
134 type Output = ResponseMessage;
135
136 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
137 Some(message)
138 }
139}
140pub trait MessageProcessor: Send + Sync + 'static {
144 type Output: Send + Clone + Sync + 'static + core::fmt::Debug;
146
147 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output>;
151}
152
153impl<T: MessageProcessor> BetfairStreamBuilder<T> {
154 pub fn new(client: BetfairRpcClient<Unauthenticated>) -> BetfairStreamBuilder<Cache> {
167 BetfairStreamBuilder {
168 client,
169 heartbeat_interval: None,
170 processor: Cache {
171 state: StreamState::new(),
172 },
173 }
174 }
175
176 pub fn new_without_cache(
189 client: BetfairRpcClient<Unauthenticated>,
190 ) -> BetfairStreamBuilder<Forwarder> {
191 BetfairStreamBuilder {
192 client,
193 heartbeat_interval: None,
194 processor: Forwarder,
195 }
196 }
197
198 pub fn with_heartbeat(mut self, interval: Duration) -> Self {
208 self.heartbeat_interval = Some(interval);
209 self
210 }
211
212 pub async fn start(self) -> (BetfairStreamClient<T>, JoinHandle<eyre::Result<()>>) {
224 let (to_stream_tx, to_stream_rx) = mpsc::channel(100);
225 let (from_stream_tx, from_stream_rx) = mpsc::channel(100);
226
227 let task = tokio::task::spawn(self.run(from_stream_tx, to_stream_rx));
228
229 (
230 BetfairStreamClient {
231 send_to_stream: to_stream_tx,
232 sink: from_stream_rx,
233 },
234 task,
235 )
236 }
237
238 async fn run(
239 self,
240 from_stream_tx: Sender<T::Output>,
241 to_stream_rx: Receiver<RequestMessage>,
242 ) -> eyre::Result<()> {
243 if let Some(hb) = self.heartbeat_interval {
244 let heartbeat_stream = {
245 let mut interval = tokio::time::interval(hb);
246 interval.reset();
247 let interval_stream = IntervalStream::new(interval).fuse();
248 interval_stream
249 .map(move |instant| HeartbeatMessage {
250 id: Some(
251 instant
252 .into_std()
253 .elapsed()
254 .as_secs()
255 .try_into()
256 .unwrap_or_default(),
257 ),
258 })
259 .map(RequestMessage::Heartbeat)
260 .boxed()
261 };
262 let input_stream = futures::stream::select_all([
263 heartbeat_stream,
264 ReceiverStream::new(to_stream_rx).boxed(),
265 ]);
266
267 self.run_base(from_stream_tx, input_stream).await
268 } else {
269 self.run_base(from_stream_tx, ReceiverStream::new(to_stream_rx))
270 .await
271 }
272 }
273
274 async fn run_base(
275 mut self,
276 mut from_stream_tx: Sender<T::Output>,
277 mut to_stream_rx: impl futures::Stream<Item = RequestMessage> + Unpin,
278 ) -> eyre::Result<()> {
279 let (mut client, _) = self.client.clone().authenticate().await?;
280 let mut backoff = ExponentialBuilder::new().build();
281 let mut first_call = true;
282 'retry: loop {
283 if !first_call {
284 let Some(delay) = backoff.next() else {
286 eyre::bail!("connection retry attempts exceeded")
287 };
288 sleep(delay).await;
289 }
290 first_call = true;
291
292 let mut stream = self
294 .connect_with_retry(&mut from_stream_tx, &mut client)
295 .await?;
296 tracing::info!("Connected to {}", self.client.stream.url());
297
298 loop {
299 let stream_next = pin!(stream.next());
300 let to_stream_rx_next = pin!(to_stream_rx.next());
301 match select(to_stream_rx_next, stream_next).await {
302 future::Either::Left((request, _)) => {
303 let Some(request) = request else {
304 tracing::warn!("request returned None");
305 continue 'retry;
306 };
307
308 tracing::debug!(?request, "sending to betfair");
309 let Ok(()) = stream.send(request).await else {
310 tracing::warn!("could not send request to stream");
311 continue 'retry;
312 };
313 }
314 future::Either::Right((message, _)) => {
315 let Some(message) = message else {
316 tracing::warn!("stream returned None");
317 continue 'retry;
318 };
319
320 match message {
321 Ok(message) => {
322 let message = self.processor.process_message(message);
323 tracing::debug!(?message, "received from betfair");
324 let Some(message) = message else {
325 continue;
326 };
327
328 let Ok(()) = from_stream_tx.send(message).await else {
329 tracing::warn!("could not send stream message to sink");
330 continue 'retry;
331 };
332 }
333 Err(err) => tracing::warn!(?err, "reading message error"),
334 }
335 }
336 }
337 }
338 }
339 }
340
341 #[tracing::instrument(skip_all, err)]
343 async fn connect_with_retry(
344 &mut self,
345 from_stream_tx: &mut Sender<T::Output>,
346 client: &mut Arc<BetfairRpcClient<Authenticated>>,
347 ) -> eyre::Result<Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>>
348 {
349 let mut backoff = ExponentialBuilder::new().build();
350 let mut delay = async || {
351 if let Some(delay) = backoff.next() {
352 sleep(delay).await;
353 Ok(())
354 } else {
355 eyre::bail!("exceeded retry attempts, could not connect");
356 }
357 };
358
359 loop {
360 let server_addr = self.client.stream.url();
361 let host = server_addr
362 .host_str()
363 .ok_or_else(|| eyre::eyre!("invalid betfair url"))?;
364 let port = server_addr.port().unwrap_or(443);
365
366 let domain_str = server_addr
367 .domain()
368 .ok_or_else(|| eyre::eyre!("domain must be known"))?;
369 let domain = rustls::pki_types::ServerName::try_from(domain_str.to_owned())
370 .wrap_err("failed to parse server name")?;
371
372 let Some(socket_addr) = tokio::net::lookup_host((host, port)).await?.next() else {
374 eyre::bail!("no valid socket addresses for {host}:{port}")
375 };
376
377 let tcp_stream = TcpStream::connect(socket_addr).await;
378 let Ok(stream) = tcp_stream else {
379 tracing::error!(err = ?tcp_stream.unwrap_err(), "Connect error. Retrying...");
380 delay().await?;
381 continue;
382 };
383 let tls_stream = tls_connector()?.connect(domain.clone(), stream).await?;
384 let mut tls_stream = Framed::new(tls_stream, StreamAPIClientCodec);
385
386 match self
387 .handshake(from_stream_tx, &client, &mut tls_stream)
388 .await
389 {
390 Ok(()) => return Ok(tls_stream),
391 Err(err) => match err {
392 HandshakeErr::WaitAndRetry => {
393 delay().await?;
394 continue;
395 }
396 HandshakeErr::Reauthenticate => {
397 let (new_client, _) = self.client.clone().authenticate().await?;
398 *client = new_client;
399 delay().await?;
400 continue;
401 }
402 HandshakeErr::Fatal => eyre::bail!("fatal error in stream processing"),
403 },
404 }
405 }
406 }
407
408 #[tracing::instrument(err, skip_all)]
409 async fn handshake(
410 &mut self,
411 from_stream_tx: &mut Sender<T::Output>,
412 client: &BetfairRpcClient<Authenticated>,
413 stream: &mut Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>,
414 ) -> Result<(), HandshakeErr> {
415 let res = stream
417 .next()
418 .await
419 .transpose()
420 .inspect_err(|err| {
421 tracing::warn!(?err, "error when parsing stream message");
422 })
423 .map_err(|_| HandshakeErr::WaitAndRetry)?
424 .ok_or(HandshakeErr::WaitAndRetry)?;
425 tracing::info!(?res, "message from stream");
426 let message = self
427 .processor
428 .process_message(res.clone())
429 .ok_or(HandshakeErr::Fatal)
430 .inspect_err(|err| tracing::error!(?err))?;
431 from_stream_tx
432 .send(message.clone())
433 .await
434 .inspect_err(|err| tracing::warn!(?err))
435 .map_err(|_| HandshakeErr::Fatal)?;
436 let ResponseMessage::Connection(_) = &res else {
437 tracing::warn!("stream responded with invalid connection message");
438 return Err(HandshakeErr::Reauthenticate);
439 };
440
441 let msg = authentication_message::AuthenticationMessage {
443 id: Some(-1),
444 session: client.session_token().0.expose_secret().clone(),
445 app_key: self
446 .client
447 .secret_provider
448 .application_key
449 .0
450 .expose_secret()
451 .clone(),
452 };
453 stream
454 .send(RequestMessage::Authentication(msg))
455 .await
456 .inspect_err(|err| tracing::warn!(?err, "stream exited"))
457 .map_err(|_| HandshakeErr::WaitAndRetry)?;
458
459 let message = stream
461 .next()
462 .await
463 .transpose()
464 .inspect_err(|err| {
465 tracing::warn!(?err, "error when parsing stream message");
466 })
467 .map_err(|_| HandshakeErr::WaitAndRetry)?
468 .ok_or(HandshakeErr::WaitAndRetry)?;
469 let processed_message = self
470 .processor
471 .process_message(message.clone())
472 .ok_or(HandshakeErr::Fatal)
473 .inspect_err(|err| tracing::warn!(?err))
474 .map_err(|_| HandshakeErr::Fatal)?;
475 from_stream_tx
476 .send(processed_message)
477 .await
478 .inspect_err(|err| tracing::warn!(?err))
479 .map_err(|_| HandshakeErr::Fatal)?;
480 tracing::info!(?message, "message from stream");
481 let ResponseMessage::Status(status_message) = &message else {
482 tracing::warn!("expected status message, got {message:?}");
483 return Err(HandshakeErr::WaitAndRetry);
484 };
485
486 let StatusMessage::Failure(err) = &status_message else {
487 return Ok(());
488 };
489
490 tracing::error!(?err, "stream respondend wit an error");
491 let action = match err.error_code {
492 ErrorCode::NoAppKey => HandshakeErr::Fatal,
493 ErrorCode::InvalidAppKey => HandshakeErr::Fatal,
494 ErrorCode::NoSession => HandshakeErr::Reauthenticate,
495 ErrorCode::InvalidSessionInformation => HandshakeErr::Reauthenticate,
496 ErrorCode::NotAuthorized => HandshakeErr::Reauthenticate,
497 ErrorCode::InvalidInput => HandshakeErr::Fatal,
498 ErrorCode::InvalidClock => HandshakeErr::Fatal,
499 ErrorCode::UnexpectedError => HandshakeErr::Fatal,
500 ErrorCode::Timeout => HandshakeErr::WaitAndRetry,
501 ErrorCode::SubscriptionLimitExceeded => HandshakeErr::WaitAndRetry,
502 ErrorCode::InvalidRequest => HandshakeErr::Fatal,
503 ErrorCode::ConnectionFailed => HandshakeErr::WaitAndRetry,
504 ErrorCode::MaxConnectionLimitExceeded => HandshakeErr::Fatal,
505 ErrorCode::TooManyRequests => HandshakeErr::WaitAndRetry,
506 };
507
508 Err(action)
509 }
510}
511
512#[derive(Debug)]
513enum HandshakeErr {
514 WaitAndRetry,
515 Reauthenticate,
516 Fatal,
517}
518
519impl fmt::Display for HandshakeErr {
520 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
521 write!(f, "Stream Handshake Error {:?}", self)
522 }
523}
524
525impl core::error::Error for HandshakeErr {}
526
527#[tracing::instrument(err)]
528fn tls_connector() -> eyre::Result<tokio_rustls::TlsConnector> {
529 use tokio_rustls::TlsConnector;
530
531 let mut roots = rustls::RootCertStore::empty();
532 let native_certs = rustls_native_certs::load_native_certs();
533 for cert in native_certs.certs {
534 roots.add(cert)?;
535 }
536
537 let config = rustls::ClientConfig::builder()
538 .with_root_certificates(roots)
539 .with_no_client_auth();
540 Ok(TlsConnector::from(Arc::new(config)))
541}
542
543pub struct StreamAPIClientCodec;
545
546impl Decoder for StreamAPIClientCodec {
547 type Item = ResponseMessage;
548 type Error = eyre::Report;
549
550 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
551 if let Some(pos) = src.iter().position(|&byte| byte == b'\n') {
553 let delimiter_size = if pos > 0 && src[pos - 1] == b'\r' {
555 2
556 } else {
557 1
558 };
559
560 let line = src.split_to(pos + 1);
562
563 let (json_part, _) = line.split_at(line.len().saturating_sub(delimiter_size));
565
566 let data = serde_json::from_slice::<Self::Item>(json_part)?;
568 return Ok(Some(data));
569 }
570 Ok(None)
571 }
572}
573
574impl Encoder<RequestMessage> for StreamAPIClientCodec {
575 type Error = eyre::Report;
576
577 fn encode(
578 &mut self,
579 item: RequestMessage,
580 dst: &mut bytes::BytesMut,
581 ) -> Result<(), Self::Error> {
582 let json = serde_json::to_string(&item)?;
584 dst.extend_from_slice(json.as_bytes());
586 dst.extend_from_slice(b"\r\n");
587 Ok(())
588 }
589}
590
591#[cfg(test)]
592mod tests {
593
594 use core::fmt::Write as _;
595
596 use super::*;
597
598 #[tokio::test]
599 async fn can_resolve_host_ipv4() {
600 let url = url::Url::parse("tcptls://stream-api.betfair.com:443").unwrap();
601 let host = url.host_str().unwrap();
602 let port = url
603 .port()
604 .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
605 let socket_addr = tokio::net::lookup_host((host, port))
606 .await
607 .unwrap()
608 .next()
609 .unwrap();
610 assert!(socket_addr.ip().is_ipv4());
611 assert_eq!(socket_addr.port(), 443);
612 }
613
614 #[test]
615 fn can_decode_single_message() {
616 let msg = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
617 let separator = "\r\n";
618 let data = format!("{msg}{separator}");
619
620 let mut codec = StreamAPIClientCodec;
621 let mut buf = bytes::BytesMut::from(data.as_bytes());
622 let msg = codec.decode(&mut buf).unwrap().unwrap();
623
624 assert!(matches!(msg, ResponseMessage::Connection(_)));
625 }
626
627 #[test]
628 fn can_decode_multiple_messages() {
629 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
631 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
632 let separator = "\r\n";
633 let data = format!("{msg_one}{separator}{msg_two}{separator}");
634
635 let mut codec = StreamAPIClientCodec;
636 let mut buf = bytes::BytesMut::from(data.as_bytes());
637 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
638 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
639
640 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
641 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
642 }
643
644 #[test]
645 fn can_decode_multiple_partial_messages() {
646 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
648 let msg_two_pt_one = r#"{"op":"ocm","id":3,"clk""#;
649 let msg_two_pt_two = r#":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
650 let separator = "\r\n";
651 let data = format!("{msg_one}{separator}{msg_two_pt_one}");
652
653 let mut codec = StreamAPIClientCodec;
654 let mut buf = bytes::BytesMut::from(data.as_bytes());
655 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
656 let msg_two_attempt = codec.decode(&mut buf).unwrap();
657 assert!(msg_two_attempt.is_none());
658 buf.write_str(msg_two_pt_two).unwrap();
659 buf.write_str(separator).unwrap();
660 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
661
662 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
663 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
664 }
665
666 #[test]
667 fn can_decode_subsequent_messages() {
668 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
670 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
671 let separator = "\r\n";
672 let data = format!("{msg_one}{separator}");
673
674 let mut codec = StreamAPIClientCodec;
675 let mut buf = bytes::BytesMut::from(data.as_bytes());
676 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
677 let msg_two_attempt = codec.decode(&mut buf).unwrap();
678 assert!(msg_two_attempt.is_none());
679 let data = format!("{msg_two}{separator}");
680 buf.write_str(data.as_str()).unwrap();
681 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
682
683 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
684 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
685 }
686
687 #[test]
688 fn can_encode_message() {
689 let msg = RequestMessage::Authentication(
690 betfair_stream_types::request::authentication_message::AuthenticationMessage {
691 id: Some(1),
692 session: "sss".to_owned(),
693 app_key: "aaaa".to_owned(),
694 },
695 );
696 let mut codec = StreamAPIClientCodec;
697 let mut buf = bytes::BytesMut::new();
698 codec.encode(msg, &mut buf).unwrap();
699
700 let data = buf.freeze();
701 let data = core::str::from_utf8(&data).unwrap();
702
703 assert!(data.ends_with("\r\n"));
705 assert!(data.starts_with("{\"op\":\"authentication\""));
707 }
708}