erdos 0.4.0

ERDOS is a platform for developing self-driving cars and robotics applications.
Documentation
use std::{collections::HashMap, sync::Arc};

use tokio::{
    sync::{broadcast, mpsc},
    task::JoinHandle,
};

use crate::OperatorId;

use super::{lattice::ExecutionLattice, operator_executors::OperatorExecutorT};

#[derive(Clone, Debug, PartialEq)]
pub(crate) enum EventNotification {
    AddedEvents(OperatorId),
}

#[derive(Clone, Debug)]
enum EventRunnerNotification {
    UpdateLattices(Arc<HashMap<OperatorId, Arc<ExecutionLattice>>>),
    Shutdown,
}

#[derive(Clone, Debug, PartialEq)]
pub(crate) enum OperatorExecutorNotification {
    Shutdown,
}

#[derive(Clone, Debug)]
pub(crate) enum WorkerNotification {
    DestroyedOperator(OperatorId),
}
async fn process_events(lattice: &ExecutionLattice) {
    while let Some((event, event_id)) = lattice.get_event().await {
        (event.callback)();
        lattice.mark_as_completed(event_id).await;
    }
}

/// Returns an iterator that starts at lattice at the offset
/// iterates through all items once.
fn get_lattices_iterator(
    offset: usize,
    lattices: &HashMap<OperatorId, Arc<ExecutionLattice>>,
) -> impl Iterator<Item = &Arc<ExecutionLattice>> {
    let mut iter = lattices.values().cycle();
    iter.nth(offset);
    iter.take(lattices.len())
}

/// An `event_runner` invocation is in charge of executing callbacks associated with an event.
/// Upon receipt of an `AddedEvents` notification, it queries the lattice for events that are
/// ready to run, executes them, and notifies the lattice of their completion.
async fn event_runner(
    id: usize,
    mut lattices: Arc<HashMap<OperatorId, Arc<ExecutionLattice>>>,
    mut events_channel: broadcast::Receiver<EventNotification>,
    mut control_channel: broadcast::Receiver<EventRunnerNotification>,
) {
    tracing::debug!("Worker: started event runner {}", id);

    loop {
        tokio::select! {
            event_result = events_channel.recv() => {
                match event_result {
                    Ok(EventNotification::AddedEvents(operator_id)) => {
                        if let Some(lattice) = lattices.get(&operator_id) {
                            process_events(lattice).await;
                        }
                        // Iterate through remaining lattices to avoid starvation.
                        for lattice in get_lattices_iterator(id, &lattices) {
                            process_events(lattice).await;
                        }
                    }
                    Err(broadcast::error::RecvError::Lagged(_)) => (),
                    Err(e) => {
                        tracing::error!("Event runner {}: shutting down due to error {:?}", id, e);
                        return;
                    }
                }
            },
            control_result = control_channel.recv() => {
                match control_result {
                    Ok(EventRunnerNotification::UpdateLattices(updated_lattices)) => {
                        lattices = updated_lattices;
                    }
                    Ok(EventRunnerNotification::Shutdown) => {
                        tracing::debug!("Event runner {}: shutting down", id);
                        return;
                    }
                    Err(broadcast::error::RecvError::Lagged(_)) => (),
                    Err(e) => {
                        tracing::error!("Event runner {}: shutting down due to error {:?}", id, e);
                    }
                }
            }
        };
    }
}

pub(crate) struct Worker {
    /// Number of tasks that execute events generate by operators.
    num_event_runners: usize,
    // Lattices of events for each operator.
    lattices: HashMap<OperatorId, Arc<ExecutionLattice>>,
    // Arc of the above lattices which is shared with the event runners.
    lattices_arc: Arc<HashMap<OperatorId, Arc<ExecutionLattice>>>,
    /// Tasks used to manage execution of operator which generate
    /// events and insert them into the lattice.
    operator_executor_tasks: HashMap<OperatorId, JoinHandle<()>>,
    /// Tasks which process events generated by operators.
    event_runner_tasks: Vec<JoinHandle<()>>,
    /// Notifies event runners that new events have been inserted.
    events_channel: broadcast::Sender<EventNotification>,
    /// Notifies event runners of changes to the lattices or to shut down.
    event_runner_notifications: broadcast::Sender<EventRunnerNotification>,
    /// Notifies operator executors to shut down.
    operator_executor_notifications: broadcast::Sender<OperatorExecutorNotification>,
    /// Notifies the worker that an operator has been destroyed.
    worker_notifications_tx: mpsc::UnboundedSender<WorkerNotification>,
    worker_notifications_rx: mpsc::UnboundedReceiver<WorkerNotification>,
}

