1extern crate alloc;
8pub mod cache;
9use backon::{BackoffBuilder as _, ExponentialBuilder};
10use betfair_adapter::{Authenticated, BetfairRpcClient, Unauthenticated};
11pub use betfair_stream_types as types;
12use betfair_stream_types::{
13 request::{RequestMessage, authentication_message, heartbeat_message::HeartbeatMessage},
14 response::{
15 ResponseMessage,
16 connection_message::ConnectionMessage,
17 status_message::{ErrorCode, StatusMessage},
18 },
19};
20use cache::{
21 primitives::{MarketBookCache, OrderBookCache},
22 tracker::StreamState,
23};
24use core::fmt;
25use core::{pin::pin, time::Duration};
26use eyre::Context as _;
27use futures::{
28 SinkExt as _, StreamExt as _,
29 future::{self, select},
30};
31use std::sync::Arc;
32use tokio::{
33 net::TcpStream,
34 sync::mpsc::{self, Receiver, Sender},
35 task::JoinHandle,
36 time::sleep,
37};
38use tokio_stream::wrappers::{IntervalStream, ReceiverStream};
39use tokio_util::{
40 bytes,
41 codec::{Decoder, Encoder, Framed},
42};
43
44#[derive(Debug, Clone)]
52pub struct BetfairStreamBuilder<T: MessageProcessor> {
53 pub client: BetfairRpcClient<Unauthenticated>,
55 pub heartbeat_interval: Option<Duration>,
57 pub processor: T,
59}
60
61#[derive(Debug)]
65pub struct BetfairStreamClient<T: MessageProcessor> {
66 pub send_to_stream: Sender<RequestMessage>,
68 pub sink: Receiver<T::Output>,
70}
71
72#[derive(Debug, Clone)]
76pub struct Cache {
77 state: StreamState,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum CachedMessage {
86 Connection(ConnectionMessage),
90
91 MarketChange(Vec<MarketBookCache>),
93
94 OrderChange(Vec<OrderBookCache>),
97
98 Status(StatusMessage),
101}
102
103impl MessageProcessor for Cache {
104 type Output = CachedMessage;
105
106 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
107 match message {
108 ResponseMessage::Connection(connection_message) => {
109 Some(CachedMessage::Connection(connection_message))
110 }
111 ResponseMessage::MarketChange(market_change_message) => {
112 let data = self
113 .state
114 .market_change_update(market_change_message)
115 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
116 .map(CachedMessage::MarketChange);
117
118 data
119 }
120 ResponseMessage::OrderChange(order_change_message) => {
121 let data = self
122 .state
123 .order_change_update(order_change_message)
124 .map(|markets| markets.into_iter().cloned().collect::<Vec<_>>())
125 .map(CachedMessage::OrderChange);
126
127 data
128 }
129 ResponseMessage::Status(status_message) => Some(CachedMessage::Status(status_message)),
130 }
131 }
132}
133
134#[derive(Debug)]
136pub struct Forwarder;
137impl MessageProcessor for Forwarder {
138 type Output = ResponseMessage;
139
140 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output> {
141 Some(message)
142 }
143}
144pub trait MessageProcessor: Send + Sync + 'static {
148 type Output: Send + Clone + Sync + 'static + core::fmt::Debug;
150
151 fn process_message(&mut self, message: ResponseMessage) -> Option<Self::Output>;
155}
156
157impl<T: MessageProcessor> BetfairStreamBuilder<T> {
158 pub fn new(client: BetfairRpcClient<Unauthenticated>) -> BetfairStreamBuilder<Cache> {
171 BetfairStreamBuilder {
172 client,
173 heartbeat_interval: None,
174 processor: Cache {
175 state: StreamState::new(),
176 },
177 }
178 }
179
180 pub fn new_without_cache(
193 client: BetfairRpcClient<Unauthenticated>,
194 ) -> BetfairStreamBuilder<Forwarder> {
195 BetfairStreamBuilder {
196 client,
197 heartbeat_interval: None,
198 processor: Forwarder,
199 }
200 }
201
202 pub fn with_heartbeat(mut self, interval: Duration) -> Self {
212 self.heartbeat_interval = Some(interval);
213 self
214 }
215
216 pub async fn start(self) -> (BetfairStreamClient<T>, JoinHandle<eyre::Result<()>>) {
228 let (to_stream_tx, to_stream_rx) = mpsc::channel(100);
229 let (from_stream_tx, from_stream_rx) = mpsc::channel(100);
230
231 let task = tokio::task::spawn(self.run(from_stream_tx, to_stream_rx));
232
233 (
234 BetfairStreamClient {
235 send_to_stream: to_stream_tx,
236 sink: from_stream_rx,
237 },
238 task,
239 )
240 }
241
242 async fn run(
243 self,
244 from_stream_tx: Sender<T::Output>,
245 to_stream_rx: Receiver<RequestMessage>,
246 ) -> eyre::Result<()> {
247 if let Some(hb) = self.heartbeat_interval {
248 let heartbeat_stream = {
249 let mut interval = tokio::time::interval(hb);
250 interval.reset();
251 let interval_stream = IntervalStream::new(interval).fuse();
252 interval_stream
253 .map(move |instant| HeartbeatMessage {
254 id: Some(
255 instant
256 .into_std()
257 .elapsed()
258 .as_secs()
259 .try_into()
260 .unwrap_or_default(),
261 ),
262 })
263 .map(RequestMessage::Heartbeat)
264 .boxed()
265 };
266 let input_stream = futures::stream::select_all([
267 heartbeat_stream,
268 ReceiverStream::new(to_stream_rx).boxed(),
269 ]);
270
271 self.run_base(from_stream_tx, input_stream).await
272 } else {
273 self.run_base(from_stream_tx, ReceiverStream::new(to_stream_rx))
274 .await
275 }
276 }
277
278 async fn run_base(
279 mut self,
280 mut from_stream_tx: Sender<T::Output>,
281 mut to_stream_rx: impl futures::Stream<Item = RequestMessage> + Unpin,
282 ) -> eyre::Result<()> {
283 let (mut client, _) = self.client.clone().authenticate().await?;
284 let mut backoff = ExponentialBuilder::new().build();
285 let mut first_call = true;
286 'retry: loop {
287 if !first_call {
288 let Some(delay) = backoff.next() else {
290 eyre::bail!("connection retry attempts exceeded")
291 };
292 sleep(delay).await;
293 }
294 first_call = true;
295
296 let mut stream = self
298 .connect_with_retry(&mut from_stream_tx, &mut client)
299 .await?;
300 tracing::info!("Connected to {}", self.client.stream.url());
301
302 loop {
303 let stream_next = pin!(stream.next());
304 let to_stream_rx_next = pin!(to_stream_rx.next());
305 match select(to_stream_rx_next, stream_next).await {
306 future::Either::Left((request, _)) => {
307 let Some(request) = request else {
308 tracing::warn!("request returned None");
309 continue 'retry;
310 };
311
312 tracing::debug!(?request, "sending to betfair");
313 let Ok(()) = stream.send(request).await else {
314 tracing::warn!("could not send request to stream");
315 continue 'retry;
316 };
317 }
318 future::Either::Right((message, _)) => {
319 let Some(message) = message else {
320 tracing::warn!("stream returned None");
321 continue 'retry;
322 };
323
324 match message {
325 Ok(message) => {
326 let message = self.processor.process_message(message);
327 tracing::debug!(?message, "received from betfair");
328 let Some(message) = message else {
329 continue;
330 };
331
332 let Ok(()) = from_stream_tx.send(message).await else {
333 tracing::warn!("could not send stream message to sink");
334 continue 'retry;
335 };
336 }
337 Err(err) => tracing::warn!(?err, "reading message error"),
338 }
339 }
340 }
341 }
342 }
343 }
344
345 #[tracing::instrument(skip_all, err)]
347 async fn connect_with_retry(
348 &mut self,
349 from_stream_tx: &mut Sender<T::Output>,
350 client: &mut Arc<BetfairRpcClient<Authenticated>>,
351 ) -> eyre::Result<Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>>
352 {
353 let mut backoff = ExponentialBuilder::new().build();
354 let mut delay = async || {
355 if let Some(delay) = backoff.next() {
356 sleep(delay).await;
357 Ok(())
358 } else {
359 eyre::bail!("exceeded retry attempts, could not connect");
360 }
361 };
362
363 loop {
364 let server_addr = self.client.stream.url();
365 let host = server_addr
366 .host_str()
367 .ok_or_else(|| eyre::eyre!("invalid betfair url"))?;
368 let port = server_addr.port().unwrap_or(443);
369
370 let domain_str = server_addr
371 .domain()
372 .ok_or_else(|| eyre::eyre!("domain must be known"))?;
373 let domain = rustls::pki_types::ServerName::try_from(domain_str.to_owned())
374 .wrap_err("failed to parse server name")?;
375
376 let Some(socket_addr) = tokio::net::lookup_host((host, port)).await?.next() else {
378 eyre::bail!("no valid socket addresses for {host}:{port}")
379 };
380
381 let tcp_stream = TcpStream::connect(socket_addr).await;
382 let Ok(stream) = tcp_stream else {
383 tracing::error!(err = ?tcp_stream.unwrap_err(), "Connect error. Retrying...");
384 delay().await?;
385 continue;
386 };
387 let tls_stream = tls_connector()?.connect(domain.clone(), stream).await?;
388 let mut tls_stream = Framed::new(tls_stream, StreamAPIClientCodec);
389
390 match self
391 .handshake(from_stream_tx, client, &mut tls_stream)
392 .await
393 {
394 Ok(()) => return Ok(tls_stream),
395 Err(err) => match err {
396 HandshakeErr::WaitAndRetry => {
397 delay().await?;
398 continue;
399 }
400 HandshakeErr::Reauthenticate => {
401 let (new_client, _) = self.client.clone().authenticate().await?;
402 *client = new_client;
403 delay().await?;
404 continue;
405 }
406 HandshakeErr::Fatal => eyre::bail!("fatal error in stream processing"),
407 },
408 }
409 }
410 }
411
412 #[tracing::instrument(err, skip_all)]
413 async fn handshake(
414 &mut self,
415 from_stream_tx: &mut Sender<T::Output>,
416 client: &BetfairRpcClient<Authenticated>,
417 stream: &mut Framed<tokio_rustls::client::TlsStream<TcpStream>, StreamAPIClientCodec>,
418 ) -> Result<(), HandshakeErr> {
419 let res = stream
421 .next()
422 .await
423 .transpose()
424 .inspect_err(|err| {
425 tracing::warn!(?err, "error when parsing stream message");
426 })
427 .map_err(|_| HandshakeErr::WaitAndRetry)?
428 .ok_or(HandshakeErr::WaitAndRetry)?;
429 tracing::info!(?res, "message from stream");
430 let message = self
431 .processor
432 .process_message(res.clone())
433 .ok_or(HandshakeErr::Fatal)
434 .inspect_err(|err| tracing::error!(?err))?;
435 from_stream_tx
436 .send(message.clone())
437 .await
438 .inspect_err(|err| tracing::warn!(?err))
439 .map_err(|_| HandshakeErr::Fatal)?;
440 let ResponseMessage::Connection(_) = &res else {
441 tracing::warn!("stream responded with invalid connection message");
442 return Err(HandshakeErr::Reauthenticate);
443 };
444
445 let msg = authentication_message::AuthenticationMessage {
447 id: Some(-1),
448 session: client.session_token().0.expose_secret().clone(),
449 app_key: self
450 .client
451 .secret_provider
452 .application_key
453 .0
454 .expose_secret()
455 .clone(),
456 };
457 stream
458 .send(RequestMessage::Authentication(msg))
459 .await
460 .inspect_err(|err| tracing::warn!(?err, "stream exited"))
461 .map_err(|_| HandshakeErr::WaitAndRetry)?;
462
463 let message = stream
465 .next()
466 .await
467 .transpose()
468 .inspect_err(|err| {
469 tracing::warn!(?err, "error when parsing stream message");
470 })
471 .map_err(|_| HandshakeErr::WaitAndRetry)?
472 .ok_or(HandshakeErr::WaitAndRetry)?;
473 let processed_message = self
474 .processor
475 .process_message(message.clone())
476 .ok_or(HandshakeErr::Fatal)
477 .inspect_err(|err| tracing::warn!(?err))
478 .map_err(|_| HandshakeErr::Fatal)?;
479 from_stream_tx
480 .send(processed_message)
481 .await
482 .inspect_err(|err| tracing::warn!(?err))
483 .map_err(|_| HandshakeErr::Fatal)?;
484 tracing::info!(?message, "message from stream");
485 let ResponseMessage::Status(status_message) = &message else {
486 tracing::warn!("expected status message, got {message:?}");
487 return Err(HandshakeErr::WaitAndRetry);
488 };
489
490 let StatusMessage::Failure(err) = &status_message else {
491 return Ok(());
492 };
493
494 tracing::error!(?err, "stream respondend wit an error");
495 let action = match err.error_code {
496 ErrorCode::NoAppKey => HandshakeErr::Fatal,
497 ErrorCode::InvalidAppKey => HandshakeErr::Fatal,
498 ErrorCode::NoSession => HandshakeErr::Reauthenticate,
499 ErrorCode::InvalidSessionInformation => HandshakeErr::Reauthenticate,
500 ErrorCode::NotAuthorized => HandshakeErr::Reauthenticate,
501 ErrorCode::InvalidInput => HandshakeErr::Fatal,
502 ErrorCode::InvalidClock => HandshakeErr::Fatal,
503 ErrorCode::UnexpectedError => HandshakeErr::Fatal,
504 ErrorCode::Timeout => HandshakeErr::WaitAndRetry,
505 ErrorCode::SubscriptionLimitExceeded => HandshakeErr::WaitAndRetry,
506 ErrorCode::InvalidRequest => HandshakeErr::Fatal,
507 ErrorCode::ConnectionFailed => HandshakeErr::WaitAndRetry,
508 ErrorCode::MaxConnectionLimitExceeded => HandshakeErr::Fatal,
509 ErrorCode::TooManyRequests => HandshakeErr::WaitAndRetry,
510 };
511
512 Err(action)
513 }
514}
515
516#[derive(Debug)]
517enum HandshakeErr {
518 WaitAndRetry,
519 Reauthenticate,
520 Fatal,
521}
522
523impl fmt::Display for HandshakeErr {
524 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
525 write!(f, "Stream Handshake Error {:?}", self)
526 }
527}
528
529impl core::error::Error for HandshakeErr {}
530
531#[tracing::instrument(err)]
532fn tls_connector() -> eyre::Result<tokio_rustls::TlsConnector> {
533 use tokio_rustls::TlsConnector;
534
535 let mut roots = rustls::RootCertStore::empty();
536 let native_certs = rustls_native_certs::load_native_certs();
537 for cert in native_certs.certs {
538 roots.add(cert)?;
539 }
540
541 let config = rustls::ClientConfig::builder()
542 .with_root_certificates(roots)
543 .with_no_client_auth();
544 Ok(TlsConnector::from(Arc::new(config)))
545}
546
547pub struct StreamAPIClientCodec;
549
550impl Decoder for StreamAPIClientCodec {
551 type Item = ResponseMessage;
552 type Error = eyre::Report;
553
554 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
555 if let Some(pos) = src.iter().position(|&byte| byte == b'\n') {
557 let delimiter_size = if pos > 0 && src[pos - 1] == b'\r' {
559 2
560 } else {
561 1
562 };
563
564 let line = src.split_to(pos + 1);
566
567 let (json_part, _) = line.split_at(line.len().saturating_sub(delimiter_size));
569
570 let data = serde_json::from_slice::<Self::Item>(json_part)?;
572 return Ok(Some(data));
573 }
574 Ok(None)
575 }
576}
577
578impl Encoder<RequestMessage> for StreamAPIClientCodec {
579 type Error = eyre::Report;
580
581 fn encode(
582 &mut self,
583 item: RequestMessage,
584 dst: &mut bytes::BytesMut,
585 ) -> Result<(), Self::Error> {
586 let json = serde_json::to_string(&item)?;
588 dst.extend_from_slice(json.as_bytes());
590 dst.extend_from_slice(b"\r\n");
591 Ok(())
592 }
593}
594
595#[cfg(test)]
596mod tests {
597
598 use core::fmt::Write as _;
599
600 use super::*;
601
602 #[tokio::test]
603 async fn can_resolve_host_ipv4() {
604 let url = url::Url::parse("tcptls://stream-api.betfair.com:443").unwrap();
605 let host = url.host_str().unwrap();
606 let port = url
607 .port()
608 .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
609 let socket_addr = tokio::net::lookup_host((host, port))
610 .await
611 .unwrap()
612 .next()
613 .unwrap();
614 assert!(socket_addr.ip().is_ipv4());
615 assert_eq!(socket_addr.port(), 443);
616 }
617
618 #[test]
619 fn can_decode_single_message() {
620 let msg = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
621 let separator = "\r\n";
622 let data = format!("{msg}{separator}");
623
624 let mut codec = StreamAPIClientCodec;
625 let mut buf = bytes::BytesMut::from(data.as_bytes());
626 let msg = codec.decode(&mut buf).unwrap().unwrap();
627
628 assert!(matches!(msg, ResponseMessage::Connection(_)));
629 }
630
631 #[test]
632 fn can_decode_multiple_messages() {
633 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
635 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
636 let separator = "\r\n";
637 let data = format!("{msg_one}{separator}{msg_two}{separator}");
638
639 let mut codec = StreamAPIClientCodec;
640 let mut buf = bytes::BytesMut::from(data.as_bytes());
641 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
642 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
643
644 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
645 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
646 }
647
648 #[test]
649 fn can_decode_multiple_partial_messages() {
650 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
652 let msg_two_pt_one = r#"{"op":"ocm","id":3,"clk""#;
653 let msg_two_pt_two = r#":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
654 let separator = "\r\n";
655 let data = format!("{msg_one}{separator}{msg_two_pt_one}");
656
657 let mut codec = StreamAPIClientCodec;
658 let mut buf = bytes::BytesMut::from(data.as_bytes());
659 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
660 let msg_two_attempt = codec.decode(&mut buf).unwrap();
661 assert!(msg_two_attempt.is_none());
662 buf.write_str(msg_two_pt_two).unwrap();
663 buf.write_str(separator).unwrap();
664 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
665
666 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
667 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
668 }
669
670 #[test]
671 fn can_decode_subsequent_messages() {
672 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
674 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
675 let separator = "\r\n";
676 let data = format!("{msg_one}{separator}");
677
678 let mut codec = StreamAPIClientCodec;
679 let mut buf = bytes::BytesMut::from(data.as_bytes());
680 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
681 let msg_two_attempt = codec.decode(&mut buf).unwrap();
682 assert!(msg_two_attempt.is_none());
683 let data = format!("{msg_two}{separator}");
684 buf.write_str(data.as_str()).unwrap();
685 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
686
687 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
688 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
689 }
690
691 #[test]
692 fn can_encode_message() {
693 let msg = RequestMessage::Authentication(
694 betfair_stream_types::request::authentication_message::AuthenticationMessage {
695 id: Some(1),
696 session: "sss".to_owned(),
697 app_key: "aaaa".to_owned(),
698 },
699 );
700 let mut codec = StreamAPIClientCodec;
701 let mut buf = bytes::BytesMut::new();
702 codec.encode(msg, &mut buf).unwrap();
703
704 let data = buf.freeze();
705 let data = core::str::from_utf8(&data).unwrap();
706
707 assert!(data.ends_with("\r\n"));
709 assert!(data.starts_with("{\"op\":\"authentication\""));
711 }
712}