use super::AsyncCache;
use crate::models::ContainerReference;
use std::sync::Arc;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct ContainerNameKey {
account_endpoint: String,
db_name: String,
container_name: String,
}
impl ContainerNameKey {
fn from_container(c: &ContainerReference) -> Self {
Self {
account_endpoint: c.account().endpoint().as_str().to_owned(),
db_name: c.database_name().to_owned(),
container_name: c.name().to_owned(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct ContainerRidKey {
account_endpoint: String,
container_rid: String,
}
impl ContainerRidKey {
fn from_container(c: &ContainerReference) -> Self {
Self {
account_endpoint: c.account().endpoint().as_str().to_owned(),
container_rid: c.rid().to_owned(),
}
}
}
#[derive(Debug)]
pub(crate) struct ContainerCache {
by_name: AsyncCache<ContainerNameKey, azure_core::Result<ContainerReference>>,
by_rid: AsyncCache<ContainerRidKey, azure_core::Result<ContainerReference>>,
}
impl ContainerCache {
pub(crate) fn new() -> Self {
Self {
by_name: AsyncCache::new(),
by_rid: AsyncCache::new(),
}
}
pub(crate) async fn get_or_fetch_by_name<F, Fut>(
&self,
account_endpoint: &str,
db_name: &str,
container_name: &str,
fetch_fn: F,
) -> azure_core::Result<Arc<ContainerReference>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = azure_core::Result<ContainerReference>>,
{
let key = ContainerNameKey {
account_endpoint: account_endpoint.to_owned(),
db_name: db_name.to_owned(),
container_name: container_name.to_owned(),
};
self.get_or_fetch_impl(&self.by_name, key, fetch_fn).await
}
pub(crate) async fn get_or_fetch_by_rid<F, Fut>(
&self,
account_endpoint: &str,
container_rid: &str,
fetch_fn: F,
) -> azure_core::Result<Arc<ContainerReference>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = azure_core::Result<ContainerReference>>,
{
let key = ContainerRidKey {
account_endpoint: account_endpoint.to_owned(),
container_rid: container_rid.to_owned(),
};
self.get_or_fetch_impl(&self.by_rid, key, fetch_fn).await
}
#[allow(dead_code)] pub(crate) async fn get_by_name(
&self,
account_endpoint: &str,
db_name: &str,
container_name: &str,
) -> Option<Arc<ContainerReference>> {
let key = ContainerNameKey {
account_endpoint: account_endpoint.to_owned(),
db_name: db_name.to_owned(),
container_name: container_name.to_owned(),
};
self.get_from(&self.by_name, &key).await
}
#[allow(dead_code)] pub(crate) async fn get_by_rid(
&self,
account_endpoint: &str,
container_rid: &str,
) -> Option<Arc<ContainerReference>> {
let key = ContainerRidKey {
account_endpoint: account_endpoint.to_owned(),
container_rid: container_rid.to_owned(),
};
self.get_from(&self.by_rid, &key).await
}
async fn get_or_fetch_impl<K, F, Fut>(
&self,
cache: &AsyncCache<K, azure_core::Result<ContainerReference>>,
key: K,
fetch_fn: F,
) -> azure_core::Result<Arc<ContainerReference>>
where
K: Eq + std::hash::Hash + Clone,
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = azure_core::Result<ContainerReference>>,
{
if let Some(cached) = self.get_from(cache, &key).await {
return Ok(cached);
}
let resolved = cache.get_or_insert_with(key.clone(), fetch_fn).await;
match resolved.as_ref() {
Ok(container) => {
self.put(container.clone()).await;
Ok(Arc::new(container.clone()))
}
Err(error) => {
cache.invalidate(&key).await;
Err(azure_core::Error::with_message(
error.kind().clone(),
crate::driver::error_chain_summary(error),
))
}
}
}
async fn get_from<K>(
&self,
cache: &AsyncCache<K, azure_core::Result<ContainerReference>>,
key: &K,
) -> Option<Arc<ContainerReference>>
where
K: Eq + std::hash::Hash + Clone,
{
cache
.get(key)
.await
.and_then(|entry| entry.as_ref().as_ref().ok().map(|c| Arc::new(c.clone())))
}
pub(crate) async fn put(&self, container: ContainerReference) {
let name_key = ContainerNameKey::from_container(&container);
let rid_key = ContainerRidKey::from_container(&container);
let container_for_rid = container.clone();
self.by_name
.get_or_insert_with(name_key, || async { Ok(container) })
.await;
self.by_rid
.get_or_insert_with(rid_key, || async { Ok(container_for_rid) })
.await;
}
}
impl Default for ContainerCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::{
AccountReference, ContainerProperties, ContainerReference, PartitionKeyDefinition,
SystemProperties,
};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use url::Url;
fn test_account() -> AccountReference {
AccountReference::with_master_key(
Url::parse("https://myaccount.documents.azure.com:443/").unwrap(),
"test-key",
)
}
const ACCOUNT_ENDPOINT: &str = "https://myaccount.documents.azure.com/";
fn test_partition_key_definition(path: &str) -> PartitionKeyDefinition {
serde_json::from_str(&format!(r#"{{"paths":["{path}"]}}"#)).unwrap()
}
fn test_container_props() -> ContainerProperties {
ContainerProperties {
id: "testcontainer".into(),
partition_key: test_partition_key_definition("/pk"),
system_properties: SystemProperties::default(),
}
}
fn test_container(db: &str, container: &str) -> ContainerReference {
ContainerReference::new(
test_account(),
db.to_owned(),
format!("{db}_rid"),
container.to_owned(),
format!("{db}_{container}_rid"),
&test_container_props(),
)
}
#[tokio::test]
async fn fetch_by_name_caches_and_cross_populates_rid() {
let cache = ContainerCache::new();
let counter = Arc::new(AtomicUsize::new(0));
let container = test_container("mydb", "mycoll");
let container_clone = container.clone();
let counter_clone = counter.clone();
let resolved = cache
.get_or_fetch_by_name(ACCOUNT_ENDPOINT, "mydb", "mycoll", || async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok(container_clone)
})
.await
.unwrap();
assert_eq!(resolved.name(), "mycoll");
assert_eq!(counter.load(Ordering::SeqCst), 1);
let by_name = cache.get_by_name(ACCOUNT_ENDPOINT, "mydb", "mycoll").await;
assert!(by_name.is_some());
assert_eq!(by_name.unwrap().name(), "mycoll");
let by_rid = cache.get_by_rid(ACCOUNT_ENDPOINT, container.rid()).await;
assert!(by_rid.is_some());
assert_eq!(by_rid.unwrap().name(), "mycoll");
}
#[tokio::test]
async fn fetch_by_name_deduplicates() {
let cache = ContainerCache::new();
let counter = Arc::new(AtomicUsize::new(0));
let container = test_container("mydb", "mycoll");
let c1 = container.clone();
let counter1 = counter.clone();
cache
.get_or_fetch_by_name(ACCOUNT_ENDPOINT, "mydb", "mycoll", || async move {
counter1.fetch_add(1, Ordering::SeqCst);
Ok(c1)
})
.await
.unwrap();
let c2 = container.clone();
let counter2 = counter.clone();
let resolved = cache
.get_or_fetch_by_name(ACCOUNT_ENDPOINT, "mydb", "mycoll", || async move {
counter2.fetch_add(1, Ordering::SeqCst);
Ok(c2)
})
.await
.unwrap();
assert_eq!(resolved.name(), "mycoll");
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn fetch_by_rid_caches_and_cross_populates_name() {
let cache = ContainerCache::new();
let counter = Arc::new(AtomicUsize::new(0));
let container = test_container("mydb", "mycoll");
let container_rid = container.rid().to_owned();
let container_clone = container.clone();
let counter_clone = counter.clone();
let resolved = cache
.get_or_fetch_by_rid(ACCOUNT_ENDPOINT, &container_rid, || async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok(container_clone)
})
.await
.unwrap();
assert_eq!(resolved.name(), "mycoll");
assert_eq!(counter.load(Ordering::SeqCst), 1);
let by_name = cache.get_by_name(ACCOUNT_ENDPOINT, "mydb", "mycoll").await;
assert!(by_name.is_some());
assert_eq!(by_name.unwrap().rid(), container_rid);
let by_rid = cache.get_by_rid(ACCOUNT_ENDPOINT, &container_rid).await;
assert!(by_rid.is_some());
}
#[tokio::test]
async fn put_populates_both_caches() {
let cache = ContainerCache::new();
let container = test_container("mydb", "mycoll");
let rid = container.rid().to_owned();
cache.put(container).await;
assert!(cache
.get_by_name(ACCOUNT_ENDPOINT, "mydb", "mycoll")
.await
.is_some());
assert!(cache.get_by_rid(ACCOUNT_ENDPOINT, &rid).await.is_some());
}
#[tokio::test]
async fn different_containers_cached_separately() {
let cache = ContainerCache::new();
let c1 = test_container("db1", "coll1");
let c2 = test_container("db1", "coll2");
cache.put(c1).await;
cache.put(c2).await;
let r1 = cache
.get_by_name(ACCOUNT_ENDPOINT, "db1", "coll1")
.await
.unwrap();
let r2 = cache
.get_by_name(ACCOUNT_ENDPOINT, "db1", "coll2")
.await
.unwrap();
assert_eq!(r1.name(), "coll1");
assert_eq!(r2.name(), "coll2");
}
#[tokio::test]
async fn same_container_different_databases() {
let cache = ContainerCache::new();
let c1 = test_container("db1", "coll");
let c2 = test_container("db2", "coll");
cache.put(c1).await;
cache.put(c2).await;
let r1 = cache
.get_by_name(ACCOUNT_ENDPOINT, "db1", "coll")
.await
.unwrap();
let r2 = cache
.get_by_name(ACCOUNT_ENDPOINT, "db2", "coll")
.await
.unwrap();
assert_eq!(r1.database_name(), "db1");
assert_eq!(r2.database_name(), "db2");
}
#[tokio::test]
async fn get_by_name_returns_none_before_fetch() {
let cache = ContainerCache::new();
assert!(cache
.get_by_name(ACCOUNT_ENDPOINT, "db", "unknown")
.await
.is_none());
}
#[tokio::test]
async fn get_by_rid_returns_none_before_fetch() {
let cache = ContainerCache::new();
assert!(cache
.get_by_rid(ACCOUNT_ENDPOINT, "unknown_rid")
.await
.is_none());
}
#[tokio::test]
async fn clear_removes_all() {
let cache = ContainerCache::new();
cache.put(test_container("db", "coll1")).await;
cache.put(test_container("db", "coll2")).await;
cache.by_name.clear().await;
cache.by_rid.clear().await;
assert!(cache
.get_by_name(ACCOUNT_ENDPOINT, "db", "coll1")
.await
.is_none());
assert!(cache
.get_by_name(ACCOUNT_ENDPOINT, "db", "coll2")
.await
.is_none());
}
}