use super::AsyncCache;
use crate::models::{AccountEndpoint, DefaultConsistencyLevel};
use crate::options::Region;
use serde::Deserialize;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
#[allow(dead_code)]
pub(crate) struct AccountRegion {
pub name: Region,
pub database_account_endpoint: AccountEndpoint,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
#[allow(dead_code)]
pub(crate) struct ReplicationPolicy {
pub min_replica_set_size: i32,
#[serde(rename = "maxReplicasetSize")]
pub max_replica_set_size: i32,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
#[allow(dead_code)]
pub(crate) struct ConsistencyPolicy {
pub default_consistency_level: DefaultConsistencyLevel,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
#[allow(dead_code)]
pub(crate) struct ReadPolicy {
pub primary_read_coefficient: i32,
pub secondary_read_coefficient: i32,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
#[allow(dead_code)]
pub(crate) struct AccountProperties {
#[serde(rename = "_self")]
pub self_link: String,
pub id: String,
#[serde(rename = "_rid")]
pub rid: String,
pub media: String,
pub addresses: String,
#[serde(rename = "_dbs")]
pub dbs: String,
pub writable_locations: Vec<AccountRegion>,
pub readable_locations: Vec<AccountRegion>,
pub enable_multiple_write_locations: bool,
#[serde(default)]
pub continuous_backup_enabled: bool,
#[serde(default)]
pub enable_n_region_synchronous_commit: bool,
#[serde(default)]
pub enable_per_partition_failover_behavior: bool,
pub user_replication_policy: ReplicationPolicy,
pub user_consistency_policy: ConsistencyPolicy,
pub system_replication_policy: ReplicationPolicy,
pub read_policy: ReadPolicy,
pub query_engine_configuration: String,
#[serde(default)]
pub thin_client_writable_locations: Vec<AccountRegion>,
#[serde(default)]
pub thin_client_readable_locations: Vec<AccountRegion>,
#[serde(rename = "_etag", default)]
pub etag: String,
}
#[allow(dead_code)]
impl AccountProperties {
pub(crate) fn write_account_region(&self) -> Option<&AccountRegion> {
self.writable_locations.first()
}
pub(crate) fn write_region(&self) -> Option<Region> {
self.writable_locations.first().map(|loc| loc.name.clone())
}
pub(crate) fn readable_regions(&self) -> Vec<Region> {
self.readable_locations
.iter()
.map(|loc| loc.name.clone())
.collect()
}
pub(crate) fn has_thin_client_endpoints(&self) -> bool {
!self.thin_client_writable_locations.is_empty()
|| !self.thin_client_readable_locations.is_empty()
}
pub(crate) fn thin_client_writable_regions(&self) -> Vec<Region> {
self.thin_client_writable_locations
.iter()
.map(|loc| loc.name.clone())
.collect()
}
pub(crate) fn thin_client_readable_regions(&self) -> Vec<Region> {
self.thin_client_readable_locations
.iter()
.map(|loc| loc.name.clone())
.collect()
}
}
const DEFAULT_STALENESS_THRESHOLD: Duration = Duration::from_secs(600);
#[derive(Debug)]
pub(crate) struct AccountMetadataCache {
cache: AsyncCache<AccountEndpoint, AccountProperties>,
last_refresh: async_lock::RwLock<std::collections::HashMap<AccountEndpoint, Instant>>,
#[allow(dead_code)] staleness_threshold: Duration,
#[allow(dead_code)] refresh_mutex: async_lock::Mutex<()>,
}
impl AccountMetadataCache {
pub(crate) fn new() -> Self {
Self {
cache: AsyncCache::new(),
last_refresh: async_lock::RwLock::new(std::collections::HashMap::new()),
staleness_threshold: DEFAULT_STALENESS_THRESHOLD,
refresh_mutex: async_lock::Mutex::new(()),
}
}
pub(crate) async fn get_or_fetch<F, Fut>(
&self,
endpoint: AccountEndpoint,
fetch_fn: F,
) -> azure_core::Result<Arc<AccountProperties>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = azure_core::Result<AccountProperties>>,
{
if let Some(cached) = self.cache.get(&endpoint).await {
return Ok(cached);
}
let properties = fetch_fn().await?;
let result = self
.cache
.get_or_insert_with(endpoint.clone(), || async { properties })
.await;
{
let mut timestamps = self.last_refresh.write().await;
timestamps.insert(endpoint, Instant::now());
}
Ok(result)
}
#[allow(dead_code)] pub(crate) async fn refresh_if_stale<F, Fut>(
&self,
endpoint: AccountEndpoint,
fetch_fn: F,
) -> azure_core::Result<Option<Arc<AccountProperties>>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = azure_core::Result<AccountProperties>>,
{
if !self.is_stale(&endpoint).await {
return Ok(self.cache.get(&endpoint).await);
}
let _guard = self.refresh_mutex.lock().await;
if !self.is_stale(&endpoint).await {
return Ok(self.cache.get(&endpoint).await);
}
let cached = self.cache.get(&endpoint).await;
let properties = match fetch_fn().await {
Ok(props) => props,
Err(e) => {
if cached.is_some() {
return Ok(cached);
}
return Err(e);
}
};
let endpoint_for_timestamp = endpoint.clone();
let result = self
.cache
.get_or_refresh_with(
endpoint,
|_existing| true, || async { properties },
)
.await;
if result.is_some() {
let mut timestamps = self.last_refresh.write().await;
timestamps.insert(endpoint_for_timestamp, Instant::now());
}
Ok(result)
}
#[allow(dead_code)] async fn is_stale(&self, endpoint: &AccountEndpoint) -> bool {
let cached = self.cache.get(endpoint).await;
let timestamps = self.last_refresh.read().await;
match timestamps.get(endpoint) {
Some(last) => cached.is_none() || last.elapsed() > self.staleness_threshold,
None => true,
}
}
pub(crate) async fn invalidate(
&self,
endpoint: &AccountEndpoint,
) -> Option<Arc<AccountProperties>> {
self.cache.invalidate(endpoint).await
}
}
impl Default for AccountMetadataCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
fn test_endpoint(name: &str) -> AccountEndpoint {
AccountEndpoint::from(
url::Url::parse(&format!("https://{name}.documents.azure.com:443/")).unwrap(),
)
}
fn test_properties(region_name: &str) -> AccountProperties {
let endpoint = format!("https://test-{region_name}.documents.azure.com:443/");
serde_json::from_value(serde_json::json!({
"_self": "",
"id": "test",
"_rid": "test.documents.azure.com",
"media": "//media/",
"addresses": "//addresses/",
"_dbs": "//dbs/",
"writableLocations": [{ "name": region_name, "databaseAccountEndpoint": endpoint }],
"readableLocations": [{ "name": region_name, "databaseAccountEndpoint": endpoint }],
"enableMultipleWriteLocations": false,
"userReplicationPolicy": { "minReplicaSetSize": 3, "maxReplicasetSize": 4 },
"userConsistencyPolicy": { "defaultConsistencyLevel": "Session" },
"systemReplicationPolicy": { "minReplicaSetSize": 3, "maxReplicasetSize": 4 },
"readPolicy": { "primaryReadCoefficient": 1, "secondaryReadCoefficient": 1 },
"queryEngineConfiguration": "{}"
}))
.expect("test JSON is valid")
}
#[tokio::test]
async fn caches_account_properties() {
let cache = AccountMetadataCache::new();
let counter = Arc::new(AtomicUsize::new(0));
let endpoint = test_endpoint("myaccount");
let counter_clone = counter.clone();
let props = cache
.get_or_fetch(endpoint.clone(), || async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok(test_properties("westus"))
})
.await
.unwrap();
assert_eq!(props.write_region().unwrap().as_str(), "westus");
assert_eq!(counter.load(Ordering::SeqCst), 1);
let counter_clone = counter.clone();
let props2 = cache
.get_or_fetch(endpoint, || async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok(test_properties("eastus"))
})
.await
.unwrap();
assert_eq!(props2.write_region().unwrap().as_str(), "westus");
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn different_accounts_cached_separately() {
let cache = AccountMetadataCache::new();
let props1 = cache
.get_or_fetch(test_endpoint("account1"), || async {
Ok(test_properties("westus"))
})
.await
.unwrap();
let props2 = cache
.get_or_fetch(test_endpoint("account2"), || async {
Ok(test_properties("eastus"))
})
.await
.unwrap();
assert_eq!(props1.write_region().unwrap().as_str(), "westus");
assert_eq!(props2.write_region().unwrap().as_str(), "eastus");
}
#[tokio::test]
async fn get_returns_none_before_fetch() {
let cache = AccountMetadataCache::new();
let endpoint = test_endpoint("myaccount");
assert!(cache.cache.get(&endpoint).await.is_none());
}
#[tokio::test]
async fn invalidate_removes_entry() {
let cache = AccountMetadataCache::new();
let endpoint = test_endpoint("myaccount");
cache
.get_or_fetch(endpoint.clone(), || async { Ok(test_properties("westus")) })
.await
.unwrap();
let removed = cache.cache.invalidate(&endpoint).await;
assert!(removed.is_some());
assert_eq!(removed.unwrap().write_region().unwrap().as_str(), "westus");
assert!(cache.cache.get(&endpoint).await.is_none());
}
#[tokio::test]
async fn clear_removes_all() {
let cache = AccountMetadataCache::new();
cache
.get_or_fetch(test_endpoint("account1"), || async {
Ok(test_properties("westus"))
})
.await
.unwrap();
cache
.get_or_fetch(test_endpoint("account2"), || async {
Ok(test_properties("eastus"))
})
.await
.unwrap();
cache.cache.clear().await;
assert!(cache.cache.get(&test_endpoint("account1")).await.is_none());
assert!(cache.cache.get(&test_endpoint("account2")).await.is_none());
}
#[test]
fn deserialize_full_account_payload() {
let json = r#"{
"_self": "",
"id": "testaccount",
"_rid": "testaccount.documents.azure.com",
"media": "//media/",
"addresses": "//addresses/",
"_dbs": "//dbs/",
"writableLocations": [
{ "name": "West US 2", "databaseAccountEndpoint": "https://test-westus2.documents.azure.com:443/" }
],
"readableLocations": [
{ "name": "West US 2", "databaseAccountEndpoint": "https://test-westus2.documents.azure.com:443/" },
{ "name": "East US 2", "databaseAccountEndpoint": "https://test-eastus2.documents.azure.com:443/" }
],
"enableMultipleWriteLocations": false,
"continuousBackupEnabled": false,
"enableNRegionSynchronousCommit": false,
"enablePerPartitionFailoverBehavior": false,
"userReplicationPolicy": { "minReplicaSetSize": 3, "maxReplicasetSize": 4 },
"userConsistencyPolicy": { "defaultConsistencyLevel": "Session" },
"systemReplicationPolicy": { "minReplicaSetSize": 3, "maxReplicasetSize": 4 },
"readPolicy": { "primaryReadCoefficient": 1, "secondaryReadCoefficient": 1 },
"queryEngineConfiguration": "{\"allowNewKeywords\":true}"
}"#;
let props: AccountProperties = serde_json::from_str(json).expect("deserialize");
assert_eq!(props.id, "testaccount");
assert_eq!(props.write_region().unwrap().as_str(), "westus2");
assert_eq!(props.readable_regions().len(), 2);
assert_eq!(props.writable_locations.len(), 1);
assert_eq!(props.readable_locations.len(), 2);
assert_eq!(props.user_replication_policy.min_replica_set_size, 3);
assert_eq!(
props.user_consistency_policy.default_consistency_level,
DefaultConsistencyLevel::Session
);
assert!(!props.enable_multiple_write_locations);
}
#[test]
fn write_region_is_none_when_empty() {
let props: AccountProperties = serde_json::from_value(serde_json::json!({
"_self": "",
"id": "",
"_rid": "",
"media": "",
"addresses": "",
"_dbs": "",
"writableLocations": [],
"readableLocations": [],
"enableMultipleWriteLocations": false,
"userReplicationPolicy": { "minReplicaSetSize": 0, "maxReplicasetSize": 0 },
"userConsistencyPolicy": { "defaultConsistencyLevel": "Session" },
"systemReplicationPolicy": { "minReplicaSetSize": 0, "maxReplicasetSize": 0 },
"readPolicy": { "primaryReadCoefficient": 0, "secondaryReadCoefficient": 0 },
"queryEngineConfiguration": "{}"
}))
.unwrap();
assert!(props.write_region().is_none());
assert!(props.readable_regions().is_empty());
}
#[tokio::test]
async fn refresh_if_stale_returns_cached_value_when_fresh() {
let cache = AccountMetadataCache::new();
let endpoint = test_endpoint("myaccount");
cache
.get_or_fetch(endpoint.clone(), || async { Ok(test_properties("westus")) })
.await
.unwrap();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = cache
.refresh_if_stale(endpoint, || async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok(test_properties("eastus"))
})
.await
.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().write_region().unwrap().as_str(), "westus");
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn refresh_if_stale_refreshes_when_threshold_exceeded() {
let cache = AccountMetadataCache {
cache: AsyncCache::new(),
last_refresh: async_lock::RwLock::new(std::collections::HashMap::new()),
staleness_threshold: Duration::from_secs(0),
refresh_mutex: async_lock::Mutex::new(()),
};
let endpoint = test_endpoint("myaccount");
cache
.get_or_fetch(endpoint.clone(), || async { Ok(test_properties("westus")) })
.await
.unwrap();
let result = cache
.refresh_if_stale(endpoint, || async { Ok(test_properties("eastus")) })
.await
.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().write_region().unwrap().as_str(), "eastus");
}
#[tokio::test]
async fn refresh_if_stale_returns_cached_on_fetch_failure() {
let cache = AccountMetadataCache {
cache: AsyncCache::new(),
last_refresh: async_lock::RwLock::new(std::collections::HashMap::new()),
staleness_threshold: Duration::from_secs(0),
refresh_mutex: async_lock::Mutex::new(()),
};
let endpoint = test_endpoint("myaccount");
cache
.get_or_fetch(endpoint.clone(), || async { Ok(test_properties("westus")) })
.await
.unwrap();
let result = cache
.refresh_if_stale(endpoint, || async {
Err(azure_core::Error::with_message(
azure_core::error::ErrorKind::Other,
"network failure",
))
})
.await
.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().write_region().unwrap().as_str(), "westus");
}
#[tokio::test]
async fn refresh_if_stale_propagates_error_when_no_cached_value() {
let cache = AccountMetadataCache {
cache: AsyncCache::new(),
last_refresh: async_lock::RwLock::new(std::collections::HashMap::new()),
staleness_threshold: Duration::from_secs(0),
refresh_mutex: async_lock::Mutex::new(()),
};
let endpoint = test_endpoint("myaccount");
let result = cache
.refresh_if_stale(endpoint, || async {
Err(azure_core::Error::with_message(
azure_core::error::ErrorKind::Other,
"network failure",
))
})
.await;
assert!(result.is_err());
}
}