1use std::collections::{HashMap, VecDeque};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use serde::{Deserialize, Serialize};
10
11use crate::ir::{WorkflowExecutionId, ActivityExecutionId, ExecutionStatus, WorkflowStrategyOp};
12use kotoba_errors::WorkflowError;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub enum SagaStatus {
17 Started,
19 Executing,
21 Completed,
23 Compensating,
25 Compensated,
27 Failed,
29 TimedOut,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct SagaContext {
36 pub saga_id: String,
37 pub workflow_id: WorkflowExecutionId,
38 pub status: SagaStatus,
39 pub start_time: chrono::DateTime<chrono::Utc>,
40 pub end_time: Option<chrono::DateTime<chrono::Utc>>,
41 pub timeout: Option<std::time::Duration>,
42 pub transaction_log: Vec<SagaTransaction>,
43 pub compensation_log: Vec<SagaCompensation>,
44 pub metadata: HashMap<String, serde_json::Value>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SagaTransaction {
50 pub transaction_id: String,
51 pub activity_ref: String,
52 pub activity_id: Option<ActivityExecutionId>,
53 pub inputs: HashMap<String, serde_json::Value>,
54 pub outputs: Option<HashMap<String, serde_json::Value>>,
55 pub status: ExecutionStatus,
56 pub timestamp: chrono::DateTime<chrono::Utc>,
57 pub compensation_ref: Option<String>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct SagaCompensation {
63 pub compensation_id: String,
64 pub original_transaction_id: String,
65 pub compensation_activity: String,
66 pub status: CompensationStatus,
67 pub timestamp: chrono::DateTime<chrono::Utc>,
68 pub error: Option<String>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub enum CompensationStatus {
74 Pending,
75 Executing,
76 Completed,
77 Failed,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct AdvancedSagaPattern {
83 pub name: String,
84 pub description: Option<String>,
85 pub version: String,
86
87 pub main_flow: WorkflowStrategyOp,
89
90 pub compensations: HashMap<String, WorkflowStrategyOp>,
92
93 pub config: SagaConfig,
95
96 pub dependencies: HashMap<String, Vec<String>>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct SagaConfig {
103 pub timeout: Option<std::time::Duration>,
105
106 pub compensation_policy: CompensationPolicy,
108
109 pub parallelism: usize,
111
112 pub retry_config: Option<SagaRetryConfig>,
114
115 pub monitoring_config: SagaMonitoringConfig,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub enum CompensationPolicy {
122 ReverseOrder,
124 Parallel,
126 Custom(Vec<String>),
128 Conditional,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct SagaRetryConfig {
135 pub max_attempts: u32,
136 pub backoff_multiplier: f64,
137 pub max_backoff: std::time::Duration,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct SagaMonitoringConfig {
143 pub enable_metrics: bool,
144 pub enable_tracing: bool,
145 pub log_level: String,
146}
147
148pub struct SagaManager {
150 sagas: RwLock<HashMap<String, SagaContext>>,
151 patterns: RwLock<HashMap<String, AdvancedSagaPattern>>,
152 metrics: SagaMetrics,
153}
154
155#[derive(Debug, Default)]
156pub struct SagaMetrics {
157 pub total_sagas: u64,
158 pub completed_sagas: u64,
159 pub failed_sagas: u64,
160 pub compensated_sagas: u64,
161 pub avg_execution_time: std::time::Duration,
162 pub compensation_rate: f64,
163}
164
165impl SagaManager {
166 pub fn new() -> Self {
167 Self {
168 sagas: RwLock::new(HashMap::new()),
169 patterns: RwLock::new(HashMap::new()),
170 metrics: SagaMetrics::default(),
171 }
172 }
173
174 pub async fn register_pattern(&self, pattern: AdvancedSagaPattern) -> Result<(), WorkflowError> {
176 let mut patterns = self.patterns.write().await;
177 patterns.insert(pattern.name.clone(), pattern);
178 Ok(())
179 }
180
181 pub async fn start_saga(
183 &self,
184 pattern_name: &str,
185 workflow_id: WorkflowExecutionId,
186 inputs: HashMap<String, serde_json::Value>,
187 ) -> Result<String, WorkflowError> {
188 let patterns = self.patterns.read().await;
189 let pattern = patterns.get(pattern_name)
190 .ok_or_else(|| WorkflowError::InvalidDefinition(format!("Saga pattern '{}' not found", pattern_name)))?;
191
192 let saga_id = uuid::Uuid::new_v4().to_string();
193 let context = SagaContext {
194 saga_id: saga_id.clone(),
195 workflow_id,
196 status: SagaStatus::Started,
197 start_time: chrono::Utc::now(),
198 end_time: None,
199 timeout: pattern.config.timeout,
200 transaction_log: Vec::new(),
201 compensation_log: Vec::new(),
202 metadata: inputs,
203 };
204
205 let mut sagas = self.sagas.write().await;
206 sagas.insert(saga_id.clone(), context);
207
208 Ok(saga_id)
209 }
210
211 pub async fn record_transaction(
213 &self,
214 saga_id: &str,
215 transaction: SagaTransaction,
216 ) -> Result<(), WorkflowError> {
217 let mut sagas = self.sagas.write().await;
218 if let Some(context) = sagas.get_mut(saga_id) {
219 context.transaction_log.push(transaction);
220 Ok(())
221 } else {
222 Err(WorkflowError::WorkflowNotFound(saga_id.to_string()))
223 }
224 }
225
226 pub async fn record_compensation(
228 &self,
229 saga_id: &str,
230 compensation: SagaCompensation,
231 ) -> Result<(), WorkflowError> {
232 let mut sagas = self.sagas.write().await;
233 if let Some(context) = sagas.get_mut(saga_id) {
234 context.compensation_log.push(compensation);
235 Ok(())
236 } else {
237 Err(WorkflowError::WorkflowNotFound(saga_id.to_string()))
238 }
239 }
240
241 pub async fn update_saga_status(
243 &self,
244 saga_id: &str,
245 status: SagaStatus,
246 ) -> Result<(), WorkflowError> {
247 let mut sagas = self.sagas.write().await;
248 if let Some(context) = sagas.get_mut(saga_id) {
249 context.status = status.clone();
250
251 if matches!(status, SagaStatus::Completed | SagaStatus::Compensated | SagaStatus::Failed | SagaStatus::TimedOut) {
253 context.end_time = Some(chrono::Utc::now());
254 }
255
256 Ok(())
257 } else {
258 Err(WorkflowError::WorkflowNotFound(saga_id.to_string()))
259 }
260 }
261
262 pub async fn get_saga_context(&self, saga_id: &str) -> Option<SagaContext> {
264 let sagas = self.sagas.read().await;
265 sagas.get(saga_id).cloned()
266 }
267
268 pub async fn get_compensable_transactions(&self, saga_id: &str) -> Result<Vec<SagaTransaction>, WorkflowError> {
270 let sagas = self.sagas.read().await;
271 if let Some(context) = sagas.get(saga_id) {
272 let compensable = context.transaction_log.iter()
273 .filter(|tx| tx.compensation_ref.is_some() && matches!(tx.status, ExecutionStatus::Completed))
274 .cloned()
275 .collect();
276 Ok(compensable)
277 } else {
278 Err(WorkflowError::WorkflowNotFound(saga_id.to_string()))
279 }
280 }
281
282 pub async fn get_compensation_order(&self, saga_id: &str, policy: &CompensationPolicy) -> Result<Vec<String>, WorkflowError> {
284 let compensable = self.get_compensable_transactions(saga_id).await?;
285
286 match policy {
287 CompensationPolicy::ReverseOrder => {
288 let mut order: Vec<String> = compensable.iter()
290 .rev()
291 .filter_map(|tx| tx.compensation_ref.clone())
292 .collect();
293 Ok(order)
294 }
295 CompensationPolicy::Parallel => {
296 let order: Vec<String> = compensable.iter()
298 .filter_map(|tx| tx.compensation_ref.clone())
299 .collect();
300 Ok(order)
301 }
302 CompensationPolicy::Custom(order) => {
303 Ok(order.clone())
304 }
305 CompensationPolicy::Conditional => {
306 Ok(Vec::new())
308 }
309 }
310 }
311
312 pub async fn resolve_dependencies(&self, pattern: &AdvancedSagaPattern, completed: &[String]) -> Vec<String> {
314 let mut ready = Vec::new();
315
316 for (activity, deps) in &pattern.dependencies {
317 if !completed.contains(activity) {
318 let all_deps_completed = deps.iter().all(|dep| completed.contains(dep));
319 if all_deps_completed {
320 ready.push(activity.clone());
321 }
322 }
323 }
324
325 ready
326 }
327
328 pub fn get_metrics(&self) -> &SagaMetrics {
330 &self.metrics
331 }
332
333 pub async fn get_running_sagas(&self) -> Vec<SagaContext> {
335 let sagas = self.sagas.read().await;
336 sagas.values()
337 .filter(|ctx| matches!(ctx.status, SagaStatus::Started | SagaStatus::Executing))
338 .cloned()
339 .collect()
340 }
341
342 pub async fn detect_timed_out_sagas(&self) -> Vec<String> {
344 let sagas = self.sagas.read().await;
345 let now = chrono::Utc::now();
346
347 sagas.iter()
348 .filter_map(|(id, ctx)| {
349 if let Some(timeout) = ctx.timeout {
350 let elapsed = now.signed_duration_since(ctx.start_time);
351 if elapsed.to_std().unwrap_or(std::time::Duration::from_secs(0)) > timeout {
352 Some(id.clone())
353 } else {
354 None
355 }
356 } else {
357 None
358 }
359 })
360 .collect()
361 }
362
363 pub async fn cleanup_saga(&self, saga_id: &str) -> Result<(), WorkflowError> {
365 let mut sagas = self.sagas.write().await;
366 sagas.remove(saga_id);
367 Ok(())
368 }
369}
370
371pub struct SagaExecutionEngine {
373 saga_manager: Arc<SagaManager>,
374 activity_registry: Arc<crate::executor::ActivityRegistry>,
375 state_manager: Arc<crate::executor::WorkflowStateManager>,
376}
377
378impl SagaExecutionEngine {
379 pub fn new(
380 saga_manager: Arc<SagaManager>,
381 activity_registry: Arc<crate::executor::ActivityRegistry>,
382 state_manager: Arc<crate::executor::WorkflowStateManager>,
383 ) -> Self {
384 Self {
385 saga_manager,
386 activity_registry,
387 state_manager,
388 }
389 }
390
391 pub async fn execute_advanced_saga(
393 &self,
394 pattern: &AdvancedSagaPattern,
395 workflow_id: WorkflowExecutionId,
396 inputs: HashMap<String, serde_json::Value>,
397 ) -> Result<(), WorkflowError> {
398 let saga_id = self.saga_manager.start_saga(&pattern.name, workflow_id.clone(), inputs).await?;
400
401 self.saga_manager.update_saga_status(&saga_id, SagaStatus::Executing).await?;
403
404 let mut execution_queue: VecDeque<String> = self.saga_manager.resolve_dependencies(pattern, &[]).await.into();
406 let mut completed = Vec::new();
407 let mut failed = false;
408
409 while !execution_queue.is_empty() && !failed {
410 let activity_ref = execution_queue.pop_front().unwrap();
411
412 match self.execute_activity_with_tracking(&saga_id, &activity_ref, HashMap::new()).await {
414 Ok(_) => {
415 completed.push(activity_ref);
416
417 let new_ready = self.saga_manager.resolve_dependencies(pattern, &completed).await;
419 for activity in new_ready {
420 if !execution_queue.contains(&activity) {
421 execution_queue.push_back(activity);
422 }
423 }
424 }
425 Err(e) => {
426 println!("Activity {} failed: {:?}", activity_ref, e);
427 failed = true;
428
429 self.execute_compensation(&saga_id, pattern).await?;
431 }
432 }
433 }
434
435 let final_status = if failed {
437 SagaStatus::Compensated
438 } else {
439 SagaStatus::Completed
440 };
441
442 self.saga_manager.update_saga_status(&saga_id, final_status).await?;
443 Ok(())
444 }
445
446 async fn execute_activity_with_tracking(
448 &self,
449 saga_id: &str,
450 activity_ref: &str,
451 inputs: HashMap<String, serde_json::Value>,
452 ) -> Result<HashMap<String, serde_json::Value>, WorkflowError> {
453 let result = self.activity_registry.execute(activity_ref, inputs.clone()).await?;
455
456 let transaction = SagaTransaction {
458 transaction_id: uuid::Uuid::new_v4().to_string(),
459 activity_ref: activity_ref.to_string(),
460 activity_id: None, inputs,
462 outputs: result.outputs.clone(),
463 status: if result.error.is_some() { ExecutionStatus::Failed } else { ExecutionStatus::Completed },
464 timestamp: chrono::Utc::now(),
465 compensation_ref: Some(format!("compensate_{}", activity_ref)), };
467
468 self.saga_manager.record_transaction(saga_id, transaction).await?;
469
470 if let Some(error) = result.error {
471 return Err(WorkflowError::InvalidDefinition(format!("Activity execution failed: {:?}", crate::executor::ActivityError::ExecutionFailed(error))));
472 }
473
474 Ok(result.outputs.unwrap_or_default())
475 }
476
477 async fn execute_compensation(
479 &self,
480 saga_id: &str,
481 pattern: &AdvancedSagaPattern,
482 ) -> Result<(), WorkflowError> {
483 self.saga_manager.update_saga_status(saga_id, SagaStatus::Compensating).await?;
484
485 let compensable = self.saga_manager.get_compensable_transactions(saga_id).await?;
486 let compensation_order = self.saga_manager.get_compensation_order(saga_id, &pattern.config.compensation_policy).await?;
487
488 for compensation_ref in compensation_order {
489 if let Some(compensation_strategy) = pattern.compensations.get(&compensation_ref) {
490 match self.execute_compensation_activity(saga_id, &compensation_ref).await {
492 Ok(_) => {
493 let compensation = SagaCompensation {
494 compensation_id: uuid::Uuid::new_v4().to_string(),
495 original_transaction_id: "".to_string(), compensation_activity: compensation_ref,
497 status: CompensationStatus::Completed,
498 timestamp: chrono::Utc::now(),
499 error: None,
500 };
501 self.saga_manager.record_compensation(saga_id, compensation).await?;
502 }
503 Err(e) => {
504 let compensation = SagaCompensation {
505 compensation_id: uuid::Uuid::new_v4().to_string(),
506 original_transaction_id: "".to_string(),
507 compensation_activity: compensation_ref,
508 status: CompensationStatus::Failed,
509 timestamp: chrono::Utc::now(),
510 error: Some(e.to_string()),
511 };
512 self.saga_manager.record_compensation(saga_id, compensation).await?;
513 }
514 }
515 }
516 }
517
518 Ok(())
519 }
520
521 async fn execute_compensation_activity(
523 &self,
524 saga_id: &str,
525 compensation_ref: &str,
526 ) -> Result<(), WorkflowError> {
527 println!("Executing compensation activity: {}", compensation_ref);
529 Ok(())
530 }
531}