use std::{collections::HashMap, sync::Arc};
use tokio::{
sync::{broadcast, mpsc},
task::JoinHandle,
};
use crate::OperatorId;
use super::{lattice::ExecutionLattice, operator_executors::OperatorExecutorT};
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum EventNotification {
AddedEvents(OperatorId),
}
#[derive(Clone, Debug)]
enum EventRunnerNotification {
UpdateLattices(Arc<HashMap<OperatorId, Arc<ExecutionLattice>>>),
Shutdown,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum OperatorExecutorNotification {
Shutdown,
}
#[derive(Clone, Debug)]
pub(crate) enum WorkerNotification {
DestroyedOperator(OperatorId),
}
async fn process_events(lattice: &ExecutionLattice) {
while let Some((event, event_id)) = lattice.get_event().await {
(event.callback)();
lattice.mark_as_completed(event_id).await;
}
}
fn get_lattices_iterator(
offset: usize,
lattices: &HashMap<OperatorId, Arc<ExecutionLattice>>,
) -> impl Iterator<Item = &Arc<ExecutionLattice>> {
let mut iter = lattices.values().cycle();
iter.nth(offset);
iter.take(lattices.len())
}
async fn event_runner(
id: usize,
mut lattices: Arc<HashMap<OperatorId, Arc<ExecutionLattice>>>,
mut events_channel: broadcast::Receiver<EventNotification>,
mut control_channel: broadcast::Receiver<EventRunnerNotification>,
) {
tracing::debug!("Worker: started event runner {}", id);
loop {
tokio::select! {
event_result = events_channel.recv() => {
match event_result {
Ok(EventNotification::AddedEvents(operator_id)) => {
if let Some(lattice) = lattices.get(&operator_id) {
process_events(lattice).await;
}
for lattice in get_lattices_iterator(id, &lattices) {
process_events(lattice).await;
}
}
Err(broadcast::error::RecvError::Lagged(_)) => (),
Err(e) => {
tracing::error!("Event runner {}: shutting down due to error {:?}", id, e);
return;
}
}
},
control_result = control_channel.recv() => {
match control_result {
Ok(EventRunnerNotification::UpdateLattices(updated_lattices)) => {
lattices = updated_lattices;
}
Ok(EventRunnerNotification::Shutdown) => {
tracing::debug!("Event runner {}: shutting down", id);
return;
}
Err(broadcast::error::RecvError::Lagged(_)) => (),
Err(e) => {
tracing::error!("Event runner {}: shutting down due to error {:?}", id, e);
}
}
}
};
}
}
pub(crate) struct Worker {
num_event_runners: usize,
lattices: HashMap<OperatorId, Arc<ExecutionLattice>>,
lattices_arc: Arc<HashMap<OperatorId, Arc<ExecutionLattice>>>,
operator_executor_tasks: HashMap<OperatorId, JoinHandle<()>>,
event_runner_tasks: Vec<JoinHandle<()>>,
events_channel: broadcast::Sender<EventNotification>,
event_runner_notifications: broadcast::Sender<EventRunnerNotification>,
operator_executor_notifications: broadcast::Sender<OperatorExecutorNotification>,
worker_notifications_tx: mpsc::UnboundedSender<WorkerNotification>,
worker_notifications_rx: mpsc::UnboundedReceiver<WorkerNotification>,
}
impl Worker {
pub fn new(num_event_runners: usize) -> Self {
let (events_channel, _) = broadcast::channel(16);
let (event_runner_notifications, _) = broadcast::channel(1);
let (operator_executor_notifications, _) = broadcast::channel(1);
let (worker_notifications_tx, worker_notifications_rx) = mpsc::unbounded_channel();
Self {
num_event_runners,
lattices: HashMap::new(),
lattices_arc: Arc::new(HashMap::new()),
operator_executor_tasks: HashMap::new(),
event_runner_tasks: Vec::new(),
events_channel,
event_runner_notifications,
operator_executor_notifications,
worker_notifications_tx,
worker_notifications_rx,
}
}
pub async fn spawn_tasks(&mut self, operator_executors: Vec<Box<dyn OperatorExecutorT>>) {
for i in 0..self.num_event_runners {
self.spawn_event_runner(i).await;
}
for operator_executor in operator_executors {
self.spawn_operator(operator_executor).await;
}
}
pub async fn execute(&mut self) {
while let Some(notification) = self.worker_notifications_rx.recv().await {
match notification {
WorkerNotification::DestroyedOperator(operator_id) => {
self.on_destroyed_operator(operator_id).await
}
}
}
self.shutdown().await;
}
async fn shutdown(&mut self) {
tracing::info!("Worker: shutting down");
self.operator_executor_notifications
.send(OperatorExecutorNotification::Shutdown)
.unwrap();
let operator_ids: Vec<_> = self.operator_executor_tasks.keys().cloned().collect();
for operator_id in operator_ids {
self.on_destroyed_operator(operator_id).await;
}
tracing::debug!("[Worker] shut down all operator executors");
self.event_runner_notifications
.send(EventRunnerNotification::Shutdown)
.unwrap();
for (i, event_runner_task) in self.event_runner_tasks.drain(..).enumerate() {
match event_runner_task.await {
Ok(_) => (),
Err(e) => tracing::error!(
"[Worker] shutting down event runner {} errored with {:?}",
i,
e
),
}
}
tracing::debug!("[Worker] shut down all event runners");
tracing::info!("[Worker] finished shutting down");
}
async fn spawn_operator(&mut self, mut operator_executor: Box<dyn OperatorExecutorT>) {
let operator_id = operator_executor.operator_id();
tracing::debug!("Worker: spawning operator with ID {}", operator_id);
self.lattices
.insert(operator_id, operator_executor.lattice());
self.lattices_arc = Arc::new(self.lattices.clone());
self.event_runner_notifications
.send(EventRunnerNotification::UpdateLattices(Arc::clone(
&self.lattices_arc,
)))
.unwrap();
let channel_from_worker = self.operator_executor_notifications.subscribe();
let channel_to_worker = self.worker_notifications_tx.clone();
let channel_to_event_runners = self.events_channel.clone();
let task = tokio::task::spawn(async move {
operator_executor
.execute(
channel_from_worker,
channel_to_worker,
channel_to_event_runners,
)
.await;
});
self.operator_executor_tasks.insert(operator_id, task);
}
async fn on_destroyed_operator(&mut self, operator_id: OperatorId) {
if let Some(task) = self.operator_executor_tasks.remove(&operator_id) {
match task.await {
Ok(_) => tracing::debug!(
"Worker: shut down task for operator executor with ID {}",
operator_id
),
Err(e) => tracing::error!(
"Worker: error duing shut down of task for operator executor with ID {}: {:?}",
operator_id,
e
),
}
}
}
async fn spawn_event_runner(&mut self, id: usize) {
let events_channel = self.events_channel.subscribe();
let control_channel = self.event_runner_notifications.subscribe();
let lattices = Arc::clone(&self.lattices_arc);
let task = tokio::task::spawn(async move {
event_runner(id, lattices, events_channel, control_channel).await;
});
self.event_runner_tasks.push(task);
}
}