use azure_security_keyvault_keys::KeyClient;
use azure_security_keyvault_secrets::SecretClient;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
use url::Url;
#[derive(Default)]
pub struct ClientCache<T> {
cache: Arc<Mutex<HashMap<String, Arc<T>>>>,
}
impl<T: TypeName> ClientCache<T> {
pub fn new() -> Self {
Self {
cache: Default::default(),
}
}
pub async fn get<F>(&self, endpoint: impl AsRef<str>, f: F) -> crate::Result<Arc<T>>
where
F: FnOnce(&str) -> azure_core::Result<T>,
{
let endpoint = Url::parse(endpoint.as_ref())?.to_string();
let mut cache = self.cache.lock().await;
if let Some(c) = cache.get(&endpoint) {
tracing::debug!(target: "akv::cache", "found cached {client} for '{vault}'", client = T::type_name(), vault = &endpoint);
return Ok(c.clone());
};
let client = Arc::new(f(&endpoint)?);
tracing::debug!(target: "akv::cache", "caching new {client} for '{vault}'", client = T::type_name(), vault = &endpoint,);
cache.insert(endpoint, client.clone());
Ok(client)
}
}
impl<T> Clone for ClientCache<T> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
}
}
}
pub trait TypeName {
fn type_name() -> &'static str;
}
impl TypeName for KeyClient {
fn type_name() -> &'static str {
"KeyClient"
}
}
impl TypeName for SecretClient {
fn type_name() -> &'static str {
"SecretClient"
}
}
#[cfg(test)]
mod tests {
use super::*;
use azure_identity::AzureDeveloperCliCredential;
use azure_security_keyvault_secrets::SecretClient;
#[tokio::test]
async fn test_client_cache() {
let credential = AzureDeveloperCliCredential::new(None).unwrap();
let cache = ClientCache::<SecretClient>::new();
cache
.get("https://vault1.vault.azure.net", |endpoint| {
SecretClient::new(endpoint, credential.clone(), None)
})
.await
.expect("add first client");
cache
.get("https://vault2.vault.azure.net", |endpoint| {
SecretClient::new(endpoint, credential.clone(), None)
})
.await
.expect("add first client");
cache
.get("https://vault1.vault.azure.net/", |endpoint| {
SecretClient::new(endpoint, credential.clone(), None)
})
.await
.expect("add first client again");
assert_eq!(cache.cache.lock().await.len(), 2);
}
}