use std::fmt::Display;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::block::{BlockStructure, Connection, NextStrategy, OperatorStructure, Replication};
use crate::network::{Coord, NetworkMessage, NetworkSender};
use crate::operator::iteration::{IterationResult, StateFeedback};
use crate::operator::source::Source;
use crate::operator::start::{SimpleStartOperator, Start, StartReceiver};
use crate::operator::{ExchangeData, Operator, StreamElement};
use crate::profiler::{get_profiler, Profiler};
use crate::scheduler::{BlockId, ExecutionMetadata};
#[derive(Derivative)]
#[derivative(Clone, Debug)]
pub struct IterationLeader<StateUpdate: ExchangeData, State: ExchangeData, Global, LoopCond>
where
Global: Fn(&mut State, StateUpdate) + Send + Clone,
LoopCond: Fn(&mut State) -> bool + Send + Clone,
{
coord: Coord,
iteration_index: usize,
max_iterations: usize,
state: Option<State>,
initial_state: State,
state_update_receiver: Option<SimpleStartOperator<StateUpdate>>,
num_receivers: usize,
feedback_block_id: Arc<AtomicUsize>,
feedback_senders: Vec<NetworkSender<StateFeedback<State>>>,
flush_and_restart: bool,
#[derivative(Debug = "ignore")]
global_fold: Global,
#[derivative(Debug = "ignore")]
loop_condition: LoopCond,
}
impl<DeltaUpdate: ExchangeData, State: ExchangeData, Global, LoopCond> Display
for IterationLeader<DeltaUpdate, State, Global, LoopCond>
where
Global: Fn(&mut State, DeltaUpdate) + Send + Clone,
LoopCond: Fn(&mut State) -> bool + Send + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "IterationLeader<{}>", std::any::type_name::<State>())
}
}
impl<DeltaUpdate: ExchangeData, State: ExchangeData, Global, LoopCond>
IterationLeader<DeltaUpdate, State, Global, LoopCond>
where
Global: Fn(&mut State, DeltaUpdate) + Send + Clone,
LoopCond: Fn(&mut State) -> bool + Send + Clone,
{
pub fn new(
initial_state: State,
num_iterations: usize,
global_fold: Global,
loop_condition: LoopCond,
feedback_block_id: Arc<AtomicUsize>,
) -> Self {
Self {
state_update_receiver: None,
feedback_senders: Default::default(),
coord: Coord::new(0, 0, 0),
num_receivers: 0,
max_iterations: num_iterations,
iteration_index: 0,
state: Some(initial_state.clone()),
initial_state,
feedback_block_id,
flush_and_restart: false,
global_fold,
loop_condition,
}
}
fn process_updates(&mut self) -> Option<StreamElement<State>> {
let mut missing_state_updates = self.num_receivers;
let rx = self.state_update_receiver.as_mut().unwrap();
while missing_state_updates > 0 {
match rx.next() {
StreamElement::Item(state_update) => {
missing_state_updates -= 1;
log::trace!(
"iter_leader delta_update {}, {} left",
self.coord,
missing_state_updates
);
(self.global_fold)(self.state.as_mut().unwrap(), state_update);
}
StreamElement::Terminate => {
log::trace!("iter_leader terminate {}", self.coord);
return Some(StreamElement::Terminate);
}
StreamElement::FlushAndRestart | StreamElement::FlushBatch => {}
update => unreachable!(
"IterationLeader received an invalid message: {}",
update.variant()
),
}
}
None
}
fn final_result(&mut self) -> Option<State> {
let loop_condition = (self.loop_condition)(self.state.as_mut().unwrap());
let more_iterations = self.iteration_index < self.max_iterations;
let should_continue = loop_condition && more_iterations;
if !loop_condition {
log::trace!("iter_leader finish_condition {}", self.coord,);
}
if !more_iterations {
log::trace!("iter_leader finish_max_iter {}", self.coord);
}
if should_continue {
None
} else {
let state = self.state.take();
self.state = Some(self.initial_state.clone());
state
}
}
}
impl<DeltaUpdate: ExchangeData, State: ExchangeData, Global, LoopCond> Operator
for IterationLeader<DeltaUpdate, State, Global, LoopCond>
where
Global: Fn(&mut State, DeltaUpdate) + Send + Clone,
LoopCond: Fn(&mut State) -> bool + Send + Clone,
{
type Out = State;
fn setup(&mut self, metadata: &mut ExecutionMetadata) {
self.coord = metadata.coord;
self.feedback_senders = metadata
.network
.get_senders(metadata.coord)
.into_iter()
.map(|(_, s)| s)
.collect();
let feedback_block_id = self.feedback_block_id.load(Ordering::Acquire) as BlockId;
let mut delta_update_receiver = Start::single(feedback_block_id, None);
delta_update_receiver.setup(metadata);
self.num_receivers = delta_update_receiver.receiver().prev_replicas().len();
self.state_update_receiver = Some(delta_update_receiver);
}
fn next(&mut self) -> StreamElement<State> {
if self.flush_and_restart {
self.flush_and_restart = false;
return StreamElement::FlushAndRestart;
}
loop {
log::trace!(
"iter_leader {} {} delta updates left",
self.coord,
self.num_receivers
);
if let Some(value) = self.process_updates() {
return value;
}
get_profiler().iteration_boundary(self.coord.block_id);
self.iteration_index += 1;
let result = self.final_result();
let state_feedback = (
IterationResult::from_condition(result.is_none()),
self.state.clone().unwrap(),
);
for sender in &self.feedback_senders {
let message = NetworkMessage::new_single(
StreamElement::Item(state_feedback.clone()),
self.coord,
);
sender.send(message).unwrap();
}
if let Some(state) = result {
self.flush_and_restart = true;
self.iteration_index = 0;
return StreamElement::Item(state);
}
}
}
fn structure(&self) -> BlockStructure {
let mut operator = OperatorStructure::new::<State, _>("IterationLeader");
operator
.connections
.push(Connection::new::<StateFeedback<State>, _>(
self.feedback_senders[0].receiver_endpoint.coord.block_id,
&NextStrategy::only_one(),
));
self.state_update_receiver
.as_ref()
.unwrap()
.structure()
.add_operator(operator)
}
}
impl<DeltaUpdate: ExchangeData, State: ExchangeData, Global, LoopCond> Source
for IterationLeader<DeltaUpdate, State, Global, LoopCond>
where
Global: Fn(&mut State, DeltaUpdate) + Send + Clone,
LoopCond: Fn(&mut State) -> bool + Send + Clone,
{
fn replication(&self) -> Replication {
Replication::One
}
}