use azure_core::http::ClientOptions;
use std::{
collections::HashMap,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, RwLock,
},
time::Duration,
};
use crate::{
diagnostics::ProxyConfiguration,
models::{AccountReference, ContainerReference, ThroughputControlGroupName, UserAgent},
options::{
parse_duration_millis_from_env, ConnectionPoolOptions, CorrelationId, DriverOptions,
OperationOptions, ThroughputControlGroupOptions, ThroughputControlGroupRegistry,
UserAgentSuffix, WorkloadId,
},
system::{CpuMemoryMonitor, VmMetadataService},
};
use super::cache::{AccountMetadataCache, ContainerCache};
use super::{
transport::{
http_client_factory::{DefaultHttpClientFactory, HttpClientFactory},
CosmosTransport,
},
CosmosDriver,
};
#[non_exhaustive]
#[derive(Debug)]
pub struct CosmosDriverRuntime {
id: usize,
client_options: ClientOptions,
connection_pool: ConnectionPoolOptions,
bootstrap_transport: Arc<CosmosTransport>,
http_client_factory: Arc<dyn HttpClientFactory>,
env_operation_options: Arc<OperationOptions>,
operation_options: RwLock<Arc<OperationOptions>>,
user_agent: UserAgent,
workload_id: Option<WorkloadId>,
correlation_id: Option<CorrelationId>,
user_agent_suffix: Option<UserAgentSuffix>,
throughput_control_groups: ThroughputControlGroupRegistry,
driver_registry: RwLock<HashMap<String, Arc<CosmosDriver>>>,
container_cache: ContainerCache,
account_metadata_cache: Arc<AccountMetadataCache>,
cpu_monitor: CpuMemoryMonitor,
machine_id: Arc<String>,
fault_injection_enabled: bool,
proxy_configuration: ProxyConfiguration,
}
impl CosmosDriverRuntime {
pub fn builder() -> CosmosDriverRuntimeBuilder {
CosmosDriverRuntimeBuilder::new()
}
#[expect(dead_code, reason = "will be used when tracing spans are re-added")]
pub(crate) fn id(&self) -> usize {
self.id
}
pub fn client_options(&self) -> &ClientOptions {
&self.client_options
}
pub fn connection_pool(&self) -> &ConnectionPoolOptions {
&self.connection_pool
}
pub(crate) fn bootstrap_transport(&self) -> &Arc<CosmosTransport> {
&self.bootstrap_transport
}
pub(crate) fn http_client_factory(&self) -> &Arc<dyn HttpClientFactory> {
&self.http_client_factory
}
pub(crate) fn container_cache(&self) -> &ContainerCache {
&self.container_cache
}
pub(crate) fn account_metadata_cache(&self) -> &Arc<AccountMetadataCache> {
&self.account_metadata_cache
}
pub(crate) fn cpu_monitor(&self) -> &CpuMemoryMonitor {
&self.cpu_monitor
}
pub(crate) fn machine_id(&self) -> &Arc<String> {
&self.machine_id
}
pub(crate) fn fault_injection_enabled(&self) -> bool {
self.fault_injection_enabled
}
pub fn proxy_configuration(&self) -> &ProxyConfiguration {
&self.proxy_configuration
}
pub fn env_operation_options(&self) -> &Arc<OperationOptions> {
&self.env_operation_options
}
pub fn operation_options(&self) -> Arc<OperationOptions> {
self.operation_options
.read()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
pub fn set_operation_options(&self, options: OperationOptions) {
*self
.operation_options
.write()
.unwrap_or_else(|e| e.into_inner()) = Arc::new(options);
}
pub fn user_agent(&self) -> &UserAgent {
&self.user_agent
}
pub fn workload_id(&self) -> Option<WorkloadId> {
self.workload_id
}
pub fn correlation_id(&self) -> Option<&CorrelationId> {
self.correlation_id.as_ref()
}
pub fn user_agent_suffix(&self) -> Option<&UserAgentSuffix> {
self.user_agent_suffix.as_ref()
}
pub fn effective_correlation(&self) -> Option<&str> {
self.correlation_id
.as_ref()
.map(|c| c.as_str())
.or_else(|| self.user_agent_suffix.as_ref().map(|s| s.as_str()))
}
pub(crate) fn get_throughput_control_group(
&self,
container: &ContainerReference,
name: &ThroughputControlGroupName,
) -> Option<&Arc<ThroughputControlGroupOptions>> {
self.throughput_control_groups
.get_by_container_and_name(container, name)
}
pub(crate) fn get_default_throughput_control_group(
&self,
container: &ContainerReference,
) -> Option<&Arc<ThroughputControlGroupOptions>> {
self.throughput_control_groups
.get_default_for_container(container)
}
pub async fn get_or_create_driver(
self: &Arc<Self>,
account: AccountReference,
driver_options: Option<DriverOptions>,
) -> azure_core::Result<Arc<CosmosDriver>> {
let key = account.endpoint().to_string();
{
let registry = self.driver_registry.read().unwrap();
if let Some(driver) = registry.get(&key) {
tracing::trace!("retrieved existing driver");
return Ok(driver.clone());
}
}
tracing::trace!("creating new driver");
let options = driver_options.unwrap_or_else(|| DriverOptions::builder(account).build());
let driver = Arc::new(CosmosDriver::new(Arc::clone(self), options));
driver.initialize().await?;
let mut registry = self.driver_registry.write().unwrap();
let entry = registry.entry(key).or_insert_with(|| driver.clone());
Ok(entry.clone())
}
}
#[non_exhaustive]
#[derive(Clone, Debug, Default)]
pub struct CosmosDriverRuntimeBuilder {
client_options: Option<ClientOptions>,
connection_pool: Option<ConnectionPoolOptions>,
operation_options: Option<OperationOptions>,
workload_id: Option<WorkloadId>,
correlation_id: Option<CorrelationId>,
user_agent_suffix: Option<UserAgentSuffix>,
throughput_control_groups: ThroughputControlGroupRegistry,
cpu_refresh_interval: Option<Duration>,
#[cfg(feature = "fault_injection")]
fault_injection_rules: Option<Vec<std::sync::Arc<crate::fault_injection::FaultInjectionRule>>>,
#[cfg(any(test, feature = "__internal_mocking"))]
http_client_factory: Option<Arc<dyn HttpClientFactory>>,
}
impl CosmosDriverRuntimeBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_client_options(mut self, options: ClientOptions) -> Self {
self.client_options = Some(options);
self
}
pub fn with_connection_pool(mut self, options: ConnectionPoolOptions) -> Self {
self.connection_pool = Some(options);
self
}
pub fn with_operation_options(mut self, options: OperationOptions) -> Self {
self.operation_options = Some(options);
self
}
pub fn with_workload_id(mut self, workload_id: WorkloadId) -> Self {
self.workload_id = Some(workload_id);
self
}
pub fn with_correlation_id(mut self, correlation_id: CorrelationId) -> Self {
self.correlation_id = Some(correlation_id);
self
}
pub fn with_user_agent_suffix(mut self, suffix: UserAgentSuffix) -> Self {
self.user_agent_suffix = Some(suffix);
self
}
pub fn with_cpu_refresh_interval(mut self, interval: Duration) -> Self {
self.cpu_refresh_interval = Some(interval);
self
}
#[cfg(test)]
pub(crate) fn with_http_client_factory(mut self, factory: Arc<dyn HttpClientFactory>) -> Self {
self.http_client_factory = Some(factory);
self
}
#[cfg(feature = "__internal_mocking")]
pub fn with_mock_http_client_factory(mut self, factory: Arc<dyn HttpClientFactory>) -> Self {
self.http_client_factory = Some(factory);
self
}
pub fn register_throughput_control_group(
mut self,
group: ThroughputControlGroupOptions,
) -> azure_core::Result<Self> {
self.throughput_control_groups
.register(group)
.map_err(|e| {
azure_core::Error::with_message(azure_core::error::ErrorKind::Other, e.to_string())
})?;
Ok(self)
}
#[cfg(feature = "fault_injection")]
pub fn with_fault_injection_rules(
mut self,
rules: Vec<std::sync::Arc<crate::fault_injection::FaultInjectionRule>>,
) -> Self {
self.fault_injection_rules = Some(rules);
self
}
pub async fn build(self) -> azure_core::Result<Arc<CosmosDriverRuntime>> {
let user_agent = if let Some(ref suffix) = self.user_agent_suffix {
UserAgent::from_suffix(suffix)
} else if let Some(workload_id) = self.workload_id {
UserAgent::from_workload_id(workload_id)
} else if let Some(ref correlation_id) = self.correlation_id {
UserAgent::from_correlation_id(correlation_id)
} else {
UserAgent::default()
};
let connection_pool = self.connection_pool.unwrap_or_default();
let proxy_configuration = ProxyConfiguration::from_env(connection_pool.proxy_allowed());
#[allow(unused_mut)]
let mut fault_injection_enabled = false;
let http_client_factory: Arc<dyn HttpClientFactory> = {
let base_factory: Arc<dyn HttpClientFactory> = {
#[cfg(any(test, feature = "__internal_mocking"))]
{
self.http_client_factory
.unwrap_or_else(|| Arc::new(DefaultHttpClientFactory::new()))
}
#[cfg(not(any(test, feature = "__internal_mocking")))]
{
Arc::new(DefaultHttpClientFactory::new())
}
};
#[cfg(feature = "fault_injection")]
{
if let Some(rules) = self.fault_injection_rules {
fault_injection_enabled = true;
Arc::new(
crate::fault_injection::FaultInjectingHttpClientFactory::new(
base_factory,
rules,
),
)
} else {
base_factory
}
}
#[cfg(not(feature = "fault_injection"))]
{
base_factory
}
};
let bootstrap_version = if connection_pool.is_http2_allowed() {
crate::diagnostics::TransportHttpVersion::Http2
} else {
crate::diagnostics::TransportHttpVersion::Http11
};
let bootstrap_transport = Arc::new(CosmosTransport::bootstrap_metadata_only(
connection_pool.clone(),
http_client_factory.clone(),
bootstrap_version,
)?);
let refresh_interval = parse_duration_millis_from_env(
self.cpu_refresh_interval,
"AZURE_COSMOS_CPU_REFRESH_INTERVAL_MS",
5_000,
1_000,
60_000,
)?;
let cpu_monitor = CpuMemoryMonitor::get_or_init(refresh_interval);
let vm_metadata = VmMetadataService::get_or_init().await;
Ok(Arc::new(CosmosDriverRuntime {
id: NEXT_RUNTIME_ID.fetch_add(1, Ordering::Relaxed),
client_options: self.client_options.unwrap_or_default(),
connection_pool,
bootstrap_transport,
http_client_factory,
env_operation_options: Arc::new(OperationOptions::from_env()),
operation_options: RwLock::new(Arc::new(self.operation_options.unwrap_or_default())),
user_agent,
workload_id: self.workload_id,
correlation_id: self.correlation_id,
user_agent_suffix: self.user_agent_suffix,
throughput_control_groups: self.throughput_control_groups,
driver_registry: RwLock::new(HashMap::new()),
container_cache: ContainerCache::new(),
account_metadata_cache: Arc::new(AccountMetadataCache::new()),
cpu_monitor,
machine_id: Arc::new(vm_metadata.machine_id().to_owned()),
fault_injection_enabled,
proxy_configuration,
}))
}
}
static NEXT_RUNTIME_ID: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
mod tests {
use super::*;
use url::Url;
#[tokio::test]
async fn get_or_create_driver_removes_failed_initialization_from_registry() {
let runtime = CosmosDriverRuntimeBuilder::new().build().await.unwrap();
let account = AccountReference::with_master_key(
Url::parse("https://test.documents.azure.com:443/").unwrap(),
"***not-base64***",
);
let error = runtime
.get_or_create_driver(account.clone(), None)
.await
.expect_err("invalid signing key should fail initialization");
assert!(!error.to_string().is_empty());
assert!(runtime.driver_registry.read().unwrap().is_empty());
let second_error = runtime
.get_or_create_driver(account, None)
.await
.expect_err("failed initialization should not poison the driver registry");
assert!(!second_error.to_string().is_empty());
assert!(runtime.driver_registry.read().unwrap().is_empty());
}
}