use super::{StreamBuilder, Streams};
use crate::{
error::DataError,
streams::{consumer::MarketStreamResult, reconnect::stream::ReconnectingStream},
subscription::SubscriptionKind,
};
use barter_instrument::exchange::ExchangeId;
use barter_integration::channel::Channel;
use futures_util::StreamExt;
use std::{collections::HashMap, fmt::Debug, future::Future, pin::Pin};
pub type BuilderInitFuture = Pin<Box<dyn Future<Output = Result<(), DataError>>>>;
#[derive(Default)]
pub struct MultiStreamBuilder<Output> {
pub channels: HashMap<ExchangeId, Channel<Output>>,
pub futures: Vec<BuilderInitFuture>,
}
impl<Output> Debug for MultiStreamBuilder<Output>
where
Output: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiStreamBuilder<Output>")
.field("channels", &self.channels)
.field("num_futures", &self.futures.len())
.finish()
}
}
impl<Output> MultiStreamBuilder<Output> {
pub fn new() -> Self {
Self {
channels: HashMap::new(),
futures: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<InstrumentKey, Kind>(mut self, builder: StreamBuilder<InstrumentKey, Kind>) -> Self
where
Output:
From<MarketStreamResult<InstrumentKey, Kind::Event>> + Debug + Clone + Send + 'static,
InstrumentKey: Debug + Send + 'static,
Kind: SubscriptionKind + 'static,
Kind::Event: Send,
{
let mut exchange_txs = HashMap::with_capacity(builder.channels.len());
for exchange in builder.channels.keys().cloned() {
let exchange_tx = self.channels.entry(exchange).or_default().tx.clone();
exchange_txs.insert(exchange, exchange_tx);
}
self.futures.push(Box::pin(async move {
builder
.init()
.await?
.streams
.into_iter()
.for_each(|(exchange, exchange_rx)| {
let exchange_tx = exchange_txs
.remove(&exchange)
.expect("all exchange_txs should be present here");
tokio::spawn(
exchange_rx
.into_stream()
.map(Output::from)
.forward_to(exchange_tx),
);
});
Ok(())
}));
self
}
pub async fn init(self) -> Result<Streams<Output>, DataError> {
futures::future::try_join_all(self.futures).await?;
Ok(Streams {
streams: self
.channels
.into_iter()
.map(|(exchange, channel)| (exchange, channel.rx))
.collect(),
})
}
}