1use std::{
7 collections::HashMap,
8 fmt,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use serde::{Deserialize, Serialize};
14use tokio::sync::{mpsc, RwLock};
15
16pub mod definition;
17pub mod engine;
18pub mod state_store;
19pub mod coordinator;
20pub mod monitor;
21
22pub use definition::*;
23pub use engine::*;
24pub use state_store::*;
25pub use coordinator::*;
26pub use monitor::*;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
30pub struct WorkflowId(pub [u8; 16]);
31
32impl WorkflowId {
33 pub fn generate() -> Self {
35 let mut id = [0u8; 16];
36 use rand::Rng;
37 rand::thread_rng().fill(&mut id);
38 Self(id)
39 }
40}
41
42impl fmt::Display for WorkflowId {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 write!(f, "{}", hex::encode(&self.0[..8]))
45 }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub struct StageId(pub String);
51
52impl fmt::Display for StageId {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(f, "{}", self.0)
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
60pub struct Version {
61 pub major: u32,
62 pub minor: u32,
63 pub patch: u32,
64}
65
66impl fmt::Display for Version {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
74pub enum WorkflowEvent {
75 Start,
77 StageCompleted { stage_id: StageId },
79 StageFailed { stage_id: StageId, error: String },
81 External { event_type: String, data: Vec<u8> },
83 Timeout { stage_id: StageId },
85 Cancel,
87 SystemError { error: String },
89}
90
91#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub enum WorkflowStatus {
94 Initializing,
96 Running { current_stage: StageId },
98 Waiting { stage: StageId, event: String },
100 Paused { stage: StageId },
102 Completed { result: WorkflowResult },
104 Failed { error: WorkflowError },
106 Cancelled,
108}
109
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
112pub struct WorkflowResult {
113 pub output: HashMap<String, Vec<u8>>,
115 pub duration: Duration,
117 pub metrics: WorkflowMetrics,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
123pub struct WorkflowError {
124 pub code: String,
126 pub message: String,
128 pub stage: Option<StageId>,
130 pub trace: Option<String>,
132 pub recovery_hints: Vec<String>,
134}
135
136impl fmt::Display for WorkflowError {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 write!(f, "[{}] {}", self.code, self.message)?;
139 if let Some(stage) = &self.stage {
140 write!(f, " at stage {}", stage)?;
141 }
142 Ok(())
143 }
144}
145
146impl std::error::Error for WorkflowError {}
147
148#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
150pub struct WorkflowMetrics {
151 pub stages_executed: u32,
153 pub retry_count: u32,
155 pub error_count: u32,
157 pub stage_durations: HashMap<StageId, Duration>,
159 pub custom: HashMap<String, f64>,
161}
162
163#[derive(Debug, Clone)]
165pub struct WorkflowHandle {
166 pub id: WorkflowId,
168 event_tx: mpsc::Sender<WorkflowEvent>,
170 status: Arc<RwLock<WorkflowStatus>>,
172}
173
174impl WorkflowHandle {
175 pub fn new(id: WorkflowId, event_tx: mpsc::Sender<WorkflowEvent>) -> Self {
177 Self {
178 id,
179 event_tx,
180 status: Arc::new(RwLock::new(WorkflowStatus::Initializing)),
181 }
182 }
183
184 pub async fn send_event(&self, event: WorkflowEvent) -> Result<(), WorkflowError> {
186 self.event_tx.send(event).await.map_err(|_| WorkflowError {
187 code: "SEND_FAILED".to_string(),
188 message: "Failed to send event to workflow".to_string(),
189 stage: None,
190 trace: None,
191 recovery_hints: vec!["Workflow may have terminated".to_string()],
192 })
193 }
194
195 pub async fn status(&self) -> WorkflowStatus {
197 self.status.read().await.clone()
198 }
199
200 pub async fn cancel(&self) -> Result<(), WorkflowError> {
202 self.send_event(WorkflowEvent::Cancel).await
203 }
204
205 pub(crate) async fn update_status(&self, status: WorkflowStatus) {
207 *self.status.write().await = status;
208 }
209}
210
211#[derive(Debug)]
213pub struct WorkflowContext {
214 pub workflow_id: WorkflowId,
216 pub current_stage: StageId,
218 pub state: HashMap<String, Vec<u8>>,
220 pub metrics: WorkflowMetrics,
222 pub stage_start: Instant,
224}
225
226impl WorkflowContext {
227 pub fn set_state(&mut self, key: String, value: Vec<u8>) {
229 self.state.insert(key, value);
230 }
231
232 pub fn get_state(&self, key: &str) -> Option<&Vec<u8>> {
234 self.state.get(key)
235 }
236
237 pub fn record_metric(&mut self, name: String, value: f64) {
239 self.metrics.custom.insert(name, value);
240 }
241}
242
243#[async_trait::async_trait]
245pub trait WorkflowAction: Send + Sync {
246 async fn execute(&self, context: &mut WorkflowContext) -> Result<(), WorkflowError>;
248
249 fn name(&self) -> &str;
251}
252
253#[async_trait::async_trait]
255pub trait Condition: Send + Sync {
256 async fn check(&self, context: &WorkflowContext) -> bool;
258
259 fn description(&self) -> &str;
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct ErrorHandler {
266 pub max_retries: u32,
268 pub backoff: BackoffStrategy,
270 pub fallback_stage: Option<StageId>,
272 pub propagate: bool,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
278pub enum BackoffStrategy {
279 Fixed { delay: Duration },
281 Exponential { initial: Duration, max: Duration, factor: f64 },
283 Linear { initial: Duration, increment: Duration },
285}
286
287impl BackoffStrategy {
288 pub fn calculate_delay(&self, attempt: u32) -> Duration {
290 match self {
291 BackoffStrategy::Fixed { delay } => *delay,
292 BackoffStrategy::Exponential { initial, max, factor } => {
293 let delay = initial.as_millis() as f64 * factor.powi(attempt as i32);
294 let delay_ms = delay.min(max.as_millis() as f64) as u64;
295 Duration::from_millis(delay_ms)
296 }
297 BackoffStrategy::Linear { initial, increment } => {
298 *initial + increment.saturating_mul(attempt)
299 }
300 }
301 }
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306pub enum RollbackStrategy {
307 None,
309 Compensate { actions: Vec<String> },
311 RestoreCheckpoint { checkpoint_id: String },
313 JumpToStage { stage_id: StageId },
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_workflow_id_generation() {
323 let id1 = WorkflowId::generate();
324 let id2 = WorkflowId::generate();
325 assert_ne!(id1, id2);
326 }
327
328 #[test]
329 fn test_backoff_strategy() {
330 let fixed = BackoffStrategy::Fixed { delay: Duration::from_secs(1) };
331 assert_eq!(fixed.calculate_delay(0), Duration::from_secs(1));
332 assert_eq!(fixed.calculate_delay(5), Duration::from_secs(1));
333
334 let exponential = BackoffStrategy::Exponential {
335 initial: Duration::from_millis(100),
336 max: Duration::from_secs(10),
337 factor: 2.0,
338 };
339 assert_eq!(exponential.calculate_delay(0), Duration::from_millis(100));
340 assert_eq!(exponential.calculate_delay(1), Duration::from_millis(200));
341 assert_eq!(exponential.calculate_delay(2), Duration::from_millis(400));
342 }
343
344 #[test]
345 fn test_version_display() {
346 let version = Version { major: 1, minor: 2, patch: 3 };
347 assert_eq!(version.to_string(), "1.2.3");
348 }
349}