1extern crate alloc;
8pub mod cache;
9use backon::{BackoffBuilder as _, ExponentialBuilder};
10use betfair_adapter::{Authenticated, BetfairRpcClient, Unauthenticated};
11pub use betfair_stream_types as types;
12use betfair_stream_types::{
13 request::{RequestMessage, authentication_message, heartbeat_message::HeartbeatMessage},
14 response::{
15 ResponseMessage,
16 connection_message::ConnectionMessage,
17 status_message::{ErrorCode, StatusMessage},
18 },
19};
20use cache::{
21 primitives::{MarketBookCache, OrderBookCache},
22 tracker::StreamState,
23};
24use core::fmt;
25use core::{pin::pin, time::Duration};
26use eyre::Context as _;
27use futures::{
28 FutureExt, SinkExt as _, StreamExt as _,
29 future::{self, BoxFuture, select},
30};
31use std::sync::Arc;
32use tokio::{
33 net::TcpStream,
34 sync::mpsc::{self, Receiver, Sender},
35 task::JoinHandle,
36 time::sleep,
37};
38use tokio_stream::wrappers::{IntervalStream, ReceiverStream};
39use tokio_util::{
40 bytes,
41 codec::{Decoder, Encoder, Framed},
42};
43
44#[derive(Debug, Clone)]
52pub struct BetfairStreamBuilder<T: MessageProcessor> {
53 pub client: BetfairRpcClient<Unauthenticated>,
55 pub heartbeat_interval: Option<Duration>,
57 pub processor: T,
59}
60
61#[derive(Debug)]
65pub struct BetfairStreamClient<T: MessageProcessor> {
66 pub send_to_stream: Sender<RequestMessage>,
68 pub sink: Receiver<T::Output>,
70}
71
72#[derive(Debug, Clone)]
76pub struct Cache {
77 state: StreamState,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum CachedMessage {
86 Connection(ConnectionMessage),
90
91 MarketChange(Vec<MarketBookCache>),
93
94 OrderChange(Vec<OrderBookCache>),
97
98 Status(StatusMessage),
101}
102
103impl MessageProcessor for Cache {
104 type Output = CachedMessage;
105
106 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
107 match message {
108 ResponseMessage::Connection(connection_message) => {
109 Some(CachedMessage::Connection(connection_message))
110 }
111 ResponseMessage::MarketChange(market_change_message) => self
112 .state
113 .market_change_update(market_change_message)
114 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
115 .map(CachedMessage::MarketChange),
116 ResponseMessage::OrderChange(order_change_message) => self
117 .state
118 .order_change_update(order_change_message)
119 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
120 .map(CachedMessage::OrderChange),
121 ResponseMessage::Status(status_message) => Some(CachedMessage::Status(status_message)),
122 }
123 }
124}
125
126#[derive(Debug)]
128pub struct Forwarder;
129impl MessageProcessor for Forwarder {
130 type Output = ResponseMessage;
131
132 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
133 Some(message)
134 }
135}
136pub trait MessageProcessor: Send + Sync + 'static {
140 type Output: Send + Clone + Sync + 'static + core::fmt::Debug;
142
143 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output>;
147}
148
149impl<T: MessageProcessor> BetfairStreamBuilder<T> {
150 pub fn new(client: BetfairRpcClient<Unauthenticated>) -> BetfairStreamBuilder<Cache> {
163 BetfairStreamBuilder {
164 client,
165 heartbeat_interval: None,
166 processor: Cache {
167 state: StreamState::new(),
168 },
169 }
170 }
171
172 pub fn new_without_cache(
185 client: BetfairRpcClient<Unauthenticated>,
186 ) -> BetfairStreamBuilder<Forwarder> {
187 BetfairStreamBuilder {
188 client,
189 heartbeat_interval: None,
190 processor: Forwarder,
191 }
192 }
193
194 pub fn with_heartbeat(mut self, interval: Duration) -> Self {
204 self.heartbeat_interval = Some(interval);
205 self
206 }
207
208 pub fn start_with<const C: usize, Sp, H>(self, spawner: Sp) -> (BetfairStreamClient<T>, H)
230 where
231 Sp: FnOnce(BoxFuture<'static, eyre::Result<()>>) -> H,
232 {
233 let (to_stream_tx, to_stream_rx) = mpsc::channel(C);
234 let (from_stream_tx, from_stream_rx) = mpsc::channel(C);
235
236 let fut = self.run(from_stream_tx, to_stream_rx).boxed();
238 let handle = spawner(fut);
239
240 (
241 BetfairStreamClient {
242 send_to_stream: to_stream_tx,
243 sink: from_stream_rx,
244 },
245 handle,
246 )
247 }
248
249 pub fn start<const C: usize>(self) -> (BetfairStreamClient<T>, JoinHandle<eyre::Result<()>>) {
262 self.start_with::<C, _, _>(|fut| tokio::spawn(fut))
263 }
264
265 async fn run(
266 self,
267 from_stream_tx: Sender<T::Output>,
268 to_stream_rx: Receiver<RequestMessage>,
269 ) -> eyre::Result<()> {
270 if let Some(hb) = self.heartbeat_interval {
271 let heartbeat_stream = {
272 let mut interval = tokio::time::interval(hb);
273 interval.reset();
274 let interval_stream = IntervalStream::new(interval).fuse();
275 interval_stream
276 .map(move |instant| HeartbeatMessage {
277 id: Some(
278 instant
279 .into_std()
280 .elapsed()
281 .as_secs()
282 .try_into()
283 .unwrap_or_default(),
284 ),
285 })
286 .map(RequestMessage::Heartbeat)
287 .boxed()
288 };
289 let input_stream = futures::stream::select_all([
290 heartbeat_stream,
291 ReceiverStream::new(to_stream_rx).boxed(),
292 ]);
293
294 self.run_base(from_stream_tx, input_stream).await
295 } else {
296 self.run_base(from_stream_tx, ReceiverStream::new(to_stream_rx))
297 .await
298 }
299 }
300
301 async fn run_base(
302 mut self,
303 mut from_stream_tx: Sender<T::Output>,
304 mut to_stream_rx: impl futures::Stream<Item = RequestMessage> + Unpin,
305 ) -> eyre::Result<()> {
306 let (mut client, _) = self.client.clone().authenticate().await?;
307 let mut backoff = ExponentialBuilder::new().build();
308 let mut first_call = true;
309 'retry: loop {
310 if !first_call {
311 let Some(delay) = backoff.next() else {
313 eyre::bail!("connection retry attempts exceeded")
314 };
315 sleep(delay).await;
316 }
317 first_call = true;
318
319 let mut stream = self
321 .connect_with_retry(&mut from_stream_tx, &mut client)
322 .await?;
323 tracing::info!("Connected to {}", self.client.stream.url());
324
325 loop {
326 let stream_next = pin!(stream.next());
327 let to_stream_rx_next = pin!(to_stream_rx.next());
328 match select(to_stream_rx_next, stream_next).await {
329 future::Either::Left((request, _)) => {
330 let Some(request) = request else {
331 tracing::warn!("request returned None");
332 continue 'retry;
333 };
334
335 tracing::debug!(?request, "sending to betfair");
336 let Ok(()) = stream.send(request).await else {
337 tracing::warn!("could not send request to stream");
338 continue 'retry;
339 };
340 }
341 future::Either::Right((message, _)) => {
342 let Some(message) = message else {
343 tracing::warn!("stream returned None");
344 continue 'retry;
345 };
346
347 match message {
348 Ok(message) => {
349 let message = self.processor.process_message(message);
350 tracing::debug!(?message, "received from betfair");
351 let Some(message) = message else {
352 continue;
353 };
354
355 let Ok(()) = from_stream_tx.send(message).await else {
356 tracing::warn!("could not send stream message to sink");
357 continue 'retry;
358 };
359 }
360 Err(err) => tracing::warn!(?err, "reading message error"),
361 }
362 }
363 }
364 }
365 }
366 }
367
368 #[tracing::instrument(skip_all, err)]
370 async fn connect_with_retry(
371 &mut self,
372 from_stream_tx: &mut Sender<T::Output>,
373 client: &mut Arc<BetfairRpcClient<Authenticated>>,
374 ) -> eyre::Result<Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>>
375 {
376 let mut backoff = ExponentialBuilder::new().build();
377 let mut delay = async || {
378 if let Some(delay) = backoff.next() {
379 sleep(delay).await;
380 Ok(())
381 } else {
382 eyre::bail!("exceeded retry attempts, could not connect");
383 }
384 };
385
386 loop {
387 let server_addr = self.client.stream.url();
388 let host = server_addr
389 .host_str()
390 .ok_or_else(|| eyre::eyre!("invalid betfair url"))?;
391 let port = server_addr.port().unwrap_or(443);
392
393 let domain_str = server_addr
394 .domain()
395 .ok_or_else(|| eyre::eyre!("domain must be known"))?;
396 let domain = rustls::pki_types::ServerName::try_from(domain_str.to_owned())
397 .wrap_err("failed to parse server name")?;
398
399 let Some(socket_addr) = tokio::net::lookup_host((host, port)).await?.next() else {
401 eyre::bail!("no valid socket addresses for {host}:{port}")
402 };
403
404 let tcp_stream = TcpStream::connect(socket_addr).await;
405 let Ok(stream) = tcp_stream else {
406 tracing::error!(err = ?tcp_stream.unwrap_err(), "Connect error. Retrying...");
407 delay().await?;
408 continue;
409 };
410 let tls_stream = tls_connector()?.connect(domain.clone(), stream).await?;
411 let mut tls_stream = Framed::new(tls_stream, StreamAPIClientCodec);
412
413 match self
414 .handshake(from_stream_tx, client, &mut tls_stream)
415 .await
416 {
417 Ok(()) => return Ok(tls_stream),
418 Err(err) => match err {
419 HandshakeErr::WaitAndRetry => {
420 delay().await?;
421 continue;
422 }
423 HandshakeErr::Reauthenticate => {
424 let (new_client, _) = self.client.clone().authenticate().await?;
425 *client = new_client;
426 delay().await?;
427 continue;
428 }
429 HandshakeErr::Fatal => eyre::bail!("fatal error in stream processing"),
430 },
431 }
432 }
433 }
434
435 #[tracing::instrument(err, skip_all)]
436 async fn handshake(
437 &mut self,
438 from_stream_tx: &mut Sender<T::Output>,
439 client: &BetfairRpcClient<Authenticated>,
440 stream: &mut Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>,
441 ) -> Result<(), HandshakeErr> {
442 let res = stream
444 .next()
445 .await
446 .transpose()
447 .inspect_err(|err| {
448 tracing::warn!(?err, "error when parsing stream message");
449 })
450 .map_err(|_| HandshakeErr::WaitAndRetry)?
451 .ok_or(HandshakeErr::WaitAndRetry)?;
452 tracing::info!(?res, "message from stream");
453 let message = self
454 .processor
455 .process_message(res.clone())
456 .ok_or(HandshakeErr::Fatal)
457 .inspect_err(|err| tracing::error!(?err))?;
458 from_stream_tx
459 .send(message.clone())
460 .await
461 .inspect_err(|err| tracing::warn!(?err))
462 .map_err(|_| HandshakeErr::Fatal)?;
463 let ResponseMessage::Connection(_) = &res else {
464 tracing::warn!("stream responded with invalid connection message");
465 return Err(HandshakeErr::Reauthenticate);
466 };
467
468 let msg = authentication_message::AuthenticationMessage {
470 id: Some(-1),
471 session: client.session_token().0.expose_secret().clone(),
472 app_key: self
473 .client
474 .secret_provider
475 .application_key
476 .0
477 .expose_secret()
478 .clone(),
479 };
480 stream
481 .send(RequestMessage::Authentication(msg))
482 .await
483 .inspect_err(|err| tracing::warn!(?err, "stream exited"))
484 .map_err(|_| HandshakeErr::WaitAndRetry)?;
485
486 let message = stream
488 .next()
489 .await
490 .transpose()
491 .inspect_err(|err| {
492 tracing::warn!(?err, "error when parsing stream message");
493 })
494 .map_err(|_| HandshakeErr::WaitAndRetry)?
495 .ok_or(HandshakeErr::WaitAndRetry)?;
496 let processed_message = self
497 .processor
498 .process_message(message.clone())
499 .ok_or(HandshakeErr::Fatal)
500 .inspect_err(|err| tracing::warn!(?err))
501 .map_err(|_| HandshakeErr::Fatal)?;
502 from_stream_tx
503 .send(processed_message)
504 .await
505 .inspect_err(|err| tracing::warn!(?err))
506 .map_err(|_| HandshakeErr::Fatal)?;
507 tracing::info!(?message, "message from stream");
508 let ResponseMessage::Status(status_message) = &message else {
509 tracing::warn!("expected status message, got {message:?}");
510 return Err(HandshakeErr::WaitAndRetry);
511 };
512
513 let StatusMessage::Failure(err) = &status_message else {
514 return Ok(());
515 };
516
517 tracing::error!(?err, "stream respondend wit an error");
518 let action = match err.error_code {
519 ErrorCode::NoAppKey => HandshakeErr::Fatal,
520 ErrorCode::InvalidAppKey => HandshakeErr::Fatal,
521 ErrorCode::NoSession => HandshakeErr::Reauthenticate,
522 ErrorCode::InvalidSessionInformation => HandshakeErr::Reauthenticate,
523 ErrorCode::NotAuthorized => HandshakeErr::Reauthenticate,
524 ErrorCode::InvalidInput => HandshakeErr::Fatal,
525 ErrorCode::InvalidClock => HandshakeErr::Fatal,
526 ErrorCode::UnexpectedError => HandshakeErr::Fatal,
527 ErrorCode::Timeout => HandshakeErr::WaitAndRetry,
528 ErrorCode::SubscriptionLimitExceeded => HandshakeErr::WaitAndRetry,
529 ErrorCode::InvalidRequest => HandshakeErr::Fatal,
530 ErrorCode::ConnectionFailed => HandshakeErr::WaitAndRetry,
531 ErrorCode::MaxConnectionLimitExceeded => HandshakeErr::Fatal,
532 ErrorCode::TooManyRequests => HandshakeErr::WaitAndRetry,
533 };
534
535 Err(action)
536 }
537}
538
539#[derive(Debug)]
540enum HandshakeErr {
541 WaitAndRetry,
542 Reauthenticate,
543 Fatal,
544}
545
546impl fmt::Display for HandshakeErr {
547 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548 write!(f, "Stream Handshake Error {:?}", self)
549 }
550}
551
552impl core::error::Error for HandshakeErr {}
553
554#[tracing::instrument(err)]
555fn tls_connector() -> eyre::Result<tokio_rustls::TlsConnector> {
556 use tokio_rustls::TlsConnector;
557
558 let mut roots = rustls::RootCertStore::empty();
559 let native_certs = rustls_native_certs::load_native_certs();
560 for cert in native_certs.certs {
561 roots.add(cert)?;
562 }
563
564 let config = rustls::ClientConfig::builder()
565 .with_root_certificates(roots)
566 .with_no_client_auth();
567 Ok(TlsConnector::from(Arc::new(config)))
568}
569
570pub struct StreamAPIClientCodec;
572
573impl Decoder for StreamAPIClientCodec {
574 type Item = ResponseMessage;
575 type Error = eyre::Report;
576
577 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
578 if let Some(pos) = src.iter().position(|&byte| byte == b'\n') {
580 let delimiter_size = if pos > 0 && src[pos - 1] == b'\r' {
582 2
583 } else {
584 1
585 };
586
587 let line = src.split_to(pos + 1);
589
590 let (json_part, _) = line.split_at(line.len().saturating_sub(delimiter_size));
592
593 let data = serde_json::from_slice::<Self::Item>(json_part)?;
595 return Ok(Some(data));
596 }
597 Ok(None)
598 }
599}
600
601impl Encoder<RequestMessage> for StreamAPIClientCodec {
602 type Error = eyre::Report;
603
604 fn encode(
605 &mut self,
606 item: RequestMessage,
607 dst: &mut bytes::BytesMut,
608 ) -> Result<(), Self::Error> {
609 let json = serde_json::to_string(&item)?;
611 dst.extend_from_slice(json.as_bytes());
613 dst.extend_from_slice(b"\r\n");
614 Ok(())
615 }
616}
617
618#[cfg(test)]
619mod tests {
620
621 use core::fmt::Write as _;
622
623 use super::*;
624
625 #[tokio::test]
626 async fn can_resolve_host_ipv4() {
627 let url = url::Url::parse("tcptls://stream-api.betfair.com:443").unwrap();
628 let host = url.host_str().unwrap();
629 let port = url
630 .port()
631 .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
632 let socket_addr = tokio::net::lookup_host((host, port))
633 .await
634 .unwrap()
635 .next()
636 .unwrap();
637 assert!(socket_addr.ip().is_ipv4());
638 assert_eq!(socket_addr.port(), 443);
639 }
640
641 #[test]
642 fn can_decode_single_message() {
643 let msg = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
644 let separator = "\r\n";
645 let data = format!("{msg}{separator}");
646
647 let mut codec = StreamAPIClientCodec;
648 let mut buf = bytes::BytesMut::from(data.as_bytes());
649 let msg = codec.decode(&mut buf).unwrap().unwrap();
650
651 assert!(matches!(msg, ResponseMessage::Connection(_)));
652 }
653
654 #[test]
655 fn can_decode_multiple_messages() {
656 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
658 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
659 let separator = "\r\n";
660 let data = format!("{msg_one}{separator}{msg_two}{separator}");
661
662 let mut codec = StreamAPIClientCodec;
663 let mut buf = bytes::BytesMut::from(data.as_bytes());
664 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
665 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
666
667 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
668 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
669 }
670
671 #[test]
672 fn can_decode_multiple_partial_messages() {
673 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
675 let msg_two_pt_one = r#"{"op":"ocm","id":3,"clk""#;
676 let msg_two_pt_two = r#":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
677 let separator = "\r\n";
678 let data = format!("{msg_one}{separator}{msg_two_pt_one}");
679
680 let mut codec = StreamAPIClientCodec;
681 let mut buf = bytes::BytesMut::from(data.as_bytes());
682 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
683 let msg_two_attempt = codec.decode(&mut buf).unwrap();
684 assert!(msg_two_attempt.is_none());
685 buf.write_str(msg_two_pt_two).unwrap();
686 buf.write_str(separator).unwrap();
687 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
688
689 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
690 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
691 }
692
693 #[test]
694 fn can_decode_subsequent_messages() {
695 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
697 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
698 let separator = "\r\n";
699 let data = format!("{msg_one}{separator}");
700
701 let mut codec = StreamAPIClientCodec;
702 let mut buf = bytes::BytesMut::from(data.as_bytes());
703 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
704 let msg_two_attempt = codec.decode(&mut buf).unwrap();
705 assert!(msg_two_attempt.is_none());
706 let data = format!("{msg_two}{separator}");
707 buf.write_str(data.as_str()).unwrap();
708 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
709
710 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
711 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
712 }
713
714 #[test]
715 fn can_encode_message() {
716 let msg = RequestMessage::Authentication(
717 betfair_stream_types::request::authentication_message::AuthenticationMessage {
718 id: Some(1),
719 session: "sss".to_owned(),
720 app_key: "aaaa".to_owned(),
721 },
722 );
723 let mut codec = StreamAPIClientCodec;
724 let mut buf = bytes::BytesMut::new();
725 codec.encode(msg, &mut buf).unwrap();
726
727 let data = buf.freeze();
728 let data = core::str::from_utf8(&data).unwrap();
729
730 assert!(data.ends_with("\r\n"));
732 assert!(data.starts_with("{\"op\":\"authentication\""));
734 }
735}