use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
use crate::orchestration::{
BackoffCalculator, BackoffContext, ErrorClassifier, ErrorContext, StandardErrorClassifier,
};
use tasker_shared::{
errors::OrchestrationError,
models::WorkflowStep,
state_machine::{StepEvent, StepStateMachine, WorkflowStepState},
system_context::SystemContext,
TaskerError,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorHandlingResult {
pub step_uuid: Uuid,
pub action: ErrorHandlingAction,
pub final_state: WorkflowStepState,
pub backoff_applied: bool,
pub next_retry_at: Option<chrono::DateTime<chrono::Utc>>,
pub classification_summary: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ErrorHandlingAction {
MarkedAsPermanentFailure,
TransitionedToWaitingForRetry,
MarkedAsError,
NoActionTaken,
}
#[derive(Debug, Clone)]
pub struct ErrorHandlingConfig {
pub use_error_classification: bool,
pub use_waiting_for_retry_state: bool,
pub default_max_attempts: u32,
}
impl Default for ErrorHandlingConfig {
fn default() -> Self {
Self {
use_error_classification: true,
use_waiting_for_retry_state: true,
default_max_attempts: 3,
}
}
}
pub struct ErrorHandlingService {
config: ErrorHandlingConfig,
error_classifier: Arc<dyn ErrorClassifier + Send + Sync>,
backoff_calculator: BackoffCalculator,
system_context: Arc<SystemContext>,
}
impl std::fmt::Debug for ErrorHandlingService {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ErrorHandlingService")
.field("config", &self.config)
.field("has_error_classifier", &true)
.field("backoff_calculator", &self.backoff_calculator)
.finish()
}
}
impl ErrorHandlingService {
pub fn new(
config: ErrorHandlingConfig,
backoff_calculator: BackoffCalculator,
system_context: Arc<SystemContext>,
) -> Self {
let error_classifier = Arc::new(StandardErrorClassifier::new());
Self {
config,
error_classifier,
backoff_calculator,
system_context,
}
}
pub fn with_classifier(
config: ErrorHandlingConfig,
error_classifier: Arc<dyn ErrorClassifier + Send + Sync>,
backoff_calculator: BackoffCalculator,
system_context: Arc<SystemContext>,
) -> Self {
Self {
config,
error_classifier,
backoff_calculator,
system_context,
}
}
pub async fn handle_step_error(
&self,
step: &WorkflowStep,
error: &OrchestrationError,
_error_message: Option<String>,
) -> Result<ErrorHandlingResult, TaskerError> {
let error_context = self.create_error_context(step, _error_message).await?;
let classification = self.error_classifier.classify_error(error, &error_context);
if classification.is_retryable && self.config.use_error_classification {
self.handle_retryable_error(step, &classification, &error_context)
.await
} else {
self.handle_permanent_error(step, &classification).await
}
}
async fn handle_retryable_error(
&self,
step: &WorkflowStep,
classification: &crate::orchestration::ErrorClassification,
error_context: &ErrorContext,
) -> Result<ErrorHandlingResult, TaskerError> {
if self.config.use_waiting_for_retry_state {
self.transition_to_waiting_for_retry(step, classification, error_context)
.await
} else {
self.apply_backoff_legacy(step, classification).await
}
}
async fn transition_to_waiting_for_retry(
&self,
step: &WorkflowStep,
classification: &crate::orchestration::ErrorClassification,
_error_context: &ErrorContext,
) -> Result<ErrorHandlingResult, TaskerError> {
let backoff_context =
BackoffContext::new().with_error(classification.error_message.clone());
let backoff_result = self
.backoff_calculator
.calculate_and_apply_backoff(&step.workflow_step_uuid, backoff_context)
.await
.map_err(|e| TaskerError::DatabaseError(e.to_string()))?;
let mut state_machine = StepStateMachine::new(step.clone(), self.system_context.clone());
let event = StepEvent::wait_for_retry(format!(
"Error classified as retryable: {}",
classification.error_message
));
state_machine
.transition(event)
.await
.map_err(|e| TaskerError::StateMachineError(e.to_string()))?;
Ok(ErrorHandlingResult {
step_uuid: step.workflow_step_uuid,
action: ErrorHandlingAction::TransitionedToWaitingForRetry,
final_state: WorkflowStepState::WaitingForRetry,
backoff_applied: true,
next_retry_at: Some(backoff_result.next_retry_at),
classification_summary: format!(
"Retryable error ({}): {}",
classification.error_category, classification.error_message
),
})
}
async fn apply_backoff_legacy(
&self,
step: &WorkflowStep,
classification: &crate::orchestration::ErrorClassification,
) -> Result<ErrorHandlingResult, TaskerError> {
let backoff_context =
BackoffContext::new().with_error(classification.error_message.clone());
let backoff_result = self
.backoff_calculator
.calculate_and_apply_backoff(&step.workflow_step_uuid, backoff_context)
.await
.map_err(|e| TaskerError::DatabaseError(e.to_string()))?;
Ok(ErrorHandlingResult {
step_uuid: step.workflow_step_uuid,
action: ErrorHandlingAction::NoActionTaken,
final_state: WorkflowStepState::Error, backoff_applied: true,
next_retry_at: Some(backoff_result.next_retry_at),
classification_summary: format!(
"Retryable error with legacy backoff ({}): {}",
classification.error_category, classification.error_message
),
})
}
async fn handle_permanent_error(
&self,
step: &WorkflowStep,
classification: &crate::orchestration::ErrorClassification,
) -> Result<ErrorHandlingResult, TaskerError> {
let mut state_machine = StepStateMachine::new(step.clone(), self.system_context.clone());
let event = StepEvent::fail_with_error(format!(
"Permanent error: {}",
classification.error_message
));
state_machine
.transition(event)
.await
.map_err(|e| TaskerError::StateMachineError(e.to_string()))?;
let action = if classification.is_final_attempt {
ErrorHandlingAction::MarkedAsError
} else {
ErrorHandlingAction::MarkedAsPermanentFailure
};
Ok(ErrorHandlingResult {
step_uuid: step.workflow_step_uuid,
action,
final_state: WorkflowStepState::Error,
backoff_applied: false,
next_retry_at: None,
classification_summary: format!(
"Permanent error ({}): {}",
classification.error_category, classification.error_message
),
})
}
async fn create_error_context(
&self,
step: &WorkflowStep,
_error_message: Option<String>,
) -> Result<ErrorContext, TaskerError> {
let attempts = step.attempts.unwrap_or(0) as u32;
let max_attempts = step
.max_attempts
.unwrap_or(self.config.default_max_attempts as i32) as u32;
Ok(ErrorContext {
step_uuid: step.workflow_step_uuid,
task_uuid: step.task_uuid,
attempt_number: attempts + 1, max_attempts,
execution_duration: std::time::Duration::from_secs(0), step_name: self.get_step_name(step).await?,
error_source: "orchestration".to_string(),
metadata: std::collections::HashMap::new(),
})
}
async fn get_step_name(&self, step: &WorkflowStep) -> Result<String, TaskerError> {
let result = sqlx::query!(
"SELECT name FROM tasker.named_steps WHERE named_step_uuid = $1",
step.named_step_uuid
)
.fetch_optional(self.system_context.database_pool())
.await
.map_err(|e| TaskerError::DatabaseError(e.to_string()))?;
Ok(result
.map(|r| r.name)
.unwrap_or_else(|| "unknown_step".to_string()))
}
pub async fn check_waiting_for_retry_readiness(
&self,
step_uuid: Uuid,
) -> Result<bool, TaskerError> {
self.backoff_calculator
.is_ready_to_retry(step_uuid)
.await
.map_err(|e| TaskerError::DatabaseError(e.to_string()))
}
pub async fn transition_from_waiting_to_pending(
&self,
step: &WorkflowStep,
) -> Result<(), TaskerError> {
let mut state_machine = StepStateMachine::new(step.clone(), self.system_context.clone());
let event = StepEvent::Retry;
state_machine
.transition(event)
.await
.map_err(|e| TaskerError::StateMachineError(e.to_string()))?;
self.backoff_calculator
.clear_backoff(step.workflow_step_uuid)
.await
.map_err(|e| TaskerError::DatabaseError(e.to_string()))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::orchestration::backoff_calculator::BackoffCalculatorConfig;
use chrono::Utc;
#[test]
fn test_error_handling_config_default() {
let config = ErrorHandlingConfig::default();
assert!(config.use_error_classification);
assert!(config.use_waiting_for_retry_state);
assert_eq!(config.default_max_attempts, 3);
}
#[test]
fn test_error_handling_config_custom() {
let config = ErrorHandlingConfig {
use_error_classification: false,
use_waiting_for_retry_state: false,
default_max_attempts: 5,
};
assert!(!config.use_error_classification);
assert!(!config.use_waiting_for_retry_state);
assert_eq!(config.default_max_attempts, 5);
}
#[test]
fn test_error_handling_result_construction() {
let step_uuid = Uuid::now_v7();
let result = ErrorHandlingResult {
step_uuid,
action: ErrorHandlingAction::MarkedAsPermanentFailure,
final_state: WorkflowStepState::Error,
backoff_applied: false,
next_retry_at: None,
classification_summary: "Permanent: connection refused".to_string(),
};
assert_eq!(result.step_uuid, step_uuid);
assert!(!result.backoff_applied);
assert!(result.next_retry_at.is_none());
assert!(result.classification_summary.contains("Permanent"));
}
#[test]
fn test_error_handling_result_with_backoff() {
let step_uuid = Uuid::now_v7();
let retry_at = Utc::now() + chrono::Duration::seconds(30);
let result = ErrorHandlingResult {
step_uuid,
action: ErrorHandlingAction::TransitionedToWaitingForRetry,
final_state: WorkflowStepState::WaitingForRetry,
backoff_applied: true,
next_retry_at: Some(retry_at),
classification_summary: "Retryable: timeout".to_string(),
};
assert!(result.backoff_applied);
assert!(result.next_retry_at.is_some());
assert!(matches!(
result.action,
ErrorHandlingAction::TransitionedToWaitingForRetry
));
assert!(matches!(
result.final_state,
WorkflowStepState::WaitingForRetry
));
}
#[test]
fn test_error_handling_action_variants() {
let actions = vec![
ErrorHandlingAction::MarkedAsPermanentFailure,
ErrorHandlingAction::TransitionedToWaitingForRetry,
ErrorHandlingAction::MarkedAsError,
ErrorHandlingAction::NoActionTaken,
];
for action in &actions {
let json = serde_json::to_string(action).expect("action should serialize");
assert!(!json.is_empty());
}
assert_eq!(actions.len(), 4, "Should have 4 action variants");
}
#[test]
fn test_error_handling_action_serialization_values() {
let json = serde_json::to_string(&ErrorHandlingAction::MarkedAsPermanentFailure)
.expect("serialize");
assert!(json.contains("MarkedAsPermanentFailure"));
let json = serde_json::to_string(&ErrorHandlingAction::TransitionedToWaitingForRetry)
.expect("serialize");
assert!(json.contains("TransitionedToWaitingForRetry"));
let json = serde_json::to_string(&ErrorHandlingAction::MarkedAsError).expect("serialize");
assert!(json.contains("MarkedAsError"));
let json = serde_json::to_string(&ErrorHandlingAction::NoActionTaken).expect("serialize");
assert!(json.contains("NoActionTaken"));
}
#[test]
fn test_error_handling_result_serialization_roundtrip() {
let step_uuid = Uuid::now_v7();
let result = ErrorHandlingResult {
step_uuid,
action: ErrorHandlingAction::MarkedAsError,
final_state: WorkflowStepState::Error,
backoff_applied: false,
next_retry_at: None,
classification_summary: "retry limit exceeded".to_string(),
};
let json = serde_json::to_string(&result).expect("serialize");
let deserialized: ErrorHandlingResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(deserialized.step_uuid, step_uuid);
assert!(!deserialized.backoff_applied);
assert_eq!(deserialized.classification_summary, "retry limit exceeded");
}
#[sqlx::test(migrator = "tasker_shared::database::migrator::MIGRATOR")]
async fn test_error_handling_service_creation(
pool: sqlx::PgPool,
) -> Result<(), Box<dyn std::error::Error>> {
let system_context = Arc::new(SystemContext::with_pool(pool.clone()).await?);
let backoff_config = BackoffCalculatorConfig::default();
let backoff_calculator = BackoffCalculator::new(backoff_config, pool);
let config = ErrorHandlingConfig::default();
let service = ErrorHandlingService::new(config, backoff_calculator, system_context);
let debug_str = format!("{:?}", service);
assert!(debug_str.contains("ErrorHandlingService"));
assert!(debug_str.contains("has_error_classifier"));
Ok(())
}
#[sqlx::test(migrator = "tasker_shared::database::migrator::MIGRATOR")]
async fn test_error_handling_service_with_custom_classifier(
pool: sqlx::PgPool,
) -> Result<(), Box<dyn std::error::Error>> {
let system_context = Arc::new(SystemContext::with_pool(pool.clone()).await?);
let backoff_calculator = BackoffCalculator::with_defaults(pool);
let config = ErrorHandlingConfig {
use_error_classification: false,
use_waiting_for_retry_state: false,
default_max_attempts: 10,
};
let classifier = Arc::new(StandardErrorClassifier::new());
let service = ErrorHandlingService::with_classifier(
config,
classifier,
backoff_calculator,
system_context,
);
let debug_str = format!("{:?}", service);
assert!(debug_str.contains("ErrorHandlingService"));
Ok(())
}
#[test]
fn test_error_handling_result_clone() {
let step_uuid = Uuid::now_v7();
let result = ErrorHandlingResult {
step_uuid,
action: ErrorHandlingAction::TransitionedToWaitingForRetry,
final_state: WorkflowStepState::WaitingForRetry,
backoff_applied: true,
next_retry_at: Some(Utc::now()),
classification_summary: "clone test".to_string(),
};
let cloned = result.clone();
assert_eq!(cloned.step_uuid, result.step_uuid);
assert_eq!(cloned.backoff_applied, result.backoff_applied);
assert_eq!(cloned.classification_summary, result.classification_summary);
}
#[test]
fn test_error_handling_result_debug() {
let result = ErrorHandlingResult {
step_uuid: Uuid::now_v7(),
action: ErrorHandlingAction::NoActionTaken,
final_state: WorkflowStepState::Error,
backoff_applied: false,
next_retry_at: None,
classification_summary: "debug test".to_string(),
};
let debug_str = format!("{:?}", result);
assert!(debug_str.contains("ErrorHandlingResult"));
assert!(debug_str.contains("NoActionTaken"));
}
#[test]
fn test_error_handling_action_clone() {
let action = ErrorHandlingAction::MarkedAsPermanentFailure;
let cloned = action.clone();
let json_original = serde_json::to_string(&action).unwrap();
let json_cloned = serde_json::to_string(&cloned).unwrap();
assert_eq!(json_original, json_cloned);
}
#[test]
fn test_error_handling_action_debug() {
let action = ErrorHandlingAction::TransitionedToWaitingForRetry;
let debug_str = format!("{:?}", action);
assert!(debug_str.contains("TransitionedToWaitingForRetry"));
}
#[test]
fn test_error_handling_config_clone() {
let config = ErrorHandlingConfig {
use_error_classification: true,
use_waiting_for_retry_state: false,
default_max_attempts: 7,
};
let cloned = config.clone();
assert!(cloned.use_error_classification);
assert!(!cloned.use_waiting_for_retry_state);
assert_eq!(cloned.default_max_attempts, 7);
}
#[test]
fn test_error_handling_config_debug() {
let config = ErrorHandlingConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("ErrorHandlingConfig"));
assert!(debug_str.contains("use_error_classification"));
}
#[test]
fn test_error_handling_result_marked_as_error_action() {
let result = ErrorHandlingResult {
step_uuid: Uuid::now_v7(),
action: ErrorHandlingAction::MarkedAsError,
final_state: WorkflowStepState::Error,
backoff_applied: false,
next_retry_at: None,
classification_summary: "retry limit exceeded".to_string(),
};
assert!(matches!(result.action, ErrorHandlingAction::MarkedAsError));
assert!(matches!(result.final_state, WorkflowStepState::Error));
}
#[test]
fn test_error_handling_result_serialization_with_retry_time() {
let retry_at = Utc::now() + chrono::Duration::minutes(5);
let result = ErrorHandlingResult {
step_uuid: Uuid::now_v7(),
action: ErrorHandlingAction::TransitionedToWaitingForRetry,
final_state: WorkflowStepState::WaitingForRetry,
backoff_applied: true,
next_retry_at: Some(retry_at),
classification_summary: "timeout".to_string(),
};
let json = serde_json::to_string(&result).expect("serialize");
assert!(json.contains("next_retry_at"));
assert!(json.contains("timeout"));
let deserialized: ErrorHandlingResult = serde_json::from_str(&json).expect("deserialize");
assert!(deserialized.next_retry_at.is_some());
assert!(deserialized.backoff_applied);
}
}