use std::time::{Duration, Instant};
use azure_core::http::{headers::Headers, Method};
use url::Url;
use crate::{
diagnostics::{ExecutionContext, RequestSentStatus},
driver::{
jitter::with_jitter,
routing::{CosmosEndpoint, LocationIndex},
transport::AuthorizationContext,
},
models::{CosmosResponseHeaders, CosmosStatus},
options::Region,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum TransportMode {
Gateway,
Gateway20,
}
#[derive(Clone, Debug)]
pub(crate) struct RoutingDecision {
pub endpoint: CosmosEndpoint,
pub selected_url: Url,
pub transport_mode: TransportMode,
}
impl std::fmt::Display for RoutingDecision {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(region) = self.endpoint.region() {
write!(f, "{}({})", region, self.selected_url)
} else {
write!(f, "{}", self.selected_url)
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct OperationRetryState {
pub location: LocationIndex,
pub failover_retry_count: u32,
pub session_token_retry_count: u32,
pub max_failover_retries: u32,
pub max_session_retries: u32,
pub can_use_multiple_write_locations: bool,
pub excluded_regions: Vec<Region>,
pub session_retry_routing: SessionRetryRouting,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum SessionRetryRouting {
PreferredEndpoints,
PreferredWriteEndpoints,
}
impl OperationRetryState {
pub fn initial(
generation: u64,
can_use_multiple_write_locations: bool,
excluded_regions: Vec<Region>,
max_failover_retries: u32,
max_session_retries: u32,
) -> Self {
Self {
location: LocationIndex::initial(generation),
failover_retry_count: 0,
session_token_retry_count: 0,
max_failover_retries,
max_session_retries,
can_use_multiple_write_locations,
excluded_regions,
session_retry_routing: SessionRetryRouting::PreferredEndpoints,
}
}
pub fn can_retry_failover(&self) -> bool {
self.failover_retry_count < self.max_failover_retries
}
pub fn can_retry_session(&self) -> bool {
self.session_token_retry_count < self.max_session_retries
}
pub fn advance_failover(self) -> Self {
Self {
failover_retry_count: self.failover_retry_count + 1,
session_retry_routing: SessionRetryRouting::PreferredEndpoints,
..self
}
}
pub fn advance_session_retry(self) -> Self {
Self {
session_token_retry_count: self.session_token_retry_count + 1,
session_retry_routing: if self.can_use_multiple_write_locations {
SessionRetryRouting::PreferredEndpoints
} else {
SessionRetryRouting::PreferredWriteEndpoints
},
..self
}
}
pub fn advance_location(self, list_len: usize, generation: u64) -> Self {
Self {
location: self.location.next_for_generation(list_len, generation),
..self
}
}
pub fn route_reads_to_write_endpoints(&self) -> bool {
matches!(
self.session_retry_routing,
SessionRetryRouting::PreferredWriteEndpoints
)
}
}
#[derive(Debug)]
pub(crate) struct TransportRequest {
pub method: Method,
pub endpoint: CosmosEndpoint,
pub url: Url,
pub headers: Headers,
pub body: Option<azure_core::Bytes>,
pub auth_context: AuthorizationContext,
pub execution_context: ExecutionContext,
pub deadline: Option<Instant>,
}
#[derive(Clone, Debug)]
pub(crate) struct ThrottleRetryState {
pub attempt_count: u32,
pub max_attempts: u32,
pub cumulative_delay: Duration,
pub max_wait_time: Duration,
pub max_per_retry_delay: Duration,
pub fallback_base_delay: Duration,
pub backoff_factor: f64,
pub backoff_jitter_ratio: f64,
pub forced_final_retry_used: bool,
}
const DEFAULT_MAX_THROTTLE_ATTEMPTS: u32 = 9;
const DEFAULT_MAX_THROTTLE_WAIT: Duration = Duration::from_secs(30);
const DEFAULT_MAX_PER_RETRY_DELAY: Duration = Duration::from_secs(5);
const DEFAULT_FALLBACK_BASE_DELAY: Duration = Duration::from_millis(5);
const DEFAULT_BACKOFF_FACTOR: f64 = 2.0;
const DEFAULT_BACKOFF_JITTER_RATIO: f64 = 0.25;
impl ThrottleRetryState {
pub fn new() -> Self {
Self {
attempt_count: 0,
max_attempts: DEFAULT_MAX_THROTTLE_ATTEMPTS,
cumulative_delay: Duration::ZERO,
max_wait_time: DEFAULT_MAX_THROTTLE_WAIT,
max_per_retry_delay: DEFAULT_MAX_PER_RETRY_DELAY,
fallback_base_delay: DEFAULT_FALLBACK_BASE_DELAY,
backoff_factor: DEFAULT_BACKOFF_FACTOR,
backoff_jitter_ratio: DEFAULT_BACKOFF_JITTER_RATIO,
forced_final_retry_used: false,
}
}
pub fn can_use_forced_final_retry(&self) -> bool {
!self.forced_final_retry_used
}
pub fn mark_forced_final_retry_used(&self) -> Self {
Self {
forced_final_retry_used: true,
..self.clone()
}
}
fn fallback_exponential_delay(&self) -> Duration {
let multiplier = self.backoff_factor.powi(self.attempt_count as i32);
self.fallback_base_delay.mul_f64(multiplier)
}
pub fn fallback_delay(&self) -> Duration {
let base_delay = self.fallback_exponential_delay();
let ratio = self.backoff_jitter_ratio.clamp(0.0, 1.0);
if ratio == 0.0 || base_delay.is_zero() {
return base_delay;
}
Duration::from_secs_f64(with_jitter(base_delay.as_secs_f64(), ratio))
}
}
#[derive(Debug)]
pub(crate) struct TransportResult {
pub outcome: TransportOutcome,
}
impl TransportResult {
pub fn deadline_exceeded(request_sent: RequestSentStatus) -> Self {
Self {
outcome: TransportOutcome::DeadlineExceeded { request_sent },
}
}
pub fn from_http_response(
status: CosmosStatus,
headers: Headers,
cosmos_headers: CosmosResponseHeaders,
body: Vec<u8>,
) -> Self {
if status.is_success() {
Self {
outcome: TransportOutcome::Success {
status,
cosmos_headers,
body,
},
}
} else {
Self {
outcome: TransportOutcome::HttpError {
status,
headers,
cosmos_headers,
body,
request_sent: RequestSentStatus::Sent,
},
}
}
}
pub fn cosmos_headers(&self) -> Option<&CosmosResponseHeaders> {
match &self.outcome {
TransportOutcome::Success { cosmos_headers, .. } => Some(cosmos_headers),
TransportOutcome::HttpError { cosmos_headers, .. } => Some(cosmos_headers),
TransportOutcome::TransportError { .. } | TransportOutcome::DeadlineExceeded { .. } => {
None
}
}
}
pub fn response_headers(&self) -> Option<&Headers> {
match &self.outcome {
TransportOutcome::HttpError { headers, .. } => Some(headers),
_ => None,
}
}
}
pub(crate) enum TransportOutcome {
Success {
status: CosmosStatus,
cosmos_headers: CosmosResponseHeaders,
body: Vec<u8>,
},
HttpError {
status: CosmosStatus,
headers: Headers,
cosmos_headers: CosmosResponseHeaders,
body: Vec<u8>,
request_sent: RequestSentStatus,
},
TransportError {
status: CosmosStatus,
error: azure_core::Error,
request_sent: RequestSentStatus,
},
DeadlineExceeded { request_sent: RequestSentStatus },
}
impl std::fmt::Display for TransportOutcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportOutcome::Success { status, .. } => write!(f, "Success({})", status),
TransportOutcome::HttpError { status, .. } => write!(f, "HttpError({})", status),
TransportOutcome::TransportError { error, .. } => {
write!(f, "TransportError({})", error)
}
TransportOutcome::DeadlineExceeded { .. } => write!(f, "DeadlineExceeded"),
}
}
}
impl std::fmt::Debug for TransportOutcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportOutcome::Success { status, .. } => f
.debug_struct("Success")
.field("status", status)
.field("body", &"...")
.finish(),
TransportOutcome::HttpError {
status, headers, ..
} => f
.debug_struct("HttpError")
.field("status", status)
.field("headers", headers)
.field("body", &"...")
.finish(),
TransportOutcome::TransportError {
error,
request_sent,
..
} => f
.debug_struct("TransportError")
.field("error", error)
.field("request_sent", request_sent)
.finish(),
TransportOutcome::DeadlineExceeded { request_sent } => f
.debug_struct("DeadlineExceeded")
.field("request_sent", request_sent)
.finish(),
}
}
}
#[derive(Debug)]
pub(crate) enum OperationAction {
Complete(Box<TransportResult>),
FailoverRetry {
new_state: OperationRetryState,
delay: Option<Duration>,
},
SessionRetry { new_state: OperationRetryState },
Abort {
error: azure_core::Error,
status: Option<CosmosStatus>,
},
}
#[derive(Debug)]
pub(crate) enum ThrottleAction {
Retry {
delay: Duration,
new_state: ThrottleRetryState,
},
Propagate,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn throttle_retry_state_defaults() {
let state = ThrottleRetryState::new();
assert_eq!(state.attempt_count, 0);
assert_eq!(state.max_attempts, 9);
assert_eq!(state.max_wait_time, Duration::from_secs(30));
assert_eq!(state.fallback_base_delay, Duration::from_millis(5));
assert_eq!(state.max_per_retry_delay, Duration::from_secs(5));
assert_eq!(state.backoff_jitter_ratio, 0.25);
assert!(!state.forced_final_retry_used);
}
#[test]
fn throttle_retry_state_marks_forced_final_retry_as_used() {
let state = ThrottleRetryState::new();
assert!(state.can_use_forced_final_retry());
let updated = state.mark_forced_final_retry_used();
assert!(!updated.can_use_forced_final_retry());
}
#[test]
fn throttle_retry_fallback_backoff_with_jitter_bounds() {
let state = ThrottleRetryState::new();
let delay = state.fallback_delay();
assert!(delay >= Duration::from_nanos(3_750_000));
assert!(delay <= Duration::from_nanos(6_250_000));
let state = ThrottleRetryState {
attempt_count: 1,
..ThrottleRetryState::new()
};
let delay = state.fallback_delay();
assert!(delay >= Duration::from_nanos(7_500_000));
assert!(delay <= Duration::from_nanos(12_500_000));
let state = ThrottleRetryState {
attempt_count: 5,
..ThrottleRetryState::new()
};
let delay = state.fallback_delay();
assert!(delay >= Duration::from_millis(120));
assert!(delay <= Duration::from_millis(200));
}
#[test]
fn throttle_retry_fallback_exponential_when_jitter_disabled() {
let state = ThrottleRetryState {
backoff_jitter_ratio: 0.0,
..ThrottleRetryState::new()
};
assert_eq!(state.fallback_delay(), Duration::from_millis(5));
let state = ThrottleRetryState {
attempt_count: 1,
backoff_jitter_ratio: 0.0,
..ThrottleRetryState::new()
};
assert_eq!(state.fallback_delay(), Duration::from_millis(10));
let state = ThrottleRetryState {
attempt_count: 5,
backoff_jitter_ratio: 0.0,
..ThrottleRetryState::new()
};
assert_eq!(state.fallback_delay(), Duration::from_millis(160));
}
#[test]
fn operation_retry_state_budget() {
let state = OperationRetryState::initial(0, false, Vec::new(), 1, 1);
assert!(state.can_retry_failover());
let state = state.advance_failover();
assert!(!state.can_retry_failover());
}
#[test]
fn advance_session_retry_single_write_routes_to_write_endpoints() {
let state = OperationRetryState::initial(0, false, Vec::new(), 3, 2);
assert_eq!(
state.session_retry_routing,
SessionRetryRouting::PreferredEndpoints
);
let state = state.advance_session_retry();
assert_eq!(state.session_token_retry_count, 1);
assert_eq!(
state.session_retry_routing,
SessionRetryRouting::PreferredWriteEndpoints
);
}
#[test]
fn advance_session_retry_multi_write_stays_on_preferred_endpoints() {
let state = OperationRetryState::initial(0, true, Vec::new(), 3, 2);
let state = state.advance_session_retry();
assert_eq!(state.session_token_retry_count, 1);
assert_eq!(
state.session_retry_routing,
SessionRetryRouting::PreferredEndpoints
);
}
}