pub(crate) mod adaptive_transport;
mod authorization_policy;
#[cfg(feature = "tokio")]
pub(crate) mod background_task_manager;
pub(crate) mod cosmos_headers;
pub(crate) mod cosmos_transport_client;
mod emulator;
pub(crate) mod http_client_factory;
pub(crate) mod request_signing;
#[cfg(feature = "reqwest")]
pub(crate) mod reqwest_transport_client;
mod sharded_transport;
pub(crate) use sharded_transport::EndpointKey;
mod tracked_transport;
pub(crate) mod transport_pipeline;
use crate::{
driver::pipeline::components::TransportMode,
models::{AccountEndpoint, OperationType, ResourceType},
options::ConnectionPoolOptions,
};
use std::sync::{Arc, OnceLock};
use self::{
adaptive_transport::AdaptiveTransport,
http_client_factory::{HttpClientConfig, HttpClientFactory},
};
use crate::diagnostics::TransportHttpVersion;
#[cfg(test)]
use self::http_client_factory::DefaultHttpClientFactory;
pub(crate) use authorization_policy::generate_authorization;
pub(crate) use authorization_policy::AuthorizationContext;
pub(crate) use emulator::is_emulator_host;
pub(crate) use tracked_transport::infer_request_sent_status;
pub(crate) const COSMOS_API_VERSION: &str = "2020-07-15";
pub(crate) fn uses_dataplane_pipeline(
resource_type: ResourceType,
operation_type: OperationType,
) -> bool {
match resource_type {
ResourceType::Document => true,
ResourceType::StoredProcedure => matches!(operation_type, OperationType::Execute),
_ => false,
}
}
#[derive(Debug)]
pub(crate) struct CosmosTransport {
connection_pool: ConnectionPoolOptions,
http_client_factory: Arc<dyn HttpClientFactory>,
negotiated_version: TransportHttpVersion,
metadata_transport: AdaptiveTransport,
dataplane_gateway_transport: AdaptiveTransport,
dataplane_gateway20_transport: OnceLock<AdaptiveTransport>,
insecure_emulator_metadata_transport: OnceLock<AdaptiveTransport>,
insecure_emulator_dataplane_transport: OnceLock<AdaptiveTransport>,
}
impl CosmosTransport {
#[cfg(test)]
pub(crate) fn for_tests(
connection_pool: ConnectionPoolOptions,
negotiated_version: TransportHttpVersion,
) -> azure_core::Result<Self> {
let http_client_factory: Arc<dyn HttpClientFactory> =
Arc::new(DefaultHttpClientFactory::new());
Self::with_factory(connection_pool, http_client_factory, negotiated_version)
}
pub(crate) fn with_factory(
connection_pool: ConnectionPoolOptions,
http_client_factory: Arc<dyn HttpClientFactory>,
negotiated_version: TransportHttpVersion,
) -> azure_core::Result<Self> {
let metadata_config = HttpClientConfig::metadata(&connection_pool, negotiated_version);
let metadata_transport = AdaptiveTransport::from_config(
&connection_pool,
http_client_factory.clone(),
metadata_config,
)?;
let gateway_config =
HttpClientConfig::dataplane_gateway(&connection_pool, negotiated_version);
let dataplane_gateway_transport = AdaptiveTransport::from_config(
&connection_pool,
http_client_factory.clone(),
gateway_config,
)?;
Ok(Self {
connection_pool,
http_client_factory,
negotiated_version,
metadata_transport,
dataplane_gateway_transport,
dataplane_gateway20_transport: OnceLock::new(),
insecure_emulator_metadata_transport: OnceLock::new(),
insecure_emulator_dataplane_transport: OnceLock::new(),
})
}
pub(crate) fn bootstrap_metadata_only(
connection_pool: ConnectionPoolOptions,
http_client_factory: Arc<dyn HttpClientFactory>,
negotiated_version: TransportHttpVersion,
) -> azure_core::Result<Self> {
let metadata_config = HttpClientConfig::metadata(&connection_pool, negotiated_version);
let metadata_transport = AdaptiveTransport::unsharded(
&connection_pool,
http_client_factory.clone(),
metadata_config,
)?;
let gateway_config =
HttpClientConfig::dataplane_gateway(&connection_pool, negotiated_version);
let dataplane_gateway_transport = AdaptiveTransport::unsharded(
&connection_pool,
http_client_factory.clone(),
gateway_config,
)?;
Ok(Self {
connection_pool,
http_client_factory,
negotiated_version,
metadata_transport,
dataplane_gateway_transport,
dataplane_gateway20_transport: OnceLock::new(),
insecure_emulator_metadata_transport: OnceLock::new(),
insecure_emulator_dataplane_transport: OnceLock::new(),
})
}
pub(crate) fn negotiated_version(&self) -> TransportHttpVersion {
self.negotiated_version
}
fn should_use_insecure_emulator_transport(&self, endpoint: &AccountEndpoint) -> bool {
bool::from(self.connection_pool.emulator_server_cert_validation())
&& is_emulator_host(endpoint)
}
pub(crate) fn get_metadata_transport(
&self,
endpoint: &AccountEndpoint,
) -> azure_core::Result<AdaptiveTransport> {
let transport = if self.should_use_insecure_emulator_transport(endpoint) {
match self.insecure_emulator_metadata_transport.get() {
Some(t) => t.clone(),
None => {
let config =
HttpClientConfig::metadata(&self.connection_pool, self.negotiated_version)
.with_allow_invalid_cert();
let t = AdaptiveTransport::from_config(
&self.connection_pool,
self.http_client_factory.clone(),
config,
)?;
self.insecure_emulator_metadata_transport
.get_or_init(|| t)
.clone()
}
}
} else {
self.metadata_transport.clone()
};
Ok(transport)
}
pub(crate) fn get_dataplane_transport(
&self,
endpoint: &AccountEndpoint,
transport_mode: TransportMode,
) -> azure_core::Result<AdaptiveTransport> {
if self.should_use_insecure_emulator_transport(endpoint) {
let transport = match self.insecure_emulator_dataplane_transport.get() {
Some(t) => t.clone(),
None => {
let config = HttpClientConfig::dataplane_gateway(
&self.connection_pool,
self.negotiated_version,
)
.with_allow_invalid_cert();
let t = AdaptiveTransport::from_config(
&self.connection_pool,
self.http_client_factory.clone(),
config,
)?;
self.insecure_emulator_dataplane_transport
.get_or_init(|| t)
.clone()
}
};
return Ok(transport);
}
match transport_mode {
TransportMode::Gateway20 if self.connection_pool.is_gateway20_allowed() => {
let transport = match self.dataplane_gateway20_transport.get() {
Some(t) => t.clone(),
None => {
let config = HttpClientConfig::dataplane_gateway20(&self.connection_pool);
let t = AdaptiveTransport::gateway20(
&self.connection_pool,
self.http_client_factory.clone(),
config,
);
self.dataplane_gateway20_transport.get_or_init(|| t).clone()
}
};
Ok(transport)
}
_ => Ok(self.dataplane_gateway_transport.clone()),
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::diagnostics::TransportHttpVersion;
use crate::driver::pipeline::components::TransportMode;
use crate::options::{ConnectionPoolOptionsBuilder, EmulatorServerCertValidation};
#[test]
fn transport_creates_with_http2() {
let pool = ConnectionPoolOptionsBuilder::new().build().unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap();
let endpoint =
AccountEndpoint::try_from("https://myaccount.documents.azure.com:443/").unwrap();
assert!(!transport.should_use_insecure_emulator_transport(&endpoint));
}
#[test]
fn transport_detects_emulator_when_disabled() {
let pool = ConnectionPoolOptionsBuilder::new()
.with_emulator_server_cert_validation(EmulatorServerCertValidation::DangerousDisabled)
.build()
.unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap();
let endpoint = AccountEndpoint::try_from("https://localhost:8081/").unwrap();
assert!(transport.should_use_insecure_emulator_transport(&endpoint));
let endpoint = AccountEndpoint::try_from("https://127.0.0.1:8081/").unwrap();
assert!(transport.should_use_insecure_emulator_transport(&endpoint));
let endpoint =
AccountEndpoint::try_from("https://myaccount.documents.azure.com:443/").unwrap();
assert!(!transport.should_use_insecure_emulator_transport(&endpoint));
}
#[test]
fn transport_ignores_emulator_hosts_when_validation_enabled() {
let pool = ConnectionPoolOptionsBuilder::new().build().unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap();
let endpoint = AccountEndpoint::try_from("https://localhost:8081/").unwrap();
assert!(!transport.should_use_insecure_emulator_transport(&endpoint));
}
#[test]
fn metadata_transport_is_sharded_when_http2_negotiated() {
let pool = ConnectionPoolOptionsBuilder::new().build().unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap();
let endpoint =
AccountEndpoint::try_from("https://myaccount.documents.azure.com:443/").unwrap();
assert!(matches!(
transport.get_metadata_transport(&endpoint).unwrap(),
AdaptiveTransport::ShardedGateway(_)
));
}
#[test]
fn metadata_transport_is_unsharded_when_http11_negotiated() {
let pool = ConnectionPoolOptionsBuilder::new().build().unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http11).unwrap();
let endpoint =
AccountEndpoint::try_from("https://myaccount.documents.azure.com:443/").unwrap();
assert!(matches!(
transport.get_metadata_transport(&endpoint).unwrap(),
AdaptiveTransport::Gateway(_)
));
}
#[test]
fn dataplane_transport_is_unsharded_when_http11_negotiated() {
let pool = ConnectionPoolOptionsBuilder::new().build().unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http11).unwrap();
let endpoint =
AccountEndpoint::try_from("https://myaccount.documents.azure.com:443/").unwrap();
let ctx = transport
.get_dataplane_transport(&endpoint, TransportMode::Gateway)
.unwrap();
assert!(matches!(ctx, AdaptiveTransport::Gateway(_)));
}
#[test]
fn dataplane_transport_uses_gateway20_when_selected() {
let pool = ConnectionPoolOptionsBuilder::new()
.with_is_gateway20_allowed(true)
.build()
.unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap();
let endpoint =
AccountEndpoint::try_from("https://myaccount.documents.azure.com:443/").unwrap();
let ctx = transport
.get_dataplane_transport(&endpoint, TransportMode::Gateway20)
.unwrap();
assert!(matches!(ctx, AdaptiveTransport::ShardedGateway20(_)));
}
#[test]
fn dataplane_transport_falls_back_to_sharded_gateway_when_endpoint_is_standard() {
let pool = ConnectionPoolOptionsBuilder::new()
.with_is_gateway20_allowed(true)
.build()
.unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap();
let endpoint =
AccountEndpoint::try_from("https://myaccount.documents.azure.com:443/").unwrap();
let ctx = transport
.get_dataplane_transport(&endpoint, TransportMode::Gateway)
.unwrap();
assert!(matches!(ctx, AdaptiveTransport::ShardedGateway(_)));
}
#[test]
fn dataplane_transport_ignores_gateway20_when_gateway20_disabled() {
let pool = ConnectionPoolOptionsBuilder::new()
.with_is_gateway20_allowed(false)
.build()
.unwrap();
let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap();
let endpoint =
AccountEndpoint::try_from("https://myaccount.documents.azure.com:443/").unwrap();
let ctx = transport
.get_dataplane_transport(&endpoint, TransportMode::Gateway20)
.unwrap();
assert!(matches!(ctx, AdaptiveTransport::ShardedGateway(_)));
}
#[test]
fn uses_dataplane_for_document_operations() {
assert!(uses_dataplane_pipeline(
ResourceType::Document,
OperationType::Read
));
assert!(uses_dataplane_pipeline(
ResourceType::Document,
OperationType::Create
));
assert!(uses_dataplane_pipeline(
ResourceType::Document,
OperationType::Replace
));
assert!(uses_dataplane_pipeline(
ResourceType::Document,
OperationType::Delete
));
assert!(uses_dataplane_pipeline(
ResourceType::Document,
OperationType::Upsert
));
}
#[test]
fn uses_dataplane_for_stored_procedure_execute() {
assert!(uses_dataplane_pipeline(
ResourceType::StoredProcedure,
OperationType::Execute
));
assert!(!uses_dataplane_pipeline(
ResourceType::StoredProcedure,
OperationType::Read
));
assert!(!uses_dataplane_pipeline(
ResourceType::StoredProcedure,
OperationType::Create
));
assert!(!uses_dataplane_pipeline(
ResourceType::StoredProcedure,
OperationType::Delete
));
}
#[test]
fn uses_metadata_for_other_resources() {
assert!(!uses_dataplane_pipeline(
ResourceType::Database,
OperationType::Read
));
assert!(!uses_dataplane_pipeline(
ResourceType::Database,
OperationType::Create
));
assert!(!uses_dataplane_pipeline(
ResourceType::Database,
OperationType::Delete
));
assert!(!uses_dataplane_pipeline(
ResourceType::DocumentCollection,
OperationType::Read
));
assert!(!uses_dataplane_pipeline(
ResourceType::DocumentCollection,
OperationType::Create
));
assert!(!uses_dataplane_pipeline(
ResourceType::DocumentCollection,
OperationType::Delete
));
assert!(!uses_dataplane_pipeline(
ResourceType::DatabaseAccount,
OperationType::Read
));
assert!(!uses_dataplane_pipeline(
ResourceType::Trigger,
OperationType::Read
));
assert!(!uses_dataplane_pipeline(
ResourceType::UserDefinedFunction,
OperationType::Create
));
assert!(!uses_dataplane_pipeline(
ResourceType::Offer,
OperationType::Read
));
assert!(!uses_dataplane_pipeline(
ResourceType::Offer,
OperationType::Replace
));
}
}