use std::{fmt::Debug, future::Future, panic, pin::Pin};
use log::{error, info};
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::mpsc::{Receiver, Sender};
use crate::message_queue::{MessageQueue, MessageQueueBuilder};
use super::consumer::consume;
use super::dispatcher::TaskDispatcher;
use crate::task::{Task, TaskExecutionData};
struct ExecutionNode;
impl ExecutionNode {
pub async fn spawn(
task_consumers: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
n_workers: u32,
message_queue: MessageQueue,
task_request_rx: Receiver<TaskExecutionData>,
) {
info!("Starting execution node...");
for consumer_closure in task_consumers {
info!("Spawning consumer...");
tokio::spawn(consumer_closure);
}
info!("Spawning dispatcher...");
TaskDispatcher::start(n_workers, task_request_rx, message_queue).await;
}
}
pub struct ExecutionNodeBuilder {
task_consumers: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
task_request_tx: Sender<TaskExecutionData>,
task_request_rx: Receiver<TaskExecutionData>,
n_workers: u32,
gcp_project_id: String,
execution_updates_topic_id: String,
}
impl ExecutionNodeBuilder {
pub async fn new(n_workers: u32, project_id: &str, execution_updates_topic_id: &str) -> Self {
let orig_hook = panic::take_hook();
panic::set_hook(Box::new(move |panic_info| {
orig_hook(panic_info);
std::process::exit(1);
}));
let (task_request_tx, task_request_rx) =
tokio::sync::mpsc::channel::<TaskExecutionData>(n_workers as usize * 2);
Self {
task_consumers: Vec::new(),
n_workers,
task_request_tx,
task_request_rx,
gcp_project_id: project_id.to_string(),
execution_updates_topic_id: execution_updates_topic_id.to_string(),
}
}
pub async fn with_consumer<T>(mut self, subscription_id: &str) -> Self
where
T: DeserializeOwned + Serialize + Send + Sync + Debug + Clone + Task + 'static,
{
let consumer_message_queue = match MessageQueueBuilder::new()
.with_project_id(&self.gcp_project_id)
.with_incoming_executions_subscription_id(subscription_id)
.build()
.await
{
Ok(queue) => queue,
Err(e) => {
error!("Error creating consumer message queue: {:?}", e);
std::process::exit(1);
}
};
self.task_consumers.push(Box::pin(consume::<T>(
consumer_message_queue,
self.task_request_tx.clone(),
)));
self
}
pub async fn build(self) {
let Self {
task_consumers,
n_workers,
task_request_rx,
..
} = self;
let dispatcher_message_queue = match MessageQueueBuilder::new()
.with_project_id(&self.gcp_project_id)
.with_execution_updates_topic_id(&self.execution_updates_topic_id)
.build()
.await
{
Ok(queue) => queue,
Err(e) => {
error!("Error creating message queue for dispatch: {:?}", e);
std::process::exit(1);
}
};
ExecutionNode::spawn(
task_consumers,
n_workers,
dispatcher_message_queue,
task_request_rx,
)
.await;
}
}