impl Worker {
    pub fn new(num_event_runners: usize) -> Self {
        let (events_channel, _) = broadcast::channel(16);
        // Only need to store most recent update to lattices or shutdown.
        let (event_runner_notifications, _) = broadcast::channel(1);
        // Only need to store shutdown.
        let (operator_executor_notifications, _) = broadcast::channel(1);
        // All updates are important.
        let (worker_notifications_tx, worker_notifications_rx) = mpsc::unbounded_channel();
        Self {
            num_event_runners,
            lattices: HashMap::new(),
            lattices_arc: Arc::new(HashMap::new()),
            operator_executor_tasks: HashMap::new(),
            event_runner_tasks: Vec::new(),
            events_channel,
            event_runner_notifications,
            operator_executor_notifications,
            worker_notifications_tx,
            worker_notifications_rx,
        }
    }

    pub async fn spawn_tasks(&mut self, operator_executors: Vec<Box<dyn OperatorExecutorT>>) {
        // Spawn event runners.
        for i in 0..self.num_event_runners {
            self.spawn_event_runner(i).await;
        }
        // Spawn operator executors.
        for operator_executor in operator_executors {
            self.spawn_operator(operator_executor).await;
        }
    }

    pub async fn execute(&mut self) {
        // Manage destruction of operators.
        // TODO: in the future, scale up/down event runners, spawn new operators.
        while let Some(notification) = self.worker_notifications_rx.recv().await {
            match notification {
                WorkerNotification::DestroyedOperator(operator_id) => {
                    self.on_destroyed_operator(operator_id).await
                }
            }
        }
        self.shutdown().await;
    }

    async fn shutdown(&mut self) {
        tracing::info!("Worker: shutting down");
        // Shutdown operator executors.
        self.operator_executor_notifications
            .send(OperatorExecutorNotification::Shutdown)
            .unwrap();
        let operator_ids: Vec<_> = self.operator_executor_tasks.keys().cloned().collect();
        for operator_id in operator_ids {
            self.on_destroyed_operator(operator_id).await;
        }
        tracing::debug!("[Worker] shut down all operator executors");

        // Shutdown event runners.
        self.event_runner_notifications
            .send(EventRunnerNotification::Shutdown)
            .unwrap();
        for (i, event_runner_task) in self.event_runner_tasks.drain(..).enumerate() {
            match event_runner_task.await {
                Ok(_) => (),
                Err(e) => tracing::error!(
                    "[Worker] shutting down event runner {} errored with {:?}",
                    i,
                    e
                ),
            }
        }
        tracing::debug!("[Worker] shut down all event runners");
        tracing::info!("[Worker] finished shutting down");
    }

    async fn spawn_operator(&mut self, mut operator_executor: Box<dyn OperatorExecutorT>) {
        let operator_id = operator_executor.operator_id();
        tracing::debug!("Worker: spawning operator with ID {}", operator_id);
        // Get lattice and share with event runners.
        self.lattices
            .insert(operator_id, operator_executor.lattice());
        self.lattices_arc = Arc::new(self.lattices.clone());
        self.event_runner_notifications
            .send(EventRunnerNotification::UpdateLattices(Arc::clone(
                &self.lattices_arc,
            )))
            .unwrap();

        let channel_from_worker = self.operator_executor_notifications.subscribe();
        let channel_to_worker = self.worker_notifications_tx.clone();
        let channel_to_event_runners = self.events_channel.clone();
        let task = tokio::task::spawn(async move {
            operator_executor
                .execute(
                    channel_from_worker,
                    channel_to_worker,
                    channel_to_event_runners,
                )
                .await;
        });
        self.operator_executor_tasks.insert(operator_id, task);
    }

    async fn on_destroyed_operator(&mut self, operator_id: OperatorId) {
        if let Some(task) = self.operator_executor_tasks.remove(&operator_id) {
            match task.await {
                Ok(_) => tracing::debug!(
                    "Worker: shut down task for operator executor with ID {}",
                    operator_id
                ),
                Err(e) => tracing::error!(
                    "Worker: error duing shut down of task for operator executor with ID {}: {:?}",
                    operator_id,
                    e
                ),
            }
        }
    }

    async fn spawn_event_runner(&mut self, id: usize) {
        let events_channel = self.events_channel.subscribe();
        let control_channel = self.event_runner_notifications.subscribe();
        let lattices = Arc::clone(&self.lattices_arc);
        let task = tokio::task::spawn(async move {
            event_runner(id, lattices, events_channel, control_channel).await;
        });
        self.event_runner_tasks.push(task);
    }
}