Skip to main content

celers_broker_sql/
workflow.rs

1//! Workflow types, task hooks, and builder patterns
2//!
3//! This module contains workflow orchestration types including DAG workflows,
4//! stage-based execution, task lifecycle hooks, and builder patterns.
5
6use celers_core::{Result, SerializedTask};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use uuid::Uuid;
11
12/// Workflow state
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum WorkflowState {
16    Pending,
17    Running,
18    Completed,
19    Failed,
20    Cancelled,
21}
22
23impl std::fmt::Display for WorkflowState {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            WorkflowState::Pending => write!(f, "pending"),
27            WorkflowState::Running => write!(f, "running"),
28            WorkflowState::Completed => write!(f, "completed"),
29            WorkflowState::Failed => write!(f, "failed"),
30            WorkflowState::Cancelled => write!(f, "cancelled"),
31        }
32    }
33}
34
35/// Workflow stage state
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37#[serde(rename_all = "lowercase")]
38pub enum StageState {
39    Pending,
40    Running,
41    Completed,
42    Failed,
43    Skipped,
44}
45
46impl std::fmt::Display for StageState {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            StageState::Pending => write!(f, "pending"),
50            StageState::Running => write!(f, "running"),
51            StageState::Completed => write!(f, "completed"),
52            StageState::Failed => write!(f, "failed"),
53            StageState::Skipped => write!(f, "skipped"),
54        }
55    }
56}
57
58/// Workflow definition
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct Workflow {
61    pub id: Uuid,
62    pub workflow_name: String,
63    pub state: WorkflowState,
64    pub config: serde_json::Value,
65    pub created_at: DateTime<Utc>,
66    pub started_at: Option<DateTime<Utc>>,
67    pub completed_at: Option<DateTime<Utc>>,
68    pub error_message: Option<String>,
69}
70
71/// Workflow stage for parallel execution
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct WorkflowStage {
74    pub id: Uuid,
75    pub workflow_id: Uuid,
76    pub stage_number: i32,
77    pub stage_name: String,
78    pub state: StageState,
79    pub task_count: i32,
80    pub completed_count: i32,
81    pub failed_count: i32,
82    pub started_at: Option<DateTime<Utc>>,
83    pub completed_at: Option<DateTime<Utc>>,
84}
85
86/// Task dependency for DAG execution
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct TaskDependency {
89    pub id: Uuid,
90    pub task_id: Uuid,
91    pub parent_task_id: Uuid,
92    pub workflow_id: Option<Uuid>,
93    pub stage_id: Option<Uuid>,
94    pub satisfied: bool,
95    pub created_at: DateTime<Utc>,
96    pub satisfied_at: Option<DateTime<Utc>>,
97}
98
99/// Workflow statistics
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct WorkflowStatistics {
102    pub workflow_id: Uuid,
103    pub workflow_name: String,
104    pub workflow_state: WorkflowState,
105    pub created_at: DateTime<Utc>,
106    pub started_at: Option<DateTime<Utc>>,
107    pub completed_at: Option<DateTime<Utc>>,
108    pub total_stages: i64,
109    pub completed_stages: i64,
110    pub failed_stages: i64,
111    pub running_stages: i64,
112    pub total_tasks: i64,
113    pub completed_tasks: i64,
114    pub failed_tasks: i64,
115    pub duration_secs: Option<i64>,
116}
117
118/// Workflow builder for creating complex DAG workflows
119#[derive(Debug, Clone)]
120pub struct WorkflowBuilder {
121    workflow_name: String,
122    stages: Vec<WorkflowStageBuilder>,
123}
124
125/// Workflow stage builder
126#[derive(Debug, Clone)]
127pub struct WorkflowStageBuilder {
128    stage_name: String,
129    tasks: Vec<SerializedTask>,
130    dependencies: Vec<String>, // Stage names this stage depends on
131}
132
133impl WorkflowBuilder {
134    /// Create a new workflow builder
135    pub fn new(workflow_name: String) -> Self {
136        Self {
137            workflow_name,
138            stages: Vec::new(),
139        }
140    }
141
142    /// Add a new stage to the workflow
143    pub fn add_stage(mut self, stage_name: String) -> Self {
144        self.stages.push(WorkflowStageBuilder {
145            stage_name,
146            tasks: Vec::new(),
147            dependencies: Vec::new(),
148        });
149        self
150    }
151
152    /// Add a task to the current stage
153    pub fn add_task_to_stage(mut self, task: SerializedTask) -> Self {
154        if let Some(stage) = self.stages.last_mut() {
155            stage.tasks.push(task);
156        }
157        self
158    }
159
160    /// Add dependencies to the current stage
161    pub fn add_stage_dependencies(mut self, dependencies: Vec<String>) -> Self {
162        if let Some(stage) = self.stages.last_mut() {
163            stage.dependencies = dependencies;
164        }
165        self
166    }
167
168    /// Get the workflow name
169    pub fn workflow_name(&self) -> &str {
170        &self.workflow_name
171    }
172
173    /// Get the stages
174    pub fn stages(&self) -> &[WorkflowStageBuilder] {
175        &self.stages
176    }
177}
178
179impl WorkflowStageBuilder {
180    /// Get the stage name
181    pub fn stage_name(&self) -> &str {
182        &self.stage_name
183    }
184
185    /// Get the tasks in this stage
186    pub fn tasks(&self) -> &[SerializedTask] {
187        &self.tasks
188    }
189
190    /// Get the dependencies
191    pub fn dependencies(&self) -> &[String] {
192        &self.dependencies
193    }
194}
195
196/// Type alias for async lifecycle hook functions
197///
198/// Hooks are async functions that take a hook context and serialized task,
199/// and return a Result. They can be used to inject custom logic at various
200/// points in the task lifecycle.
201pub type HookFn = Arc<
202    dyn Fn(
203            &HookContext,
204            &SerializedTask,
205        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
206        + Send
207        + Sync,
208>;
209
210/// Context passed to lifecycle hooks
211#[derive(Debug, Clone)]
212pub struct HookContext {
213    /// Queue name
214    pub queue_name: String,
215    /// Task ID (if available)
216    pub task_id: Option<Uuid>,
217    /// Current timestamp
218    pub timestamp: DateTime<Utc>,
219    /// Additional metadata
220    pub metadata: serde_json::Value,
221}
222
223/// Task lifecycle hook enum
224#[derive(Clone)]
225pub enum TaskHook {
226    /// Called before a task is enqueued
227    BeforeEnqueue(HookFn),
228    /// Called after a task is successfully enqueued
229    AfterEnqueue(HookFn),
230    /// Called before a task is dequeued (reserved for future use)
231    BeforeDequeue(HookFn),
232    /// Called after a task is dequeued
233    AfterDequeue(HookFn),
234    /// Called before a task is acknowledged
235    BeforeAck(HookFn),
236    /// Called after a task is acknowledged
237    AfterAck(HookFn),
238    /// Called before a task is rejected
239    BeforeReject(HookFn),
240    /// Called after a task is rejected
241    AfterReject(HookFn),
242}
243
244/// Container for all registered hooks
245#[derive(Clone, Default)]
246pub struct TaskHooks {
247    pub(crate) before_enqueue: Vec<HookFn>,
248    pub(crate) after_enqueue: Vec<HookFn>,
249    pub(crate) before_dequeue: Vec<HookFn>,
250    pub(crate) after_dequeue: Vec<HookFn>,
251    pub(crate) before_ack: Vec<HookFn>,
252    pub(crate) after_ack: Vec<HookFn>,
253    pub(crate) before_reject: Vec<HookFn>,
254    pub(crate) after_reject: Vec<HookFn>,
255}
256
257impl TaskHooks {
258    /// Create empty hooks container
259    pub fn new() -> Self {
260        Self::default()
261    }
262
263    /// Add a hook
264    pub fn add(&mut self, hook: TaskHook) {
265        match hook {
266            TaskHook::BeforeEnqueue(f) => self.before_enqueue.push(f),
267            TaskHook::AfterEnqueue(f) => self.after_enqueue.push(f),
268            TaskHook::BeforeDequeue(f) => self.before_dequeue.push(f),
269            TaskHook::AfterDequeue(f) => self.after_dequeue.push(f),
270            TaskHook::BeforeAck(f) => self.before_ack.push(f),
271            TaskHook::AfterAck(f) => self.after_ack.push(f),
272            TaskHook::BeforeReject(f) => self.before_reject.push(f),
273            TaskHook::AfterReject(f) => self.after_reject.push(f),
274        }
275    }
276
277    /// Clear all hooks
278    pub fn clear(&mut self) {
279        self.before_enqueue.clear();
280        self.after_enqueue.clear();
281        self.before_dequeue.clear();
282        self.after_dequeue.clear();
283        self.before_ack.clear();
284        self.after_ack.clear();
285        self.before_reject.clear();
286        self.after_reject.clear();
287    }
288
289    /// Execute before_enqueue hooks
290    pub(crate) async fn run_before_enqueue(
291        &self,
292        ctx: &HookContext,
293        task: &SerializedTask,
294    ) -> Result<()> {
295        for hook in &self.before_enqueue {
296            hook(ctx, task).await?;
297        }
298        Ok(())
299    }
300
301    /// Execute after_enqueue hooks
302    pub(crate) async fn run_after_enqueue(
303        &self,
304        ctx: &HookContext,
305        task: &SerializedTask,
306    ) -> Result<()> {
307        for hook in &self.after_enqueue {
308            hook(ctx, task).await?;
309        }
310        Ok(())
311    }
312
313    /// Execute before_ack hooks
314    #[allow(dead_code)]
315    pub(crate) async fn run_before_ack(
316        &self,
317        ctx: &HookContext,
318        task: &SerializedTask,
319    ) -> Result<()> {
320        for hook in &self.before_ack {
321            hook(ctx, task).await?;
322        }
323        Ok(())
324    }
325
326    /// Execute after_ack hooks
327    #[allow(dead_code)]
328    pub(crate) async fn run_after_ack(
329        &self,
330        ctx: &HookContext,
331        task: &SerializedTask,
332    ) -> Result<()> {
333        for hook in &self.after_ack {
334            hook(ctx, task).await?;
335        }
336        Ok(())
337    }
338
339    /// Execute before_reject hooks
340    #[allow(dead_code)]
341    pub(crate) async fn run_before_reject(
342        &self,
343        ctx: &HookContext,
344        task: &SerializedTask,
345    ) -> Result<()> {
346        for hook in &self.before_reject {
347            hook(ctx, task).await?;
348        }
349        Ok(())
350    }
351
352    /// Execute after_reject hooks
353    #[allow(dead_code)]
354    pub(crate) async fn run_after_reject(
355        &self,
356        ctx: &HookContext,
357        task: &SerializedTask,
358    ) -> Result<()> {
359        for hook in &self.after_reject {
360            hook(ctx, task).await?;
361        }
362        Ok(())
363    }
364
365    /// Execute after_dequeue hooks
366    #[allow(dead_code)]
367    pub(crate) async fn run_after_dequeue(
368        &self,
369        ctx: &HookContext,
370        task: &SerializedTask,
371    ) -> Result<()> {
372        for hook in &self.after_dequeue {
373            hook(ctx, task).await?;
374        }
375        Ok(())
376    }
377}