barter_data/streams/builder/multi.rs
1use super::{StreamBuilder, Streams};
2use crate::{
3 error::DataError,
4 streams::{consumer::MarketStreamResult, reconnect::stream::ReconnectingStream},
5 subscription::SubscriptionKind,
6};
7use barter_instrument::exchange::ExchangeId;
8use barter_integration::channel::Channel;
9use futures_util::StreamExt;
10use std::{collections::HashMap, fmt::Debug, future::Future, pin::Pin};
11
12/// Communicative type alias representing the [`Future`] result of a [`StreamBuilder::init`] call
13/// generated whilst executing [`MultiStreamBuilder::add`].
14pub type BuilderInitFuture = Pin<Box<dyn Future<Output = Result<(), DataError>>>>;
15
16/// Builder to configure and initialise a common [`Streams<Output>`](Streams) instance from
17/// multiple [`StreamBuilder<SubscriptionKind>`](StreamBuilder)s.
18#[derive(Default)]
19pub struct MultiStreamBuilder<Output> {
20 pub channels: HashMap<ExchangeId, Channel<Output>>,
21 pub futures: Vec<BuilderInitFuture>,
22}
23
24impl<Output> Debug for MultiStreamBuilder<Output>
25where
26 Output: Debug,
27{
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("MultiStreamBuilder<Output>")
30 .field("channels", &self.channels)
31 .field("num_futures", &self.futures.len())
32 .finish()
33 }
34}
35
36impl<Output> MultiStreamBuilder<Output> {
37 /// Construct a new [`Self`].
38 pub fn new() -> Self {
39 Self {
40 channels: HashMap::new(),
41 futures: Vec::new(),
42 }
43 }
44
45 /// Add a [`StreamBuilder<SubscriptionKind>`](StreamBuilder) to the [`MultiStreamBuilder`]. Creates a
46 /// [`Future`] that calls [`StreamBuilder::init`] and maps the [`SubscriptionKind::Event`](SubscriptionKind)
47 /// into a common `Output`.
48 ///
49 /// Note that the created [`Future`] is not awaited until the [`MultiStreamBuilder::init`]
50 /// method is invoked.
51 #[allow(clippy::should_implement_trait)]
52 pub fn add<InstrumentKey, Kind>(mut self, builder: StreamBuilder<InstrumentKey, Kind>) -> Self
53 where
54 Output:
55 From<MarketStreamResult<InstrumentKey, Kind::Event>> + Debug + Clone + Send + 'static,
56 InstrumentKey: Debug + Send + 'static,
57 Kind: SubscriptionKind + 'static,
58 Kind::Event: Send,
59 {
60 // Allocate HashMap to hold the exchange_tx<Output> for each StreamBuilder exchange present
61 let mut exchange_txs = HashMap::with_capacity(builder.channels.len());
62
63 // Iterate over each StreamBuilder exchange present
64 for exchange in builder.channels.keys().cloned() {
65 // Insert ExchangeChannel<Output> Entry to Self for each exchange
66 let exchange_tx = self.channels.entry(exchange).or_default().tx.clone();
67
68 // Insert new exchange_tx<Output> into HashMap for each exchange
69 exchange_txs.insert(exchange, exchange_tx);
70 }
71
72 // Init Streams<Kind::Event> & send mapped Outputs to the associated exchange_tx
73 self.futures.push(Box::pin(async move {
74 builder
75 .init()
76 .await?
77 .streams
78 .into_iter()
79 .for_each(|(exchange, exchange_rx)| {
80 // Remove exchange_tx<Output> from HashMap that's associated with this tuple:
81 // (ExchangeId, exchange_rx<MarketStreamResult<InstrumentKey, SubscriptionKind::Event>>)
82 let exchange_tx = exchange_txs
83 .remove(&exchange)
84 .expect("all exchange_txs should be present here");
85
86 // Task to receive MarketStreamResult<SubscriptionKind::Event> and send Outputs via exchange_tx
87 tokio::spawn(
88 exchange_rx
89 .into_stream()
90 .map(Output::from)
91 .forward_to(exchange_tx),
92 );
93 });
94
95 Ok(())
96 }));
97
98 self
99 }
100
101 /// Initialise each [`StreamBuilder<SubscriptionKind>`](StreamBuilder) that was added to the
102 /// [`MultiStreamBuilder`] and map all [`Streams<SubscriptionKind::Event>`](Streams) into a common
103 /// [`Streams<Output>`](Streams).
104 pub async fn init(self) -> Result<Streams<Output>, DataError> {
105 // Await Stream initialisation perpetual and ensure success
106 futures::future::try_join_all(self.futures).await?;
107
108 // Construct Streams<Output> using each ExchangeChannel receiver
109 Ok(Streams {
110 streams: self
111 .channels
112 .into_iter()
113 .map(|(exchange, channel)| (exchange, channel.rx))
114 .collect(),
115 })
116 }
117}