use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::block::{BlockStructure, OperatorReceiver, OperatorStructure};
use crate::channel::{RecvTimeoutError, SelectResult};
use crate::network::{Coord, NetworkMessage};
use crate::operator::start::{SimpleStartReceiver, StartReceiver};
use crate::operator::{Data, ExchangeData, StreamElement};
use crate::scheduler::{BlockId, ExecutionMetadata};
#[derive(Clone, Debug, Serialize, Deserialize, Ord, PartialOrd, Eq, PartialEq)]
pub(crate) enum BinaryElement<OutL: Data, OutR: Data> {
Left(OutL),
Right(OutR),
LeftEnd,
RightEnd,
}
#[derive(Clone, Debug)]
struct SideReceiver<Out: ExchangeData, Item: ExchangeData> {
receiver: SimpleStartReceiver<Out>,
instances: usize,
missing_flush_and_restart: usize,
missing_terminate: usize,
cached: bool,
cache: Vec<NetworkMessage<Item>>,
cache_full: bool,
cache_pointer: usize,
}
impl<Out: ExchangeData, Item: ExchangeData> SideReceiver<Out, Item> {
fn new(previous_block_id: BlockId, cached: bool) -> Self {
Self {
receiver: SimpleStartReceiver::new(previous_block_id),
instances: 0,
missing_flush_and_restart: 0,
missing_terminate: 0,
cached,
cache: Default::default(),
cache_full: false,
cache_pointer: 0,
}
}
fn setup(&mut self, metadata: &mut ExecutionMetadata) {
self.receiver.setup(metadata);
self.instances = self.receiver.prev_replicas().len();
self.missing_flush_and_restart = self.instances;
self.missing_terminate = self.instances;
}
fn recv(&mut self, timeout: Option<Duration>) -> Result<NetworkMessage<Out>, RecvTimeoutError> {
if let Some(timeout) = timeout {
self.receiver.recv_timeout(timeout)
} else {
Ok(self.receiver.recv())
}
}
fn reset(&mut self) {
self.missing_flush_and_restart = self.instances;
if self.cached {
self.cache_full = true;
self.cache_pointer = 0;
}
}
fn is_ended(&self) -> bool {
if self.cached {
self.is_terminated()
} else {
self.missing_flush_and_restart == 0
}
}
fn is_terminated(&self) -> bool {
self.missing_terminate == 0
}
fn cache_finished(&self) -> bool {
self.cache_pointer >= self.cache.len()
}
fn next_cached_item(&mut self) -> NetworkMessage<Item> {
self.cache_pointer += 1;
if self.cache_finished() {
self.missing_flush_and_restart = 0;
}
self.cache[self.cache_pointer - 1].clone()
}
}
#[derive(Clone, Debug)]
pub(crate) struct BinaryStartReceiver<OutL: ExchangeData, OutR: ExchangeData> {
left: SideReceiver<OutL, BinaryElement<OutL, OutR>>,
right: SideReceiver<OutR, BinaryElement<OutL, OutR>>,
first_message: bool,
}
impl<OutL: ExchangeData, OutR: ExchangeData> BinaryStartReceiver<OutL, OutR> {
pub(super) fn new(
left_block_id: BlockId,
right_block_id: BlockId,
left_cache: bool,
right_cache: bool,
) -> Self {
assert!(
!(left_cache && right_cache),
"At most one of the two sides can be cached"
);
Self {
left: SideReceiver::new(left_block_id, left_cache),
right: SideReceiver::new(right_block_id, right_cache),
first_message: false,
}
}
fn process_side<Out: ExchangeData>(
side: &mut SideReceiver<Out, BinaryElement<OutL, OutR>>,
message: NetworkMessage<Out>,
wrap: fn(Out) -> BinaryElement<OutL, OutR>,
end: BinaryElement<OutL, OutR>,
) -> NetworkMessage<BinaryElement<OutL, OutR>> {
let sender = message.sender();
let data = message
.into_iter()
.flat_map(|item| {
let mut res = Vec::new();
if matches!(item, StreamElement::FlushAndRestart) {
side.missing_flush_and_restart -= 1;
if side.missing_flush_and_restart == 0 {
res.push(StreamElement::Item(end.clone()));
}
}
if matches!(item, StreamElement::Terminate) {
side.missing_terminate -= 1;
}
if !side.cached || !matches!(item, StreamElement::Terminate) {
res.push(item.map(wrap));
}
res
})
.collect::<Vec<_>>();
let message = NetworkMessage::new_batch(data, sender);
if side.cached {
side.cache.push(message.clone());
side.cache_pointer = side.cache.len();
}
message
}
fn select(
&mut self,
timeout: Option<Duration>,
) -> Result<NetworkMessage<BinaryElement<OutL, OutR>>, RecvTimeoutError> {
if self.left.is_terminated() && self.right.is_terminated() {
let num_terminates = if self.left.cached {
self.left.instances
} else if self.right.cached {
self.right.instances
} else {
0
};
if num_terminates > 0 {
return Ok(NetworkMessage::new_batch(
(0..num_terminates)
.map(|_| StreamElement::Terminate)
.collect(),
Default::default(),
));
}
}
if self.left.is_ended()
&& self.right.is_ended()
&& self.left.cache_finished()
&& self.right.cache_finished()
{
self.left.reset();
self.right.reset();
self.first_message = true;
}
enum Side<L, R> {
Left(L),
Right(R),
}
let data = if self.first_message && (self.left.cached || self.right.cached) {
debug_assert!(!self.left.cached || self.left.cache_full);
debug_assert!(!self.right.cached || self.right.cache_full);
self.first_message = false;
if self.left.cached {
Side::Right(self.right.recv(timeout))
} else {
Side::Left(self.left.recv(timeout))
}
} else if self.left.cached && self.left.cache_full && !self.left.cache_finished() {
return Ok(self.left.next_cached_item());
} else if self.right.cached && self.right.cache_full && !self.right.cache_finished() {
return Ok(self.right.next_cached_item());
} else if self.left.is_ended() {
Side::Right(self.right.recv(timeout))
} else if self.right.is_ended() {
Side::Left(self.left.recv(timeout))
} else {
let left_terminated = self.left.is_terminated();
let right_terminated = self.right.is_terminated();
let left = self.left.receiver.receiver.as_mut().unwrap();
let right = self.right.receiver.receiver.as_mut().unwrap();
let data = match (left_terminated, right_terminated, timeout) {
(false, false, Some(timeout)) => left.select_timeout(right, timeout),
(false, false, None) => Ok(left.select(right)),
(true, false, Some(timeout)) => {
right.recv_timeout(timeout).map(|r| SelectResult::B(Ok(r)))
}
(false, true, Some(timeout)) => {
left.recv_timeout(timeout).map(|r| SelectResult::A(Ok(r)))
}
(true, false, None) => Ok(SelectResult::B(right.recv())),
(false, true, None) => Ok(SelectResult::A(left.recv())),
(true, true, _) => Err(RecvTimeoutError::Disconnected),
};
match data {
Ok(SelectResult::A(left)) => {
Side::Left(left.map_err(|_| RecvTimeoutError::Disconnected))
}
Ok(SelectResult::B(right)) => {
Side::Right(right.map_err(|_| RecvTimeoutError::Disconnected))
}
Err(e) => Side::Left(Err(e)),
}
};
match data {
Side::Left(Ok(left)) => Ok(Self::process_side(
&mut self.left,
left,
BinaryElement::Left,
BinaryElement::LeftEnd,
)),
Side::Right(Ok(right)) => Ok(Self::process_side(
&mut self.right,
right,
BinaryElement::Right,
BinaryElement::RightEnd,
)),
Side::Left(Err(e)) | Side::Right(Err(e)) => Err(e),
}
}
}
impl<OutL: ExchangeData, OutR: ExchangeData> StartReceiver for BinaryStartReceiver<OutL, OutR> {
type Out = BinaryElement<OutL, OutR>;
fn setup(&mut self, metadata: &mut ExecutionMetadata) {
self.left.setup(metadata);
self.right.setup(metadata);
}
fn prev_replicas(&self) -> Vec<Coord> {
let mut previous = self.left.receiver.prev_replicas();
previous.append(&mut self.right.receiver.prev_replicas());
previous
}
fn cached_replicas(&self) -> usize {
let mut cached = 0;
if self.left.cached {
cached += self.left.instances
}
if self.right.cached {
cached += self.right.instances
}
cached
}
fn recv_timeout(
&mut self,
timeout: Duration,
) -> Result<NetworkMessage<BinaryElement<OutL, OutR>>, RecvTimeoutError> {
self.select(Some(timeout))
}
fn recv(&mut self) -> NetworkMessage<BinaryElement<OutL, OutR>> {
self.select(None).expect("receiver failed")
}
fn structure(&self) -> BlockStructure {
let mut operator = OperatorStructure::new::<BinaryElement<OutL, OutR>, _>("Start");
operator.receivers.push(OperatorReceiver::new::<OutL>(
self.left.receiver.previous_block_id,
));
operator.receivers.push(OperatorReceiver::new::<OutR>(
self.right.receiver.previous_block_id,
));
BlockStructure::default().add_operator(operator)
}
}