hatchet_sdk/worker/
worker.rs1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use serde::Serialize;
5use serde::de::DeserializeOwned;
6use tokio::sync::mpsc;
7
8use crate::clients::grpc::v0::dispatcher;
9use crate::clients::grpc::v0::dispatcher::WorkerRegisterRequest;
10use crate::clients::hatchet::Hatchet;
11use crate::error::HatchetError;
12use crate::runnables::*;
13use crate::worker::action_listener::ActionListener;
14
15#[derive(derive_builder::Builder)]
18#[builder(pattern = "owned")]
19pub struct Worker {
20 pub name: String,
21 client: Hatchet,
22 #[builder(default = 100)]
23 slots: i32,
24 #[builder(default = Arc::new(Mutex::new(HashMap::new())))]
25 tasks: Arc<Mutex<HashMap<String, Arc<dyn ExecutableTask>>>>,
26 #[builder(default = vec![])]
27 workflows: Vec<crate::clients::grpc::v1::workflows::CreateWorkflowVersionRequest>,
28 #[builder(default = HashMap::new())]
29 labels: HashMap<String, String>,
30}
31
32impl Worker {
33 async fn register_workflows(&mut self) {
56 for workflow in &self.workflows {
57 self.client
58 .admin_client
59 .put_workflow(workflow.clone())
60 .await
61 .unwrap();
62 }
63 }
64
65 pub async fn start(&mut self) -> Result<(), HatchetError> {
94 log::info!("STARTING HATCHET...");
95 let mut actions = vec![];
96 for workflow in &self.workflows {
97 for task in &workflow.tasks {
98 actions.push(task.action.clone());
99 }
100 }
101 log::debug!("{} waiting for actions: {:?}", self.name, actions);
102
103 let worker_id = Arc::new(
104 Self::register_worker(
105 &mut self.client,
106 &self.name,
107 actions,
108 self.slots,
109 self.labels.clone(),
110 )
111 .await?,
112 );
113 self.register_workflows().await;
114
115 let (action_tx, mut action_rx) =
116 mpsc::channel::<dispatcher::AssignedAction>(self.slots as usize);
117
118 let dispatcher = Arc::new(tokio::sync::Mutex::new(
119 crate::worker::task_dispatcher::TaskDispatcher {
120 registry: self.tasks.clone(),
121 client: self.client.clone(),
122 task_runs: Arc::new(Mutex::new(HashMap::new())),
123 },
124 ));
125
126 let action_listener = Arc::new(tokio::sync::Mutex::new(ActionListener::new(
127 self.client.clone(),
128 )));
129
130 let worker_id_clone = worker_id.clone();
131 tokio::spawn(async move {
132 log::debug!("starting action listener");
133 action_listener
134 .lock()
135 .await
136 .listen(worker_id_clone, action_tx)
137 .await as Result<(), HatchetError>
138 });
139
140 tokio::try_join!(
141 async {
142 const HEARTBEAT_INTERVAL: u64 = 4;
143 loop {
144 log::debug!("sending heartbeat");
145 self.client.dispatcher_client.heartbeat(&worker_id).await?;
146 tokio::time::sleep(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL)).await;
147 }
148 #[allow(unreachable_code)]
149 Ok::<(), HatchetError>(())
150 },
151 async {
152 while let Some(task) = action_rx.recv().await {
153 dispatcher
154 .lock()
155 .await
156 .dispatch(worker_id.clone(), task)
157 .await?
158 }
159 Ok(())
160 }
161 )?;
162
163 Ok(())
164 }
165
166 async fn register_worker(
167 client: &mut Hatchet,
168 name: &str,
169 actions: Vec<String>,
170 slots: i32,
171 labels: HashMap<String, String>,
172 ) -> Result<String, HatchetError> {
173 let registration = WorkerRegisterRequest {
174 worker_name: name.to_string(),
175 actions,
176 services: vec![],
177 slots: Some(slots),
178 labels: labels
179 .into_iter()
180 .map(|(k, v)| {
181 (
182 k,
183 dispatcher::WorkerLabels {
184 str_value: Some(v),
185 int_value: None,
186 },
187 )
188 })
189 .collect(),
190 webhook_id: None,
191 runtime_info: None,
192 };
193
194 let response = client
195 .dispatcher_client
196 .register_worker(registration)
197 .await?;
198
199 Ok(response.into_inner().worker_id)
200 }
201}
202
203impl<I, O> Register<Workflow<I, O>, I, O> for Worker
204where
205 I: Serialize + Send + Sync + 'static,
206 O: DeserializeOwned + Send + Sync + 'static,
207{
208 fn add_task_or_workflow(mut self, workflow: &Workflow<I, O>) -> Self {
209 self.workflows.push(workflow.to_proto());
210
211 for task in &workflow.executable_tasks {
212 let fully_qualified_name = format!("{}:{}", workflow.name, task.name());
213 self.tasks
214 .lock()
215 .unwrap()
216 .insert(fully_qualified_name, Arc::from(task.clone()));
217 }
218 self
219 }
220}
221
222impl<I, O> Register<Task<I, O>, I, O> for Worker
223where
224 I: DeserializeOwned + Serialize + Send + Sync + 'static,
225 O: Serialize + DeserializeOwned + Send + Sync + 'static,
226{
227 fn add_task_or_workflow(mut self, workflow: &Task<I, O>) -> Self {
228 let workflow_proto = workflow.to_standalone_workflow_proto();
229 self.workflows.push(workflow_proto);
230
231 let fully_qualified_name = format!("{}:{}", workflow.name, workflow.name);
232 self.tasks
233 .lock()
234 .unwrap()
235 .insert(fully_qualified_name, Arc::from(workflow.into_executable()));
236 self
237 }
238}
239
240pub trait Register<T, I, O>
241where
242 I: Serialize + Send + Sync + 'static,
243 O: DeserializeOwned + Send + Sync + 'static,
244{
245 fn add_task_or_workflow(self, workflow: &T) -> Self;
246}