use std::collections::HashMap;
use std::fmt::Display;
use crate::block::{
BatchMode, Batcher, BlockStructure, Connection, NextStrategy, OperatorStructure,
};
use crate::network::{Coord, ReceiverEndpoint};
use crate::operator::{ExchangeData, KeyerFn, Operator, StreamElement};
use crate::scheduler::{BlockId, ExecutionMetadata};
#[derive(Debug, Clone)]
pub(crate) struct BlockSenders {
pub indexes: Vec<usize>,
}
impl BlockSenders {
pub(crate) fn new(indexes: Vec<usize>) -> Self {
Self { indexes }
}
}
pub struct End<OperatorChain, IndexFn>
where
IndexFn: KeyerFn<u64, OperatorChain::Out>,
OperatorChain: Operator,
OperatorChain::Out: Send + 'static,
{
prev: OperatorChain,
coord: Option<Coord>,
next_strategy: NextStrategy<OperatorChain::Out, IndexFn>,
batch_mode: BatchMode,
block_senders: Vec<BlockSenders>,
senders: Vec<(ReceiverEndpoint, Batcher<OperatorChain::Out>)>,
feedback_id: Option<BlockId>,
ignore_block_ids: Vec<BlockId>,
}
impl<OperatorChain: std::fmt::Debug, IndexFn: std::fmt::Debug> std::fmt::Debug
for End<OperatorChain, IndexFn>
where
IndexFn: KeyerFn<u64, OperatorChain::Out>,
OperatorChain: Operator,
OperatorChain::Out: Send + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("End")
.field("prev", &self.prev)
.field("coord", &self.coord)
.field("next_strategy", &self.next_strategy)
.field("batch_mode", &self.batch_mode)
.field("block_senders", &self.block_senders)
.field("feedback_id", &self.feedback_id)
.field("ignore_block_ids", &self.ignore_block_ids)
.finish()
}
}
impl<OperatorChain: Clone, IndexFn: Clone> Clone for End<OperatorChain, IndexFn>
where
IndexFn: KeyerFn<u64, OperatorChain::Out>,
OperatorChain: Operator,
OperatorChain::Out: Send + 'static,
{
fn clone(&self) -> Self {
Self {
prev: self.prev.clone(),
coord: self.coord,
next_strategy: self.next_strategy.clone(),
batch_mode: self.batch_mode,
block_senders: self.block_senders.clone(),
senders: Default::default(),
feedback_id: self.feedback_id,
ignore_block_ids: self.ignore_block_ids.clone(),
}
}
}
impl<OperatorChain, IndexFn> Display for End<OperatorChain, IndexFn>
where
IndexFn: KeyerFn<u64, OperatorChain::Out>,
OperatorChain: Operator,
OperatorChain::Out: Send + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.next_strategy {
NextStrategy::Random => write!(f, "{} -> Shuffle", self.prev),
NextStrategy::OnlyOne => write!(f, "{} -> OnlyOne", self.prev),
_ => self.prev.fmt(f),
}
}
}
impl<OperatorChain, IndexFn> End<OperatorChain, IndexFn>
where
IndexFn: KeyerFn<u64, OperatorChain::Out>,
OperatorChain: Operator,
OperatorChain::Out: Send + 'static,
{
pub(crate) fn new(
prev: OperatorChain,
next_strategy: NextStrategy<OperatorChain::Out, IndexFn>,
batch_mode: BatchMode,
) -> Self {
Self {
prev,
coord: None,
next_strategy,
batch_mode,
block_senders: Default::default(),
senders: Default::default(),
feedback_id: None,
ignore_block_ids: Default::default(),
}
}
fn setup_senders(&mut self) {
glidesort::sort_by_key(&mut self.senders, |s| s.0);
self.block_senders = match self.next_strategy {
NextStrategy::All => (0..self.senders.len())
.map(|i| vec![i])
.map(BlockSenders::new)
.collect(),
_ => self
.senders
.iter()
.enumerate()
.fold(HashMap::<_, Vec<_>>::new(), |mut map, (i, (coord, _))| {
map.entry(coord.coord.block_id).or_default().push(i);
map
})
.into_values()
.map(BlockSenders::new)
.collect(),
};
if matches!(self.next_strategy, NextStrategy::OnlyOne) {
self.block_senders
.iter()
.for_each(|s| assert_eq!(s.indexes.len(), 1));
}
}
pub(crate) fn mark_feedback(&mut self, block_id: BlockId) {
self.feedback_id = Some(block_id);
}
pub(crate) fn ignore_destination(&mut self, block_id: BlockId) {
self.ignore_block_ids.push(block_id);
}
}
impl<OperatorChain, IndexFn> Operator for End<OperatorChain, IndexFn>
where
IndexFn: KeyerFn<u64, OperatorChain::Out>,
OperatorChain: Operator,
OperatorChain::Out: ExchangeData,
{
type Out = ();
fn setup(&mut self, metadata: &mut ExecutionMetadata) {
self.prev.setup(metadata);
let senders = metadata.network.get_senders(metadata.coord);
self.senders = senders
.into_iter()
.filter(|(endpoint, _)| !self.ignore_block_ids.contains(&endpoint.coord.block_id))
.map(|(coord, sender)| (coord, Batcher::new(sender, self.batch_mode, metadata.coord)))
.collect();
self.setup_senders();
self.coord = Some(metadata.coord);
}
fn next(&mut self) -> StreamElement<()> {
let message = self.prev.next();
let to_return = message.take();
match &message {
StreamElement::Watermark(_)
| StreamElement::Terminate
| StreamElement::FlushAndRestart => {
for block in self.block_senders.iter() {
for &sender_idx in block.indexes.iter() {
let sender = &mut self.senders[sender_idx];
if matches!(message, StreamElement::Terminate)
&& Some(sender.0.coord.block_id) == self.feedback_id
{
continue;
}
sender.1.enqueue(message.clone());
}
}
}
StreamElement::Item(item) | StreamElement::Timestamped(item, _) => {
let index = self.next_strategy.index(item);
for block in self.block_senders.iter() {
let index = index % block.indexes.len();
let sender_idx = block.indexes[index];
self.senders[sender_idx].1.enqueue(message.clone());
}
}
StreamElement::FlushBatch => {}
};
match to_return {
StreamElement::FlushAndRestart | StreamElement::FlushBatch => {
for (_, batcher) in self.senders.iter_mut() {
batcher.flush();
}
}
StreamElement::Terminate => {
log::debug!(
"{} received terminate, closing {} channels",
self.coord.unwrap(),
self.senders.len()
);
for (_, batcher) in self.senders.drain(..) {
batcher.end();
}
}
_ => {}
}
to_return
}
fn structure(&self) -> BlockStructure {
let mut operator = OperatorStructure::new::<OperatorChain::Out, _>("End");
for sender_group in &self.block_senders {
if !sender_group.indexes.is_empty() {
let block_id = self.senders[sender_group.indexes[0]].0.coord.block_id;
operator
.connections
.push(Connection::new::<OperatorChain::Out, _>(
block_id,
&self.next_strategy,
));
}
}
self.prev.structure().add_operator(operator)
}
}