1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7
8use super::workflow::DefaultFilter;
9use super::{ExtractRunnableOutput, TriggerWorkflowOptions};
10use crate::clients::grpc::v1::workflows::{CreateTaskOpts, CreateWorkflowVersionRequest};
11use crate::clients::rest::features::crons::{CreateCronOpts, CronOptions, CronTrigger};
12use crate::clients::rest::features::schedules::{
13 CreateScheduleOpts, ScheduleOptions, ScheduledRun,
14};
15use crate::utils::duration_to_expr;
16use crate::{Context, GetWorkflowRunResponse, Hatchet, HatchetError};
17
18pub type TaskResult = Pin<Box<dyn Future<Output = Result<serde_json::Value, TaskError>> + Send>>;
19
20#[derive(Debug)]
21pub enum TaskError {
22 InputDeserialization(serde_json::Error),
23 OutputSerialization(serde_json::Error),
24 Execution(anyhow::Error),
25}
26
27impl std::fmt::Display for TaskError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 TaskError::InputDeserialization(e) => write!(f, "Failed to deserialize input: {}", e),
31 TaskError::OutputSerialization(e) => write!(f, "Failed to serialize output: {}", e),
32 TaskError::Execution(e) => {
33 let error_message = format!("Task execution failed: {}", e);
34 if std::env::var("RUST_BACKTRACE").is_ok_and(|v| v != "0") {
35 write!(f, "{}\n\n{}", error_message, e.backtrace())
36 } else {
37 write!(f, "{}", error_message)
38 }
39 }
40 }
41 }
42}
43
44impl std::error::Error for TaskError {
45 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
46 match self {
47 TaskError::InputDeserialization(e) => Some(e),
48 TaskError::OutputSerialization(e) => Some(e),
49 TaskError::Execution(e) => Some(e.as_ref()),
50 }
51 }
52}
53
54pub trait ExecutableTask: Send + Sync + dyn_clone::DynClone {
55 fn execute(&self, input: serde_json::Value, ctx: Context) -> TaskResult;
56 fn name(&self) -> &str;
57}
58
59dyn_clone::clone_trait_object!(ExecutableTask);
60
61#[derive(Clone, derive_builder::Builder)]
64#[builder(pattern = "owned")]
65pub struct Task<I, O> {
66 client: Hatchet,
67 pub(crate) name: String,
68 handler: Arc<
69 dyn Fn(I, Context) -> Pin<Box<dyn Future<Output = anyhow::Result<O>> + Send>> + Send + Sync,
70 >,
71 #[builder(default = vec![])]
72 parents: Vec<String>,
73 #[builder(default = String::from(""))]
74 description: String,
75 #[builder(default = String::from(""))]
76 version: String,
77 #[builder(default = 1)]
78 default_priority: i32,
79 #[builder(default = vec![])]
80 on_events: Vec<String>,
81 #[builder(default = vec![])]
82 cron_triggers: Vec<String>,
83 #[builder(default = vec![])]
84 default_filters: Vec<DefaultFilter>,
85 #[builder(default = 0)]
86 retries: i32,
87 #[builder(default = std::time::Duration::from_secs(300))]
88 schedule_timeout: std::time::Duration,
89 #[builder(default = std::time::Duration::from_secs(60))]
90 execution_timeout: std::time::Duration,
91 #[builder(default)]
92 input_json_schema: Option<serde_json::Value>,
93}
94
95impl<I, O> Task<I, O>
96where
97 I: Serialize + for<'de> Deserialize<'de> + Send + 'static,
98 O: Serialize + Send + 'static,
99{
100 pub fn add_parent<J, P>(mut self, parent: &Task<J, P>) -> Self {
101 self.parents.push(parent.name.clone());
102 self
103 }
104
105 pub(crate) fn into_executable(&self) -> Box<dyn ExecutableTask> {
106 let handler = self.handler.clone();
107 let name = self.name.clone();
108
109 Box::new(TypeErasedTask {
110 name: name.clone(),
111 handler: Arc::new(
112 move |input: serde_json::Value, ctx: Context| -> TaskResult {
113 let handler = handler.clone();
114 Box::pin(async move {
115 let typed_input: I = serde_json::from_value(input)
116 .map_err(TaskError::InputDeserialization)?;
117
118 let result = handler(typed_input, ctx)
119 .await
120 .map_err(TaskError::Execution)?;
121
122 serde_json::to_value(result).map_err(TaskError::OutputSerialization)
123 }) as TaskResult
124 },
125 ),
126 })
127 }
128
129 pub(crate) fn to_task_proto(&self, workflow_name: &str) -> CreateTaskOpts {
130 CreateTaskOpts {
131 readable_id: self.name.clone(),
132 action: format!("{workflow_name}:{}", &self.name),
133 timeout: duration_to_expr(self.execution_timeout),
134 inputs: String::from("{{}}"),
135 parents: self.parents.clone(),
136 retries: self.retries,
137 rate_limits: vec![],
138 worker_labels: std::collections::HashMap::new(),
139 backoff_factor: None,
140 backoff_max_seconds: None,
141 concurrency: vec![],
142 conditions: None,
143 schedule_timeout: Some(duration_to_expr(self.schedule_timeout)),
144 }
145 }
146
147 pub(crate) fn to_standalone_workflow_proto(&self) -> CreateWorkflowVersionRequest {
148 let task_proto = self.to_task_proto(&self.name);
149 CreateWorkflowVersionRequest {
150 name: self.name.clone().to_lowercase(),
151 description: self.description.clone(),
152 version: self.version.clone(),
153 event_triggers: self.on_events.clone(),
154 cron_triggers: self.cron_triggers.clone(),
155 tasks: vec![task_proto],
156 concurrency: None,
157 cron_input: None,
158 on_failure_task: None,
159 sticky: None,
160 default_priority: Some(self.default_priority),
161 concurrency_arr: vec![],
162 default_filters: self
163 .default_filters
164 .clone()
165 .into_iter()
166 .map(|f| f.to_proto())
167 .collect(),
168 input_json_schema: self
169 .input_json_schema
170 .as_ref()
171 .map(|value| serde_json::to_vec(value).expect("must be serializable")),
172 }
173 }
174
175 pub async fn schedule(
178 &self,
179 trigger_at: chrono::DateTime<chrono::Utc>,
180 input: &I,
181 options: Option<&ScheduleOptions>,
182 ) -> Result<ScheduledRun, HatchetError> {
183 let input_json =
184 serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
185 self.client
186 .schedules
187 .create(
188 &self.name.to_lowercase(),
189 CreateScheduleOpts {
190 trigger_at,
191 input: input_json,
192 additional_metadata: options.and_then(|o| o.additional_metadata.clone()),
193 priority: options.and_then(|o| o.priority),
194 },
195 )
196 .await
197 }
198
199 pub async fn cron(
202 &self,
203 name: &str,
204 expression: &str,
205 input: &I,
206 options: Option<&CronOptions>,
207 ) -> Result<CronTrigger, HatchetError> {
208 let input_json =
209 serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
210 self.client
211 .crons
212 .create(
213 &self.name.to_lowercase(),
214 CreateCronOpts {
215 name: name.to_string(),
216 expression: expression.to_string(),
217 input: input_json,
218 additional_metadata: options.and_then(|o| o.additional_metadata.clone()),
219 priority: options.and_then(|o| o.priority),
220 },
221 )
222 .await
223 }
224
225 async fn trigger(
226 &self,
227 input: &I,
228 options: &TriggerWorkflowOptions,
229 ) -> Result<String, HatchetError> {
230 let input_json =
231 serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
232
233 let additional_metadata = options.additional_metadata.clone().map(|v| v.to_string());
234 let desired_worker_id = options.desired_worker_id.clone();
235
236 let response = self
237 .client
238 .workflow_client
239 .trigger_workflow(
240 crate::clients::grpc::v0::workflows::TriggerWorkflowRequest {
241 name: self.name.clone().to_lowercase(),
242 input: input_json.to_string(),
243 parent_id: None,
244 parent_task_run_external_id: None,
245 child_index: None,
246 child_key: None,
247 additional_metadata,
248 desired_worker_id,
249 priority: None,
250 },
251 )
252 .await?;
253
254 Ok(response.workflow_run_id)
255 }
256}
257
258impl<I, O> ExtractRunnableOutput<O> for Task<I, O>
259where
260 I: Serialize + DeserializeOwned + Send + Sync + 'static,
261 O: Serialize + DeserializeOwned + Send + Sync + 'static,
262{
263 fn extract_output(&self, workflow: GetWorkflowRunResponse) -> Result<O, HatchetError> {
264 let task_output = workflow
265 .tasks
266 .iter()
267 .find(|task| task.action_id == Some(format!("{}:{}", &self.name, &self.name)))
268 .and_then(|task| task.output.clone())
269 .ok_or(HatchetError::MissingOutput)?;
270
271 serde_json::from_value(task_output)
272 .map_err(|e| HatchetError::JsonDecodeError(e.to_string()))
273 }
274}
275
276#[async_trait::async_trait]
277impl<I, O> super::Runnable<I, O> for Task<I, O>
278where
279 I: Serialize + DeserializeOwned + Send + Sync + 'static,
280 O: Serialize + DeserializeOwned + Send + Sync + 'static,
281{
282 async fn get_run(&self, run_id: &str) -> Result<GetWorkflowRunResponse, HatchetError> {
283 self.client.workflow_rest_client.get(run_id).await
284 }
285 async fn run_no_wait(
286 &self,
287 input: &I,
288 options: Option<&TriggerWorkflowOptions>,
289 ) -> Result<String, HatchetError> {
290 Ok(self
291 .trigger(input, options.unwrap_or(&TriggerWorkflowOptions::default()))
292 .await?)
293 }
294}
295
296#[derive(Clone)]
297struct TypeErasedTask {
298 name: String,
299 handler: Arc<dyn Fn(serde_json::Value, Context) -> TaskResult + Send + Sync>,
300}
301
302impl ExecutableTask for TypeErasedTask {
303 fn execute(&self, input: serde_json::Value, ctx: Context) -> TaskResult {
304 (self.handler)(input, ctx)
305 }
306
307 fn name(&self) -> &str {
308 &self.name
309 }
310}