Skip to main content

butterflow_scheduler/
lib.rs

1use std::collections::{HashMap, HashSet};
2
3use log::{debug, warn};
4#[cfg(feature = "wasm")]
5use serde::Serialize;
6#[cfg(feature = "wasm")]
7use serde_wasm_bindgen::{from_value, to_value};
8use uuid::Uuid;
9#[cfg(feature = "wasm")]
10use wasm_bindgen::prelude::*;
11
12use butterflow_models::node::NodeType;
13use butterflow_models::trigger::TriggerType;
14use butterflow_models::{Error, Result, Strategy, StrategyType, Task, TaskStatus, WorkflowRun};
15
16#[cfg(feature = "wasm")]
17#[wasm_bindgen(typescript_custom_section)]
18const MATRIX_TASK_CHANGES: &'static str = r#"
19type Uuid = string;
20
21type Task = import("../types").Task;
22type WorkflowRun = import("../types").WorkflowRun;
23type State = Record<string, unknown>;
24
25interface MatrixTaskChanges {
26    new_tasks: Task[];
27    tasks_to_mark_wont_do: Uuid[];
28    master_tasks_to_update: Uuid[];
29}
30
31interface RunnableTaskChanges {
32    tasks_to_await_trigger: Uuid[];
33    runnable_tasks: Uuid[];
34}
35"#;
36
37/// Struct to hold the result of matrix task recompilation calculations
38#[derive(serde::Serialize, serde::Deserialize)]
39pub struct MatrixTaskChanges {
40    pub new_tasks: Vec<Task>,
41    pub tasks_to_mark_wont_do: Vec<Uuid>,
42    pub master_tasks_to_update: Vec<Uuid>,
43}
44
45/// Struct to hold the result of finding runnable tasks
46#[derive(serde::Serialize, serde::Deserialize)]
47pub struct RunnableTaskChanges {
48    pub tasks_to_await_trigger: Vec<Uuid>,
49    pub runnable_tasks: Vec<Uuid>,
50}
51
52#[cfg(not(feature = "wasm"))]
53pub struct Scheduler {}
54
55#[cfg(feature = "wasm")]
56#[wasm_bindgen]
57pub struct Scheduler {}
58
59#[cfg(not(feature = "wasm"))]
60impl Default for Scheduler {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66#[cfg(not(feature = "wasm"))]
67impl Scheduler {
68    pub fn new() -> Self {
69        Self {}
70    }
71
72    /// Calculate initial tasks for all nodes in a workflow
73    pub async fn calculate_initial_tasks(&self, workflow_run: &WorkflowRun) -> Result<Vec<Task>> {
74        self.calculate_initial_tasks_internal(workflow_run).await
75    }
76
77    /// Calculate changes needed for matrix tasks based on current state
78    pub async fn calculate_matrix_task_changes(
79        &self,
80        workflow_run_id: Uuid,
81        workflow_run: &WorkflowRun,
82        tasks: &[Task],
83        state: &HashMap<String, serde_json::Value>,
84    ) -> Result<MatrixTaskChanges> {
85        self.calculate_matrix_task_changes_internal(workflow_run_id, workflow_run, tasks, state)
86            .await
87    }
88
89    /// Find tasks that can be executed
90    pub async fn find_runnable_tasks(
91        &self,
92        workflow_run: &WorkflowRun,
93        tasks: &[Task],
94    ) -> Result<RunnableTaskChanges> {
95        self.find_runnable_tasks_internal(workflow_run, tasks).await
96    }
97}
98
99#[cfg(feature = "wasm")]
100#[wasm_bindgen]
101impl Scheduler {
102    // Expose constructor to WASM
103    #[wasm_bindgen(constructor)]
104    pub fn new() -> Self {
105        Self {}
106    }
107
108    // --- WASM Exposed Methods ---
109
110    /// Calculate initial tasks for a workflow run (WASM API).
111    #[wasm_bindgen(js_name = calculateInitialTasks, unchecked_return_type = "Task[]")]
112    pub async fn calculate_initial_tasks_wasm(
113        &self,
114        #[wasm_bindgen(unchecked_param_type = "WorkflowRun")] workflow_run_js: JsValue,
115    ) -> std::result::Result<JsValue, JsValue> {
116        let workflow_run: WorkflowRun = from_value(workflow_run_js)
117            .map_err(|e| JsValue::from_str(&format!("Failed to deserialize WorkflowRun: {}", e)))?;
118        let serializer = serde_wasm_bindgen::Serializer::json_compatible()
119            .serialize_maps_as_objects(true)
120            .serialize_missing_as_null(true);
121
122        let result = self.calculate_initial_tasks_internal(&workflow_run).await;
123
124        match result {
125            Ok(tasks) => tasks
126                .serialize(&serializer)
127                .map_err(|e| JsValue::from_str(&format!("Failed to serialize tasks: {}", e))),
128            Err(e) => Err(JsValue::from_str(&e.to_string())),
129        }
130    }
131
132    /// Calculate changes needed for matrix tasks based on current state (WASM API).
133    #[wasm_bindgen(js_name = calculateMatrixTaskChanges, unchecked_return_type = "MatrixTaskChanges")]
134    pub async fn calculate_matrix_task_changes_wasm(
135        &self,
136        #[wasm_bindgen(unchecked_param_type = "Uuid")] workflow_run_id_js: JsValue, // Expect Uuid as string
137        #[wasm_bindgen(unchecked_param_type = "WorkflowRun")] workflow_run_js: JsValue,
138        #[wasm_bindgen(unchecked_param_type = "Task[]")] tasks_js: JsValue,
139        #[wasm_bindgen(unchecked_param_type = "State")] state_js: JsValue, // Expect JSON object
140    ) -> std::result::Result<JsValue, JsValue> {
141        let workflow_run_id_str: String = from_value(workflow_run_id_js).map_err(|e| {
142            JsValue::from_str(&format!("Failed to deserialize workflow_run_id: {}", e))
143        })?;
144        let workflow_run_id = Uuid::parse_str(&workflow_run_id_str).map_err(|e| {
145            JsValue::from_str(&format!("Invalid UUID format for workflow_run_id: {}", e))
146        })?;
147
148        let workflow_run: WorkflowRun = from_value(workflow_run_js)
149            .map_err(|e| JsValue::from_str(&format!("Failed to deserialize WorkflowRun: {}", e)))?;
150        let tasks: Vec<Task> = from_value(tasks_js)
151            .map_err(|e| JsValue::from_str(&format!("Failed to deserialize tasks: {}", e)))?;
152        let state: HashMap<String, serde_json::Value> = from_value(state_js)
153            .map_err(|e| JsValue::from_str(&format!("Failed to deserialize state: {}", e)))?;
154
155        let result = self
156            .calculate_matrix_task_changes_internal(workflow_run_id, &workflow_run, &tasks, &state)
157            .await;
158
159        match result {
160            Ok(changes) => to_value(&changes).map_err(|e| {
161                JsValue::from_str(&format!("Failed to serialize MatrixTaskChanges: {}", e))
162            }),
163            Err(e) => Err(JsValue::from_str(&e.to_string())),
164        }
165    }
166
167    /// Find tasks that can be executed (WASM API).
168    #[wasm_bindgen(js_name = findRunnableTasks, unchecked_return_type = "RunnableTaskChanges")]
169    pub async fn find_runnable_tasks_wasm(
170        &self,
171        #[wasm_bindgen(unchecked_param_type = "WorkflowRun")] workflow_run_js: JsValue,
172        #[wasm_bindgen(unchecked_param_type = "Task[]")] tasks_js: JsValue,
173    ) -> std::result::Result<JsValue, JsValue> {
174        let workflow_run: WorkflowRun = from_value(workflow_run_js)
175            .map_err(|e| JsValue::from_str(&format!("Failed to deserialize WorkflowRun: {}", e)))?;
176        let tasks: Vec<Task> = from_value(tasks_js)
177            .map_err(|e| JsValue::from_str(&format!("Failed to deserialize tasks: {}", e)))?;
178
179        let result = self
180            .find_runnable_tasks_internal(&workflow_run, &tasks)
181            .await;
182
183        match result {
184            Ok(changes) => to_value(&changes).map_err(|e| {
185                JsValue::from_str(&format!("Failed to serialize RunnableTaskChanges: {}", e))
186            }),
187            Err(e) => Err(JsValue::from_str(&e.to_string())),
188        }
189    }
190}
191
192// Internal implementation shared by both Rust and WASM APIs
193impl Scheduler {
194    async fn calculate_initial_tasks_internal(
195        &self,
196        workflow_run: &WorkflowRun,
197    ) -> Result<Vec<Task>> {
198        let mut tasks = Vec::new();
199
200        for node in &workflow_run.workflow.nodes {
201            // Check if the node has a matrix strategy
202            if let Some(Strategy {
203                r#type: StrategyType::Matrix,
204                values,
205                from_state: _, // Corrected variable name
206            }) = &node.strategy
207            {
208                // Create a master task for the matrix
209                let master_task = Task::new(workflow_run.id, node.id.clone(), true);
210                tasks.push(master_task.clone());
211
212                // If the matrix uses values, create tasks for each value
213                if let Some(values) = values {
214                    for value in values {
215                        // Create a task for each matrix value
216                        let task = Task::new_matrix(
217                            workflow_run.id,
218                            node.id.clone(),
219                            master_task.id,
220                            value.clone(),
221                        );
222                        tasks.push(task);
223                    }
224                }
225                // If the matrix uses state, tasks will be created during recompilation
226            } else {
227                // Create a single task for the node
228                let task = Task::new(workflow_run.id, node.id.clone(), false);
229                tasks.push(task);
230            }
231        }
232
233        Ok(tasks)
234    }
235
236    /// Calculate changes needed for matrix tasks based on current state
237    async fn calculate_matrix_task_changes_internal(
238        &self,
239        workflow_run_id: Uuid,
240        workflow_run: &WorkflowRun,
241        tasks: &[Task],
242        state: &HashMap<String, serde_json::Value>,
243    ) -> Result<MatrixTaskChanges> {
244        let mut new_tasks = Vec::new();
245        let mut tasks_to_mark_wont_do = Vec::new();
246        let mut master_tasks_to_update = Vec::new();
247
248        for node in &workflow_run.workflow.nodes {
249            if let Some(Strategy {
250                r#type: StrategyType::Matrix,
251                from_state: Some(state_key), // Only process matrix nodes using from_state
252                .. // Use .. to ignore other fields like `values`
253            }) = &node.strategy
254            {
255                debug!(
256                    "Calculating changes for matrix node '{}' using state key '{}'",
257                    node.id, state_key
258                );
259
260                // Find the master task for this node
261                let master_task_id =
262                    match tasks.iter().find(|t| t.node_id == node.id && t.is_master) {
263                        Some(master) => master.id,
264                        None => {
265                            // Master task doesn't exist yet, create it
266                            let new_master_task = Task::new(workflow_run_id, node.id.clone(), true);
267                            new_tasks.push(new_master_task.clone());
268                            master_tasks_to_update.push(new_master_task.id);
269                            new_master_task.id
270                        }
271                    };
272
273                // Add master task to update list if not already there
274                if !master_tasks_to_update.contains(&master_task_id) {
275                    master_tasks_to_update.push(master_task_id);
276                }
277
278                // Get the current value from the state
279                let state_value = state.get(state_key);
280
281                // --- Calculate Values for Current State Items ---
282                let mut current_item_values = Vec::new();
283
284                match state_value {
285                    Some(serde_json::Value::Array(items)) => {
286                        for item in items {
287                            current_item_values.push(item.clone());
288                        }
289                        debug!("Found {} items in state array '{}'", items.len(), state_key);
290                    }
291                    Some(serde_json::Value::Object(_obj)) => {
292                        // Object mapping not fully supported yet
293                        warn!("Matrix from_state for object key '{}' is not yet fully supported, skipping.", state_key);
294                        continue; // Skip this node
295                    }
296                    _ => {
297                        // State key not found or not an array/object
298                        debug!("State key '{}' for matrix node '{}' is missing or not an array/object.", state_key, node.id);
299                    }
300                }
301
302                // --- Compare with Existing Tasks ---
303                // Store existing tasks keyed by their matrix_values for comparison
304                let existing_child_tasks_by_value: HashMap<serde_json::Value, &Task> = tasks
305                    .iter()
306                    .filter(|t| {
307                        t.master_task_id == Some(master_task_id) && t.matrix_values.is_some()
308                    })
309                    .filter_map(|t| {
310                        serde_json::to_value(t.matrix_values.as_ref().unwrap())
311                            .ok()
312                            .map(|v| (v, t))
313                    })
314                    .collect();
315
316                let existing_child_values: HashSet<serde_json::Value> =
317                    existing_child_tasks_by_value.keys().cloned().collect();
318
319                debug!(
320                    "Found {} existing child tasks for node '{}'",
321                    existing_child_tasks_by_value.len(),
322                    node.id
323                );
324
325                // --- Identify Tasks to Create ---
326                let current_item_values_set: HashSet<_> =
327                    current_item_values.iter().cloned().collect();
328
329                for item_value in current_item_values {
330                    if !existing_child_values.contains(&item_value) {
331                        // Task for this value doesn't exist, need to create it
332                        let matrix_data = match item_value.as_object() {
333                            Some(obj) => obj
334                                .iter()
335                                .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), serde_json::Value::String(s.to_string()))))
336                                .collect::<HashMap<_, _>>(),
337                            None => {
338                                warn!(
339                                    "Matrix item for node '{}' is not a JSON object, skipping: {:?}",
340                                    node.id,
341                                    item_value
342                                );
343                                continue; // Skip this item
344                            }
345                        };
346
347                        let new_task = Task::new_matrix(
348                            workflow_run_id,
349                            node.id.clone(),
350                            master_task_id,
351                            matrix_data,
352                        );
353                        debug!(
354                            "Need to create new task for node '{}', value: {:?}",
355                            node.id, item_value
356                        );
357                        new_tasks.push(new_task);
358                    }
359                }
360
361                // --- Identify Tasks to Mark as WontDo ---
362                for (task_value, task) in &existing_child_tasks_by_value {
363                    if !current_item_values_set.contains(task_value) {
364                        // This task's value is no longer in the current state
365                        // Mark as WontDo only if it's not already in a terminal state
366                        if !matches!(
367                            task.status,
368                            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::WontDo
369                        ) {
370                            debug!(
371                                "Need to mark task {} (value {:?}) for node '{}' as WontDo",
372                                task.id, task_value, node.id
373                            );
374                            tasks_to_mark_wont_do.push(task.id);
375                        }
376                    }
377                }
378            }
379        }
380
381        Ok(MatrixTaskChanges {
382            new_tasks,
383            tasks_to_mark_wont_do,
384            master_tasks_to_update,
385        })
386    }
387
388    /// Find tasks that can be executed
389    async fn find_runnable_tasks_internal(
390        &self,
391        workflow_run: &WorkflowRun,
392        tasks: &[Task],
393    ) -> Result<RunnableTaskChanges> {
394        let mut runnable_tasks = Vec::new();
395        let mut tasks_to_await_trigger = Vec::new();
396
397        for task in tasks {
398            // Only consider pending tasks and non-master tasks
399            if task.status != TaskStatus::Pending || task.is_master {
400                continue;
401            }
402
403            // Get the node for this task
404            let node = workflow_run
405                .workflow
406                .nodes
407                .iter()
408                .find(|n| n.id == task.node_id)
409                .ok_or_else(|| Error::NodeNotFound(task.node_id.clone()))?;
410
411            // Check if the node has a manual trigger
412            if node.r#type == NodeType::Manual
413                || node
414                    .trigger
415                    .as_ref()
416                    .map(|t| t.r#type == TriggerType::Manual)
417                    .unwrap_or(false)
418            {
419                tasks_to_await_trigger.push(task.id);
420                continue;
421            }
422
423            // Check if all dependencies are satisfied
424            let mut dependencies_satisfied = true;
425            for dep_id in &node.depends_on {
426                // Find all tasks for this dependency
427                let dep_tasks: Vec<&Task> = tasks.iter().filter(|t| t.node_id == *dep_id).collect();
428
429                // If there are no tasks for this dependency, it's not satisfied
430                if dep_tasks.is_empty() {
431                    dependencies_satisfied = false;
432                    break;
433                }
434
435                // Check if all tasks for this dependency are completed
436                let all_completed = dep_tasks.iter().all(|t| t.status == TaskStatus::Completed);
437
438                if !all_completed {
439                    dependencies_satisfied = false;
440                    break;
441                }
442            }
443
444            if dependencies_satisfied {
445                runnable_tasks.push(task.id);
446            }
447        }
448
449        Ok(RunnableTaskChanges {
450            tasks_to_await_trigger,
451            runnable_tasks,
452        })
453    }
454}