#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
use tokio::time::sleep;
use tracing::{debug, error, info, warn};
use crate::tls::error::{ErrorCode, RecoveryHint, TlsError};
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryPolicy {
#[must_use]
pub fn conservative() -> Self {
Self {
max_attempts: 2,
initial_backoff: Duration::from_millis(200),
max_backoff: Duration::from_secs(2),
backoff_multiplier: 2.0,
jitter: true,
}
}
#[must_use]
pub fn aggressive() -> Self {
Self {
max_attempts: 5,
initial_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 1.5,
jitter: true,
}
}
#[must_use]
pub fn new(max_attempts: u32, initial_backoff: Duration, max_backoff: Duration) -> Self {
Self { max_attempts, initial_backoff, max_backoff, backoff_multiplier: 2.0, jitter: true }
}
#[must_use]
pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
let initial_ms = u64::try_from(self.initial_backoff.as_millis()).unwrap_or(u64::MAX);
#[allow(clippy::cast_precision_loss)]
let base_ms = initial_ms as f64;
let multiplier = self.backoff_multiplier;
let delay =
base_ms * multiplier.powi(i32::try_from(attempt.saturating_sub(1)).unwrap_or(i32::MAX));
let max_ms_128 = self.max_backoff.as_millis();
let max_ms = u64::try_from(max_ms_128).unwrap_or(u64::MAX);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let capped_delay_ms = (delay as u64).min(max_ms);
let mut duration = Duration::from_millis(capped_delay_ms);
if self.jitter {
let jitter_pct = crate::primitives::rand::csprng::random_u64() % 50;
let jitter_ms = capped_delay_ms.saturating_mul(jitter_pct) / 100;
let final_ms = capped_delay_ms.saturating_add(jitter_ms);
duration = Duration::from_millis(final_ms);
}
duration
}
#[must_use]
pub fn should_retry(&self, err: &TlsError, attempt: u32) -> bool {
if attempt >= self.max_attempts {
return false;
}
match err {
TlsError::Io { code, .. } => {
matches!(
code,
ErrorCode::ConnectionRefused
| ErrorCode::ConnectionTimeout
| ErrorCode::ConnectionReset
)
}
TlsError::Tls { code, .. } => matches!(
code,
ErrorCode::HandshakeFailed
| ErrorCode::InvalidHandshakeMessage
| ErrorCode::HandshakeTimeout
),
TlsError::Handshake { code, .. } => matches!(
code,
ErrorCode::HandshakeFailed
| ErrorCode::ProtocolVersionMismatch
| ErrorCode::HandshakeTimeout
),
TlsError::KeyExchange { code, .. } => {
matches!(code, ErrorCode::KeyExchangeFailed | ErrorCode::EncapsulationFailed)
}
_ => false,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug)]
pub struct CircuitBreaker {
state: Arc<AtomicU32>, failure_count: Arc<AtomicU32>,
success_count: Arc<AtomicU32>,
last_failure_time: Arc<std::sync::Mutex<Option<Instant>>>,
failure_threshold: u32,
success_threshold: u32,
timeout: Duration,
}
impl CircuitBreaker {
#[must_use]
pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
Self {
state: Arc::new(AtomicU32::new(0)), failure_count: Arc::new(AtomicU32::new(0)),
success_count: Arc::new(AtomicU32::new(0)),
last_failure_time: Arc::new(std::sync::Mutex::new(None)),
failure_threshold,
success_threshold: 3, timeout,
}
}
#[must_use]
pub fn state(&self) -> CircuitState {
match self.state.load(Ordering::SeqCst) {
0 => CircuitState::Closed,
1 => CircuitState::Open,
2 => CircuitState::HalfOpen,
_ => CircuitState::Open,
}
}
#[must_use]
pub fn allow_request(&self) -> bool {
match self.state() {
CircuitState::Closed => true,
CircuitState::Open => {
let Ok(last_failure) = self.last_failure_time.lock() else {
warn!("Failed to acquire circuit breaker lock, assuming no timeout");
return false;
};
if let Some(last) = *last_failure
&& last.elapsed() >= self.timeout
{
self.set_state(CircuitState::HalfOpen);
info!("Circuit breaker transitioning to half-open state");
return true;
}
warn!("Circuit breaker is open, request blocked");
false
}
CircuitState::HalfOpen => true,
}
}
pub fn record_success(&self) {
self.success_count.fetch_add(1, Ordering::SeqCst);
if self.state() == CircuitState::HalfOpen {
let success = self.success_count.load(Ordering::SeqCst);
if success >= self.success_threshold {
self.set_state(CircuitState::Closed);
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
info!("Circuit breaker closed after {} successful operations", success);
}
} else if self.state() == CircuitState::Closed {
self.failure_count.store(0, Ordering::SeqCst);
}
}
pub fn record_failure(&self) {
let failures = self.failure_count.fetch_add(1, Ordering::SeqCst).saturating_add(1);
if let Ok(mut guard) = self.last_failure_time.lock() {
*guard = Some(Instant::now());
} else {
warn!("Failed to record failure time due to lock contention");
}
if self.state() == CircuitState::HalfOpen {
self.set_state(CircuitState::Open);
self.success_count.store(0, Ordering::SeqCst);
warn!("Circuit breaker returned to open state after failure in half-open");
} else if failures >= self.failure_threshold {
self.set_state(CircuitState::Open);
error!("Circuit breaker opened after {} consecutive failures", failures);
}
debug!("Circuit breaker failure count: {}", failures);
}
fn set_state(&self, state: CircuitState) {
let value = match state {
CircuitState::Closed => 0,
CircuitState::Open => 1,
CircuitState::HalfOpen => 2,
};
self.state.store(value, Ordering::SeqCst);
}
pub fn reset(&self) {
self.set_state(CircuitState::Closed);
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
if let Ok(mut guard) = self.last_failure_time.lock() {
*guard = None;
} else {
warn!("Failed to reset failure time due to lock contention");
}
info!("Circuit breaker reset to closed state");
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Default)]
pub enum FallbackStrategy {
#[default]
None,
HybridToClassical,
PqToHybrid,
Custom {
description: String,
},
}
impl FallbackStrategy {
#[must_use]
pub fn hybrid_to_classical() -> Self {
Self::HybridToClassical
}
#[must_use]
pub fn pq_to_hybrid() -> Self {
Self::PqToHybrid
}
#[must_use]
pub fn should_fallback(&self, err: &TlsError) -> bool {
match self {
FallbackStrategy::None => false,
FallbackStrategy::HybridToClassical => {
matches!(err.code(), ErrorCode::PqNotAvailable | ErrorCode::HybridKemFailed)
}
FallbackStrategy::PqToHybrid => {
matches!(err.code(), ErrorCode::HybridKemFailed)
}
FallbackStrategy::Custom { .. } => true,
}
}
#[must_use]
pub fn description(&self) -> String {
match self {
FallbackStrategy::None => "No fallback available".to_string(),
FallbackStrategy::HybridToClassical => {
"Falling back from hybrid to classical TLS".to_string()
}
FallbackStrategy::PqToHybrid => "Falling back from PQ-only to hybrid TLS".to_string(),
FallbackStrategy::Custom { description } => description.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct DegradationConfig {
pub enable_fallback: bool,
pub allow_reduced_security: bool,
pub max_degradation_attempts: u32,
}
impl Default for DegradationConfig {
fn default() -> Self {
Self { enable_fallback: false, allow_reduced_security: false, max_degradation_attempts: 2 }
}
}
pub async fn retry_with_policy<F, Fut, T>(
policy: &RetryPolicy,
operation: F,
operation_name: &str,
) -> Result<T, TlsError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, TlsError>>,
{
let mut last_error = None;
for attempt in 1..=policy.max_attempts {
debug!("{} attempt {} of {}", operation_name, attempt, policy.max_attempts);
match operation().await {
Ok(result) => {
if attempt > 1 {
info!("{} succeeded on attempt {} after retry", operation_name, attempt);
}
return Ok(result);
}
Err(err) => {
let error_info = match &err {
TlsError::Io { .. } => "IO error".to_string(),
TlsError::Tls { message, .. } => format!("TLS error: {}", message),
TlsError::Certificate { .. } => "Certificate error".to_string(),
TlsError::KeyExchange { .. } => "Key exchange error".to_string(),
TlsError::CryptoProvider { .. } => "Crypto provider error".to_string(),
TlsError::Config { .. } => "Configuration error".to_string(),
_ => "Unknown error".to_string(),
};
last_error = Some(TlsError::Config {
message: format!("Circuit breaker failure: {}", error_info),
field: Some("circuit_breaker".to_string()),
code: ErrorCode::InvalidConfig,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry { max_attempts: 3, backoff_ms: 1000 }),
});
if !policy.should_retry(&err, attempt) {
warn!("{} error not retryable: {:?}", operation_name, err);
return Err(err);
}
if attempt < policy.max_attempts {
let backoff = policy.backoff_for_attempt(attempt);
info!(
"{} failed on attempt {}, retrying after {:?}",
operation_name, attempt, backoff
);
sleep(backoff).await;
}
}
}
}
error!("{} failed after {} attempts", operation_name, policy.max_attempts);
Err(last_error.unwrap_or_else(|| TlsError::Internal {
message: "Operation failed with unknown error".to_string(),
code: ErrorCode::InternalError,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
}))
}
pub async fn execute_with_circuit_breaker<F, Fut, T>(
circuit_breaker: &CircuitBreaker,
operation: F,
operation_name: &str,
) -> Result<T, TlsError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, TlsError>>,
{
if !circuit_breaker.allow_request() {
return Err(TlsError::Internal {
message: format!("Circuit breaker is open, {} operation blocked", operation_name),
code: ErrorCode::TooManyConnections,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry { max_attempts: 1, backoff_ms: 5000 }),
});
}
match operation().await {
Ok(result) => {
circuit_breaker.record_success();
Ok(result)
}
Err(err) => {
circuit_breaker.record_failure();
Err(err)
}
}
}
pub async fn execute_with_fallback<F1, Fut1, F2, Fut2, T>(
strategy: &FallbackStrategy,
primary: F1,
fallback: F2,
operation_name: &str,
) -> Result<T, TlsError>
where
F1: Fn() -> Fut1,
Fut1: Future<Output = Result<T, TlsError>>,
F2: Fn() -> Fut2,
Fut2: Future<Output = Result<T, TlsError>>,
{
match primary().await {
Ok(result) => Ok(result),
Err(err) => {
if strategy.should_fallback(&err) {
warn!(
"{} primary failed, attempting fallback: {}",
operation_name,
strategy.description()
);
fallback().await
} else {
Err(err)
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_retry_policy_default_has_correct_values_succeeds() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_attempts, 3);
assert_eq!(policy.initial_backoff, Duration::from_millis(100));
}
#[test]
fn test_retry_policy_backoff_increases_with_attempts_succeeds() {
let policy = RetryPolicy::default();
let backoff1 = policy.backoff_for_attempt(1);
let backoff2 = policy.backoff_for_attempt(2);
assert!(backoff2 > backoff1);
assert!(backoff1 >= Duration::from_millis(100));
}
#[test]
fn test_circuit_breaker_initial_state_is_closed_succeeds() {
let breaker = CircuitBreaker::new(5, Duration::from_secs(60));
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_opens_after_failures_succeeds() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
for _ in 0..3 {
breaker.record_failure();
}
assert_eq!(breaker.state(), CircuitState::Open);
}
#[test]
fn test_fallback_strategy_description_has_correct_format() {
let strategy = FallbackStrategy::hybrid_to_classical();
assert!(strategy.description().contains("hybrid to classical"));
}
#[test]
fn test_retry_policy_conservative_has_correct_values_succeeds() {
let policy = RetryPolicy::conservative();
assert_eq!(policy.max_attempts, 2);
assert_eq!(policy.initial_backoff, Duration::from_millis(200));
assert_eq!(policy.max_backoff, Duration::from_secs(2));
}
#[test]
fn test_retry_policy_aggressive_has_correct_values_succeeds() {
let policy = RetryPolicy::aggressive();
assert_eq!(policy.max_attempts, 5);
assert_eq!(policy.initial_backoff, Duration::from_millis(50));
assert_eq!(policy.max_backoff, Duration::from_secs(10));
}
#[test]
fn test_retry_policy_custom_has_correct_values_succeeds() {
let policy = RetryPolicy::new(10, Duration::from_millis(500), Duration::from_secs(30));
assert_eq!(policy.max_attempts, 10);
assert_eq!(policy.initial_backoff, Duration::from_millis(500));
assert_eq!(policy.max_backoff, Duration::from_secs(30));
}
#[test]
fn test_retry_policy_backoff_is_capped_at_max_succeeds() {
let policy = RetryPolicy {
max_attempts: 10,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_millis(500),
backoff_multiplier: 2.0,
jitter: false,
};
let backoff = policy.backoff_for_attempt(10);
assert!(backoff <= Duration::from_millis(500));
}
#[test]
fn test_retry_policy_backoff_without_jitter_is_deterministic() {
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
backoff_multiplier: 2.0,
jitter: false,
};
let backoff1 = policy.backoff_for_attempt(1);
let backoff1_again = policy.backoff_for_attempt(1);
assert_eq!(backoff1, backoff1_again);
assert_eq!(backoff1, Duration::from_millis(100));
let backoff2 = policy.backoff_for_attempt(2);
assert_eq!(backoff2, Duration::from_millis(200));
}
#[test]
fn test_retry_policy_retries_io_errors_fails() {
let policy = RetryPolicy::default();
let retryable = TlsError::Io {
message: "refused".to_string(),
source: None,
code: ErrorCode::ConnectionRefused,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry { max_attempts: 3, backoff_ms: 1000 }),
};
assert!(policy.should_retry(&retryable, 1));
let timeout = TlsError::Io {
message: "timeout".to_string(),
source: None,
code: ErrorCode::ConnectionTimeout,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry { max_attempts: 3, backoff_ms: 1000 }),
};
assert!(policy.should_retry(&timeout, 1));
let reset = TlsError::Io {
message: "reset".to_string(),
source: None,
code: ErrorCode::ConnectionReset,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry { max_attempts: 3, backoff_ms: 1000 }),
};
assert!(policy.should_retry(&reset, 1));
}
#[test]
fn test_retry_policy_does_not_retry_at_max_attempts_succeeds() {
let policy = RetryPolicy::default();
let retryable = TlsError::Io {
message: "refused".to_string(),
source: None,
code: ErrorCode::ConnectionRefused,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry { max_attempts: 3, backoff_ms: 1000 }),
};
assert!(!policy.should_retry(&retryable, 3)); }
#[test]
fn test_retry_policy_retries_tls_errors_fails() {
let policy = RetryPolicy::default();
let handshake = TlsError::Tls {
message: "handshake failed".to_string(),
code: ErrorCode::HandshakeFailed,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry { max_attempts: 3, backoff_ms: 1000 }),
};
assert!(policy.should_retry(&handshake, 1));
}
#[test]
fn test_retry_policy_retries_handshake_errors_fails() {
let policy = RetryPolicy::default();
let handshake = TlsError::Handshake {
message: "handshake failed".to_string(),
state: "ClientHello".to_string(),
code: ErrorCode::HandshakeFailed,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry { max_attempts: 3, backoff_ms: 1000 }),
};
assert!(policy.should_retry(&handshake, 1));
}
#[test]
fn test_retry_policy_retries_key_exchange_errors_fails() {
let policy = RetryPolicy::default();
let kex = TlsError::KeyExchange {
message: "key exchange failed".to_string(),
method: "X25519".to_string(),
operation: None,
code: ErrorCode::KeyExchangeFailed,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(policy.should_retry(&kex, 1));
}
#[test]
fn test_retry_policy_does_not_retry_cert_errors_fails() {
let policy = RetryPolicy::default();
let cert = TlsError::Certificate {
message: "cert expired".to_string(),
subject: None,
issuer: None,
code: ErrorCode::CertificateExpired,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(!policy.should_retry(&cert, 1));
}
#[test]
fn test_circuit_breaker_allows_requests_when_closed_succeeds() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
assert!(breaker.allow_request());
}
#[test]
fn test_circuit_breaker_blocks_requests_when_open_fails() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
for _ in 0..3 {
breaker.record_failure();
}
assert_eq!(breaker.state(), CircuitState::Open);
assert!(!breaker.allow_request());
}
#[test]
fn test_circuit_breaker_transitions_to_half_open_after_timeout_succeeds() {
let breaker = CircuitBreaker::new(3, Duration::from_millis(1));
for _ in 0..3 {
breaker.record_failure();
}
assert_eq!(breaker.state(), CircuitState::Open);
std::thread::sleep(Duration::from_millis(5));
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
}
#[test]
fn test_circuit_breaker_closes_after_success_in_half_open_succeeds() {
let breaker = CircuitBreaker::new(3, Duration::from_millis(1));
for _ in 0..3 {
breaker.record_failure();
}
std::thread::sleep(Duration::from_millis(5));
let _ = breaker.allow_request();
for _ in 0..3 {
breaker.record_success();
}
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_reopens_on_failure_in_half_open_fails() {
let breaker = CircuitBreaker::new(3, Duration::from_millis(1));
for _ in 0..3 {
breaker.record_failure();
}
std::thread::sleep(Duration::from_millis(5));
let _ = breaker.allow_request();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_reset_succeeds() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
for _ in 0..3 {
breaker.record_failure();
}
assert_eq!(breaker.state(), CircuitState::Open);
breaker.reset();
assert_eq!(breaker.state(), CircuitState::Closed);
assert!(breaker.allow_request());
}
#[test]
fn test_circuit_breaker_success_resets_failure_count_when_closed_succeeds() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
breaker.record_failure();
breaker.record_failure();
breaker.record_success(); breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_fallback_strategy_none_does_not_trigger_succeeds() {
let strategy = FallbackStrategy::None;
assert!(strategy.description().contains("No fallback"));
let err = TlsError::Internal {
message: "test".to_string(),
code: ErrorCode::InternalError,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(!strategy.should_fallback(&err));
}
#[test]
fn test_fallback_strategy_pq_to_hybrid_has_correct_description_is_documented() {
let strategy = FallbackStrategy::pq_to_hybrid();
assert!(strategy.description().contains("PQ-only to hybrid"));
}
#[test]
fn test_fallback_strategy_hybrid_to_classical_triggers_on_pq_error_fails() {
let strategy = FallbackStrategy::hybrid_to_classical();
let pq_err = TlsError::Config {
message: "PQ not available".to_string(),
field: None,
code: ErrorCode::PqNotAvailable,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(strategy.should_fallback(&pq_err));
}
#[test]
fn test_fallback_strategy_custom_always_triggers_succeeds() {
let strategy = FallbackStrategy::Custom { description: "My fallback".to_string() };
assert_eq!(strategy.description(), "My fallback");
let err = TlsError::Internal {
message: "any error".to_string(),
code: ErrorCode::InternalError,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(strategy.should_fallback(&err));
}
#[test]
fn test_fallback_strategy_default_is_none_variant_succeeds() {
let strategy = FallbackStrategy::default();
assert!(matches!(strategy, FallbackStrategy::None));
}
#[test]
fn test_degradation_config_default_has_correct_values_succeeds() {
let config = DegradationConfig::default();
assert!(!config.enable_fallback);
assert!(!config.allow_reduced_security);
assert_eq!(config.max_degradation_attempts, 2);
}
#[test]
fn test_circuit_state_equality_is_correct() {
assert_eq!(CircuitState::Closed, CircuitState::Closed);
assert_ne!(CircuitState::Open, CircuitState::Closed);
assert_ne!(CircuitState::HalfOpen, CircuitState::Open);
}
#[tokio::test]
async fn test_retry_with_policy_succeeds_on_first_try_succeeds() {
let policy = RetryPolicy::default();
let result = retry_with_policy(&policy, || async { Ok::<_, TlsError>(42) }, "test").await;
assert_eq!(result.expect("should succeed"), 42);
}
#[tokio::test]
async fn test_retry_with_policy_returns_error_for_non_retryable_fails() {
let policy = RetryPolicy::default();
let result: Result<i32, TlsError> = retry_with_policy(
&policy,
|| async {
Err(TlsError::Certificate {
message: "expired".to_string(),
subject: None,
issuer: None,
code: ErrorCode::CertificateExpired,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
})
},
"cert_test",
)
.await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), ErrorCode::CertificateExpired);
}
#[tokio::test]
async fn test_retry_with_policy_returns_error_when_retries_exhausted_fails() {
use std::sync::atomic::{AtomicU32, Ordering};
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let policy = RetryPolicy {
max_attempts: 2,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
backoff_multiplier: 2.0,
jitter: false,
};
let result: Result<i32, TlsError> = retry_with_policy(
&policy,
|| {
let a = attempts_clone.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
Err(TlsError::Io {
message: "refused".to_string(),
source: None,
code: ErrorCode::ConnectionRefused,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry {
max_attempts: 3,
backoff_ms: 1000,
}),
})
}
},
"retry_test",
)
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_retry_with_policy_succeeds_on_second_attempt_succeeds() {
use std::sync::atomic::{AtomicU32, Ordering};
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
backoff_multiplier: 2.0,
jitter: false,
};
let result = retry_with_policy(
&policy,
|| {
let a = attempts_clone.clone();
async move {
let attempt = a.fetch_add(1, Ordering::SeqCst);
if attempt < 1 {
Err(TlsError::Io {
message: "refused".to_string(),
source: None,
code: ErrorCode::ConnectionRefused,
context: Box::default(),
recovery: Box::new(RecoveryHint::Retry {
max_attempts: 3,
backoff_ms: 1000,
}),
})
} else {
Ok(99)
}
}
},
"retry_success_test",
)
.await;
assert_eq!(result.expect("should succeed on retry"), 99);
assert_eq!(attempts.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_execute_with_circuit_breaker_succeeds() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
let result =
execute_with_circuit_breaker(&breaker, || async { Ok::<_, TlsError>(42) }, "test")
.await;
assert_eq!(result.expect("should succeed"), 42);
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_execute_with_circuit_breaker_records_failure_fails() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
let result: Result<i32, TlsError> = execute_with_circuit_breaker(
&breaker,
|| async {
Err(TlsError::Internal {
message: "fail".to_string(),
code: ErrorCode::InternalError,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
})
},
"test",
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_execute_with_circuit_breaker_open_returns_error() {
let breaker = CircuitBreaker::new(2, Duration::from_secs(60));
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
let result: Result<i32, TlsError> =
execute_with_circuit_breaker(&breaker, || async { Ok(42) }, "blocked_test").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), ErrorCode::TooManyConnections);
}
#[tokio::test]
async fn test_execute_with_fallback_primary_succeeds_without_fallback_succeeds() {
let strategy = FallbackStrategy::hybrid_to_classical();
let result = execute_with_fallback(
&strategy,
|| async { Ok::<_, TlsError>(42) },
|| async { Ok(99) },
"test",
)
.await;
assert_eq!(result.expect("primary should succeed"), 42);
}
#[tokio::test]
async fn test_execute_with_fallback_uses_fallback_on_error_succeeds() {
let strategy = FallbackStrategy::hybrid_to_classical();
let result = execute_with_fallback(
&strategy,
|| async {
Err::<i32, _>(TlsError::Config {
message: "PQ not available".to_string(),
field: None,
code: ErrorCode::PqNotAvailable,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
})
},
|| async { Ok(99) },
"fallback_test",
)
.await;
assert_eq!(result.expect("fallback should succeed"), 99);
}
#[tokio::test]
async fn test_execute_with_fallback_does_not_trigger_on_non_matching_error_fails() {
let strategy = FallbackStrategy::None;
let result: Result<i32, TlsError> = execute_with_fallback(
&strategy,
|| async {
Err(TlsError::Internal {
message: "fail".to_string(),
code: ErrorCode::InternalError,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
})
},
|| async { Ok(99) },
"no_fallback_test",
)
.await;
assert!(result.is_err());
}
#[test]
fn test_fallback_pq_to_hybrid_triggers_on_hybrid_kem_failed_succeeds() {
let strategy = FallbackStrategy::pq_to_hybrid();
let err = TlsError::Config {
message: "kem failed".to_string(),
field: None,
code: ErrorCode::HybridKemFailed,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(strategy.should_fallback(&err));
}
#[test]
fn test_fallback_pq_to_hybrid_does_not_trigger_on_pq_not_available_error_fails() {
let strategy = FallbackStrategy::pq_to_hybrid();
let err = TlsError::Config {
message: "PQ not available".to_string(),
field: None,
code: ErrorCode::PqNotAvailable,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(!strategy.should_fallback(&err));
}
#[test]
fn test_fallback_hybrid_to_classical_triggers_on_hybrid_kem_failed_succeeds() {
let strategy = FallbackStrategy::hybrid_to_classical();
let err = TlsError::Config {
message: "kem failed".to_string(),
field: None,
code: ErrorCode::HybridKemFailed,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(strategy.should_fallback(&err));
}
#[test]
fn test_retry_policy_retries_tls_invalid_handshake_fails() {
let policy = RetryPolicy::default();
let err = TlsError::Tls {
message: "invalid handshake".to_string(),
code: ErrorCode::InvalidHandshakeMessage,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(policy.should_retry(&err, 1));
}
#[test]
fn test_retry_policy_retries_tls_handshake_timeout_succeeds() {
let policy = RetryPolicy::default();
let err = TlsError::Tls {
message: "timeout".to_string(),
code: ErrorCode::HandshakeTimeout,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(policy.should_retry(&err, 1));
}
#[test]
fn test_retry_policy_retries_handshake_protocol_version_succeeds() {
let policy = RetryPolicy::default();
let err = TlsError::Handshake {
message: "version mismatch".to_string(),
state: "ClientHello".to_string(),
code: ErrorCode::ProtocolVersionMismatch,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(policy.should_retry(&err, 1));
}
#[test]
fn test_retry_policy_retries_handshake_timeout_succeeds() {
let policy = RetryPolicy::default();
let err = TlsError::Handshake {
message: "timeout".to_string(),
state: "ServerHello".to_string(),
code: ErrorCode::HandshakeTimeout,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(policy.should_retry(&err, 1));
}
#[test]
fn test_retry_policy_retries_kex_encapsulation_succeeds() {
let policy = RetryPolicy::default();
let err = TlsError::KeyExchange {
message: "encap failed".to_string(),
method: "ML-KEM".to_string(),
operation: Some("encapsulate".to_string()),
code: ErrorCode::EncapsulationFailed,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(policy.should_retry(&err, 1));
}
#[test]
fn test_retry_policy_does_not_retry_non_retryable_io_succeeds() {
let policy = RetryPolicy::default();
let err = TlsError::Io {
message: "not found".to_string(),
source: None,
code: ErrorCode::IoError,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(!policy.should_retry(&err, 1));
}
#[test]
fn test_retry_policy_does_not_retry_config_error_fails() {
let policy = RetryPolicy::default();
let err = TlsError::Config {
message: "invalid".to_string(),
field: None,
code: ErrorCode::InvalidConfig,
context: Box::default(),
recovery: Box::new(RecoveryHint::NoRecovery),
};
assert!(!policy.should_retry(&err, 1));
}
#[test]
fn test_degradation_config_custom_has_correct_values_succeeds() {
let config = DegradationConfig {
enable_fallback: false,
allow_reduced_security: true,
max_degradation_attempts: 5,
};
assert!(!config.enable_fallback);
assert!(config.allow_reduced_security);
assert_eq!(config.max_degradation_attempts, 5);
}
#[test]
fn test_circuit_state_debug_has_correct_format() {
let state = CircuitState::HalfOpen;
let debug = format!("{:?}", state);
assert!(debug.contains("HalfOpen"));
}
#[test]
fn test_circuit_state_clone_copy_succeeds() {
let state = CircuitState::Open;
let cloned = state;
let copied = state;
assert_eq!(state, cloned);
assert_eq!(state, copied);
}
#[test]
fn test_retry_policy_clone_debug_succeeds() {
let policy = RetryPolicy::default();
let cloned = policy.clone();
assert_eq!(cloned.max_attempts, policy.max_attempts);
let debug = format!("{:?}", policy);
assert!(debug.contains("RetryPolicy"));
}
#[test]
fn test_fallback_strategy_clone_debug_succeeds() {
let strategy = FallbackStrategy::HybridToClassical;
let cloned = strategy.clone();
assert!(matches!(cloned, FallbackStrategy::HybridToClassical));
let debug = format!("{:?}", strategy);
assert!(debug.contains("HybridToClassical"));
}
}