use std::sync::{Arc, Barrier};
use lazy_init::Lazy;
use crate::network::{Coord, NetworkReceiver, ReceiverEndpoint};
use crate::operator::iteration::{IterationStateHandle, IterationStateLock, StateFeedback};
use crate::operator::ExchangeData;
use crate::scheduler::{BlockId, ExecutionMetadata};
use super::IterationResult;
#[derive(Debug)]
pub(crate) struct IterationStateHandler<State: ExchangeData> {
pub coord: Coord,
pub new_state_receiver: Option<NetworkReceiver<StateFeedback<State>>>,
pub leader_block_id: BlockId,
pub is_local_leader: bool,
pub num_local_replicas: usize,
pub state_ref: IterationStateHandle<State>,
pub state_barrier: Arc<Lazy<Barrier>>,
pub state_lock: Arc<IterationStateLock>,
}
impl<State: ExchangeData + Clone> Clone for IterationStateHandler<State> {
fn clone(&self) -> Self {
Self {
coord: self.coord,
new_state_receiver: None,
leader_block_id: self.leader_block_id,
is_local_leader: self.is_local_leader,
num_local_replicas: self.num_local_replicas,
state_ref: self.state_ref.clone(),
state_barrier: self.state_barrier.clone(),
state_lock: self.state_lock.clone(),
}
}
}
fn select_leader(replicas: &[Coord]) -> Coord {
*replicas.iter().min().unwrap()
}
impl<State: ExchangeData> IterationStateHandler<State> {
pub(crate) fn new(
leader_block_id: BlockId,
state_ref: IterationStateHandle<State>,
state_lock: Arc<IterationStateLock>,
) -> Self {
Self {
coord: Default::default(),
is_local_leader: false,
num_local_replicas: 0,
new_state_receiver: None,
leader_block_id,
state_ref,
state_barrier: Arc::new(Default::default()),
state_lock,
}
}
pub(crate) fn setup(&mut self, metadata: &mut ExecutionMetadata) {
let local_replicas: Vec<_> = metadata
.replicas
.clone()
.into_iter()
.filter(|r| r.host_id == metadata.coord.host_id)
.collect();
self.is_local_leader = select_leader(&local_replicas) == metadata.coord;
self.num_local_replicas = local_replicas.len();
self.coord = metadata.coord;
let endpoint = ReceiverEndpoint::new(metadata.coord, self.leader_block_id);
self.new_state_receiver = Some(metadata.network.get_receiver(endpoint));
}
pub(crate) fn lock(&self) {
self.state_lock.lock();
}
pub(crate) fn state_receiver(&self) -> Option<&NetworkReceiver<StateFeedback<State>>> {
self.new_state_receiver.as_ref()
}
pub(crate) fn wait_sync_state(
&mut self,
state_update: StateFeedback<State>,
) -> IterationResult {
let (should_continue, new_state) = state_update;
if self.is_local_leader {
unsafe {
self.state_ref.set(new_state);
}
}
self.state_barrier
.get_or_create(|| Barrier::new(self.num_local_replicas))
.wait();
if self.is_local_leader {
self.state_lock.unlock();
}
should_continue
}
}