1use std::sync::atomic::{AtomicI64, Ordering};
2use std::sync::{Arc, OnceLock};
3use tokio::sync::mpsc;
4
5use crate::{GetWorkflowRunResponse, Hatchet, HatchetError};
6
7pub struct Context {
9 log_tx: OnceLock<mpsc::Sender<String>>,
10 stream_tx: OnceLock<mpsc::Sender<(Vec<u8>, i64)>>,
11 stream_index: Arc<AtomicI64>,
12 client: Hatchet,
13 workflow_run_id: String,
14 task_run_external_id: String,
15}
16
17impl std::fmt::Debug for Context {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("Context")
20 .field("workflow_run_id", &self.workflow_run_id)
21 .field("task_run_external_id", &self.task_run_external_id)
22 .finish()
23 }
24}
25
26impl Context {
27 pub(crate) fn new(client: Hatchet, workflow_run_id: &str, task_run_external_id: &str) -> Self {
28 Self {
29 log_tx: OnceLock::new(),
30 stream_tx: OnceLock::new(),
31 stream_index: Arc::new(AtomicI64::new(0)),
32 client,
33 workflow_run_id: workflow_run_id.to_string(),
34 task_run_external_id: task_run_external_id.to_string(),
35 }
36 }
37
38 fn get_or_init_log_tx(&self) -> &mpsc::Sender<String> {
39 self.log_tx.get_or_init(|| {
40 let mut log_client = self.client.clone();
41 let (tx, mut rx) = mpsc::channel::<String>(100);
42 let task_id = self.task_run_external_id.clone();
43 tokio::spawn(async move {
44 while let Some(message) = rx.recv().await {
45 if let Err(e) = log_client.event_client.put_log(&task_id, message).await {
46 log::warn!("failed to send log to hatchet: {e}");
47 }
48 }
49 });
50 tx
51 })
52 }
53
54 fn get_or_init_stream_tx(&self) -> &mpsc::Sender<(Vec<u8>, i64)> {
55 self.stream_tx.get_or_init(|| {
56 let mut stream_client = self.client.clone();
57 let (tx, mut rx) = mpsc::channel::<(Vec<u8>, i64)>(100);
58 let task_id = self.task_run_external_id.clone();
59 tokio::spawn(async move {
60 while let Some((message, index)) = rx.recv().await {
61 if let Err(e) = stream_client
62 .event_client
63 .put_stream_event(&task_id, message, Some(index))
64 .await
65 {
66 log::warn!("failed to send stream event to hatchet: {e}");
67 }
68 }
69 });
70 tx
71 })
72 }
73
74 pub async fn parent_output(
83 &self,
84 parent_step_name: &str,
85 ) -> Result<serde_json::Value, HatchetError> {
86 let workflow_run = self.get_current_workflow().await?;
87
88 let current_task = workflow_run
89 .tasks
90 .iter()
91 .find(|task| task.task_external_id == self.task_run_external_id)
92 .ok_or_else(|| HatchetError::ParentTaskNotFound {
93 parent_step_name: parent_step_name.to_string(),
94 })?;
95
96 let parent = current_task
97 .input
98 .parents
99 .get(parent_step_name)
100 .ok_or_else(|| HatchetError::ParentTaskNotFound {
101 parent_step_name: parent_step_name.to_string(),
102 })?;
103
104 Ok(parent.0.clone())
105 }
106
107 pub async fn filter_payload(&self) -> Result<serde_json::Value, HatchetError> {
108 let workflow_run = self.get_current_workflow().await?;
109
110 let current_task = workflow_run
111 .tasks
112 .iter()
113 .find(|task| task.task_external_id == self.task_run_external_id)
114 .unwrap();
115
116 Ok(current_task.input.triggers.filter_payload.clone())
117 }
118
119 pub async fn log(&self, message: &str) -> Result<(), HatchetError> {
129 self.get_or_init_log_tx()
130 .send(message.to_string())
131 .await
132 .map_err(|e| HatchetError::InternalError(e.to_string()))?;
133 Ok(())
134 }
135
136 pub async fn put_stream(&self, data: impl Into<Vec<u8>>) -> Result<(), HatchetError> {
149 let index = self.stream_index.fetch_add(1, Ordering::SeqCst);
150 self.get_or_init_stream_tx()
151 .send((data.into(), index))
152 .await
153 .map_err(|e| HatchetError::StreamError(e.to_string()))?;
154 Ok(())
155 }
156
157 async fn get_current_workflow(&self) -> Result<GetWorkflowRunResponse, HatchetError> {
158 self.client
159 .workflow_rest_client
160 .get(&self.workflow_run_id)
161 .await
162 }
163}