Skip to main content

hatchet_sdk/
context.rs

1use std::sync::atomic::{AtomicI64, Ordering};
2use std::sync::{Arc, OnceLock};
3use tokio::sync::mpsc;
4
5use crate::{GetWorkflowRunResponse, Hatchet, HatchetError};
6
7/// The context object is used to interact with the Hatchet API from within a task.
8pub struct Context {
9    log_tx: OnceLock<mpsc::Sender<String>>,
10    stream_tx: OnceLock<mpsc::Sender<(Vec<u8>, i64)>>,
11    stream_index: Arc<AtomicI64>,
12    client: Hatchet,
13    workflow_run_id: String,
14    task_run_external_id: String,
15}
16
17impl std::fmt::Debug for Context {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("Context")
20            .field("workflow_run_id", &self.workflow_run_id)
21            .field("task_run_external_id", &self.task_run_external_id)
22            .finish()
23    }
24}
25
26impl Context {
27    pub(crate) fn new(client: Hatchet, workflow_run_id: &str, task_run_external_id: &str) -> Self {
28        Self {
29            log_tx: OnceLock::new(),
30            stream_tx: OnceLock::new(),
31            stream_index: Arc::new(AtomicI64::new(0)),
32            client,
33            workflow_run_id: workflow_run_id.to_string(),
34            task_run_external_id: task_run_external_id.to_string(),
35        }
36    }
37
38    fn get_or_init_log_tx(&self) -> &mpsc::Sender<String> {
39        self.log_tx.get_or_init(|| {
40            let mut log_client = self.client.clone();
41            let (tx, mut rx) = mpsc::channel::<String>(100);
42            let task_id = self.task_run_external_id.clone();
43            tokio::spawn(async move {
44                while let Some(message) = rx.recv().await {
45                    if let Err(e) = log_client.event_client.put_log(&task_id, message).await {
46                        log::warn!("failed to send log to hatchet: {e}");
47                    }
48                }
49            });
50            tx
51        })
52    }
53
54    fn get_or_init_stream_tx(&self) -> &mpsc::Sender<(Vec<u8>, i64)> {
55        self.stream_tx.get_or_init(|| {
56            let mut stream_client = self.client.clone();
57            let (tx, mut rx) = mpsc::channel::<(Vec<u8>, i64)>(100);
58            let task_id = self.task_run_external_id.clone();
59            tokio::spawn(async move {
60                while let Some((message, index)) = rx.recv().await {
61                    if let Err(e) = stream_client
62                        .event_client
63                        .put_stream_event(&task_id, message, Some(index))
64                        .await
65                    {
66                        log::warn!("failed to send stream event to hatchet: {e}");
67                    }
68                }
69            });
70            tx
71        })
72    }
73
74    /// Get the output of a parent task in a DAG.
75    ///
76    /// ```compile_fail
77    /// let task = hatchet.task("my-task", |_input: EmptyModel, ctx: Context| async move {
78    ///     let parent_output = ctx.parent_output("parent_task").await.unwrap();
79    ///     Ok(EmptyModel)
80    /// });
81    /// ```
82    pub async fn parent_output(
83        &self,
84        parent_step_name: &str,
85    ) -> Result<serde_json::Value, HatchetError> {
86        let workflow_run = self.get_current_workflow().await?;
87
88        let current_task = workflow_run
89            .tasks
90            .iter()
91            .find(|task| task.task_external_id == self.task_run_external_id)
92            .ok_or_else(|| HatchetError::ParentTaskNotFound {
93                parent_step_name: parent_step_name.to_string(),
94            })?;
95
96        let parent = current_task
97            .input
98            .parents
99            .get(parent_step_name)
100            .ok_or_else(|| HatchetError::ParentTaskNotFound {
101                parent_step_name: parent_step_name.to_string(),
102            })?;
103
104        Ok(parent.0.clone())
105    }
106
107    pub async fn filter_payload(&self) -> Result<serde_json::Value, HatchetError> {
108        let workflow_run = self.get_current_workflow().await?;
109
110        let current_task = workflow_run
111            .tasks
112            .iter()
113            .find(|task| task.task_external_id == self.task_run_external_id)
114            .unwrap();
115
116        Ok(current_task.input.triggers.filter_payload.clone())
117    }
118
119    /// Log a line to the Hatchet API. This will send the log line to the Hatchet API and return immediately.
120    /// ```compile_fail
121    /// use hatchet_sdk::{Hatchet, EmptyModel};
122    /// let hatchet = Hatchet::from_env().await.unwrap();
123    /// let task = hatchet.task("my-task", |_input: EmptyModel, ctx: Context| async move {
124    ///     ctx.log("Hello, world!").await.unwrap();
125    ///     Ok(EmptyModel)
126    /// });
127    /// ```
128    pub async fn log(&self, message: &str) -> Result<(), HatchetError> {
129        self.get_or_init_log_tx()
130            .send(message.to_string())
131            .await
132            .map_err(|e| HatchetError::InternalError(e.to_string()))?;
133        Ok(())
134    }
135
136    /// Send a stream event to the Hatchet API. Useful for streaming data such as LLM tokens
137    /// or progress updates from a task to consumers.
138    ///
139    /// ```compile_fail
140    /// use hatchet_sdk::{Hatchet, EmptyModel};
141    /// let hatchet = Hatchet::from_env().await.unwrap();
142    /// let task = hatchet.task("my-task", |_input: EmptyModel, ctx: Context| async move {
143    ///     ctx.put_stream(b"chunk 1".to_vec()).await.unwrap();
144    ///     ctx.put_stream(b"chunk 2".to_vec()).await.unwrap();
145    ///     Ok(EmptyModel)
146    /// });
147    /// ```
148    pub async fn put_stream(&self, data: impl Into<Vec<u8>>) -> Result<(), HatchetError> {
149        let index = self.stream_index.fetch_add(1, Ordering::SeqCst);
150        self.get_or_init_stream_tx()
151            .send((data.into(), index))
152            .await
153            .map_err(|e| HatchetError::StreamError(e.to_string()))?;
154        Ok(())
155    }
156
157    async fn get_current_workflow(&self) -> Result<GetWorkflowRunResponse, HatchetError> {
158        self.client
159            .workflow_rest_client
160            .get(&self.workflow_run_id)
161            .await
162    }
163}