use super::{
get_substatus_code_from_error, get_substatus_code_from_response, is_non_retryable_status_code,
RetryResult,
};
use crate::constants::SubStatusCode;
use crate::cosmos_request::CosmosRequest;
use crate::operation_context::OperationType;
use crate::regions::Region;
use crate::retry_policies::resource_throttle_retry_policy::ResourceThrottleRetryPolicy;
use crate::routing::global_endpoint_manager::GlobalEndpointManager;
use azure_core::http::{RawResponse, StatusCode};
use azure_core::time::Duration;
use std::sync::Arc;
use tracing::trace;
#[derive(Debug)]
pub(crate) struct MetadataRequestRetryPolicy {
global_endpoint_manager: Arc<GlobalEndpointManager>,
throttling_retry_policy: ResourceThrottleRetryPolicy,
retry_context: Option<MetadataRetryContext>,
unavailable_endpoint_retry_count: usize,
excluded_regions: Option<Vec<Region>>,
}
#[derive(Clone, Debug)]
struct MetadataRetryContext {
retry_location_index: usize,
retry_request_on_preferred_locations: bool,
}
impl MetadataRequestRetryPolicy {
pub fn new(global_endpoint_manager: Arc<GlobalEndpointManager>) -> Self {
Self {
global_endpoint_manager,
throttling_retry_policy: ResourceThrottleRetryPolicy::new(5, 200, 10),
retry_context: None,
unavailable_endpoint_retry_count: 0,
excluded_regions: None,
}
}
pub(crate) async fn before_send_request(&mut self, request: &mut CosmosRequest) {
let _stat = self.global_endpoint_manager.refresh_location(false).await;
self.excluded_regions = request.excluded_regions.clone().map(|e| e.0);
request.request_context.clear_route_to_location();
if let Some(ref ctx) = self.retry_context {
let mut req_ctx = request.request_context.clone();
req_ctx.route_to_location_index(
ctx.retry_location_index,
ctx.retry_request_on_preferred_locations,
);
request.request_context = req_ctx;
}
let metadata_location_endpoint = self
.global_endpoint_manager
.resolve_service_endpoint(request);
trace!(
"MetadataRequestThrottleRetryPolicy: Routing the metadata request to: {:?} for operation type: {:?} and resource type: {:?}.",
metadata_location_endpoint,
request.operation_type,
request.resource_type
);
request
.request_context
.route_to_location_endpoint(metadata_location_endpoint);
}
pub(crate) async fn should_retry(
&mut self,
response: &azure_core::Result<RawResponse>,
) -> RetryResult {
match response {
Ok(resp) if resp.status().is_server_error() || resp.status().is_client_error() => {
self.should_retry_response(resp).await
}
Ok(_) => RetryResult::DoNotRetry,
Err(err) => self.should_retry_error(err).await,
}
}
pub async fn should_retry_error(&mut self, err: &azure_core::Error) -> RetryResult {
let status_code = err.http_status().unwrap_or(StatusCode::UnknownValue(0));
let sub_status_code = get_substatus_code_from_error(err);
let retry_result = self.should_retry_with_status_code(status_code, sub_status_code);
if retry_result.is_retry() {
return retry_result;
}
self.throttling_retry_policy.should_retry_error(err)
}
pub async fn should_retry_response(&mut self, response: &RawResponse) -> RetryResult {
let status_code = response.status();
let sub_status_code = get_substatus_code_from_response(&response.clone());
let retry_result = self.should_retry_with_status_code(status_code, sub_status_code);
if retry_result.is_retry() {
return retry_result;
}
self.throttling_retry_policy.should_retry_response(response)
}
fn should_retry_with_status_code(
&mut self,
status_code: StatusCode,
sub_status_code: Option<SubStatusCode>,
) -> RetryResult {
if !is_non_retryable_status_code(status_code, sub_status_code)
&& self.increment_retry_index_on_unavailable_endpoint_for_metadata_read()
{
return RetryResult::Retry {
after: Duration::ZERO,
};
}
RetryResult::DoNotRetry
}
fn increment_retry_index_on_unavailable_endpoint_for_metadata_read(&mut self) -> bool {
self.unavailable_endpoint_retry_count += 1;
let endpoints = self
.global_endpoint_manager
.applicable_endpoints(OperationType::Read, self.excluded_regions.as_ref());
if self.unavailable_endpoint_retry_count > endpoints.len() {
trace!(
"MetadataRequestThrottleRetryPolicy: Retry count: {} has exceeded the number of applicable endpoints: {}.",
self.unavailable_endpoint_retry_count,
endpoints.len()
);
return false;
}
trace!(
"MetadataRequestThrottleRetryPolicy: Incrementing the metadata retry location index to: {}.",
self.unavailable_endpoint_retry_count
);
self.retry_context = Some(MetadataRetryContext {
retry_location_index: self.unavailable_endpoint_retry_count,
retry_request_on_preferred_locations: true,
});
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::operation_context::OperationType;
use crate::options::ExcludedRegions;
use crate::partition_key::PartitionKey;
use crate::regions::Region;
use crate::resource_context::{ResourceLink, ResourceType};
use crate::routing::global_endpoint_manager::GlobalEndpointManager;
use azure_core::http::headers::Headers;
use azure_core::http::ClientOptions;
use azure_core::Bytes;
use std::sync::Arc;
fn create_test_endpoint_manager() -> Arc<GlobalEndpointManager> {
let pipeline = azure_core::http::Pipeline::new(
option_env!("CARGO_PKG_NAME"),
option_env!("CARGO_PKG_VERSION"),
ClientOptions::default(),
Vec::new(),
Vec::new(),
None,
);
GlobalEndpointManager::new(
"https://test.documents.azure.com".parse().unwrap(),
vec![Region::from("West US"), Region::from("East US")],
vec![],
pipeline,
)
}
fn create_test_endpoint_manager_no_locations() -> Arc<GlobalEndpointManager> {
let pipeline = azure_core::http::Pipeline::new(
option_env!("CARGO_PKG_NAME"),
option_env!("CARGO_PKG_VERSION"),
ClientOptions::default(),
Vec::new(),
Vec::new(),
None,
);
GlobalEndpointManager::new(
"https://test.documents.azure.com".parse().unwrap(),
vec![],
vec![],
pipeline,
)
}
fn create_test_endpoint_manager_with_preferred_locations() -> Arc<GlobalEndpointManager> {
let pipeline = azure_core::http::Pipeline::new(
option_env!("CARGO_PKG_NAME"),
option_env!("CARGO_PKG_VERSION"),
ClientOptions::default(),
Vec::new(),
Vec::new(),
None,
);
GlobalEndpointManager::new(
"https://test.documents.azure.com".parse().unwrap(),
vec![Region::EAST_ASIA, Region::WEST_US, Region::NORTH_CENTRAL_US],
vec![],
pipeline,
)
}
fn create_test_policy() -> MetadataRequestRetryPolicy {
let manager = create_test_endpoint_manager();
MetadataRequestRetryPolicy::new(manager)
}
fn create_test_policy_no_locations() -> MetadataRequestRetryPolicy {
let manager = create_test_endpoint_manager_no_locations();
MetadataRequestRetryPolicy::new(manager)
}
fn create_test_policy_with_preferred_locations() -> MetadataRequestRetryPolicy {
let manager = create_test_endpoint_manager_with_preferred_locations();
MetadataRequestRetryPolicy::new(manager)
}
fn create_test_request() -> CosmosRequest {
let resource_link = ResourceLink::root(ResourceType::Documents);
let mut request = CosmosRequest::builder(OperationType::Read, resource_link.clone())
.partition_key(PartitionKey::from("test"))
.build()
.unwrap();
request.request_context.location_endpoint_to_route =
Some("https://test.documents.azure.com".parse().unwrap());
request
}
fn create_raw_response(status_code: StatusCode) -> RawResponse {
let headers = Headers::new();
RawResponse::from_bytes(status_code, headers, Bytes::new())
}
fn create_error_with_status(status: StatusCode) -> azure_core::Error {
let response = create_raw_response(status);
azure_core::Error::new(
azure_core::error::ErrorKind::HttpResponse {
status: response.status(),
error_code: None,
raw_response: Some(Box::new(response)),
},
"Test error",
)
}
#[tokio::test]
async fn test_new_policy_initialization() {
let policy = create_test_policy_with_preferred_locations();
assert_eq!(policy.unavailable_endpoint_retry_count, 0);
assert!(policy.excluded_regions.is_none());
}
#[tokio::test]
async fn test_retry_context_none_initially() {
let policy = create_test_policy();
assert!(policy.retry_context.is_none());
}
#[tokio::test]
async fn test_should_retry_service_unavailable_error() {
let mut policy = create_test_policy_no_locations();
let error = create_error_with_status(StatusCode::ServiceUnavailable);
let result = policy.should_retry_error(&error).await;
assert!(result.is_retry());
if let RetryResult::Retry { after } = result {
assert_eq!(after, Duration::ZERO);
}
}
#[tokio::test]
async fn test_should_retry_internal_server_error() {
let mut policy = create_test_policy_with_preferred_locations();
let error = create_error_with_status(StatusCode::InternalServerError);
let result = policy.should_retry_error(&error).await;
assert!(result.is_retry());
}
#[tokio::test]
async fn test_should_retry_service_unavailable_response() {
let mut policy = create_test_policy_with_preferred_locations();
let response = create_raw_response(StatusCode::ServiceUnavailable);
let result = policy.should_retry_response(&response).await;
assert!(result.is_retry());
}
#[tokio::test]
async fn test_should_retry_internal_server_error_response() {
let mut policy = create_test_policy_with_preferred_locations();
let response = create_raw_response(StatusCode::InternalServerError);
let result = policy.should_retry_response(&response).await;
assert!(result.is_retry());
}
#[tokio::test]
async fn test_should_not_retry_ok_response() {
let mut policy = create_test_policy();
let response = create_raw_response(StatusCode::Ok);
let result = policy.should_retry(&Ok(response)).await;
assert!(!result.is_retry());
}
#[tokio::test]
async fn test_should_not_retry_created_response() {
let mut policy = create_test_policy();
let response = create_raw_response(StatusCode::Created);
let result = policy.should_retry(&Ok(response)).await;
assert!(!result.is_retry());
}
#[tokio::test]
async fn test_increment_retry_index_on_unavailable_endpoint() {
let mut policy = create_test_policy_with_preferred_locations();
let initial_count = policy.unavailable_endpoint_retry_count;
let result = policy.increment_retry_index_on_unavailable_endpoint_for_metadata_read();
assert!(result);
assert_eq!(policy.unavailable_endpoint_retry_count, initial_count + 1);
assert!(policy.retry_context.is_some());
}
#[tokio::test]
async fn test_increment_retry_exceeds_max_count() {
let mut policy = create_test_policy_no_locations();
assert!(policy.increment_retry_index_on_unavailable_endpoint_for_metadata_read());
let result = policy.increment_retry_index_on_unavailable_endpoint_for_metadata_read();
assert!(!result);
}
#[tokio::test]
async fn test_retry_context_set_after_increment() {
let mut policy = create_test_policy_no_locations();
policy.increment_retry_index_on_unavailable_endpoint_for_metadata_read();
assert!(policy.retry_context.is_some());
if let Some(ctx) = &policy.retry_context {
assert!(ctx.retry_request_on_preferred_locations);
assert_eq!(
ctx.retry_location_index,
policy.unavailable_endpoint_retry_count
);
}
}
#[tokio::test]
async fn test_should_retry_with_ok_result() {
let mut policy = create_test_policy();
let response = create_raw_response(StatusCode::Ok);
let result = policy.should_retry(&Ok(response)).await;
assert!(!result.is_retry());
}
#[tokio::test]
async fn test_should_retry_with_server_error_result() {
let mut policy = create_test_policy_no_locations();
let response = create_raw_response(StatusCode::InternalServerError);
let result = policy.should_retry(&Ok(response)).await;
assert!(result.is_retry());
}
#[tokio::test]
async fn test_should_retry_with_error_result() {
let mut policy = create_test_policy_no_locations();
let error = create_error_with_status(StatusCode::ServiceUnavailable);
let result = policy.should_retry(&Err(error)).await;
assert!(result.is_retry());
}
#[tokio::test]
async fn test_should_not_retry_bad_request() {
let mut policy = create_test_policy();
let response = create_raw_response(StatusCode::BadRequest);
let result = policy.should_retry_response(&response).await;
assert!(!result.is_retry());
}
#[tokio::test]
async fn test_should_not_retry_not_found() {
let mut policy = create_test_policy();
let response = create_raw_response(StatusCode::NotFound);
let result = policy.should_retry_response(&response).await;
assert!(!result.is_retry());
}
#[tokio::test]
async fn test_should_not_retry_unauthorized() {
let mut policy = create_test_policy();
let response = create_raw_response(StatusCode::Unauthorized);
let result = policy.should_retry_response(&response).await;
assert!(!result.is_retry());
}
#[tokio::test]
async fn test_should_not_retry_conflict() {
let mut policy = create_test_policy();
let response = create_raw_response(StatusCode::Conflict);
let result = policy.should_retry_response(&response).await;
assert!(!result.is_retry());
}
#[tokio::test]
async fn test_should_not_retry_precondition_failed() {
let mut policy = create_test_policy();
let response = create_raw_response(StatusCode::PreconditionFailed);
let result = policy.should_retry_response(&response).await;
assert!(!result.is_retry());
}
#[tokio::test]
async fn test_should_retry_forbidden_on_another_endpoint() {
let mut policy = create_test_policy_no_locations();
let response = create_raw_response(StatusCode::Forbidden);
let result = policy.should_retry_response(&response).await;
assert!(result.is_retry());
}
#[tokio::test]
async fn test_should_retry_gone_on_another_endpoint() {
let mut policy = create_test_policy_no_locations();
let response = create_raw_response(StatusCode::Gone);
let result = policy.should_retry_response(&response).await;
assert!(result.is_retry());
}
#[tokio::test]
async fn test_multiple_retries_increment_counter() {
let mut policy = create_test_policy_no_locations();
policy.unavailable_endpoint_retry_count = 0;
let initial_count = policy.unavailable_endpoint_retry_count;
let error1 = create_error_with_status(StatusCode::ServiceUnavailable);
let _result1 = policy.should_retry_error(&error1).await;
assert_eq!(policy.unavailable_endpoint_retry_count, initial_count + 1);
}
#[tokio::test]
async fn test_before_send_request_clears_routing() {
let mut policy = create_test_policy();
let mut request = create_test_request();
request.request_context.location_index_to_route = Some(5);
policy.before_send_request(&mut request).await;
assert!(request.request_context.location_endpoint_to_route.is_some());
}
#[tokio::test]
async fn test_retry_context_affects_routing() {
let mut policy = create_test_policy();
let mut request = create_test_request();
policy.retry_context = Some(MetadataRetryContext {
retry_location_index: 1,
retry_request_on_preferred_locations: true,
});
policy.before_send_request(&mut request).await;
assert!(request.request_context.location_endpoint_to_route.is_some());
}
#[tokio::test]
async fn test_policy_debug_format() {
let policy = create_test_policy();
let debug_str = format!("{:?}", policy);
assert!(debug_str.contains("MetadataRequestRetryPolicy"));
}
#[test]
fn test_retry_context_clone() {
let ctx = MetadataRetryContext {
retry_location_index: 3,
retry_request_on_preferred_locations: false,
};
let cloned = ctx.clone();
assert_eq!(ctx.retry_location_index, cloned.retry_location_index);
assert_eq!(
ctx.retry_request_on_preferred_locations,
cloned.retry_request_on_preferred_locations
);
}
#[tokio::test]
async fn test_before_send_request_captures_excluded_regions() {
let mut policy = create_test_policy_with_preferred_locations();
let resource_link = ResourceLink::root(ResourceType::Databases);
let mut request = CosmosRequest::builder(OperationType::Read, resource_link)
.partition_key(PartitionKey::from("test"))
.excluded_regions(Some(ExcludedRegions::from_iter([Region::EAST_ASIA])))
.build()
.unwrap();
request.request_context.location_endpoint_to_route =
Some("https://test.documents.azure.com".parse().unwrap());
policy.before_send_request(&mut request).await;
assert!(policy.excluded_regions.is_some());
assert_eq!(policy.excluded_regions.as_ref().unwrap().len(), 1);
assert_eq!(
policy.excluded_regions.as_ref().unwrap()[0],
Region::EAST_ASIA
);
}
#[tokio::test]
async fn test_excluded_regions_reduce_retry_attempts() {
let mut policy = create_test_policy_with_preferred_locations();
policy.excluded_regions = Some(vec![Region::EAST_ASIA, Region::WEST_US]);
let error = create_error_with_status(StatusCode::ServiceUnavailable);
let result = policy.should_retry_error(&error).await;
assert!(result.is_retry());
let result = policy.should_retry_error(&error).await;
assert!(
!result.is_retry(),
"Expected DoNotRetry after exhausting non-excluded endpoints"
);
}
}