use std::any::TypeId;
use std::collections::VecDeque;
use std::fmt::Display;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::block::{
BlockStructure, Connection, NextStrategy, OperatorReceiver, OperatorStructure, Replication,
};
use crate::channel::RecvError::Disconnected;
use crate::channel::SelectResult;
use crate::network::{Coord, NetworkMessage, NetworkReceiver, NetworkSender, ReceiverEndpoint};
use crate::operator::end::End;
use crate::operator::iteration::iteration_end::IterationEnd;
use crate::operator::iteration::leader::IterationLeader;
use crate::operator::iteration::state_handler::IterationStateHandler;
use crate::operator::iteration::{
IterationResult, IterationStateHandle, IterationStateLock, StateFeedback,
};
use crate::operator::source::Source;
use crate::operator::start::Start;
use crate::operator::{ExchangeData, Operator, StreamElement};
use crate::scheduler::{BlockId, ExecutionMetadata};
use crate::stream::Stream;
fn clone_with_default<T: Default>(_: &T) -> T {
T::default()
}
#[derive(Derivative)]
#[derivative(Debug, Clone)]
pub struct Iterate<Out: ExchangeData, State: ExchangeData> {
coord: Coord,
state: IterationStateHandler<State>,
#[derivative(Clone(clone_with = "clone_with_default"))]
input_receiver: Option<NetworkReceiver<Out>>,
#[derivative(Clone(clone_with = "clone_with_default"))]
feedback_receiver: Option<NetworkReceiver<Out>>,
feedback_end_block_id: Arc<AtomicUsize>,
input_block_id: BlockId,
output_sender: Option<NetworkSender<Out>>,
output_block_id: Arc<AtomicUsize>,
content: VecDeque<StreamElement<Out>>,
input_stash: VecDeque<StreamElement<Out>>,
feedback_content: VecDeque<StreamElement<Out>>,
input_finished: bool,
}
impl<Out: ExchangeData, State: ExchangeData> Iterate<Out, State> {
fn new(
state_ref: IterationStateHandle<State>,
input_block_id: BlockId,
leader_block_id: BlockId,
feedback_end_block_id: Arc<AtomicUsize>,
output_block_id: Arc<AtomicUsize>,
state_lock: Arc<IterationStateLock>,
) -> Self {
Self {
coord: Coord::new(0, 0, 0),
input_receiver: None,
feedback_receiver: None,
feedback_end_block_id,
input_block_id,
output_sender: None,
output_block_id,
content: Default::default(),
input_stash: Default::default(),
feedback_content: Default::default(),
input_finished: false,
state: IterationStateHandler::new(leader_block_id, state_ref, state_lock),
}
}
fn next_input(&mut self) -> Option<StreamElement<Out>> {
let item = self.input_stash.pop_front()?;
let el = match &item {
StreamElement::FlushAndRestart => {
log::debug!("input finished for iterate {}", self.coord);
self.input_finished = true;
self.state.lock();
StreamElement::FlushAndRestart
}
StreamElement::Item(_)
| StreamElement::Timestamped(_, _)
| StreamElement::Watermark(_)
| StreamElement::FlushBatch => item,
StreamElement::Terminate => {
log::debug!("Iterate at {} is terminating", self.coord);
let message = NetworkMessage::new_single(StreamElement::Terminate, self.coord);
self.output_sender.as_ref().unwrap().send(message).unwrap();
item
}
};
Some(el)
}
fn next_stored(&mut self) -> Option<StreamElement<Out>> {
let item = self.content.pop_front()?;
if matches!(item, StreamElement::FlushAndRestart) {
self.state.lock();
}
Some(item)
}
fn feedback_finished(&self) -> bool {
matches!(
self.feedback_content.back(),
Some(StreamElement::FlushAndRestart)
)
}
pub(crate) fn input_or_feedback(&mut self) {
let rx_feedback = self.feedback_receiver.as_ref().unwrap();
if let Some(rx_input) = self.input_receiver.as_ref() {
match rx_input.select(rx_feedback) {
SelectResult::A(Ok(msg)) => {
self.input_stash.extend(msg);
}
SelectResult::B(Ok(msg)) => {
self.feedback_content.extend(msg);
}
SelectResult::A(Err(Disconnected)) => {
self.input_receiver = None;
self.input_or_feedback();
}
SelectResult::B(Err(Disconnected)) => {
log::error!("feedback_receiver disconnected!");
panic!("feedback_receiver disconnected!");
}
}
} else {
self.feedback_content.extend(rx_feedback.recv().unwrap());
}
}
pub(crate) fn wait_update(&mut self) -> StateFeedback<State> {
let rx_state = self.state.state_receiver().unwrap();
loop {
let state_msg = if let Some(rx_input) = self.input_receiver.as_ref() {
match rx_state.select(rx_input) {
SelectResult::A(Ok(state_msg)) => state_msg,
SelectResult::A(Err(Disconnected)) => {
log::error!("state_receiver disconnected!");
panic!("state_receiver disconnected!");
}
SelectResult::B(Ok(msg)) => {
self.input_stash.extend(msg);
continue;
}
SelectResult::B(Err(Disconnected)) => {
self.input_receiver = None;
continue;
}
}
} else {
rx_state.recv().unwrap()
};
assert!(state_msg.num_items() == 1);
match state_msg.into_iter().next().unwrap() {
StreamElement::Item((should_continue, new_state)) => {
return (should_continue, new_state);
}
StreamElement::FlushBatch => {}
StreamElement::FlushAndRestart => {}
m => unreachable!(
"Iterate received invalid message from IterationLeader: {}",
m.variant()
),
}
}
}
}
impl<Out: ExchangeData, State: ExchangeData + Sync> Operator for Iterate<Out, State> {
type Out = Out;
fn setup(&mut self, metadata: &mut ExecutionMetadata) {
self.coord = metadata.coord;
let endpoint = ReceiverEndpoint::new(metadata.coord, self.input_block_id);
self.input_receiver = Some(metadata.network.get_receiver(endpoint));
let feedback_end_block_id = self.feedback_end_block_id.load(Ordering::Acquire) as BlockId;
let feedback_endpoint = ReceiverEndpoint::new(metadata.coord, feedback_end_block_id);
self.feedback_receiver = Some(metadata.network.get_receiver(feedback_endpoint));
let output_block_id = self.output_block_id.load(Ordering::Acquire) as BlockId;
let output_endpoint = ReceiverEndpoint::new(
Coord::new(
output_block_id,
metadata.coord.host_id,
metadata.coord.replica_id,
),
metadata.coord.block_id,
);
self.output_sender = Some(metadata.network.get_sender(output_endpoint));
self.state.setup(metadata);
}
fn next(&mut self) -> StreamElement<Out> {
loop {
while let Ok(message) = self.feedback_receiver.as_ref().unwrap().try_recv() {
self.feedback_content.extend(&mut message.into_iter());
}
if !self.input_finished {
while self.input_stash.is_empty() {
self.input_or_feedback();
}
return self.next_input().unwrap();
}
if !self.content.is_empty() {
return self.next_stored().unwrap();
}
while !self.feedback_finished() {
self.input_or_feedback();
}
log::debug!("Iterate at {} has finished the iteration", self.coord);
assert!(self.content.is_empty());
std::mem::swap(&mut self.content, &mut self.feedback_content);
let state_update = self.wait_update();
if let IterationResult::Finished = self.state.wait_sync_state(state_update) {
log::debug!("Iterate block at {} finished", self.coord,);
self.input_finished = false;
let message =
NetworkMessage::new_batch(self.content.drain(..).collect(), self.coord);
self.output_sender.as_ref().unwrap().send(message).unwrap();
}
}
}
fn structure(&self) -> BlockStructure {
let mut operator = OperatorStructure::new::<Out, _>("Iterate");
operator
.receivers
.push(OperatorReceiver::new::<StateFeedback<State>>(
self.state.leader_block_id,
));
operator.receivers.push(OperatorReceiver::new::<Out>(
self.feedback_end_block_id.load(Ordering::Acquire) as BlockId,
));
operator
.receivers
.push(OperatorReceiver::new::<Out>(self.input_block_id));
let output_block_id = self.output_block_id.load(Ordering::Acquire);
operator.connections.push(Connection::new::<Out, _>(
output_block_id as BlockId,
&NextStrategy::only_one(),
));
BlockStructure::default().add_operator(operator)
}
}
impl<Out: ExchangeData, State: ExchangeData + Sync> Display for Iterate<Out, State> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Iterate<{}>", std::any::type_name::<Out>())
}
}
impl<Out: ExchangeData, OperatorChain> Stream<OperatorChain>
where
OperatorChain: Operator<Out = Out> + 'static,
{
pub fn iterate<
Body,
StateUpdate: ExchangeData + Default,
State: ExchangeData + Sync,
OperatorChain2,
>(
self,
num_iterations: usize,
initial_state: State,
body: Body,
local_fold: impl Fn(&mut StateUpdate, Out) + Send + Clone + 'static,
global_fold: impl Fn(&mut State, StateUpdate) + Send + Clone + 'static,
loop_condition: impl Fn(&mut State) -> bool + Send + Clone + 'static,
) -> (
Stream<impl Operator<Out = State>>,
Stream<impl Operator<Out = Out>>,
)
where
Body: FnOnce(
Stream<Iterate<Out, State>>,
IterationStateHandle<State>,
) -> Stream<OperatorChain2>,
OperatorChain2: Operator<Out = Out> + 'static,
{
assert!(
self.block.scheduling.replication.is_unlimited(),
"Cannot have an iteration block with limited parallelism"
);
let state = IterationStateHandle::new(initial_state.clone());
let state_clone = state.clone();
let batch_mode = self.block.batch_mode;
let ctx = self.ctx;
let shared_state_update_id = Arc::new(AtomicUsize::new(0));
let shared_feedback_id = Arc::new(AtomicUsize::new(0));
let shared_output_id = Arc::new(AtomicUsize::new(0));
let leader_block = ctx.lock().new_block(
IterationLeader::new(
initial_state,
num_iterations,
global_fold,
loop_condition,
shared_state_update_id.clone(),
),
batch_mode,
self.block.iteration_ctx.clone(),
);
let state_lock = Arc::new(IterationStateLock::default());
let mut input_block = self
.block
.add_operator(|prev| End::new(prev, NextStrategy::only_one(), batch_mode));
input_block.is_only_one_strategy = true;
let iter_source = Iterate::new(
state,
input_block.id,
leader_block.id,
shared_feedback_id.clone(),
shared_output_id.clone(),
state_lock.clone(),
);
let mut iter_block =
ctx.lock()
.new_block(iter_source, batch_mode, input_block.iteration_ctx.clone());
let iter_id = iter_block.id;
iter_block.iteration_ctx.push(state_lock.clone());
let pre_iter_stack = iter_block.iteration_ctx();
let output_block = ctx.lock().new_block(
Start::single(iter_block.id, iter_block.iteration_ctx.last().cloned()),
batch_mode,
Default::default(),
);
let output_id = output_block.id;
let iter_stream = Stream::new(ctx.clone(), iter_block);
let body_stream = body(iter_stream, state_clone);
let mut body_stream = body_stream.split_block(
move |prev, next_strategy, batch_mode| {
let mut end = End::new(prev, next_strategy, batch_mode);
end.ignore_destination(output_id);
end
},
NextStrategy::only_one(),
);
let body_id = body_stream.block.id;
let post_iter_stack = body_stream.block.iteration_ctx();
assert_eq!(
pre_iter_stack, post_iter_stack,
"The body of the iteration should return the stream given as parameter"
);
body_stream.block.iteration_ctx.pop().unwrap();
let state_block = ctx.lock().new_block(
Start::single(body_stream.block.id, Some(state_lock)),
batch_mode,
Default::default(),
);
let state_stream = Stream::new(ctx.clone(), state_block);
let state_stream = state_stream
.key_by(|_| ())
.fold(StateUpdate::default(), local_fold)
.drop_key()
.add_operator(|prev| IterationEnd::new(prev, leader_block.id));
let batch_mode = body_stream.block.batch_mode;
let mut feedback_stream = body_stream.add_operator(|prev| {
let mut end = End::new(prev, NextStrategy::only_one(), batch_mode);
end.mark_feedback(iter_id);
end
});
feedback_stream.block.is_only_one_strategy = true;
let mut ctx_lock = ctx.lock();
let scheduler = ctx_lock.scheduler_mut();
scheduler.connect_blocks(input_block.id, iter_id, TypeId::of::<Out>());
scheduler.connect_blocks(body_id, state_stream.block.id, TypeId::of::<Out>());
scheduler.connect_blocks(
state_stream.block.id,
leader_block.id,
TypeId::of::<StateUpdate>(),
);
scheduler.connect_blocks(
leader_block.id,
iter_id,
TypeId::of::<StateFeedback<State>>(),
);
scheduler.connect_blocks(feedback_stream.block.id, iter_id, TypeId::of::<Out>());
scheduler.connect_blocks_fragile(iter_id, output_block.id, TypeId::of::<Out>());
shared_state_update_id.store(state_stream.block.id as usize, Ordering::Release);
shared_feedback_id.store(feedback_stream.block.id as usize, Ordering::Release);
shared_output_id.store(output_block.id as usize, Ordering::Release);
scheduler.schedule_block(state_stream.block);
scheduler.schedule_block(feedback_stream.block);
scheduler.schedule_block(input_block);
drop(ctx_lock);
(
Stream::new(ctx.clone(), leader_block).split_block(End::new, NextStrategy::random()),
Stream::new(ctx, output_block),
)
}
}
impl<Out: ExchangeData, State: ExchangeData + Sync> Source for Iterate<Out, State> {
fn replication(&self) -> Replication {
Replication::Unlimited
}
}