use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::mpsc;
use crate::clients::grpc::v0::dispatcher;
use crate::clients::grpc::v0::dispatcher::WorkerRegisterRequest;
use crate::clients::hatchet::Hatchet;
use crate::error::HatchetError;
use crate::runnables::*;
use crate::worker::action_listener::ActionListener;
#[derive(derive_builder::Builder)]
#[builder(pattern = "owned")]
pub struct Worker {
pub name: String,
client: Hatchet,
#[builder(default = 100)]
slots: i32,
#[builder(default = Arc::new(Mutex::new(HashMap::new())))]
tasks: Arc<Mutex<HashMap<String, Arc<dyn ExecutableTask>>>>,
#[builder(default = vec![])]
workflows: Vec<crate::clients::grpc::v1::workflows::CreateWorkflowVersionRequest>,
#[builder(default = HashMap::new())]
labels: HashMap<String, String>,
}
impl Worker {
async fn register_workflows(&mut self) {
for workflow in &self.workflows {
self.client
.admin_client
.put_workflow(workflow.clone())
.await
.unwrap();
}
}
pub async fn start(&mut self) -> Result<(), HatchetError> {
log::info!("STARTING HATCHET...");
let mut actions = vec![];
for workflow in &self.workflows {
for task in &workflow.tasks {
actions.push(task.action.clone());
}
}
log::debug!("{} waiting for actions: {:?}", self.name, actions);
let worker_id = Arc::new(
Self::register_worker(
&mut self.client,
&self.name,
actions,
self.slots,
self.labels.clone(),
)
.await?,
);
self.register_workflows().await;
let (action_tx, mut action_rx) =
mpsc::channel::<dispatcher::AssignedAction>(self.slots as usize);
let dispatcher = Arc::new(tokio::sync::Mutex::new(
crate::worker::task_dispatcher::TaskDispatcher {
registry: self.tasks.clone(),
client: self.client.clone(),
task_runs: Arc::new(Mutex::new(HashMap::new())),
},
));
let action_listener = Arc::new(tokio::sync::Mutex::new(ActionListener::new(
self.client.clone(),
)));
let worker_id_clone = worker_id.clone();
tokio::spawn(async move {
log::debug!("starting action listener");
action_listener
.lock()
.await
.listen(worker_id_clone, action_tx)
.await as Result<(), HatchetError>
});
tokio::try_join!(
async {
const HEARTBEAT_INTERVAL: u64 = 4;
loop {
log::debug!("sending heartbeat");
self.client.dispatcher_client.heartbeat(&worker_id).await?;
tokio::time::sleep(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL)).await;
}
#[allow(unreachable_code)]
Ok::<(), HatchetError>(())
},
async {
while let Some(task) = action_rx.recv().await {
dispatcher
.lock()
.await
.dispatch(worker_id.clone(), task)
.await?
}
Ok(())
}
)?;
Ok(())
}
async fn register_worker(
client: &mut Hatchet,
name: &str,
actions: Vec<String>,
slots: i32,
labels: HashMap<String, String>,
) -> Result<String, HatchetError> {
let registration = WorkerRegisterRequest {
worker_name: name.to_string(),
actions,
services: vec![],
slots: Some(slots),
labels: labels
.into_iter()
.map(|(k, v)| {
(
k,
dispatcher::WorkerLabels {
str_value: Some(v),
int_value: None,
},
)
})
.collect(),
webhook_id: None,
runtime_info: None,
};
let response = client
.dispatcher_client
.register_worker(registration)
.await?;
Ok(response.into_inner().worker_id)
}
}
impl<I, O> Register<Workflow<I, O>, I, O> for Worker
where
I: Serialize + Send + Sync + 'static,
O: DeserializeOwned + Send + Sync + 'static,
{
fn add_task_or_workflow(mut self, workflow: &Workflow<I, O>) -> Self {
self.workflows.push(workflow.to_proto());
for task in &workflow.executable_tasks {
let fully_qualified_name = format!("{}:{}", workflow.name, task.name());
self.tasks
.lock()
.unwrap()
.insert(fully_qualified_name, Arc::from(task.clone()));
}
self
}
}
impl<I, O> Register<Task<I, O>, I, O> for Worker
where
I: DeserializeOwned + Serialize + Send + Sync + 'static,
O: Serialize + DeserializeOwned + Send + Sync + 'static,
{
fn add_task_or_workflow(mut self, workflow: &Task<I, O>) -> Self {
let workflow_proto = workflow.to_standalone_workflow_proto();
self.workflows.push(workflow_proto);
let fully_qualified_name = format!("{}:{}", workflow.name, workflow.name);
self.tasks
.lock()
.unwrap()
.insert(fully_qualified_name, Arc::from(workflow.into_executable()));
self
}
}
pub trait Register<T, I, O>
where
I: Serialize + Send + Sync + 'static,
O: DeserializeOwned + Send + Sync + 'static,
{
fn add_task_or_workflow(self, workflow: &T) -> Self;
}