subliminal 0.0.4

Base crate for subliminal microservices project
Documentation
use crate::message_queue::MessageQueue;
use crate::task::TaskExecutionData;
use log::{debug, error, info};
use std::{borrow::BorrowMut, collections::HashMap, thread, time::Duration};
use tokio::sync::mpsc::{Receiver, Sender};
use ulid::Ulid;

use super::{
    errors::DispatchError,
    worker::{Worker, WorkerState, WorkerStateMachine, WorkerStatusReport},
};

pub struct TaskDispatcher {
    worker_status_rx: Receiver<WorkerStatusReport>,
    worker_channels: HashMap<Ulid, Sender<TaskExecutionData>>,
    worker_states: HashMap<Ulid, WorkerStateMachine>,
    task_request_rx: Receiver<TaskExecutionData>,
}

// Responsible for dispatching tasks to workers
impl TaskDispatcher {
    /// Starts the dispatcher with the specified number of workers.
    /// As the consumers pop tasks from the message broker, they will place them into the task request channel.
    /// The dispatcher will then pick up the tasks from the task request channel and dispatch them to the workers.
    pub async fn start(
        n_workers: u32,
        task_request_rx: Receiver<TaskExecutionData>,
        message_queue: MessageQueue,
    ) {
        let mut worker_pool = Vec::new();
        let mut worker_channels = HashMap::new();

        // These channels are used by the workers to report their status back to the dispatcher and deliver updates to the message queue
        let (worker_status_tx, worker_status_rx) = tokio::sync::mpsc::channel(1);
        let (update_tx, mut update_rx) = tokio::sync::mpsc::channel(1);

        // Create the workers and assign them a task input channel and a task output channel
        for i in 0..n_workers {
            let (task_tx, task_rx) = tokio::sync::mpsc::channel(1);
            let worker = Worker::new(
                format!("worker-{}", i),
                task_rx,
                worker_status_tx.clone(),
                update_tx.clone(),
            );
            info!("Spawned worker {}", worker.name);
            let worker_id = worker.id();

            // Store the channel used to send tasks to the worker and add the worker to the worker pool
            worker_channels.insert(worker_id, task_tx);
            worker_pool.push(worker);
        }

        // Worker state machine contains the worker state and the channel used to send tasks to it
        let mut worker_states = HashMap::new();
        for (id, tx) in worker_channels.borrow_mut() {
            worker_states.insert(
                id.to_owned(),
                WorkerStateMachine {
                    task_tx: tx.to_owned(),
                    worker_state: WorkerState::Idle,
                },
            );
        }

        // Start the workers in separate threads
        for mut worker in worker_pool {
            tokio::spawn(async move {
                worker.work().await;
            });
        }

        // Start a thread that monitors the task_update_rx channel and updates the message queue
        // TODO: Make sure this drains even if the dispatcher thread dies
        // TODO: Batching?
        tokio::spawn(async move {
            loop {
                match update_rx.try_recv() {
                    Ok(update) => {
                        debug!("Got update: {:?}", update);
                        match message_queue.push_task_execution_update(update).await {
                            Ok(_) => {}
                            Err(e) => {
                                error!("Error updating execution record: {:?}", e);
                            }
                        }
                    }
                    Err(e) => match e {
                        tokio::sync::mpsc::error::TryRecvError::Empty => {
                            continue;
                        }
                        tokio::sync::mpsc::error::TryRecvError::Disconnected => {
                            error!("Task update channel closed!");
                            break;
                        }
                    },
                }
            }
        });

        // Start the dispatcher
        TaskDispatcher {
            worker_status_rx,
            worker_channels,
            worker_states,
            task_request_rx,
        }
        .dispatch()
        .await;
    }

    /// Poll the worker status channel for updates and update the worker states accordingly
    async fn update_worker_states(&mut self) -> Result<(), DispatchError> {
        // Each worker can only be IDLE or RUNNING, so we allow for 2 updates per worker
        for _ in 0..self.worker_channels.len() * 2 {
            match self.worker_status_rx.try_recv() {
                // If there was an update in the channel, update the worker state
                Ok(worker_status_update) => {
                    if let Some(worker_state) = self.worker_states.get_mut(&worker_status_update.id)
                    {
                        worker_state.worker_state = worker_status_update.worker_state;

                    // This should never happen, but we check just in case
                    } else {
                        return Err(DispatchError::WorkerStatusReceiverError(String::from(
                            "Worker state not found in worker state map!",
                        )));
                    }
                }
                // If there was an error encountered, determine what kind of error it was and act accordingly
                Err(e) => match e {
                    tokio::sync::mpsc::error::TryRecvError::Empty => {
                        continue;
                    }
                    tokio::sync::mpsc::error::TryRecvError::Disconnected => {
                        error!("Worker status channel closed!");
                        break;
                    }
                },
            }
        }
        Ok(())
    }

    /// Get the first idle worker transmitter
    async fn get_idle_job_transmitter(
        &mut self,
    ) -> Result<Option<Sender<TaskExecutionData>>, DispatchError> {
        // Update the worker states first
        self.update_worker_states().await?;

        // Filter the worker states to get the ones that are idle
        let workers_in_idle_state: Vec<(&Ulid, &WorkerStateMachine)> = self
            .worker_states
            .iter()
            .filter(|(_, state)| state.worker_state == WorkerState::Idle)
            .collect();

        // If there are idle workers, return the first one
        if let Some((worker_id, worker_state)) = workers_in_idle_state.first() {
            debug!("Found idle worker {}", worker_id);
            Ok(Some(worker_state.task_tx.clone()))
        } else {
            Ok(None)
        }
    }

    /// Start a thread that monitors the input queue and dispatches tasks to the workers as they become available
    async fn dispatch(mut self) {
        loop {
            // Get an idle worker transmitter and send a task to it.
            // Breaks out of the loop if an IO channel is closed, or errors occurred when getting an idle worker transmitter
            match self.get_idle_job_transmitter().await {
                Ok(Some(transmitter)) => match self.task_request_rx.try_recv() {
                    Ok(task) => {
                        if let Err(e) = transmitter.send(task).await {
                            error!(
                                "Error sending task to worker! {:?} stopping dispatcher...",
                                e
                            );
                            break;
                        }
                    }
                    Err(e) => match e {
                        tokio::sync::mpsc::error::TryRecvError::Empty => {
                            continue;
                        }
                        tokio::sync::mpsc::error::TryRecvError::Disconnected => {
                            error!("Task request channel closed!");
                            break;
                        }
                    },
                },
                Ok(None) => {
                    if let Err(e) = self.update_worker_states().await {
                        error!(
                            "Error updating worker states! {:?} stopping dispatcher...",
                            e
                        );
                        break;
                    }
                }
                Err(e) => {
                    error!("Error getting idle worker! {:?} stopping dispatcher...", e);
                    break;
                }
            }
            thread::sleep(Duration::from_millis(100));
        }
    }
}