1extern crate alloc;
8pub mod cache;
9use backon::{BackoffBuilder as _, ExponentialBuilder};
10use betfair_adapter::{Authenticated, BetfairRpcClient, Unauthenticated};
11pub use betfair_stream_types as types;
12use betfair_stream_types::{
13 request::{RequestMessage, authentication_message, heartbeat_message::HeartbeatMessage},
14 response::{
15 ResponseMessage,
16 connection_message::ConnectionMessage,
17 status_message::{ErrorCode, StatusMessage},
18 },
19};
20pub use bytes::Bytes;
21use cache::{
22 primitives::{MarketBookCache, OrderBookCache},
23 tracker::StreamState,
24};
25use core::fmt;
26use core::{pin::pin, time::Duration};
27use eyre::Context as _;
28use futures::{
29 FutureExt, SinkExt as _, StreamExt as _,
30 future::{self, BoxFuture, select},
31};
32use std::sync::Arc;
33use tokio::{
34 net::TcpStream,
35 sync::mpsc::{self, Receiver, Sender},
36 task::JoinHandle,
37 time::sleep,
38};
39use tokio_stream::wrappers::{IntervalStream, ReceiverStream};
40use tokio_util::codec::{Decoder, Encoder, Framed};
41
42#[derive(Debug, Clone)]
50pub struct BetfairStreamBuilder<T: MessageProcessor> {
51 pub client: BetfairRpcClient<Unauthenticated>,
53 pub heartbeat_interval: Option<Duration>,
55 pub processor: T,
57}
58
59#[derive(Debug)]
63pub struct BetfairStreamClient<T: MessageProcessor> {
64 pub send_to_stream: Sender<RequestMessage>,
66 pub sink: Receiver<T::Output>,
68}
69
70#[derive(Debug, Clone)]
74pub struct Cache {
75 state: StreamState,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
83pub enum CachedMessage {
84 Connection(ConnectionMessage),
88
89 MarketChange(Vec<MarketBookCache>),
91
92 OrderChange(Vec<OrderBookCache>),
95
96 Status(StatusMessage),
99}
100
101impl MessageProcessor for Cache {
102 type Output = CachedMessage;
103
104 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
105 match message {
106 ResponseMessage::Connection(connection_message) => {
107 Some(CachedMessage::Connection(connection_message))
108 }
109 ResponseMessage::MarketChange(market_change_message) => self
110 .state
111 .market_change_update(market_change_message)
112 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
113 .map(CachedMessage::MarketChange),
114 ResponseMessage::OrderChange(order_change_message) => self
115 .state
116 .order_change_update(order_change_message)
117 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
118 .map(CachedMessage::OrderChange),
119 ResponseMessage::Status(status_message) => Some(CachedMessage::Status(status_message)),
120 }
121 }
122}
123
124#[derive(Debug)]
126pub struct Forwarder;
127impl MessageProcessor for Forwarder {
128 type Output = ResponseMessage;
129
130 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
131 Some(message)
132 }
133}
134pub trait MessageProcessor: Send + Sync + 'static {
138 type Output: Send + Clone + Sync + 'static + core::fmt::Debug;
140
141 fn on_message_received(&mut self, raw: Bytes, message: &ResponseMessage) {
152 let _ = raw;
153 let _ = message;
154 }
155
156 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output>;
160}
161
162impl<T: MessageProcessor> BetfairStreamBuilder<T> {
163 pub fn new(client: BetfairRpcClient<Unauthenticated>) -> BetfairStreamBuilder<Cache> {
176 BetfairStreamBuilder {
177 client,
178 heartbeat_interval: None,
179 processor: Cache {
180 state: StreamState::new(),
181 },
182 }
183 }
184
185 pub fn new_without_cache(
198 client: BetfairRpcClient<Unauthenticated>,
199 ) -> BetfairStreamBuilder<Forwarder> {
200 BetfairStreamBuilder {
201 client,
202 heartbeat_interval: None,
203 processor: Forwarder,
204 }
205 }
206
207 pub fn with_heartbeat(mut self, interval: Duration) -> Self {
217 self.heartbeat_interval = Some(interval);
218 self
219 }
220
221 pub fn start_with<const C: usize, Sp, H>(self, spawner: Sp) -> (BetfairStreamClient<T>, H)
243 where
244 Sp: FnOnce(BoxFuture<'static, eyre::Result<()>>) -> H,
245 {
246 let (to_stream_tx, to_stream_rx) = mpsc::channel(C);
247 let (from_stream_tx, from_stream_rx) = mpsc::channel(C);
248
249 let fut = self.run(from_stream_tx, to_stream_rx).boxed();
251 let handle = spawner(fut);
252
253 (
254 BetfairStreamClient {
255 send_to_stream: to_stream_tx,
256 sink: from_stream_rx,
257 },
258 handle,
259 )
260 }
261
262 pub fn start<const C: usize>(self) -> (BetfairStreamClient<T>, JoinHandle<eyre::Result<()>>) {
275 self.start_with::<C, _, _>(|fut| tokio::spawn(fut))
276 }
277
278 async fn run(
279 self,
280 from_stream_tx: Sender<T::Output>,
281 to_stream_rx: Receiver<RequestMessage>,
282 ) -> eyre::Result<()> {
283 if let Some(hb) = self.heartbeat_interval {
284 let heartbeat_stream = {
285 let mut interval = tokio::time::interval(hb);
286 interval.reset();
287 let interval_stream = IntervalStream::new(interval).fuse();
288 interval_stream
289 .map(move |instant| HeartbeatMessage {
290 id: Some(
291 instant
292 .into_std()
293 .elapsed()
294 .as_secs()
295 .try_into()
296 .unwrap_or_default(),
297 ),
298 })
299 .map(RequestMessage::Heartbeat)
300 .boxed()
301 };
302 let input_stream = futures::stream::select_all([
303 heartbeat_stream,
304 ReceiverStream::new(to_stream_rx).boxed(),
305 ]);
306
307 self.run_base(from_stream_tx, input_stream).await
308 } else {
309 self.run_base(from_stream_tx, ReceiverStream::new(to_stream_rx))
310 .await
311 }
312 }
313
314 async fn run_base(
315 mut self,
316 mut from_stream_tx: Sender<T::Output>,
317 mut to_stream_rx: impl futures::Stream<Item = RequestMessage> + Unpin,
318 ) -> eyre::Result<()> {
319 let (mut client, _) = self.client.clone().authenticate().await?;
320 let mut backoff = ExponentialBuilder::new().build();
321 let mut first_call = true;
322 'retry: loop {
323 if !first_call {
324 let Some(delay) = backoff.next() else {
326 eyre::bail!("connection retry attempts exceeded")
327 };
328 sleep(delay).await;
329 }
330 first_call = true;
331
332 let mut stream = self
334 .connect_with_retry(&mut from_stream_tx, &mut client)
335 .await?;
336 tracing::info!("Connected to {}", self.client.stream.url());
337
338 loop {
339 let stream_next = pin!(stream.next());
340 let to_stream_rx_next = pin!(to_stream_rx.next());
341 match select(to_stream_rx_next, stream_next).await {
342 future::Either::Left((request, _)) => {
343 let Some(request) = request else {
344 tracing::info!("request channel closed, shutting down stream task");
345 return Ok(());
346 };
347
348 tracing::debug!(?request, "sending to betfair");
349 let Ok(()) = stream.send(request).await else {
350 tracing::warn!("could not send request to stream");
351 continue 'retry;
352 };
353 }
354 future::Either::Right((message, _)) => {
355 let Some(message) = message else {
356 tracing::warn!("stream returned None");
357 continue 'retry;
358 };
359
360 match message {
361 Ok((raw, message)) => {
362 self.processor.on_message_received(raw, &message);
363 let message = self.processor.process_message(message);
364 tracing::debug!(?message, "received from betfair");
365 let Some(message) = message else {
366 continue;
367 };
368
369 if let Err(err) = from_stream_tx.send(message).await {
370 tracing::info!(
371 "output channel receiver dropped, shutting down stream task: {:?}",
372 err
373 );
374 return Ok(());
375 };
376 }
377 Err(err) => tracing::warn!(?err, "reading message error"),
378 }
379 }
380 }
381 }
382 }
383 }
384
385 #[tracing::instrument(skip_all, err)]
387 async fn connect_with_retry(
388 &mut self,
389 from_stream_tx: &mut Sender<T::Output>,
390 client: &mut Arc<BetfairRpcClient<Authenticated>>,
391 ) -> eyre::Result<Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>>
392 {
393 let mut backoff = ExponentialBuilder::new().build();
394 let mut delay = async || {
395 if let Some(delay) = backoff.next() {
396 sleep(delay).await;
397 Ok(())
398 } else {
399 eyre::bail!("exceeded retry attempts, could not connect");
400 }
401 };
402
403 loop {
404 let server_addr = self.client.stream.url();
405 let host = server_addr
406 .host_str()
407 .ok_or_else(|| eyre::eyre!("invalid betfair url"))?;
408 let port = server_addr.port().unwrap_or(443);
409
410 let domain_str = server_addr
411 .domain()
412 .ok_or_else(|| eyre::eyre!("domain must be known"))?;
413 let domain = rustls::pki_types::ServerName::try_from(domain_str.to_owned())
414 .wrap_err("failed to parse server name")?;
415
416 let Some(socket_addr) = tokio::net::lookup_host((host, port)).await?.next() else {
418 eyre::bail!("no valid socket addresses for {host}:{port}")
419 };
420
421 let tcp_stream = TcpStream::connect(socket_addr).await;
422 let Ok(stream) = tcp_stream else {
423 tracing::error!(err = ?tcp_stream.unwrap_err(), "Connect error. Retrying...");
424 delay().await?;
425 continue;
426 };
427 let tls_stream = tls_connector()?.connect(domain.clone(), stream).await?;
428 let mut tls_stream = Framed::new(tls_stream, StreamAPIClientCodec);
429
430 match self
431 .handshake(from_stream_tx, client, &mut tls_stream)
432 .await
433 {
434 Ok(()) => return Ok(tls_stream),
435 Err(err) => match err {
436 HandshakeErr::WaitAndRetry => {
437 delay().await?;
438 continue;
439 }
440 HandshakeErr::Reauthenticate => {
441 let (new_client, _) = self.client.clone().authenticate().await?;
442 *client = new_client;
443 delay().await?;
444 continue;
445 }
446 HandshakeErr::Fatal => eyre::bail!("fatal error in stream processing"),
447 },
448 }
449 }
450 }
451
452 #[tracing::instrument(err, skip_all)]
453 async fn handshake(
454 &mut self,
455 from_stream_tx: &mut Sender<T::Output>,
456 client: &BetfairRpcClient<Authenticated>,
457 stream: &mut Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>,
458 ) -> Result<(), HandshakeErr> {
459 let (raw, res) = stream
461 .next()
462 .await
463 .transpose()
464 .inspect_err(|err| {
465 tracing::warn!(?err, "error when parsing stream message");
466 })
467 .map_err(|_| HandshakeErr::WaitAndRetry)?
468 .ok_or(HandshakeErr::WaitAndRetry)?;
469 tracing::info!(?res, "message from stream");
470 self.processor.on_message_received(raw, &res);
471 let message = self
472 .processor
473 .process_message(res.clone())
474 .ok_or(HandshakeErr::Fatal)
475 .inspect_err(|_err| {
476 tracing::error!(
477 "processor.process_message returned None for connection message: {:?}",
478 res
479 )
480 })?;
481 from_stream_tx
482 .send(message.clone())
483 .await
484 .inspect_err(|err| {
485 tracing::warn!("failed to send connection message to channel: {:?}", err)
486 })
487 .map_err(|_| HandshakeErr::Fatal)?;
488 let ResponseMessage::Connection(_) = &res else {
489 tracing::warn!("stream responded with invalid connection message");
490 return Err(HandshakeErr::Reauthenticate);
491 };
492
493 let msg = authentication_message::AuthenticationMessage {
495 id: Some(-1),
496 session: client.session_token().0.expose_secret().clone(),
497 app_key: self
498 .client
499 .secret_provider
500 .application_key
501 .0
502 .expose_secret()
503 .clone(),
504 };
505 stream
506 .send(RequestMessage::Authentication(msg))
507 .await
508 .inspect_err(|err| tracing::warn!(?err, "stream exited"))
509 .map_err(|_| HandshakeErr::WaitAndRetry)?;
510
511 let (raw, message) = stream
513 .next()
514 .await
515 .transpose()
516 .inspect_err(|err| {
517 tracing::warn!(?err, "error when parsing stream message");
518 })
519 .map_err(|_| HandshakeErr::WaitAndRetry)?
520 .ok_or(HandshakeErr::WaitAndRetry)?;
521 self.processor.on_message_received(raw, &message);
522 let processed_message = self
523 .processor
524 .process_message(message.clone())
525 .ok_or(HandshakeErr::Fatal)
526 .inspect_err(|_err| {
527 tracing::warn!(
528 "processor.process_message returned None for status message: {:?}",
529 message
530 )
531 })
532 .map_err(|_| HandshakeErr::Fatal)?;
533 from_stream_tx
534 .send(processed_message)
535 .await
536 .inspect_err(|err| {
537 tracing::warn!("failed to send status message to channel: {:?}", err)
538 })
539 .map_err(|_| HandshakeErr::Fatal)?;
540 tracing::info!(?message, "message from stream");
541 let ResponseMessage::Status(status_message) = &message else {
542 tracing::warn!("expected status message, got {message:?}");
543 return Err(HandshakeErr::WaitAndRetry);
544 };
545
546 let StatusMessage::Failure(err) = &status_message else {
547 return Ok(());
548 };
549
550 tracing::error!(?err, "stream respondend with an error");
551 let action = match err.error_code {
552 ErrorCode::NoAppKey => HandshakeErr::Fatal,
553 ErrorCode::InvalidAppKey => HandshakeErr::Fatal,
554 ErrorCode::NoSession => HandshakeErr::Reauthenticate,
555 ErrorCode::InvalidSessionInformation => HandshakeErr::Reauthenticate,
556 ErrorCode::NotAuthorized => HandshakeErr::Reauthenticate,
557 ErrorCode::InvalidInput => HandshakeErr::Fatal,
558 ErrorCode::InvalidClock => HandshakeErr::Fatal,
559 ErrorCode::UnexpectedError => HandshakeErr::Fatal,
560 ErrorCode::Timeout => HandshakeErr::WaitAndRetry,
561 ErrorCode::SubscriptionLimitExceeded => HandshakeErr::WaitAndRetry,
562 ErrorCode::InvalidRequest => HandshakeErr::Fatal,
563 ErrorCode::ConnectionFailed => HandshakeErr::WaitAndRetry,
564 ErrorCode::MaxConnectionLimitExceeded => HandshakeErr::Fatal,
565 ErrorCode::TooManyRequests => HandshakeErr::WaitAndRetry,
566 };
567
568 Err(action)
569 }
570}
571
572#[derive(Debug)]
573enum HandshakeErr {
574 WaitAndRetry,
575 Reauthenticate,
576 Fatal,
577}
578
579impl fmt::Display for HandshakeErr {
580 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
581 write!(f, "Stream Handshake Error {:?}", self)
582 }
583}
584
585impl core::error::Error for HandshakeErr {}
586
587#[tracing::instrument(err)]
588fn tls_connector() -> eyre::Result<tokio_rustls::TlsConnector> {
589 use tokio_rustls::TlsConnector;
590
591 let mut roots = rustls::RootCertStore::empty();
592 let native_certs = rustls_native_certs::load_native_certs();
593 for cert in native_certs.certs {
594 roots.add(cert)?;
595 }
596
597 let config = rustls::ClientConfig::builder()
598 .with_root_certificates(roots)
599 .with_no_client_auth();
600 Ok(TlsConnector::from(Arc::new(config)))
601}
602
603pub struct StreamAPIClientCodec;
605
606impl Decoder for StreamAPIClientCodec {
607 type Item = (bytes::Bytes, ResponseMessage);
608 type Error = eyre::Report;
609
610 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
611 if let Some(pos) = src.iter().position(|&byte| byte == b'\n') {
613 let delimiter_size = if pos > 0 && src[pos - 1] == b'\r' {
615 2
616 } else {
617 1
618 };
619
620 let mut line = src.split_to(pos + 1);
622
623 line.truncate(line.len().saturating_sub(delimiter_size));
625
626 let raw = line.freeze();
628
629 let data = serde_json::from_slice::<ResponseMessage>(&raw)?;
631 return Ok(Some((raw, data)));
632 }
633 Ok(None)
634 }
635}
636
637impl Encoder<RequestMessage> for StreamAPIClientCodec {
638 type Error = eyre::Report;
639
640 fn encode(
641 &mut self,
642 item: RequestMessage,
643 dst: &mut bytes::BytesMut,
644 ) -> Result<(), Self::Error> {
645 let json = serde_json::to_string(&item)?;
647 dst.extend_from_slice(json.as_bytes());
649 dst.extend_from_slice(b"\r\n");
650 Ok(())
651 }
652}
653
654#[cfg(test)]
655mod tests {
656
657 use core::fmt::Write as _;
658
659 use super::*;
660
661 #[tokio::test]
662 async fn can_resolve_host_ipv4() {
663 let url = url::Url::parse("tcptls://stream-api.betfair.com:443").unwrap();
664 let host = url.host_str().unwrap();
665 let port = url
666 .port()
667 .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
668 let socket_addr = tokio::net::lookup_host((host, port))
669 .await
670 .unwrap()
671 .next()
672 .unwrap();
673 assert!(socket_addr.ip().is_ipv4());
674 assert_eq!(socket_addr.port(), 443);
675 }
676
677 #[test]
678 fn can_decode_single_message() {
679 let json = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
680 let separator = "\r\n";
681 let data = format!("{json}{separator}");
682
683 let mut codec = StreamAPIClientCodec;
684 let mut buf = bytes::BytesMut::from(data.as_bytes());
685 let (raw, msg) = codec.decode(&mut buf).unwrap().unwrap();
686
687 assert!(matches!(msg, ResponseMessage::Connection(_)));
688 assert_eq!(&raw[..], json.as_bytes());
689 }
690
691 #[test]
692 fn can_decode_multiple_messages() {
693 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
695 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
696 let separator = "\r\n";
697 let data = format!("{msg_one}{separator}{msg_two}{separator}");
698
699 let mut codec = StreamAPIClientCodec;
700 let mut buf = bytes::BytesMut::from(data.as_bytes());
701 let (raw_one, msg_one) = codec.decode(&mut buf).unwrap().unwrap();
702 let (raw_two, msg_two) = codec.decode(&mut buf).unwrap().unwrap();
703
704 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
705 assert_eq!(
706 &raw_one[..],
707 r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#.as_bytes()
708 );
709 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
710 assert_eq!(&raw_two[..], r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#.as_bytes());
711 }
712
713 #[test]
714 fn can_decode_multiple_partial_messages() {
715 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
717 let msg_two_pt_one = r#"{"op":"ocm","id":3,"clk""#;
718 let msg_two_pt_two = r#":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
719 let separator = "\r\n";
720 let data = format!("{msg_one}{separator}{msg_two_pt_one}");
721
722 let mut codec = StreamAPIClientCodec;
723 let mut buf = bytes::BytesMut::from(data.as_bytes());
724 let (_raw_one, msg_one) = codec.decode(&mut buf).unwrap().unwrap();
725 let msg_two_attempt = codec.decode(&mut buf).unwrap();
726 assert!(msg_two_attempt.is_none());
727 buf.write_str(msg_two_pt_two).unwrap();
728 buf.write_str(separator).unwrap();
729 let (_raw_two, msg_two) = codec.decode(&mut buf).unwrap().unwrap();
730
731 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
732 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
733 }
734
735 #[test]
736 fn can_decode_subsequent_messages() {
737 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
739 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
740 let separator = "\r\n";
741 let data = format!("{msg_one}{separator}");
742
743 let mut codec = StreamAPIClientCodec;
744 let mut buf = bytes::BytesMut::from(data.as_bytes());
745 let (_raw_one, msg_one) = codec.decode(&mut buf).unwrap().unwrap();
746 let msg_two_attempt = codec.decode(&mut buf).unwrap();
747 assert!(msg_two_attempt.is_none());
748 let data = format!("{msg_two}{separator}");
749 buf.write_str(data.as_str()).unwrap();
750 let (_raw_two, msg_two) = codec.decode(&mut buf).unwrap().unwrap();
751
752 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
753 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
754 }
755
756 #[test]
757 fn can_encode_message() {
758 let msg = RequestMessage::Authentication(
759 betfair_stream_types::request::authentication_message::AuthenticationMessage {
760 id: Some(1),
761 session: "sss".to_owned(),
762 app_key: "aaaa".to_owned(),
763 },
764 );
765 let mut codec = StreamAPIClientCodec;
766 let mut buf = bytes::BytesMut::new();
767 codec.encode(msg, &mut buf).unwrap();
768
769 let data = buf.freeze();
770 let data = core::str::from_utf8(&data).unwrap();
771
772 assert!(data.ends_with("\r\n"));
774 assert!(data.starts_with("{\"op\":\"authentication\""));
776 }
777}