Skip to main content

hatchet_sdk/worker/
worker.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use serde::Serialize;
5use serde::de::DeserializeOwned;
6use tokio::sync::mpsc;
7
8use crate::clients::grpc::v0::dispatcher;
9use crate::clients::grpc::v0::dispatcher::WorkerRegisterRequest;
10use crate::clients::hatchet::Hatchet;
11use crate::error::HatchetError;
12use crate::runnables::*;
13use crate::worker::action_listener::ActionListener;
14
15/// A worker is a container for tasks that can be executed by a worker.
16/// See [Hatchet.worker()](crate::Hatchet::worker()) for more information.
17#[derive(derive_builder::Builder)]
18#[builder(pattern = "owned")]
19pub struct Worker {
20    pub name: String,
21    client: Hatchet,
22    #[builder(default = 100)]
23    slots: i32,
24    #[builder(default = Arc::new(Mutex::new(HashMap::new())))]
25    tasks: Arc<Mutex<HashMap<String, Arc<dyn ExecutableTask>>>>,
26    #[builder(default = vec![])]
27    workflows: Vec<crate::clients::grpc::v1::workflows::CreateWorkflowVersionRequest>,
28    #[builder(default = HashMap::new())]
29    labels: HashMap<String, String>,
30}
31
32impl Worker {
33    /// Register a workflow with this worker. When the worker starts, it will register the workflow with Hatchet.
34    /// Hatchet will then assign runs of the workflow to this worker.
35    ///
36    /// ```compile_fail
37    /// use hatchet_sdk::{Context, Hatchet, EmptyModel};
38    ///
39    /// #[tokio::main]
40    /// async fn main() {
41    ///     let hatchet = Hatchet::from_env().await.unwrap();
42    ///     let my_task = hatchet.task::<EmptyModel, EmptyModel, MyError>("my-task", |input: EmptyModel, _ctx: Context| async move {
43    ///     Ok(EmptyModel)
44    /// });
45    ///
46    /// let my_workflow = hatchet.workflow("my-workflow")
47    ///     .build()
48    ///     .unwrap()
49    ///     .add_task(&my_task)
50    ///
51    ///     let worker = hatchet.worker("my-worker").build().unwrap();
52    ///     worker.add_task_or_workflow(my_workflow);
53    /// }
54    /// ```
55    async fn register_workflows(&mut self) {
56        for workflow in &self.workflows {
57            self.client
58                .admin_client
59                .put_workflow(workflow.clone())
60                .await
61                .unwrap();
62        }
63    }
64
65    /// Start the worker.
66    /// This will register the worker with Hatchet and start listening for assigned tasks.
67    /// Use ctrl+c to stop the worker.
68    ///
69    /// ```compile_fail
70    /// use hatchet_sdk::{Context, Hatchet, EmptyModel, Runnable,Register};
71    ///
72    /// #[tokio::main]
73    /// async fn main() {
74    ///     let hatchet = Hatchet::from_env().await.unwrap();
75    ///     
76    ///     let my_workflow = hatchet.
77    ///         workflow::<EmptyModel, EmptyModel>("my-workflow")
78    ///         .build()
79    ///         .unwrap()
80    ///         .add_task(&hatchet.task("my-task", async move |input: EmptyModel, _ctx: Context| -> anyhow::Result<EmptyModel> {
81    ///             Ok(EmptyModel)
82    ///         }))
83    ///
84    ///     let mut worker = hatchet.worker("my-worker")
85    ///         .slots(5)
86    ///         .build()
87    ///         .unwrap()
88    ///         .add_task_or_workflow(my_workflow);
89    ///
90    ///     worker.start().await.unwrap();
91    /// }
92    /// ```
93    pub async fn start(&mut self) -> Result<(), HatchetError> {
94        log::info!("STARTING HATCHET...");
95        let mut actions = vec![];
96        for workflow in &self.workflows {
97            for task in &workflow.tasks {
98                actions.push(task.action.clone());
99            }
100        }
101        log::debug!("{} waiting for actions: {:?}", self.name, actions);
102
103        let worker_id = Arc::new(
104            Self::register_worker(
105                &mut self.client,
106                &self.name,
107                actions,
108                self.slots,
109                self.labels.clone(),
110            )
111            .await?,
112        );
113        self.register_workflows().await;
114
115        let (action_tx, mut action_rx) =
116            mpsc::channel::<dispatcher::AssignedAction>(self.slots as usize);
117
118        let dispatcher = Arc::new(tokio::sync::Mutex::new(
119            crate::worker::task_dispatcher::TaskDispatcher {
120                registry: self.tasks.clone(),
121                client: self.client.clone(),
122                task_runs: Arc::new(Mutex::new(HashMap::new())),
123            },
124        ));
125
126        let action_listener = Arc::new(tokio::sync::Mutex::new(ActionListener::new(
127            self.client.clone(),
128        )));
129
130        let worker_id_clone = worker_id.clone();
131        tokio::spawn(async move {
132            log::debug!("starting action listener");
133            action_listener
134                .lock()
135                .await
136                .listen(worker_id_clone, action_tx)
137                .await as Result<(), HatchetError>
138        });
139
140        tokio::try_join!(
141            async {
142                const HEARTBEAT_INTERVAL: u64 = 4;
143                loop {
144                    log::debug!("sending heartbeat");
145                    self.client.dispatcher_client.heartbeat(&worker_id).await?;
146                    tokio::time::sleep(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL)).await;
147                }
148                #[allow(unreachable_code)]
149                Ok::<(), HatchetError>(())
150            },
151            async {
152                while let Some(task) = action_rx.recv().await {
153                    dispatcher
154                        .lock()
155                        .await
156                        .dispatch(worker_id.clone(), task)
157                        .await?
158                }
159                Ok(())
160            }
161        )?;
162
163        Ok(())
164    }
165
166    async fn register_worker(
167        client: &mut Hatchet,
168        name: &str,
169        actions: Vec<String>,
170        slots: i32,
171        labels: HashMap<String, String>,
172    ) -> Result<String, HatchetError> {
173        let registration = WorkerRegisterRequest {
174            worker_name: name.to_string(),
175            actions,
176            services: vec![],
177            slots: Some(slots),
178            labels: labels
179                .into_iter()
180                .map(|(k, v)| {
181                    (
182                        k,
183                        dispatcher::WorkerLabels {
184                            str_value: Some(v),
185                            int_value: None,
186                        },
187                    )
188                })
189                .collect(),
190            webhook_id: None,
191            runtime_info: None,
192        };
193
194        let response = client
195            .dispatcher_client
196            .register_worker(registration)
197            .await?;
198
199        Ok(response.into_inner().worker_id)
200    }
201}
202
203impl<I, O> Register<Workflow<I, O>, I, O> for Worker
204where
205    I: Serialize + Send + Sync + 'static,
206    O: DeserializeOwned + Send + Sync + 'static,
207{
208    fn add_task_or_workflow(mut self, workflow: &Workflow<I, O>) -> Self {
209        self.workflows.push(workflow.to_proto());
210
211        for task in &workflow.executable_tasks {
212            let fully_qualified_name = format!("{}:{}", workflow.name, task.name());
213            self.tasks
214                .lock()
215                .unwrap()
216                .insert(fully_qualified_name, Arc::from(task.clone()));
217        }
218        self
219    }
220}
221
222impl<I, O> Register<Task<I, O>, I, O> for Worker
223where
224    I: DeserializeOwned + Serialize + Send + Sync + 'static,
225    O: Serialize + DeserializeOwned + Send + Sync + 'static,
226{
227    fn add_task_or_workflow(mut self, workflow: &Task<I, O>) -> Self {
228        let workflow_proto = workflow.to_standalone_workflow_proto();
229        self.workflows.push(workflow_proto);
230
231        let fully_qualified_name = format!("{}:{}", workflow.name, workflow.name);
232        self.tasks
233            .lock()
234            .unwrap()
235            .insert(fully_qualified_name, Arc::from(workflow.into_executable()));
236        self
237    }
238}
239
240pub trait Register<T, I, O>
241where
242    I: Serialize + Send + Sync + 'static,
243    O: DeserializeOwned + Send + Sync + 'static,
244{
245    fn add_task_or_workflow(self, workflow: &T) -> Self;
246}