1use derive_builder::Builder;
2use serde::Serialize;
3use serde::de::DeserializeOwned;
4
5use super::{ExecutableTask, ExtractRunnableOutput, Task, TriggerWorkflowOptions};
6use crate::clients::grpc::v1::workflows::{
7 CreateTaskOpts, CreateWorkflowVersionRequest, DefaultFilter as DefaultFilterProto,
8};
9use crate::clients::rest::features::crons::{CreateCronOpts, CronOptions, CronTrigger};
10use crate::clients::rest::features::schedules::{
11 CreateScheduleOpts, ScheduleOptions, ScheduledRun,
12};
13use crate::{GetWorkflowRunResponse, Hatchet, HatchetError};
14
15#[derive(Clone, Builder)]
18#[builder(pattern = "owned")]
19pub struct Workflow<I, O> {
20 pub(crate) name: String,
21 client: Hatchet,
22 #[builder(default = vec![])]
23 pub(crate) executable_tasks: Vec<Box<dyn ExecutableTask>>,
24 #[builder(default = String::from(""))]
25 description: String,
26 #[builder(default = String::from(""))]
27 version: String,
28 #[builder(default = 1)]
29 default_priority: i32,
30 #[builder(default = vec![])]
31 tasks: Vec<CreateTaskOpts>,
32 #[builder(default = vec![])]
33 on_events: Vec<String>,
34 #[builder(default = vec![])]
35 cron_triggers: Vec<String>,
36 #[builder(default = vec![])]
37 default_filters: Vec<DefaultFilter>,
38 #[builder(default)]
39 input_json_schema: Option<serde_json::Value>,
40 #[builder(default = std::marker::PhantomData)]
41 _phantom: std::marker::PhantomData<(I, O)>,
42}
43
44impl<I, O> Workflow<I, O>
45where
46 I: Serialize + Send + Sync,
47 O: DeserializeOwned + Send + Sync,
48{
49 pub fn add_task<P>(mut self, task: &Task<I, P>) -> Self
50 where
51 I: serde::de::DeserializeOwned + Send + 'static,
52 P: serde::Serialize + Send + 'static,
53 {
54 if self
55 .tasks
56 .iter()
57 .any(|existing_task| existing_task.readable_id == task.name)
58 {
59 panic!("Duplicate tasks registered to workflow: {}", task.name);
60 }
61
62 self.tasks.push(task.to_task_proto(&self.name));
63 self.executable_tasks.push(task.into_executable());
64 self
65 }
66
67 pub(crate) fn to_proto(&self) -> CreateWorkflowVersionRequest {
68 CreateWorkflowVersionRequest {
69 name: self.name.clone().to_lowercase(),
70 description: self.description.clone(),
71 version: self.version.clone(),
72 event_triggers: self.on_events.clone(),
73 cron_triggers: self.cron_triggers.clone(),
74 tasks: self.tasks.clone(),
75 concurrency: None,
76 cron_input: None,
77 on_failure_task: None,
78 sticky: None,
79 default_priority: Some(self.default_priority),
80 concurrency_arr: vec![],
81 default_filters: self
82 .default_filters
83 .clone()
84 .into_iter()
85 .map(|f| f.to_proto())
86 .collect(),
87 input_json_schema: self
88 .input_json_schema
89 .as_ref()
90 .map(|value| serde_json::to_vec(value).expect("must be serializable")),
91 }
92 }
93
94 pub async fn schedule(
97 &self,
98 trigger_at: chrono::DateTime<chrono::Utc>,
99 input: &I,
100 options: Option<&ScheduleOptions>,
101 ) -> Result<ScheduledRun, HatchetError> {
102 let input_json =
103 serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
104 self.client
105 .schedules
106 .create(
107 &self.name.to_lowercase(),
108 CreateScheduleOpts {
109 trigger_at,
110 input: input_json,
111 additional_metadata: options.and_then(|o| o.additional_metadata.clone()),
112 priority: options.and_then(|o| o.priority),
113 },
114 )
115 .await
116 }
117
118 pub async fn cron(
121 &self,
122 name: &str,
123 expression: &str,
124 input: &I,
125 options: Option<&CronOptions>,
126 ) -> Result<CronTrigger, HatchetError> {
127 let input_json =
128 serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
129 self.client
130 .crons
131 .create(
132 &self.name.to_lowercase(),
133 CreateCronOpts {
134 name: name.to_string(),
135 expression: expression.to_string(),
136 input: input_json,
137 additional_metadata: options.and_then(|o| o.additional_metadata.clone()),
138 priority: options.and_then(|o| o.priority),
139 },
140 )
141 .await
142 }
143
144 async fn trigger(
145 &self,
146 input: &I,
147 options: &TriggerWorkflowOptions,
148 ) -> Result<String, HatchetError> {
149 let input_json =
150 serde_json::to_value(input).map_err(|e| HatchetError::JsonEncode(e.to_string()))?;
151
152 let additional_metadata = options.additional_metadata.clone().map(|v| v.to_string());
153 let desired_worker_id = options.desired_worker_id.clone();
154
155 let response = self
156 .client
157 .workflow_client
158 .trigger_workflow(
159 crate::clients::grpc::v0::workflows::TriggerWorkflowRequest {
160 name: self.name.clone().to_lowercase(),
161 input: input_json.to_string(),
162 parent_id: None,
163 parent_task_run_external_id: None,
164 child_index: None,
165 child_key: None,
166 additional_metadata,
167 desired_worker_id,
168 priority: None,
169 },
170 )
171 .await?;
172
173 Ok(response.workflow_run_id)
174 }
175
176 fn safely_get_action_name(&self, action_id: &str) -> Option<String> {
177 action_id.split(':').nth(1).map(|s| s.to_string())
178 }
179}
180
181impl<I, O> ExtractRunnableOutput<O> for Workflow<I, O>
182where
183 I: Serialize + Send + Sync + 'static,
184 O: DeserializeOwned + Send + Sync + 'static,
185{
186 fn extract_output(&self, workflow: GetWorkflowRunResponse) -> Result<O, HatchetError> {
187 let mut task_outputs = serde_json::Map::new();
188
189 for task in &workflow.tasks {
190 if let (Some(action_id), Some(output)) = (&task.action_id, &task.output)
191 && let Some(task_name) = self.safely_get_action_name(action_id)
192 {
193 task_outputs.insert(task_name, output.clone());
194 }
195 }
196
197 let output_value = serde_json::Value::Object(task_outputs);
198 serde_json::from_value(output_value)
199 .map_err(|e| HatchetError::JsonDecodeError(e.to_string()))
200 }
201}
202
203#[async_trait::async_trait]
204impl<I, O> super::Runnable<I, O> for Workflow<I, O>
205where
206 I: Serialize + Send + Sync + DeserializeOwned + 'static,
207 O: DeserializeOwned + Send + Sync + 'static,
208{
209 async fn get_run(&self, run_id: &str) -> Result<GetWorkflowRunResponse, HatchetError> {
210 self.client.workflow_rest_client.get(run_id).await
211 }
212
213 async fn run_no_wait(
214 &self,
215 input: &I,
216 options: Option<&TriggerWorkflowOptions>,
217 ) -> Result<String, HatchetError> {
218 Ok(self
219 .trigger(input, options.unwrap_or(&TriggerWorkflowOptions::default()))
220 .await?)
221 }
222}
223
224#[derive(Debug, Default, Clone)]
225pub struct DefaultFilter {
226 pub expression: String,
227 pub scope: String,
228 pub payload: Option<serde_json::Value>,
229}
230
231impl DefaultFilter {
232 pub fn new(expression: String, scope: String, payload: Option<serde_json::Value>) -> Self {
233 Self {
234 expression,
235 scope,
236 payload,
237 }
238 }
239}
240
241impl DefaultFilter {
242 pub fn to_proto(&self) -> DefaultFilterProto {
243 DefaultFilterProto {
244 expression: self.expression.clone(),
245 scope: self.scope.clone(),
246 payload: self.payload.clone().map(|v| v.to_string().into()),
247 }
248 }
249}