hatchet-sdk 0.2.8

This is an unofficial Rust SDK for Hatchet, a distributed, fault-tolerant task queue.
Documentation
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::mpsc;

use crate::clients::grpc::v0::dispatcher;
use crate::clients::grpc::v0::dispatcher::WorkerRegisterRequest;
use crate::clients::hatchet::Hatchet;
use crate::error::HatchetError;
use crate::runnables::*;
use crate::worker::action_listener::ActionListener;

/// A worker is a container for tasks that can be executed by a worker.
/// See [Hatchet.worker()](crate::Hatchet::worker()) for more information.
#[derive(derive_builder::Builder)]
#[builder(pattern = "owned")]
pub struct Worker {
    pub name: String,
    client: Hatchet,
    #[builder(default = 100)]
    slots: i32,
    #[builder(default = Arc::new(Mutex::new(HashMap::new())))]
    tasks: Arc<Mutex<HashMap<String, Arc<dyn ExecutableTask>>>>,
    #[builder(default = vec![])]
    workflows: Vec<crate::clients::grpc::v1::workflows::CreateWorkflowVersionRequest>,
    #[builder(default = HashMap::new())]
    labels: HashMap<String, String>,
}

impl Worker {
    /// Register a workflow with this worker. When the worker starts, it will register the workflow with Hatchet.
    /// Hatchet will then assign runs of the workflow to this worker.
    ///
    /// ```compile_fail
    /// use hatchet_sdk::{Context, Hatchet, EmptyModel};
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let hatchet = Hatchet::from_env().await.unwrap();
    ///     let my_task = hatchet.task::<EmptyModel, EmptyModel, MyError>("my-task", |input: EmptyModel, _ctx: Context| async move {
    ///     Ok(EmptyModel)
    /// });
    ///
    /// let my_workflow = hatchet.workflow("my-workflow")
    ///     .build()
    ///     .unwrap()
    ///     .add_task(&my_task)
    ///
    ///     let worker = hatchet.worker("my-worker").build().unwrap();
    ///     worker.add_task_or_workflow(my_workflow);
    /// }
    /// ```
    async fn register_workflows(&mut self) {
        for workflow in &self.workflows {
            self.client
                .admin_client
                .put_workflow(workflow.clone())
                .await
                .unwrap();
        }
    }

    /// Start the worker.
    /// This will register the worker with Hatchet and start listening for assigned tasks.
    /// Use ctrl+c to stop the worker.
    ///
    /// ```compile_fail
    /// use hatchet_sdk::{Context, Hatchet, EmptyModel, Runnable,Register};
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let hatchet = Hatchet::from_env().await.unwrap();
    ///     
    ///     let my_workflow = hatchet.
    ///         workflow::<EmptyModel, EmptyModel>("my-workflow")
    ///         .build()
    ///         .unwrap()
    ///         .add_task(&hatchet.task("my-task", async move |input: EmptyModel, _ctx: Context| -> anyhow::Result<EmptyModel> {
    ///             Ok(EmptyModel)
    ///         }))
    ///
    ///     let mut worker = hatchet.worker("my-worker")
    ///         .slots(5)
    ///         .build()
    ///         .unwrap()
    ///         .add_task_or_workflow(my_workflow);
    ///
    ///     worker.start().await.unwrap();
    /// }
    /// ```
    pub async fn start(&mut self) -> Result<(), HatchetError> {
        log::info!("STARTING HATCHET...");
        let mut actions = vec![];
        for workflow in &self.workflows {
            for task in &workflow.tasks {
                actions.push(task.action.clone());
            }
        }
        log::debug!("{} waiting for actions: {:?}", self.name, actions);

        let worker_id = Arc::new(
            Self::register_worker(
                &mut self.client,
                &self.name,
                actions,
                self.slots,
                self.labels.clone(),
            )
            .await?,
        );
        self.register_workflows().await;

        let (action_tx, mut action_rx) =
            mpsc::channel::<dispatcher::AssignedAction>(self.slots as usize);

        let dispatcher = Arc::new(tokio::sync::Mutex::new(
            crate::worker::task_dispatcher::TaskDispatcher {
                registry: self.tasks.clone(),
                client: self.client.clone(),
                task_runs: Arc::new(Mutex::new(HashMap::new())),
            },
        ));

        let action_listener = Arc::new(tokio::sync::Mutex::new(ActionListener::new(
            self.client.clone(),
        )));

        let worker_id_clone = worker_id.clone();
        tokio::spawn(async move {
            log::debug!("starting action listener");
            action_listener
                .lock()
                .await
                .listen(worker_id_clone, action_tx)
                .await as Result<(), HatchetError>
        });

        tokio::try_join!(
            async {
                const HEARTBEAT_INTERVAL: u64 = 4;
                loop {
                    log::debug!("sending heartbeat");
                    self.client.dispatcher_client.heartbeat(&worker_id).await?;
                    tokio::time::sleep(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL)).await;
                }
                #[allow(unreachable_code)]
                Ok::<(), HatchetError>(())
            },
            async {
                while let Some(task) = action_rx.recv().await {
                    dispatcher
                        .lock()
                        .await
                        .dispatch(worker_id.clone(), task)
                        .await?
                }
                Ok(())
            }
        )?;

        Ok(())
    }

    async fn register_worker(
        client: &mut Hatchet,
        name: &str,
        actions: Vec<String>,
        slots: i32,
        labels: HashMap<String, String>,
    ) -> Result<String, HatchetError> {
        let registration = WorkerRegisterRequest {
            worker_name: name.to_string(),
            actions,
            services: vec![],
            slots: Some(slots),
            labels: labels
                .into_iter()
                .map(|(k, v)| {
                    (
                        k,
                        dispatcher::WorkerLabels {
                            str_value: Some(v),
                            int_value: None,
                        },
                    )
                })
                .collect(),
            webhook_id: None,
            runtime_info: None,
        };

        let response = client
            .dispatcher_client
            .register_worker(registration)
            .await?;

        Ok(response.into_inner().worker_id)
    }
}

impl<I, O> Register<Workflow<I, O>, I, O> for Worker
where
    I: Serialize + Send + Sync + 'static,
    O: DeserializeOwned + Send + Sync + 'static,
{
    fn add_task_or_workflow(mut self, workflow: &Workflow<I, O>) -> Self {
        self.workflows.push(workflow.to_proto());

        for task in &workflow.executable_tasks {
            let fully_qualified_name = format!("{}:{}", workflow.name, task.name());
            self.tasks
                .lock()
                .unwrap()
                .insert(fully_qualified_name, Arc::from(task.clone()));
        }
        self
    }
}

impl<I, O> Register<Task<I, O>, I, O> for Worker
where
    I: DeserializeOwned + Serialize + Send + Sync + 'static,
    O: Serialize + DeserializeOwned + Send + Sync + 'static,
{
    fn add_task_or_workflow(mut self, workflow: &Task<I, O>) -> Self {
        let workflow_proto = workflow.to_standalone_workflow_proto();
        self.workflows.push(workflow_proto);

        let fully_qualified_name = format!("{}:{}", workflow.name, workflow.name);
        self.tasks
            .lock()
            .unwrap()
            .insert(fully_qualified_name, Arc::from(workflow.into_executable()));
        self
    }
}

pub trait Register<T, I, O>
where
    I: Serialize + Send + Sync + 'static,
    O: DeserializeOwned + Send + Sync + 'static,
{
    fn add_task_or_workflow(self, workflow: &T) -> Self;
}