use std::rc::Rc;
use indexmap::{IndexMap, IndexSet};
use itertools::Itertools;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::errorhandling::MalstromFatal;
use crate::runtime::communication::Distributable;
use crate::runtime::{BiCommunicationClient, CommunicationClient, OperatorOperatorComm};
use crate::snapshot::{deserialize_state, PersistenceClient};
use crate::types::{OperatorId, WorkerId};
pub struct OperatorContext<'a> {
pub worker_id: WorkerId,
pub operator_id: OperatorId,
pub(super) communication: &'a mut dyn OperatorOperatorComm,
}
#[allow(clippy::needless_lifetimes)] impl<'a> OperatorContext<'a> {
#[cfg(test)]
pub(crate) fn new(
worker_id: WorkerId,
operator_id: OperatorId,
communication: &'a mut dyn OperatorOperatorComm,
) -> Self {
Self {
worker_id,
operator_id,
communication,
}
}
pub fn create_communication_client<T: Distributable>(
&self,
other_worker: WorkerId,
) -> BiCommunicationClient<T> {
assert!(other_worker != self.worker_id);
BiCommunicationClient::new(other_worker, self.operator_id, self.communication)
.malstrom_fatal()
}
}
pub struct BuildContext<'a> {
pub worker_id: WorkerId,
pub operator_id: OperatorId,
pub operator_name: String,
persistence_backend: Rc<dyn PersistenceClient>,
pub(crate) communication: &'a mut dyn OperatorOperatorComm,
worker_ids: IndexSet<WorkerId>,
}
impl<'a> BuildContext<'a> {
pub(crate) fn new(
worker_id: WorkerId,
operator_id: OperatorId,
name: String,
persistence_backend: Rc<dyn PersistenceClient>,
communication: &'a mut dyn OperatorOperatorComm,
worker_ids: IndexSet<WorkerId>,
) -> Self {
Self {
worker_id,
operator_id,
operator_name: name,
persistence_backend,
communication,
worker_ids,
}
}
pub fn load_state<S: Serialize + DeserializeOwned>(&self) -> Option<S> {
self.persistence_backend
.load(&self.operator_id)
.map(deserialize_state)
}
pub fn get_worker_ids(&self) -> &IndexSet<WorkerId> {
&self.worker_ids
}
pub fn create_communication_client<T: Distributable>(
&mut self,
other_worker: WorkerId,
) -> BiCommunicationClient<T> {
CommunicationClient::new(other_worker, self.operator_id, self.communication)
.malstrom_fatal()
}
pub fn create_all_communication_clients<T: Distributable>(
&mut self,
) -> IndexMap<WorkerId, BiCommunicationClient<T>> {
let other_workers = self
.get_worker_ids()
.into_iter()
.filter(|wid| **wid != self.worker_id)
.cloned()
.collect_vec();
other_workers
.into_iter()
.map(|wid| (wid, self.create_communication_client(wid)))
.collect()
}
}