1use std::collections::{HashMap, HashSet};
2
3use log::{debug, warn};
4#[cfg(feature = "wasm")]
5use serde::Serialize;
6#[cfg(feature = "wasm")]
7use serde_wasm_bindgen::{from_value, to_value};
8use uuid::Uuid;
9#[cfg(feature = "wasm")]
10use wasm_bindgen::prelude::*;
11
12use butterflow_models::node::NodeType;
13use butterflow_models::trigger::TriggerType;
14use butterflow_models::{Error, Result, Strategy, StrategyType, Task, TaskStatus, WorkflowRun};
15
16#[cfg(feature = "wasm")]
17#[wasm_bindgen(typescript_custom_section)]
18const MATRIX_TASK_CHANGES: &'static str = r#"
19type Uuid = string;
20
21type Task = import("../types").Task;
22type WorkflowRun = import("../types").WorkflowRun;
23type State = Record<string, unknown>;
24
25interface MatrixTaskChanges {
26 new_tasks: Task[];
27 tasks_to_mark_wont_do: Uuid[];
28 master_tasks_to_update: Uuid[];
29}
30
31interface RunnableTaskChanges {
32 tasks_to_await_trigger: Uuid[];
33 runnable_tasks: Uuid[];
34}
35"#;
36
37#[derive(serde::Serialize, serde::Deserialize)]
39pub struct MatrixTaskChanges {
40 pub new_tasks: Vec<Task>,
41 pub tasks_to_mark_wont_do: Vec<Uuid>,
42 pub master_tasks_to_update: Vec<Uuid>,
43}
44
45#[derive(serde::Serialize, serde::Deserialize)]
47pub struct RunnableTaskChanges {
48 pub tasks_to_await_trigger: Vec<Uuid>,
49 pub runnable_tasks: Vec<Uuid>,
50}
51
52#[cfg(not(feature = "wasm"))]
53pub struct Scheduler {}
54
55#[cfg(feature = "wasm")]
56#[wasm_bindgen]
57pub struct Scheduler {}
58
59#[cfg(not(feature = "wasm"))]
60impl Default for Scheduler {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66#[cfg(not(feature = "wasm"))]
67impl Scheduler {
68 pub fn new() -> Self {
69 Self {}
70 }
71
72 pub async fn calculate_initial_tasks(&self, workflow_run: &WorkflowRun) -> Result<Vec<Task>> {
74 self.calculate_initial_tasks_internal(workflow_run).await
75 }
76
77 pub async fn calculate_matrix_task_changes(
79 &self,
80 workflow_run_id: Uuid,
81 workflow_run: &WorkflowRun,
82 tasks: &[Task],
83 state: &HashMap<String, serde_json::Value>,
84 ) -> Result<MatrixTaskChanges> {
85 self.calculate_matrix_task_changes_internal(workflow_run_id, workflow_run, tasks, state)
86 .await
87 }
88
89 pub async fn find_runnable_tasks(
91 &self,
92 workflow_run: &WorkflowRun,
93 tasks: &[Task],
94 ) -> Result<RunnableTaskChanges> {
95 self.find_runnable_tasks_internal(workflow_run, tasks).await
96 }
97}
98
99#[cfg(feature = "wasm")]
100#[wasm_bindgen]
101impl Scheduler {
102 #[wasm_bindgen(constructor)]
104 pub fn new() -> Self {
105 Self {}
106 }
107
108 #[wasm_bindgen(js_name = calculateInitialTasks, unchecked_return_type = "Task[]")]
112 pub async fn calculate_initial_tasks_wasm(
113 &self,
114 #[wasm_bindgen(unchecked_param_type = "WorkflowRun")] workflow_run_js: JsValue,
115 ) -> std::result::Result<JsValue, JsValue> {
116 let workflow_run: WorkflowRun = from_value(workflow_run_js)
117 .map_err(|e| JsValue::from_str(&format!("Failed to deserialize WorkflowRun: {}", e)))?;
118 let serializer = serde_wasm_bindgen::Serializer::json_compatible()
119 .serialize_maps_as_objects(true)
120 .serialize_missing_as_null(true);
121
122 let result = self.calculate_initial_tasks_internal(&workflow_run).await;
123
124 match result {
125 Ok(tasks) => tasks
126 .serialize(&serializer)
127 .map_err(|e| JsValue::from_str(&format!("Failed to serialize tasks: {}", e))),
128 Err(e) => Err(JsValue::from_str(&e.to_string())),
129 }
130 }
131
132 #[wasm_bindgen(js_name = calculateMatrixTaskChanges, unchecked_return_type = "MatrixTaskChanges")]
134 pub async fn calculate_matrix_task_changes_wasm(
135 &self,
136 #[wasm_bindgen(unchecked_param_type = "Uuid")] workflow_run_id_js: JsValue, #[wasm_bindgen(unchecked_param_type = "WorkflowRun")] workflow_run_js: JsValue,
138 #[wasm_bindgen(unchecked_param_type = "Task[]")] tasks_js: JsValue,
139 #[wasm_bindgen(unchecked_param_type = "State")] state_js: JsValue, ) -> std::result::Result<JsValue, JsValue> {
141 let workflow_run_id_str: String = from_value(workflow_run_id_js).map_err(|e| {
142 JsValue::from_str(&format!("Failed to deserialize workflow_run_id: {}", e))
143 })?;
144 let workflow_run_id = Uuid::parse_str(&workflow_run_id_str).map_err(|e| {
145 JsValue::from_str(&format!("Invalid UUID format for workflow_run_id: {}", e))
146 })?;
147
148 let workflow_run: WorkflowRun = from_value(workflow_run_js)
149 .map_err(|e| JsValue::from_str(&format!("Failed to deserialize WorkflowRun: {}", e)))?;
150 let tasks: Vec<Task> = from_value(tasks_js)
151 .map_err(|e| JsValue::from_str(&format!("Failed to deserialize tasks: {}", e)))?;
152 let state: HashMap<String, serde_json::Value> = from_value(state_js)
153 .map_err(|e| JsValue::from_str(&format!("Failed to deserialize state: {}", e)))?;
154
155 let result = self
156 .calculate_matrix_task_changes_internal(workflow_run_id, &workflow_run, &tasks, &state)
157 .await;
158
159 match result {
160 Ok(changes) => to_value(&changes).map_err(|e| {
161 JsValue::from_str(&format!("Failed to serialize MatrixTaskChanges: {}", e))
162 }),
163 Err(e) => Err(JsValue::from_str(&e.to_string())),
164 }
165 }
166
167 #[wasm_bindgen(js_name = findRunnableTasks, unchecked_return_type = "RunnableTaskChanges")]
169 pub async fn find_runnable_tasks_wasm(
170 &self,
171 #[wasm_bindgen(unchecked_param_type = "WorkflowRun")] workflow_run_js: JsValue,
172 #[wasm_bindgen(unchecked_param_type = "Task[]")] tasks_js: JsValue,
173 ) -> std::result::Result<JsValue, JsValue> {
174 let workflow_run: WorkflowRun = from_value(workflow_run_js)
175 .map_err(|e| JsValue::from_str(&format!("Failed to deserialize WorkflowRun: {}", e)))?;
176 let tasks: Vec<Task> = from_value(tasks_js)
177 .map_err(|e| JsValue::from_str(&format!("Failed to deserialize tasks: {}", e)))?;
178
179 let result = self
180 .find_runnable_tasks_internal(&workflow_run, &tasks)
181 .await;
182
183 match result {
184 Ok(changes) => to_value(&changes).map_err(|e| {
185 JsValue::from_str(&format!("Failed to serialize RunnableTaskChanges: {}", e))
186 }),
187 Err(e) => Err(JsValue::from_str(&e.to_string())),
188 }
189 }
190}
191
192impl Scheduler {
194 async fn calculate_initial_tasks_internal(
195 &self,
196 workflow_run: &WorkflowRun,
197 ) -> Result<Vec<Task>> {
198 let mut tasks = Vec::new();
199
200 for node in &workflow_run.workflow.nodes {
201 if let Some(Strategy {
203 r#type: StrategyType::Matrix,
204 values,
205 from_state: _, }) = &node.strategy
207 {
208 let master_task = Task::new(workflow_run.id, node.id.clone(), true);
210 tasks.push(master_task.clone());
211
212 if let Some(values) = values {
214 for value in values {
215 let task = Task::new_matrix(
217 workflow_run.id,
218 node.id.clone(),
219 master_task.id,
220 value.clone(),
221 );
222 tasks.push(task);
223 }
224 }
225 } else {
227 let task = Task::new(workflow_run.id, node.id.clone(), false);
229 tasks.push(task);
230 }
231 }
232
233 Ok(tasks)
234 }
235
236 async fn calculate_matrix_task_changes_internal(
238 &self,
239 workflow_run_id: Uuid,
240 workflow_run: &WorkflowRun,
241 tasks: &[Task],
242 state: &HashMap<String, serde_json::Value>,
243 ) -> Result<MatrixTaskChanges> {
244 let mut new_tasks = Vec::new();
245 let mut tasks_to_mark_wont_do = Vec::new();
246 let mut master_tasks_to_update = Vec::new();
247
248 for node in &workflow_run.workflow.nodes {
249 if let Some(Strategy {
250 r#type: StrategyType::Matrix,
251 from_state: Some(state_key), .. }) = &node.strategy
254 {
255 debug!(
256 "Calculating changes for matrix node '{}' using state key '{}'",
257 node.id, state_key
258 );
259
260 let master_task_id =
262 match tasks.iter().find(|t| t.node_id == node.id && t.is_master) {
263 Some(master) => master.id,
264 None => {
265 let new_master_task = Task::new(workflow_run_id, node.id.clone(), true);
267 new_tasks.push(new_master_task.clone());
268 master_tasks_to_update.push(new_master_task.id);
269 new_master_task.id
270 }
271 };
272
273 if !master_tasks_to_update.contains(&master_task_id) {
275 master_tasks_to_update.push(master_task_id);
276 }
277
278 let state_value = state.get(state_key);
280
281 let mut current_item_values = Vec::new();
283
284 match state_value {
285 Some(serde_json::Value::Array(items)) => {
286 for item in items {
287 current_item_values.push(item.clone());
288 }
289 debug!("Found {} items in state array '{}'", items.len(), state_key);
290 }
291 Some(serde_json::Value::Object(_obj)) => {
292 warn!("Matrix from_state for object key '{}' is not yet fully supported, skipping.", state_key);
294 continue; }
296 _ => {
297 debug!("State key '{}' for matrix node '{}' is missing or not an array/object.", state_key, node.id);
299 }
300 }
301
302 let existing_child_tasks_by_value: HashMap<serde_json::Value, &Task> = tasks
305 .iter()
306 .filter(|t| {
307 t.master_task_id == Some(master_task_id) && t.matrix_values.is_some()
308 })
309 .filter_map(|t| {
310 serde_json::to_value(t.matrix_values.as_ref().unwrap())
311 .ok()
312 .map(|v| (v, t))
313 })
314 .collect();
315
316 let existing_child_values: HashSet<serde_json::Value> =
317 existing_child_tasks_by_value.keys().cloned().collect();
318
319 debug!(
320 "Found {} existing child tasks for node '{}'",
321 existing_child_tasks_by_value.len(),
322 node.id
323 );
324
325 let current_item_values_set: HashSet<_> =
327 current_item_values.iter().cloned().collect();
328
329 for item_value in current_item_values {
330 if !existing_child_values.contains(&item_value) {
331 let matrix_data = match item_value.as_object() {
333 Some(obj) => obj
334 .iter()
335 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), serde_json::Value::String(s.to_string()))))
336 .collect::<HashMap<_, _>>(),
337 None => {
338 warn!(
339 "Matrix item for node '{}' is not a JSON object, skipping: {:?}",
340 node.id,
341 item_value
342 );
343 continue; }
345 };
346
347 let new_task = Task::new_matrix(
348 workflow_run_id,
349 node.id.clone(),
350 master_task_id,
351 matrix_data,
352 );
353 debug!(
354 "Need to create new task for node '{}', value: {:?}",
355 node.id, item_value
356 );
357 new_tasks.push(new_task);
358 }
359 }
360
361 for (task_value, task) in &existing_child_tasks_by_value {
363 if !current_item_values_set.contains(task_value) {
364 if !matches!(
367 task.status,
368 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::WontDo
369 ) {
370 debug!(
371 "Need to mark task {} (value {:?}) for node '{}' as WontDo",
372 task.id, task_value, node.id
373 );
374 tasks_to_mark_wont_do.push(task.id);
375 }
376 }
377 }
378 }
379 }
380
381 Ok(MatrixTaskChanges {
382 new_tasks,
383 tasks_to_mark_wont_do,
384 master_tasks_to_update,
385 })
386 }
387
388 async fn find_runnable_tasks_internal(
390 &self,
391 workflow_run: &WorkflowRun,
392 tasks: &[Task],
393 ) -> Result<RunnableTaskChanges> {
394 let mut runnable_tasks = Vec::new();
395 let mut tasks_to_await_trigger = Vec::new();
396
397 for task in tasks {
398 if task.status != TaskStatus::Pending || task.is_master {
400 continue;
401 }
402
403 let node = workflow_run
405 .workflow
406 .nodes
407 .iter()
408 .find(|n| n.id == task.node_id)
409 .ok_or_else(|| Error::NodeNotFound(task.node_id.clone()))?;
410
411 if node.r#type == NodeType::Manual
413 || node
414 .trigger
415 .as_ref()
416 .map(|t| t.r#type == TriggerType::Manual)
417 .unwrap_or(false)
418 {
419 tasks_to_await_trigger.push(task.id);
420 continue;
421 }
422
423 let mut dependencies_satisfied = true;
425 for dep_id in &node.depends_on {
426 let dep_tasks: Vec<&Task> = tasks.iter().filter(|t| t.node_id == *dep_id).collect();
428
429 if dep_tasks.is_empty() {
431 dependencies_satisfied = false;
432 break;
433 }
434
435 let all_completed = dep_tasks.iter().all(|t| t.status == TaskStatus::Completed);
437
438 if !all_completed {
439 dependencies_satisfied = false;
440 break;
441 }
442 }
443
444 if dependencies_satisfied {
445 runnable_tasks.push(task.id);
446 }
447 }
448
449 Ok(RunnableTaskChanges {
450 tasks_to_await_trigger,
451 runnable_tasks,
452 })
453 }
454}