Skip to main content

hatchet_sdk/runnables/
task.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7
8use super::workflow::DefaultFilter;
9use super::{ExtractRunnableOutput, TriggerWorkflowOptions};
10use crate::clients::grpc::v1::workflows::{CreateTaskOpts, CreateWorkflowVersionRequest};
11use crate::clients::rest::features::crons::{CreateCronOpts, CronOptions, CronTrigger};
12use crate::clients::rest::features::schedules::{
13    CreateScheduleOpts, ScheduleOptions, ScheduledRun,
14};
15use crate::utils::duration_to_expr;
16use crate::{Context, GetWorkflowRunResponse, Hatchet, HatchetError};
17
18pub type TaskResult = Pin<Box<dyn Future<Output = Result<serde_json::Value, TaskError>> + Send>>;
19
20#[derive(Debug)]
21pub enum TaskError {
22    InputDeserialization(serde_json::Error),
23    OutputSerialization(serde_json::Error),
24    Execution(anyhow::Error),
25}
26
27impl std::fmt::Display for TaskError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            TaskError::InputDeserialization(e) => write!(f, "Failed to deserialize input: {}", e),
31            TaskError::OutputSerialization(e) => write!(f, "Failed to serialize output: {}", e),
32            TaskError::Execution(e) => {
33                let error_message = format!("Task execution failed: {}", e);
34                if std::env::var("RUST_BACKTRACE").is_ok_and(|v| v != "0") {
35                    write!(f, "{}\n\n{}", error_message, e.backtrace())
36                } else {
37                    write!(f, "{}", error_message)
38                }
39            }
40        }
41    }
42}
43
44impl std::error::Error for TaskError {
45    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
46        match self {
47            TaskError::InputDeserialization(e) => Some(e),
48            TaskError::OutputSerialization(e) => Some(e),
49            TaskError::Execution(e) => Some(e.as_ref()),
50        }
51    }
52}
53
54pub trait ExecutableTask: Send + Sync + dyn_clone::DynClone {
55    fn execute(&self, input: serde_json::Value, ctx: Context) -> TaskResult;
56    fn name(&self) -> &str;
57}
58
59dyn_clone::clone_trait_object!(ExecutableTask);
60
61/// A task is a unit of work that can be executed by a worker.
62/// See [Hatchet.task()](crate::Hatchet::task()) for more information.
63#[derive(Clone, derive_builder::Builder)]
64#[builder(pattern = "owned")]
65pub struct Task<I, O> {
66    client: Hatchet,
67    pub(crate) name: String,
68    handler: Arc<
69        dyn Fn(I, Context) -> Pin<Box<dyn Future<Output = anyhow::Result<O>> + Send>> + Send + Sync,
70    >,
71    #[builder(default = vec![])]
72    parents: Vec<String>,
73    #[builder(default = String::from(""))]
74    description: String,
75    #[builder(default = String::from(""))]
76    version: String,
77    #[builder(default = 1)]
78    default_priority: i32,
79    #[builder(default = vec![])]
80    on_events: Vec<String>,
81    #[builder(default = vec![])]
82    cron_triggers: Vec<String>,
83    #[builder(default = vec![])]
84    default_filters: Vec<DefaultFilter>,
85    #[builder(default = 0)]
86    retries: i32,
87    #[builder(default = std::time::Duration::from_secs(300))]
88    schedule_timeout: std::time::Duration,
89    #[builder(default = std::time::Duration::from_secs(60))]
90    execution_timeout: std::time::Duration,
91    #[builder(default)]
92    input_json_schema: Option<serde_json::Value>,
93}
94
95impl<I, O> Task<I, O>
96where
97    I: Serialize + for<'de> Deserialize<'de> + Send + 'static,
98    O: Serialize + Send + 'static,
99{
100    pub fn add_parent<J, P>(mut self, parent: &Task<J, P>) -> Self {
101        self.parents.push(parent.name.clone());
102        self
103    }
104
105    pub(crate) fn into_executable(&self) -> Box<dyn ExecutableTask> {
106        let handler = self.handler.clone();
107        let name = self.name.clone();
108
109        Box::new(TypeErasedTask {
110            name: name.clone(),
111            handler: Arc::new(
112                move |input: serde_json::Value, ctx: Context| -> TaskResult {
113                    let handler = handler.clone();
114                    Box::pin(async move {
115                        let typed_input: I = serde_json::from_value(input)
116                            .map_err(TaskError::InputDeserialization)?;
117
118                        let result = handler(typed_input, ctx)
119                            .await
120                            .map_err(TaskError::Execution)?;
121
122                        serde_json::to_value(result).map_err(TaskError::OutputSerialization)
123                    }) as TaskResult
124                },
125            ),
126        })
127    }
128
129    pub(crate) fn to_task_proto(&self, workflow_name: &str) -> CreateTaskOpts {
130        CreateTaskOpts {
131            readable_id: self.name.clone(),
132            action: format!("{workflow_name}:{}", &self.name),
133            timeout: duration_to_expr(self.execution_timeout),
134            inputs: String::from("{{}}"),
135            parents: self.parents.clone(),
136            retries: self.retries,
137            rate_limits: vec![],
138            worker_labels: std::collections::HashMap::new(),
139            backoff_factor: None,
140            backoff_max_seconds: None,
141            concurrency: vec![],
142            conditions: None,
143            schedule_timeout: Some(duration_to_expr(self.schedule_timeout)),
144        }
145    }
146
147    pub(crate) fn to_standalone_workflow_proto(&self) -> CreateWorkflowVersionRequest {
148        let task_proto = self.to_task_proto(&self.name);
149        CreateWorkflowVersionRequest {
150            name: self.name.clone().to_lowercase(),
151            description: self.description.clone(),
152            version: self.version.clone(),
153            event_triggers: self.on_events.clone(),
154            cron_triggers: self.cron_triggers.clone(),
155            tasks: vec![task_proto],
156            concurrency: None,
157            cron_input: None,
158            on_failure_task: None,
159            sticky: None,
160            default_priority: Some(self.default_priority),
161            concurrency_arr: vec![],
162            default_filters: self
163                .default_filters
164                .clone()
165                .into_iter()
166                .map(|f| f.to_proto())
167                .collect(),
168            input_json_schema: self
169                .input_json_schema
170                .as_ref()
171                .map(|value| serde_json::to_vec(value).expect("must be serializable")),
172        }
173    }
174
175    /// Schedule this task's workflow to run at a specific future time.
176    /// See [`SchedulesClient::create`](crate::SchedulesClient::create) for the underlying API.
177    pub async fn schedule(
178        &self,
179        trigger_at: chrono::DateTime<chrono::Utc>,
180        input: &I,
181        options: Option<&ScheduleOptions>,
182    ) -> Result<ScheduledRun, HatchetError> {
183        let input_json =
184            serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
185        self.client
186            .schedules
187            .create(
188                &self.name.to_lowercase(),
189                CreateScheduleOpts {
190                    trigger_at,
191                    input: input_json,
192                    additional_metadata: options.and_then(|o| o.additional_metadata.clone()),
193                    priority: options.and_then(|o| o.priority),
194                },
195            )
196            .await
197    }
198
199    /// Create a recurring cron trigger for this task's workflow.
200    /// See [`CronsClient::create`](crate::CronsClient::create) for the underlying API.
201    pub async fn cron(
202        &self,
203        name: &str,
204        expression: &str,
205        input: &I,
206        options: Option<&CronOptions>,
207    ) -> Result<CronTrigger, HatchetError> {
208        let input_json =
209            serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
210        self.client
211            .crons
212            .create(
213                &self.name.to_lowercase(),
214                CreateCronOpts {
215                    name: name.to_string(),
216                    expression: expression.to_string(),
217                    input: input_json,
218                    additional_metadata: options.and_then(|o| o.additional_metadata.clone()),
219                    priority: options.and_then(|o| o.priority),
220                },
221            )
222            .await
223    }
224
225    async fn trigger(
226        &self,
227        input: &I,
228        options: &TriggerWorkflowOptions,
229    ) -> Result<String, HatchetError> {
230        let input_json =
231            serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
232
233        let additional_metadata = options.additional_metadata.clone().map(|v| v.to_string());
234        let desired_worker_id = options.desired_worker_id.clone();
235
236        let response = self
237            .client
238            .workflow_client
239            .trigger_workflow(
240                crate::clients::grpc::v0::workflows::TriggerWorkflowRequest {
241                    name: self.name.clone().to_lowercase(),
242                    input: input_json.to_string(),
243                    parent_id: None,
244                    parent_task_run_external_id: None,
245                    child_index: None,
246                    child_key: None,
247                    additional_metadata,
248                    desired_worker_id,
249                    priority: None,
250                },
251            )
252            .await?;
253
254        Ok(response.workflow_run_id)
255    }
256}
257
258impl<I, O> ExtractRunnableOutput<O> for Task<I, O>
259where
260    I: Serialize + DeserializeOwned + Send + Sync + 'static,
261    O: Serialize + DeserializeOwned + Send + Sync + 'static,
262{
263    fn extract_output(&self, workflow: GetWorkflowRunResponse) -> Result<O, HatchetError> {
264        let task_output = workflow
265            .tasks
266            .iter()
267            .find(|task| task.action_id == Some(format!("{}:{}", &self.name, &self.name)))
268            .and_then(|task| task.output.clone())
269            .ok_or(HatchetError::MissingOutput)?;
270
271        serde_json::from_value(task_output)
272            .map_err(|e| HatchetError::JsonDecodeError(e.to_string()))
273    }
274}
275
276#[async_trait::async_trait]
277impl<I, O> super::Runnable<I, O> for Task<I, O>
278where
279    I: Serialize + DeserializeOwned + Send + Sync + 'static,
280    O: Serialize + DeserializeOwned + Send + Sync + 'static,
281{
282    async fn get_run(&self, run_id: &str) -> Result<GetWorkflowRunResponse, HatchetError> {
283        self.client.workflow_rest_client.get(run_id).await
284    }
285    async fn run_no_wait(
286        &self,
287        input: &I,
288        options: Option<&TriggerWorkflowOptions>,
289    ) -> Result<String, HatchetError> {
290        Ok(self
291            .trigger(input, options.unwrap_or(&TriggerWorkflowOptions::default()))
292            .await?)
293    }
294}
295
296#[derive(Clone)]
297struct TypeErasedTask {
298    name: String,
299    handler: Arc<dyn Fn(serde_json::Value, Context) -> TaskResult + Send + Sync>,
300}
301
302impl ExecutableTask for TypeErasedTask {
303    fn execute(&self, input: serde_json::Value, ctx: Context) -> TaskResult {
304        (self.handler)(input, ctx)
305    }
306
307    fn name(&self) -> &str {
308        &self.name
309    }
310}