use crate::{
clients::ClientContext, options::ThroughputControlGroupOptions, AccountReference, CosmosClient,
CosmosClientOptions, CosmosCredential, RoutingStrategy,
};
use azure_data_cosmos_driver::options::ConnectionPoolOptions;
#[cfg(all(feature = "allow_invalid_certificates", feature = "__tls",))]
use azure_data_cosmos_driver::options::EmulatorServerCertValidation;
use azure_data_cosmos_driver::CosmosDriverRuntimeBuilder;
use crate::constants::AZURE_COSMOS_PER_PARTITION_CIRCUIT_BREAKER_ENABLED;
#[derive(Default)]
pub struct CosmosClientBuilder {
options: CosmosClientOptions,
allow_proxy: bool,
throughput_control_groups: Vec<ThroughputControlGroupOptions>,
#[cfg(all(feature = "allow_invalid_certificates", feature = "__tls",))]
allow_emulator_invalid_certificates: bool,
#[cfg(feature = "fault_injection")]
fault_injection_rules:
Vec<std::sync::Arc<azure_data_cosmos_driver::fault_injection::FaultInjectionRule>>,
backup_endpoints: Vec<azure_core::http::Url>,
#[cfg(feature = "__internal_in_memory_emulator")]
driver_runtime_builder: Option<CosmosDriverRuntimeBuilder>,
}
impl CosmosClientBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_user_agent_suffix(mut self, suffix: crate::UserAgentSuffix) -> Self {
self.options.user_agent_suffix = Some(suffix);
self
}
#[doc(hidden)]
#[cfg(feature = "fault_injection")]
pub fn with_fault_injection(
mut self,
rules: Vec<std::sync::Arc<azure_data_cosmos_driver::fault_injection::FaultInjectionRule>>,
) -> Self {
self.fault_injection_rules = rules;
self
}
#[doc(hidden)]
#[cfg(all(feature = "allow_invalid_certificates", feature = "__tls",))]
pub fn with_allow_emulator_invalid_certificates(mut self, allow: bool) -> Self {
self.allow_emulator_invalid_certificates = allow;
self
}
pub fn with_proxy_allowed(mut self, allow: bool) -> Self {
self.allow_proxy = allow;
self
}
pub fn with_throughput_control_group(mut self, group: ThroughputControlGroupOptions) -> Self {
self.throughput_control_groups.push(group);
self
}
pub fn with_backup_endpoints(mut self, endpoints: Vec<crate::AccountEndpoint>) -> Self {
self.backup_endpoints = endpoints.into_iter().map(|e| e.into_url()).collect();
self
}
#[doc(hidden)]
#[cfg(feature = "__internal_in_memory_emulator")]
pub fn with_driver_runtime_builder(mut self, builder: CosmosDriverRuntimeBuilder) -> Self {
self.driver_runtime_builder = Some(builder);
self
}
pub async fn build(
self,
account: AccountReference,
routing_strategy: RoutingStrategy,
) -> crate::Result<CosmosClient> {
let (account_endpoint, credential) = account.into_parts();
let endpoint = account_endpoint.into_url();
let driver_credential = credential.clone();
#[cfg(feature = "fault_injection")]
let driver_fi_rules: Vec<
std::sync::Arc<azure_data_cosmos_driver::fault_injection::FaultInjectionRule>,
> = self.fault_injection_rules;
let ppcb_enabled = std::env::var(AZURE_COSMOS_PER_PARTITION_CIRCUIT_BREAKER_ENABLED)
.ok()
.and_then(|v| v.parse::<bool>().ok())
.unwrap_or(true);
let driver_user_agent_suffix = self.options.user_agent_suffix.clone();
let driver_account =
build_driver_account(endpoint, driver_credential, self.backup_endpoints);
#[cfg(feature = "__internal_in_memory_emulator")]
let mut driver_runtime_builder = self.driver_runtime_builder.unwrap_or_default();
#[cfg(not(feature = "__internal_in_memory_emulator"))]
let mut driver_runtime_builder = CosmosDriverRuntimeBuilder::new();
let mut pool_builder = ConnectionPoolOptions::builder();
if self.allow_proxy {
pool_builder = pool_builder.with_proxy_allowed(true);
}
#[cfg(all(feature = "allow_invalid_certificates", feature = "__tls",))]
if self.allow_emulator_invalid_certificates {
pool_builder = pool_builder.with_emulator_server_cert_validation(
EmulatorServerCertValidation::DangerousDisabled,
);
}
driver_runtime_builder = driver_runtime_builder.with_connection_pool(pool_builder.build()?);
driver_runtime_builder = driver_runtime_builder.with_wrapping_sdk_identifier(format!(
"azsdk-rust-cosmos/{}",
env!("CARGO_PKG_VERSION")
));
if let Some(suffix) = driver_user_agent_suffix {
driver_runtime_builder = driver_runtime_builder.with_user_agent_suffix(suffix);
}
let runtime_operation_options =
azure_data_cosmos_driver::options::OperationOptionsBuilder::new()
.with_per_partition_circuit_breaker_enabled(ppcb_enabled)
.build();
driver_runtime_builder =
driver_runtime_builder.with_operation_options(runtime_operation_options);
#[cfg(feature = "fault_injection")]
if !driver_fi_rules.is_empty() {
driver_runtime_builder =
driver_runtime_builder.with_fault_injection_rules(driver_fi_rules)?;
}
for group in self.throughput_control_groups {
driver_runtime_builder = driver_runtime_builder
.register_throughput_control_group(group)
.map_err(|e| {
crate::DriverCosmosError::builder()
.with_status(crate::CosmosStatus::CLIENT_THROUGHPUT_CONTROL_GROUP_REGISTRATION_FAILED)
.with_message(format!("failed to register throughput control group: {e}"))
.build()
})?;
}
let driver_runtime = driver_runtime_builder.build().await?;
let driver_options = build_driver_options(driver_account, routing_strategy);
let driver = driver_runtime
.get_or_create_driver(driver_options.account().clone(), Some(driver_options))
.await?;
Ok(CosmosClient {
context: ClientContext { driver },
})
}
}
fn build_driver_options(
account: azure_data_cosmos_driver::models::AccountReference,
strategy: RoutingStrategy,
) -> azure_data_cosmos_driver::options::DriverOptions {
let preferred_regions = match strategy {
RoutingStrategy::ProximityTo(region) =>
crate::region_proximity::generate_preferred_region_list(®ion)
.map(|s| s.to_vec())
.unwrap_or_else(|| {
tracing::warn!(
region = %region,
"unrecognized application region; falling back to account-defined region order"
);
Vec::new()
}),
RoutingStrategy::PreferredRegions(regions) => regions,
};
azure_data_cosmos_driver::options::DriverOptions::builder(account)
.with_preferred_regions(preferred_regions)
.build()
}
fn build_driver_account(
endpoint: azure_core::http::Url,
credential: CosmosCredential,
backup_endpoints: Vec<azure_core::http::Url>,
) -> azure_data_cosmos_driver::models::AccountReference {
let base = match credential {
CosmosCredential::TokenCredential(tc) => {
azure_data_cosmos_driver::models::AccountReference::with_credential(endpoint, tc)
}
#[cfg(feature = "key_auth")]
CosmosCredential::MasterKey(key) => {
azure_data_cosmos_driver::models::AccountReference::with_master_key(endpoint, key)
}
};
base.with_backup_endpoints(backup_endpoints)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Region, UserAgentSuffix};
#[tokio::test]
async fn user_agent_suffix_is_forwarded_to_driver_runtime() {
let suffix = UserAgentSuffix::new("myapp-westus2");
let options = CosmosClientOptions {
user_agent_suffix: Some(suffix.clone()),
..Default::default()
};
let mut driver_builder = CosmosDriverRuntimeBuilder::new();
if let Some(s) = options.user_agent_suffix.clone() {
driver_builder = driver_builder.with_user_agent_suffix(s);
}
let runtime = driver_builder.build().await.expect("runtime builds");
assert_eq!(
runtime.user_agent_suffix(),
Some(&suffix),
"driver runtime did not receive the user-agent suffix"
);
assert!(
runtime.user_agent().as_str().contains(suffix.as_str()),
"computed driver user-agent {:?} does not contain suffix {:?}",
runtime.user_agent().as_str(),
suffix.as_str(),
);
}
#[tokio::test]
async fn no_user_agent_suffix_yields_no_driver_suffix() {
let runtime = CosmosDriverRuntimeBuilder::new()
.build()
.await
.expect("runtime builds");
assert!(runtime.user_agent_suffix().is_none());
}
#[tokio::test]
async fn ppcb_default_is_enabled_when_env_var_unset() {
let ppcb_enabled = Option::<String>::None
.and_then(|v: String| v.parse::<bool>().ok())
.unwrap_or(true);
assert!(
ppcb_enabled,
"SDK's PPCB default must be `true` when env var is unset"
);
let runtime_op_options = azure_data_cosmos_driver::options::OperationOptionsBuilder::new()
.with_per_partition_circuit_breaker_enabled(ppcb_enabled)
.build();
let runtime = CosmosDriverRuntimeBuilder::new()
.with_operation_options(runtime_op_options)
.build()
.await
.expect("runtime builds");
assert_eq!(
runtime
.operation_options()
.per_partition_circuit_breaker_enabled,
Some(true),
"PPCB must be enabled by default on a CosmosClient-built runtime"
);
}
#[tokio::test]
async fn ppcb_can_be_opted_out_via_env_var() {
let ppcb_enabled = Some("false".to_string())
.and_then(|v| v.parse::<bool>().ok())
.unwrap_or(true);
assert!(!ppcb_enabled, "env var `false` must opt out of PPCB");
let runtime_op_options = azure_data_cosmos_driver::options::OperationOptionsBuilder::new()
.with_per_partition_circuit_breaker_enabled(ppcb_enabled)
.build();
let runtime = CosmosDriverRuntimeBuilder::new()
.with_operation_options(runtime_op_options)
.build()
.await
.expect("runtime builds");
assert_eq!(
runtime
.operation_options()
.per_partition_circuit_breaker_enabled,
Some(false),
"explicit env-var opt-out must propagate to the driver runtime"
);
}
fn test_account() -> azure_data_cosmos_driver::models::AccountReference {
azure_data_cosmos_driver::models::AccountReference::with_master_key(
"https://test.documents.azure.com/".parse().unwrap(),
"dGVzdA==",
)
}
#[test]
fn proximity_to_known_region_starts_with_source() {
let opts = build_driver_options(
test_account(),
RoutingStrategy::ProximityTo(Region::EAST_US),
);
let regions = opts.preferred_regions();
assert!(
!regions.is_empty(),
"should produce a non-empty list for a known region"
);
assert_eq!(regions[0], Region::EAST_US, "source region should be first");
}
#[test]
fn proximity_to_unknown_region_returns_empty_list() {
let opts = build_driver_options(
test_account(),
RoutingStrategy::ProximityTo(Region::from("not-a-real-region")),
);
assert!(
opts.preferred_regions().is_empty(),
"unrecognized region should yield an empty list"
);
}
#[test]
fn preferred_regions_passes_through_unchanged() {
let input = vec![Region::WEST_US, Region::EAST_US, Region::WEST_EUROPE];
let opts = build_driver_options(
test_account(),
RoutingStrategy::PreferredRegions(input.clone()),
);
assert_eq!(opts.preferred_regions(), input.as_slice());
}
}