#[cfg(feature = "fault_injection")]
use azure_data_cosmos_driver::fault_injection::FaultInjectionRule;
use azure_data_cosmos_driver::{
diagnostics::{DiagnosticsContext, PipelineType, TransportSecurity},
driver::CosmosDriverRuntime,
models::{
AccountReference, ConnectionString, ContainerReference, CosmosOperation, CosmosResponse,
DatabaseReference, ItemReference, PartitionKey,
},
options::{ConnectionPoolOptions, EmulatorServerCertValidation, OperationOptions},
};
use std::{error::Error, future::Future, sync::Arc};
use uuid::Uuid;
use super::env::{
get_test_mode, is_azure_pipelines, CosmosTestMode, CONNECTION_STRING_ENV_VAR,
EMULATOR_CONNECTION_STRING,
};
pub struct DriverTestClient {
runtime: Arc<CosmosDriverRuntime>,
account: AccountReference,
}
pub struct TestEnv {
pub account: AccountReference,
pub connection_pool: ConnectionPoolOptions,
}
pub fn resolve_test_env() -> Result<Option<TestEnv>, Box<dyn Error>> {
let _ = tracing_subscriber::fmt::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
let test_mode = get_test_mode();
if test_mode == CosmosTestMode::Skipped {
return Ok(None);
}
let connection_string = match std::env::var(CONNECTION_STRING_ENV_VAR) {
Ok(val) if val.to_lowercase() == "emulator" => EMULATOR_CONNECTION_STRING.to_string(),
Ok(val) => val,
Err(_) => {
if test_mode == CosmosTestMode::Required || is_azure_pipelines() {
panic!(
"{} is not set but test mode is required",
CONNECTION_STRING_ENV_VAR
);
}
return Ok(None);
}
};
let conn_str: ConnectionString = connection_string.parse()?;
let endpoint = conn_str.account_endpoint().parse()?;
let key = conn_str.account_key().secret().to_string();
let account = AccountReference::with_master_key(endpoint, key);
let mut connection_pool_builder = ConnectionPoolOptions::builder();
if connection_string.eq_ignore_ascii_case(EMULATOR_CONNECTION_STRING) {
connection_pool_builder = connection_pool_builder
.with_emulator_server_cert_validation(EmulatorServerCertValidation::DangerousDisabled);
}
let connection_pool = connection_pool_builder.build()?;
Ok(Some(TestEnv {
account,
connection_pool,
}))
}
impl DriverTestClient {
pub async fn from_env() -> Result<Option<Self>, Box<dyn Error>> {
let Some(env) = resolve_test_env()? else {
return Ok(None);
};
let runtime = CosmosDriverRuntime::builder()
.with_connection_pool(env.connection_pool)
.build()
.await?;
Ok(Some(Self {
runtime,
account: env.account,
}))
}
#[cfg(feature = "fault_injection")]
pub async fn from_env_with_fault_injection(
rules: Vec<Arc<FaultInjectionRule>>,
) -> Result<Option<Self>, Box<dyn Error>> {
let Some(env) = resolve_test_env()? else {
return Ok(None);
};
let runtime = CosmosDriverRuntime::builder()
.with_connection_pool(env.connection_pool)
.with_fault_injection_rules(rules)
.build()
.await?;
Ok(Some(Self {
runtime,
account: env.account,
}))
}
pub async fn run<F, Fut>(f: F) -> Result<(), Box<dyn Error>>
where
F: FnOnce(DriverTestRunContext) -> Fut,
Fut: Future<Output = Result<(), Box<dyn Error>>>,
{
let Some(client) = Self::from_env().await? else {
println!("Skipping test: Cosmos DB environment not configured");
return Ok(());
};
let run_context = DriverTestRunContext::new(client);
f(run_context).await
}
#[cfg(feature = "fault_injection")]
pub async fn run_with_fault_injection<F, Fut>(
rules: Vec<Arc<FaultInjectionRule>>,
f: F,
) -> Result<(), Box<dyn Error>>
where
F: FnOnce(DriverTestRunContext) -> Fut,
Fut: Future<Output = Result<(), Box<dyn Error>>>,
{
let Some(client) = Self::from_env_with_fault_injection(rules).await? else {
println!("Skipping test: Cosmos DB environment not configured");
return Ok(());
};
let run_context = DriverTestRunContext::new(client);
f(run_context).await
}
#[cfg(feature = "fault_injection")]
pub async fn run_with_unique_db_and_fault_injection<F, Fut>(
rules: Vec<Arc<FaultInjectionRule>>,
f: F,
) -> Result<(), Box<dyn Error>>
where
F: FnOnce(DriverTestRunContext, DatabaseReference) -> Fut,
Fut: Future<Output = Result<(), Box<dyn Error>>>,
{
Self::run_with_fault_injection(rules, async |context| {
let db_name = context.unique_database_name();
let db_ref = context.create_database(&db_name).await?;
let result = f(context.clone(), db_ref.clone()).await;
let _ = context.delete_database(&db_ref).await;
result
})
.await
}
pub async fn run_with_unique_db<F, Fut>(f: F) -> Result<(), Box<dyn Error>>
where
F: FnOnce(DriverTestRunContext, DatabaseReference) -> Fut,
Fut: Future<Output = Result<(), Box<dyn Error>>>,
{
Self::run(async |context| {
let db_name = context.unique_database_name();
let db_ref = context.create_database(&db_name).await?;
let result = f(context.clone(), db_ref.clone()).await;
let _ = context.delete_database(&db_ref).await;
result
})
.await
}
}
#[derive(Clone)]
pub struct DriverTestRunContext {
client: Arc<DriverTestClient>,
run_id: String,
}
impl DriverTestRunContext {
fn new(client: DriverTestClient) -> Self {
Self {
client: Arc::new(client),
run_id: Uuid::new_v4().to_string()[..8].to_string(),
}
}
pub fn unique_database_name(&self) -> String {
format!("test-db-{}", self.run_id)
}
pub fn unique_container_name(&self) -> String {
let uuid_str = Uuid::new_v4().to_string();
format!("test-container-{}", &uuid_str[..8])
}
pub async fn create_database(
&self,
db_name: &str,
) -> Result<DatabaseReference, Box<dyn Error>> {
let driver = self
.client
.runtime
.get_or_create_driver(self.client.account.clone(), None)
.await?;
let body = format!(r#"{{"id": "{}"}}"#, db_name);
let operation = CosmosOperation::create_database(self.client.account.clone())
.with_body(body.into_bytes());
let result = driver
.execute_operation(operation, OperationOptions::default())
.await?;
let status = result.diagnostics().status();
if !status.map(|s| s.is_success()).unwrap_or(false) {
return Err(format!("Failed to create database, status: {:?}", status).into());
}
Ok(DatabaseReference::from_name(
self.client.account.clone(),
db_name.to_string(),
))
}
pub async fn delete_database(
&self,
database: &DatabaseReference,
) -> Result<(), Box<dyn Error>> {
let driver = self
.client
.runtime
.get_or_create_driver(self.client.account.clone(), None)
.await?;
let operation = CosmosOperation::delete_database(database.clone());
let result = driver
.execute_operation(operation, OperationOptions::default())
.await?;
let status = result.diagnostics().status();
if !status.map(|s| s.is_success()).unwrap_or(false) {
return Err(format!("Failed to delete database, status: {:?}", status).into());
}
Ok(())
}
pub async fn create_container(
&self,
database: &DatabaseReference,
container_name: &str,
partition_key_path: &str,
) -> Result<ContainerReference, Box<dyn Error>> {
let driver = self
.client
.runtime
.get_or_create_driver(self.client.account.clone(), None)
.await?;
let body = format!(
r#"{{"id": "{}", "partitionKey": {{"paths": ["{}"], "kind": "Hash", "version": 2}}}}"#,
container_name, partition_key_path
);
let operation =
CosmosOperation::create_container(database.clone()).with_body(body.into_bytes());
let result = driver
.execute_operation(operation, OperationOptions::default())
.await?;
let status = result.diagnostics().status();
if !status.map(|s| s.is_success()).unwrap_or(false) {
return Err(format!("Failed to create container, status: {:?}", status).into());
}
let db_name = database
.name()
.ok_or_else(|| "database reference must be name-based".to_string())?;
let container = driver
.resolve_container_by_name(db_name, container_name)
.await?;
Ok(container)
}
pub async fn create_item(
&self,
container: &ContainerReference,
item_id: &str,
partition_key: impl Into<PartitionKey>,
body: &[u8],
) -> Result<CosmosResponse, Box<dyn Error>> {
let driver = self
.client
.runtime
.get_or_create_driver(self.client.account.clone(), None)
.await?;
let pk = partition_key.into();
let item_ref = ItemReference::from_name(container, pk, item_id.to_owned());
let operation = CosmosOperation::create_item(item_ref).with_body(body.to_vec());
let result = driver
.execute_operation(operation, OperationOptions::default())
.await?;
Ok(result)
}
pub async fn read_item(
&self,
container: &ContainerReference,
item_id: &str,
partition_key: impl Into<PartitionKey>,
) -> Result<CosmosResponse, Box<dyn Error>> {
let driver = self
.client
.runtime
.get_or_create_driver(self.client.account.clone(), None)
.await?;
let pk = partition_key.into();
let item_ref = ItemReference::from_name(container, pk, item_id.to_owned());
let operation = CosmosOperation::read_item(item_ref);
let result = driver
.execute_operation(operation, OperationOptions::default())
.await?;
Ok(result)
}
pub fn validate_data_plane_diagnostics(
&self,
diagnostics: &DiagnosticsContext,
expected_status: u16,
) {
let status = diagnostics.status();
assert!(status.is_some(), "Diagnostics should have a status code");
assert_eq!(
u16::from(status.unwrap().status_code()),
expected_status,
"Status code should match expected"
);
assert!(
!diagnostics.activity_id().as_str().is_empty(),
"Activity ID should not be empty"
);
assert!(
!diagnostics.duration().is_zero(),
"Duration should be non-zero"
);
let requests = diagnostics.requests();
assert!(!requests.is_empty(), "Should have at least one request");
let first_request = &requests[0];
assert_eq!(
first_request.pipeline_type(),
PipelineType::DataPlane,
"Should use data plane pipeline for item operations"
);
if first_request.endpoint().contains("localhost") {
assert_eq!(
first_request.transport_security(),
TransportSecurity::EmulatorWithInsecureCertificates,
"Should use emulator transport security for localhost"
);
}
assert!(
first_request.request_charge().value() >= 0.0,
"Request charge should be non-negative"
);
}
}