1extern crate alloc;
8pub mod cache;
9use backon::{BackoffBuilder as _, ExponentialBuilder};
10use betfair_adapter::{Authenticated, BetfairRpcClient, Unauthenticated};
11pub use betfair_stream_types as types;
12use betfair_stream_types::{
13 request::{RequestMessage, authentication_message, heartbeat_message::HeartbeatMessage},
14 response::{
15 ResponseMessage,
16 connection_message::ConnectionMessage,
17 status_message::{ErrorCode, StatusMessage},
18 },
19};
20use cache::{
21 primitives::{MarketBookCache, OrderBookCache},
22 tracker::StreamState,
23};
24use core::fmt;
25use core::{pin::pin, time::Duration};
26use eyre::Context as _;
27use futures::{
28 FutureExt, SinkExt as _, StreamExt as _,
29 future::{self, BoxFuture, select},
30};
31use std::sync::Arc;
32use tokio::{
33 net::TcpStream,
34 sync::mpsc::{self, Receiver, Sender},
35 task::JoinHandle,
36 time::sleep,
37};
38use tokio_stream::wrappers::{IntervalStream, ReceiverStream};
39use tokio_util::{
40 bytes,
41 codec::{Decoder, Encoder, Framed},
42};
43
44#[derive(Debug, Clone)]
52pub struct BetfairStreamBuilder<T: MessageProcessor> {
53 pub client: BetfairRpcClient<Unauthenticated>,
55 pub heartbeat_interval: Option<Duration>,
57 pub processor: T,
59}
60
61#[derive(Debug)]
65pub struct BetfairStreamClient<T: MessageProcessor> {
66 pub send_to_stream: Sender<RequestMessage>,
68 pub sink: Receiver<T::Output>,
70}
71
72#[derive(Debug, Clone)]
76pub struct Cache {
77 state: StreamState,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum CachedMessage {
86 Connection(ConnectionMessage),
90
91 MarketChange(Vec<MarketBookCache>),
93
94 OrderChange(Vec<OrderBookCache>),
97
98 Status(StatusMessage),
101}
102
103impl MessageProcessor for Cache {
104 type Output = CachedMessage;
105
106 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
107 match message {
108 ResponseMessage::Connection(connection_message) => {
109 Some(CachedMessage::Connection(connection_message))
110 }
111 ResponseMessage::MarketChange(market_change_message) => {
112 let data = self
113 .state
114 .market_change_update(market_change_message)
115 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
116 .map(CachedMessage::MarketChange);
117
118 data
119 }
120 ResponseMessage::OrderChange(order_change_message) => {
121 let data = self
122 .state
123 .order_change_update(order_change_message)
124 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
125 .map(CachedMessage::OrderChange);
126
127 data
128 }
129 ResponseMessage::Status(status_message) => Some(CachedMessage::Status(status_message)),
130 }
131 }
132}
133
134#[derive(Debug)]
136pub struct Forwarder;
137impl MessageProcessor for Forwarder {
138 type Output = ResponseMessage;
139
140 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
141 Some(message)
142 }
143}
144pub trait MessageProcessor: Send + Sync + 'static {
148 type Output: Send + Clone + Sync + 'static + core::fmt::Debug;
150
151 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output>;
155}
156
157impl<T: MessageProcessor> BetfairStreamBuilder<T> {
158 pub fn new(client: BetfairRpcClient<Unauthenticated>) -> BetfairStreamBuilder<Cache> {
171 BetfairStreamBuilder {
172 client,
173 heartbeat_interval: None,
174 processor: Cache {
175 state: StreamState::new(),
176 },
177 }
178 }
179
180 pub fn new_without_cache(
193 client: BetfairRpcClient<Unauthenticated>,
194 ) -> BetfairStreamBuilder<Forwarder> {
195 BetfairStreamBuilder {
196 client,
197 heartbeat_interval: None,
198 processor: Forwarder,
199 }
200 }
201
202 pub fn with_heartbeat(mut self, interval: Duration) -> Self {
212 self.heartbeat_interval = Some(interval);
213 self
214 }
215
216 pub fn start_with<const C: usize, Sp, H>(self, spawner: Sp) -> (BetfairStreamClient<T>, H)
238 where
239 Sp: FnOnce(BoxFuture<'static, eyre::Result<()>>) -> H,
240 {
241 let (to_stream_tx, to_stream_rx) = mpsc::channel(C);
242 let (from_stream_tx, from_stream_rx) = mpsc::channel(C);
243
244 let fut = self.run(from_stream_tx, to_stream_rx).boxed();
246 let handle = spawner(fut);
247
248 (
249 BetfairStreamClient {
250 send_to_stream: to_stream_tx,
251 sink: from_stream_rx,
252 },
253 handle,
254 )
255 }
256
257 pub fn start<const C: usize>(self) -> (BetfairStreamClient<T>, JoinHandle<eyre::Result<()>>) {
270 self.start_with::<C, _, _>(|fut| tokio::spawn(fut))
271 }
272
273 async fn run(
274 self,
275 from_stream_tx: Sender<T::Output>,
276 to_stream_rx: Receiver<RequestMessage>,
277 ) -> eyre::Result<()> {
278 if let Some(hb) = self.heartbeat_interval {
279 let heartbeat_stream = {
280 let mut interval = tokio::time::interval(hb);
281 interval.reset();
282 let interval_stream = IntervalStream::new(interval).fuse();
283 interval_stream
284 .map(move |instant| HeartbeatMessage {
285 id: Some(
286 instant
287 .into_std()
288 .elapsed()
289 .as_secs()
290 .try_into()
291 .unwrap_or_default(),
292 ),
293 })
294 .map(RequestMessage::Heartbeat)
295 .boxed()
296 };
297 let input_stream = futures::stream::select_all([
298 heartbeat_stream,
299 ReceiverStream::new(to_stream_rx).boxed(),
300 ]);
301
302 self.run_base(from_stream_tx, input_stream).await
303 } else {
304 self.run_base(from_stream_tx, ReceiverStream::new(to_stream_rx))
305 .await
306 }
307 }
308
309 async fn run_base(
310 mut self,
311 mut from_stream_tx: Sender<T::Output>,
312 mut to_stream_rx: impl futures::Stream<Item = RequestMessage> + Unpin,
313 ) -> eyre::Result<()> {
314 let (mut client, _) = self.client.clone().authenticate().await?;
315 let mut backoff = ExponentialBuilder::new().build();
316 let mut first_call = true;
317 'retry: loop {
318 if !first_call {
319 let Some(delay) = backoff.next() else {
321 eyre::bail!("connection retry attempts exceeded")
322 };
323 sleep(delay).await;
324 }
325 first_call = true;
326
327 let mut stream = self
329 .connect_with_retry(&mut from_stream_tx, &mut client)
330 .await?;
331 tracing::info!("Connected to {}", self.client.stream.url());
332
333 loop {
334 let stream_next = pin!(stream.next());
335 let to_stream_rx_next = pin!(to_stream_rx.next());
336 match select(to_stream_rx_next, stream_next).await {
337 future::Either::Left((request, _)) => {
338 let Some(request) = request else {
339 tracing::warn!("request returned None");
340 continue 'retry;
341 };
342
343 tracing::debug!(?request, "sending to betfair");
344 let Ok(()) = stream.send(request).await else {
345 tracing::warn!("could not send request to stream");
346 continue 'retry;
347 };
348 }
349 future::Either::Right((message, _)) => {
350 let Some(message) = message else {
351 tracing::warn!("stream returned None");
352 continue 'retry;
353 };
354
355 match message {
356 Ok(message) => {
357 let message = self.processor.process_message(message);
358 tracing::debug!(?message, "received from betfair");
359 let Some(message) = message else {
360 continue;
361 };
362
363 let Ok(()) = from_stream_tx.send(message).await else {
364 tracing::warn!("could not send stream message to sink");
365 continue 'retry;
366 };
367 }
368 Err(err) => tracing::warn!(?err, "reading message error"),
369 }
370 }
371 }
372 }
373 }
374 }
375
376 #[tracing::instrument(skip_all, err)]
378 async fn connect_with_retry(
379 &mut self,
380 from_stream_tx: &mut Sender<T::Output>,
381 client: &mut Arc<BetfairRpcClient<Authenticated>>,
382 ) -> eyre::Result<Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>>
383 {
384 let mut backoff = ExponentialBuilder::new().build();
385 let mut delay = async || {
386 if let Some(delay) = backoff.next() {
387 sleep(delay).await;
388 Ok(())
389 } else {
390 eyre::bail!("exceeded retry attempts, could not connect");
391 }
392 };
393
394 loop {
395 let server_addr = self.client.stream.url();
396 let host = server_addr
397 .host_str()
398 .ok_or_else(|| eyre::eyre!("invalid betfair url"))?;
399 let port = server_addr.port().unwrap_or(443);
400
401 let domain_str = server_addr
402 .domain()
403 .ok_or_else(|| eyre::eyre!("domain must be known"))?;
404 let domain = rustls::pki_types::ServerName::try_from(domain_str.to_owned())
405 .wrap_err("failed to parse server name")?;
406
407 let Some(socket_addr) = tokio::net::lookup_host((host, port)).await?.next() else {
409 eyre::bail!("no valid socket addresses for {host}:{port}")
410 };
411
412 let tcp_stream = TcpStream::connect(socket_addr).await;
413 let Ok(stream) = tcp_stream else {
414 tracing::error!(err = ?tcp_stream.unwrap_err(), "Connect error. Retrying...");
415 delay().await?;
416 continue;
417 };
418 let tls_stream = tls_connector()?.connect(domain.clone(), stream).await?;
419 let mut tls_stream = Framed::new(tls_stream, StreamAPIClientCodec);
420
421 match self
422 .handshake(from_stream_tx, client, &mut tls_stream)
423 .await
424 {
425 Ok(()) => return Ok(tls_stream),
426 Err(err) => match err {
427 HandshakeErr::WaitAndRetry => {
428 delay().await?;
429 continue;
430 }
431 HandshakeErr::Reauthenticate => {
432 let (new_client, _) = self.client.clone().authenticate().await?;
433 *client = new_client;
434 delay().await?;
435 continue;
436 }
437 HandshakeErr::Fatal => eyre::bail!("fatal error in stream processing"),
438 },
439 }
440 }
441 }
442
443 #[tracing::instrument(err, skip_all)]
444 async fn handshake(
445 &mut self,
446 from_stream_tx: &mut Sender<T::Output>,
447 client: &BetfairRpcClient<Authenticated>,
448 stream: &mut Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>,
449 ) -> Result<(), HandshakeErr> {
450 let res = stream
452 .next()
453 .await
454 .transpose()
455 .inspect_err(|err| {
456 tracing::warn!(?err, "error when parsing stream message");
457 })
458 .map_err(|_| HandshakeErr::WaitAndRetry)?
459 .ok_or(HandshakeErr::WaitAndRetry)?;
460 tracing::info!(?res, "message from stream");
461 let message = self
462 .processor
463 .process_message(res.clone())
464 .ok_or(HandshakeErr::Fatal)
465 .inspect_err(|err| tracing::error!(?err))?;
466 from_stream_tx
467 .send(message.clone())
468 .await
469 .inspect_err(|err| tracing::warn!(?err))
470 .map_err(|_| HandshakeErr::Fatal)?;
471 let ResponseMessage::Connection(_) = &res else {
472 tracing::warn!("stream responded with invalid connection message");
473 return Err(HandshakeErr::Reauthenticate);
474 };
475
476 let msg = authentication_message::AuthenticationMessage {
478 id: Some(-1),
479 session: client.session_token().0.expose_secret().clone(),
480 app_key: self
481 .client
482 .secret_provider
483 .application_key
484 .0
485 .expose_secret()
486 .clone(),
487 };
488 stream
489 .send(RequestMessage::Authentication(msg))
490 .await
491 .inspect_err(|err| tracing::warn!(?err, "stream exited"))
492 .map_err(|_| HandshakeErr::WaitAndRetry)?;
493
494 let message = stream
496 .next()
497 .await
498 .transpose()
499 .inspect_err(|err| {
500 tracing::warn!(?err, "error when parsing stream message");
501 })
502 .map_err(|_| HandshakeErr::WaitAndRetry)?
503 .ok_or(HandshakeErr::WaitAndRetry)?;
504 let processed_message = self
505 .processor
506 .process_message(message.clone())
507 .ok_or(HandshakeErr::Fatal)
508 .inspect_err(|err| tracing::warn!(?err))
509 .map_err(|_| HandshakeErr::Fatal)?;
510 from_stream_tx
511 .send(processed_message)
512 .await
513 .inspect_err(|err| tracing::warn!(?err))
514 .map_err(|_| HandshakeErr::Fatal)?;
515 tracing::info!(?message, "message from stream");
516 let ResponseMessage::Status(status_message) = &message else {
517 tracing::warn!("expected status message, got {message:?}");
518 return Err(HandshakeErr::WaitAndRetry);
519 };
520
521 let StatusMessage::Failure(err) = &status_message else {
522 return Ok(());
523 };
524
525 tracing::error!(?err, "stream respondend wit an error");
526 let action = match err.error_code {
527 ErrorCode::NoAppKey => HandshakeErr::Fatal,
528 ErrorCode::InvalidAppKey => HandshakeErr::Fatal,
529 ErrorCode::NoSession => HandshakeErr::Reauthenticate,
530 ErrorCode::InvalidSessionInformation => HandshakeErr::Reauthenticate,
531 ErrorCode::NotAuthorized => HandshakeErr::Reauthenticate,
532 ErrorCode::InvalidInput => HandshakeErr::Fatal,
533 ErrorCode::InvalidClock => HandshakeErr::Fatal,
534 ErrorCode::UnexpectedError => HandshakeErr::Fatal,
535 ErrorCode::Timeout => HandshakeErr::WaitAndRetry,
536 ErrorCode::SubscriptionLimitExceeded => HandshakeErr::WaitAndRetry,
537 ErrorCode::InvalidRequest => HandshakeErr::Fatal,
538 ErrorCode::ConnectionFailed => HandshakeErr::WaitAndRetry,
539 ErrorCode::MaxConnectionLimitExceeded => HandshakeErr::Fatal,
540 ErrorCode::TooManyRequests => HandshakeErr::WaitAndRetry,
541 };
542
543 Err(action)
544 }
545}
546
547#[derive(Debug)]
548enum HandshakeErr {
549 WaitAndRetry,
550 Reauthenticate,
551 Fatal,
552}
553
554impl fmt::Display for HandshakeErr {
555 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
556 write!(f, "Stream Handshake Error {:?}", self)
557 }
558}
559
560impl core::error::Error for HandshakeErr {}
561
562#[tracing::instrument(err)]
563fn tls_connector() -> eyre::Result<tokio_rustls::TlsConnector> {
564 use tokio_rustls::TlsConnector;
565
566 let mut roots = rustls::RootCertStore::empty();
567 let native_certs = rustls_native_certs::load_native_certs();
568 for cert in native_certs.certs {
569 roots.add(cert)?;
570 }
571
572 let config = rustls::ClientConfig::builder()
573 .with_root_certificates(roots)
574 .with_no_client_auth();
575 Ok(TlsConnector::from(Arc::new(config)))
576}
577
578pub struct StreamAPIClientCodec;
580
581impl Decoder for StreamAPIClientCodec {
582 type Item = ResponseMessage;
583 type Error = eyre::Report;
584
585 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
586 if let Some(pos) = src.iter().position(|&byte| byte == b'\n') {
588 let delimiter_size = if pos > 0 && src[pos - 1] == b'\r' {
590 2
591 } else {
592 1
593 };
594
595 let line = src.split_to(pos + 1);
597
598 let (json_part, _) = line.split_at(line.len().saturating_sub(delimiter_size));
600
601 let data = serde_json::from_slice::<Self::Item>(json_part)?;
603 return Ok(Some(data));
604 }
605 Ok(None)
606 }
607}
608
609impl Encoder<RequestMessage> for StreamAPIClientCodec {
610 type Error = eyre::Report;
611
612 fn encode(
613 &mut self,
614 item: RequestMessage,
615 dst: &mut bytes::BytesMut,
616 ) -> Result<(), Self::Error> {
617 let json = serde_json::to_string(&item)?;
619 dst.extend_from_slice(json.as_bytes());
621 dst.extend_from_slice(b"\r\n");
622 Ok(())
623 }
624}
625
626#[cfg(test)]
627mod tests {
628
629 use core::fmt::Write as _;
630
631 use super::*;
632
633 #[tokio::test]
634 async fn can_resolve_host_ipv4() {
635 let url = url::Url::parse("tcptls://stream-api.betfair.com:443").unwrap();
636 let host = url.host_str().unwrap();
637 let port = url
638 .port()
639 .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
640 let socket_addr = tokio::net::lookup_host((host, port))
641 .await
642 .unwrap()
643 .next()
644 .unwrap();
645 assert!(socket_addr.ip().is_ipv4());
646 assert_eq!(socket_addr.port(), 443);
647 }
648
649 #[test]
650 fn can_decode_single_message() {
651 let msg = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
652 let separator = "\r\n";
653 let data = format!("{msg}{separator}");
654
655 let mut codec = StreamAPIClientCodec;
656 let mut buf = bytes::BytesMut::from(data.as_bytes());
657 let msg = codec.decode(&mut buf).unwrap().unwrap();
658
659 assert!(matches!(msg, ResponseMessage::Connection(_)));
660 }
661
662 #[test]
663 fn can_decode_multiple_messages() {
664 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
666 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
667 let separator = "\r\n";
668 let data = format!("{msg_one}{separator}{msg_two}{separator}");
669
670 let mut codec = StreamAPIClientCodec;
671 let mut buf = bytes::BytesMut::from(data.as_bytes());
672 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
673 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
674
675 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
676 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
677 }
678
679 #[test]
680 fn can_decode_multiple_partial_messages() {
681 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
683 let msg_two_pt_one = r#"{"op":"ocm","id":3,"clk""#;
684 let msg_two_pt_two = r#":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
685 let separator = "\r\n";
686 let data = format!("{msg_one}{separator}{msg_two_pt_one}");
687
688 let mut codec = StreamAPIClientCodec;
689 let mut buf = bytes::BytesMut::from(data.as_bytes());
690 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
691 let msg_two_attempt = codec.decode(&mut buf).unwrap();
692 assert!(msg_two_attempt.is_none());
693 buf.write_str(msg_two_pt_two).unwrap();
694 buf.write_str(separator).unwrap();
695 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
696
697 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
698 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
699 }
700
701 #[test]
702 fn can_decode_subsequent_messages() {
703 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
705 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
706 let separator = "\r\n";
707 let data = format!("{msg_one}{separator}");
708
709 let mut codec = StreamAPIClientCodec;
710 let mut buf = bytes::BytesMut::from(data.as_bytes());
711 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
712 let msg_two_attempt = codec.decode(&mut buf).unwrap();
713 assert!(msg_two_attempt.is_none());
714 let data = format!("{msg_two}{separator}");
715 buf.write_str(data.as_str()).unwrap();
716 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
717
718 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
719 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
720 }
721
722 #[test]
723 fn can_encode_message() {
724 let msg = RequestMessage::Authentication(
725 betfair_stream_types::request::authentication_message::AuthenticationMessage {
726 id: Some(1),
727 session: "sss".to_owned(),
728 app_key: "aaaa".to_owned(),
729 },
730 );
731 let mut codec = StreamAPIClientCodec;
732 let mut buf = bytes::BytesMut::new();
733 codec.encode(msg, &mut buf).unwrap();
734
735 let data = buf.freeze();
736 let data = core::str::from_utf8(&data).unwrap();
737
738 assert!(data.ends_with("\r\n"));
740 assert!(data.starts_with("{\"op\":\"authentication\""));
742 }
743}