use crate::block::{BlockStructure, Connection, NextStrategy, OperatorStructure};
use crate::network::{Coord, NetworkMessage, NetworkSender, ReceiverEndpoint};
use crate::operator::{ExchangeData, Operator, StreamElement};
use crate::scheduler::{BlockId, ExecutionMetadata};
#[derive(Debug, Clone)]
pub struct IterationEnd<DeltaUpdate: ExchangeData, OperatorChain>
where
OperatorChain: Operator<Out = DeltaUpdate>,
{
prev: OperatorChain,
has_received_item: bool,
leader_block_id: BlockId,
leader_sender: Option<NetworkSender<DeltaUpdate>>,
coord: Coord,
}
impl<DeltaUpdate: ExchangeData, OperatorChain> std::fmt::Display
for IterationEnd<DeltaUpdate, OperatorChain>
where
OperatorChain: Operator<Out = DeltaUpdate>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} -> IterationEnd<{}>",
self.prev,
std::any::type_name::<DeltaUpdate>()
)
}
}
impl<DeltaUpdate: ExchangeData, OperatorChain> IterationEnd<DeltaUpdate, OperatorChain>
where
OperatorChain: Operator<Out = DeltaUpdate>,
{
pub fn new(prev: OperatorChain, leader_block_id: BlockId) -> Self {
Self {
prev,
has_received_item: false,
leader_block_id,
leader_sender: None,
coord: Default::default(),
}
}
}
impl<DeltaUpdate: ExchangeData, OperatorChain> Operator for IterationEnd<DeltaUpdate, OperatorChain>
where
DeltaUpdate: Default,
OperatorChain: Operator<Out = DeltaUpdate>,
{
type Out = ();
fn setup(&mut self, metadata: &mut ExecutionMetadata) {
let replicas = metadata.network.replicas(self.leader_block_id);
assert_eq!(
replicas.len(),
1,
"The IterationEnd block should not be replicated"
);
let leader = replicas.into_iter().next().unwrap();
log::debug!("IterationEnd {} has {} as leader", metadata.coord, leader);
let sender = metadata
.network
.get_sender(ReceiverEndpoint::new(leader, metadata.coord.block_id));
self.leader_sender = Some(sender);
self.coord = metadata.coord;
self.prev.setup(metadata);
}
fn next(&mut self) -> StreamElement<()> {
let elem = self.prev.next();
match &elem {
StreamElement::Item(_) => {
let message = NetworkMessage::new_single(elem, self.coord);
self.leader_sender.as_ref().unwrap().send(message).unwrap();
self.has_received_item = true;
StreamElement::Item(())
}
StreamElement::FlushAndRestart => {
if !self.has_received_item {
let update = Default::default();
let message =
NetworkMessage::new_single(StreamElement::Item(update), self.coord);
let sender = self.leader_sender.as_ref().unwrap();
sender.send(message).unwrap();
}
self.has_received_item = false;
StreamElement::FlushAndRestart
}
StreamElement::Terminate => {
let message = NetworkMessage::new_single(StreamElement::Terminate, self.coord);
self.leader_sender.as_ref().unwrap().send(message).unwrap();
StreamElement::Terminate
}
StreamElement::FlushBatch => elem.map(|_| unreachable!()),
_ => unreachable!(),
}
}
fn structure(&self) -> BlockStructure {
let mut operator = OperatorStructure::new::<DeltaUpdate, _>("IterationEnd");
operator.connections.push(Connection::new::<DeltaUpdate, _>(
self.leader_block_id,
&NextStrategy::only_one(),
));
self.prev.structure().add_operator(operator)
}
}