use crate::{
pipeline::{AuthorizationPolicy, CosmosHeadersPolicy, GatewayPipeline},
resource_context::{ResourceLink, ResourceType},
CosmosAccountReference, CosmosClient, CosmosClientOptions, CosmosCredential, RoutingStrategy,
};
#[cfg(feature = "allow_invalid_certificates")]
use azure_data_cosmos_driver::options::{ConnectionPoolOptions, EmulatorServerCertValidation};
use azure_data_cosmos_driver::CosmosDriverRuntimeBuilder;
use std::sync::Arc;
#[cfg(all(not(target_arch = "wasm32"), feature = "reqwest"))]
use crate::constants::{
AZURE_COSMOS_PER_PARTITION_CIRCUIT_BREAKER_ENABLED, COSMOS_ALLOWED_HEADERS,
DEFAULT_CONNECTION_TIMEOUT, DEFAULT_MAX_CONNECTION_POOL_SIZE, DEFAULT_REQUEST_TIMEOUT,
};
use crate::models::AccountProperties;
use crate::routing::global_endpoint_manager::GlobalEndpointManager;
use crate::routing::global_partition_endpoint_manager::GlobalPartitionEndpointManager;
use azure_core::http::{ClientOptions, LoggingOptions, RetryOptions};
#[derive(Default)]
pub struct CosmosClientBuilder {
options: CosmosClientOptions,
allow_proxy: bool,
#[cfg(feature = "allow_invalid_certificates")]
allow_emulator_invalid_certificates: bool,
#[cfg(feature = "fault_injection")]
fault_injection_builder: Option<crate::fault_injection::FaultInjectionClientBuilder>,
}
impl CosmosClientBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_user_agent_suffix(mut self, suffix: impl Into<String>) -> Self {
self.options.user_agent_suffix = Some(suffix.into());
self
}
#[doc(hidden)]
#[cfg(feature = "fault_injection")]
pub fn with_fault_injection(
mut self,
builder: crate::fault_injection::FaultInjectionClientBuilder,
) -> Self {
self.fault_injection_builder = Some(builder);
self
}
#[doc(hidden)]
#[cfg(feature = "allow_invalid_certificates")]
pub fn with_allow_emulator_invalid_certificates(mut self, allow: bool) -> Self {
self.allow_emulator_invalid_certificates = allow;
self
}
pub fn with_proxy_allowed(mut self, allow: bool) -> Self {
self.allow_proxy = allow;
self
}
pub async fn build(
mut self,
account: impl Into<CosmosAccountReference>,
routing_strategy: RoutingStrategy,
) -> azure_core::Result<CosmosClient> {
match routing_strategy {
RoutingStrategy::ProximityTo(region) => {
self.options.application_region = Some(region);
}
}
let (account_endpoint, credential) = account.into().into_parts();
let endpoint = account_endpoint.into_url();
let driver_credential = credential.clone();
#[cfg(feature = "fault_injection")]
let fault_injection_enabled = self.fault_injection_builder.is_some();
#[cfg(not(feature = "fault_injection"))]
let fault_injection_enabled = false;
#[cfg(all(not(target_arch = "wasm32"), feature = "reqwest"))]
let base_client: Option<Arc<dyn azure_core::http::HttpClient>> = {
#[allow(unused_mut)]
let mut builder = reqwest::ClientBuilder::new()
.http1_only()
.pool_max_idle_per_host(DEFAULT_MAX_CONNECTION_POOL_SIZE)
.connect_timeout(DEFAULT_CONNECTION_TIMEOUT)
.timeout(DEFAULT_REQUEST_TIMEOUT);
if self.allow_proxy {
tracing::warn!(
"Proxy usage is enabled. Azure Cosmos DB does not provide end-to-end SLAs \
when a proxy is in use. Full backend support is provided, but client/proxy \
interactions are supported on a best-effort basis only."
);
} else {
builder = builder.no_proxy();
}
#[cfg(feature = "allow_invalid_certificates")]
if self.allow_emulator_invalid_certificates {
builder = builder.danger_accept_invalid_certs(true);
}
let client = builder
.build()
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))?;
Some(Arc::new(client))
};
#[cfg(not(all(not(target_arch = "wasm32"), feature = "reqwest")))]
let base_client: Option<Arc<dyn azure_core::http::HttpClient>> = None;
#[cfg(feature = "fault_injection")]
let (transport, driver_fi_rules): (
Option<azure_core::http::Transport>,
Vec<std::sync::Arc<azure_data_cosmos_driver::fault_injection::FaultInjectionRule>>,
) = if let Some(fault_builder) = self.fault_injection_builder {
let driver_rules =
crate::driver_bridge::sdk_fi_rules_to_driver_fi_rules(fault_builder.rules());
let fault_builder = match base_client {
Some(client) => fault_builder.with_inner_client(client),
None => fault_builder,
};
(Some(fault_builder.build()), driver_rules)
} else {
(
base_client.map(azure_core::http::Transport::new),
Vec::new(),
)
};
#[cfg(not(feature = "fault_injection"))]
let transport: Option<azure_core::http::Transport> =
base_client.map(azure_core::http::Transport::new);
let client_options = ClientOptions {
retry: RetryOptions::none(),
logging: LoggingOptions {
additional_allowed_header_names: COSMOS_ALLOWED_HEADERS
.iter()
.map(|h| std::borrow::Cow::Borrowed(h.as_str()))
.collect(),
additional_allowed_query_params: vec![],
},
transport,
..Default::default()
};
let auth_policy: Arc<AuthorizationPolicy> = match credential {
CosmosCredential::TokenCredential(cred) => {
Arc::new(AuthorizationPolicy::from_token_credential(cred))
}
#[cfg(feature = "key_auth")]
CosmosCredential::MasterKey(key) => Arc::new(AuthorizationPolicy::from_shared_key(key)),
};
let crate_version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
let cosmos_headers_policy: Arc<dyn azure_core::http::policies::Policy> = Arc::new(
CosmosHeadersPolicy::new(crate_version, self.options.user_agent_suffix.as_deref()),
);
let pipeline_core = azure_core::http::Pipeline::new(
option_env!("CARGO_PKG_NAME"),
option_env!("CARGO_PKG_VERSION"),
client_options,
vec![cosmos_headers_policy],
vec![auth_policy],
None,
);
let preferred_regions = if let Some(ref region) = self.options.application_region {
crate::region_proximity::generate_preferred_region_list(region)
.map(|s| s.to_vec())
.unwrap_or_else(|| {
tracing::warn!(
region = %region,
"unrecognized application region; falling back to account-defined region order"
);
Vec::new()
})
} else {
Vec::new()
};
let global_endpoint_manager = GlobalEndpointManager::new(
endpoint.clone(),
preferred_regions,
Vec::new(),
pipeline_core.clone(),
);
let enable_partition_level_circuit_breaker =
std::env::var(AZURE_COSMOS_PER_PARTITION_CIRCUIT_BREAKER_ENABLED)
.ok()
.and_then(|v| v.parse::<bool>().ok())
.unwrap_or(true);
let global_partition_endpoint_manager: Arc<GlobalPartitionEndpointManager> =
GlobalPartitionEndpointManager::new(
global_endpoint_manager.clone(),
false,
enable_partition_level_circuit_breaker,
);
let partition_manager_clone = Arc::clone(&global_partition_endpoint_manager);
global_endpoint_manager.set_on_account_refresh_callback(Arc::new(
move |account_props: &AccountProperties| {
partition_manager_clone.configure_partition_level_automatic_failover(
account_props.enable_per_partition_failover_behavior,
);
partition_manager_clone.configure_per_partition_circuit_breaker(
account_props.enable_per_partition_failover_behavior
|| enable_partition_level_circuit_breaker,
);
},
));
let pipeline = Arc::new(GatewayPipeline::new(
endpoint.clone(),
pipeline_core,
global_endpoint_manager.clone(),
global_partition_endpoint_manager.clone(),
self.options,
fault_injection_enabled,
));
let driver_account = build_driver_account(endpoint, driver_credential);
#[allow(unused_mut)]
let mut driver_runtime_builder = CosmosDriverRuntimeBuilder::new();
#[cfg(feature = "allow_invalid_certificates")]
if self.allow_emulator_invalid_certificates {
let connection_pool = ConnectionPoolOptions::builder()
.with_emulator_server_cert_validation(
EmulatorServerCertValidation::DangerousDisabled,
)
.build()?;
driver_runtime_builder = driver_runtime_builder.with_connection_pool(connection_pool);
}
#[cfg(feature = "fault_injection")]
if !driver_fi_rules.is_empty() {
driver_runtime_builder =
driver_runtime_builder.with_fault_injection_rules(driver_fi_rules);
}
let driver_runtime = driver_runtime_builder.build().await?;
let driver = driver_runtime
.get_or_create_driver(driver_account, None)
.await?;
Ok(CosmosClient {
databases_link: ResourceLink::root(ResourceType::Databases),
pipeline,
driver,
global_endpoint_manager,
global_partition_endpoint_manager,
})
}
}
fn build_driver_account(
endpoint: azure_core::http::Url,
credential: CosmosCredential,
) -> azure_data_cosmos_driver::models::AccountReference {
match credential {
CosmosCredential::TokenCredential(tc) => {
azure_data_cosmos_driver::models::AccountReference::with_credential(endpoint, tc)
}
#[cfg(feature = "key_auth")]
CosmosCredential::MasterKey(key) => {
azure_data_cosmos_driver::models::AccountReference::with_master_key(endpoint, key)
}
}
}