1use std::{collections::HashMap, sync::Arc};
5use tokio::sync::Mutex;
6use url::Url;
7
8#[derive(Default)]
9pub struct ClientCache<T> {
10 cache: Arc<Mutex<HashMap<String, Arc<T>>>>,
12}
13
14impl<T> ClientCache<T> {
15 pub fn new() -> Self {
16 Self {
17 cache: Default::default(),
18 }
19 }
20
21 pub async fn get<F>(&self, endpoint: impl AsRef<str>, f: F) -> crate::Result<Arc<T>>
22 where
23 F: FnOnce(&str) -> azure_core::Result<T>,
24 {
25 let endpoint = Url::parse(endpoint.as_ref())?.to_string();
27 let mut cache = self.cache.lock().await;
28 if let Some(c) = cache.get(&endpoint) {
29 tracing::debug!(target: "akv::cache", "found cached client for '{vault}'", vault = &endpoint);
30 return Ok(c.clone());
31 };
32
33 let client = Arc::new(f(&endpoint)?);
34
35 tracing::debug!(target: "akv::cache", "caching new client for '{vault}'", vault = &endpoint,);
36 cache.insert(endpoint, client.clone());
37 Ok(client)
38 }
39}
40
41impl<T> Clone for ClientCache<T> {
42 fn clone(&self) -> Self {
43 Self {
44 cache: self.cache.clone(),
45 }
46 }
47}
48#[cfg(test)]
49mod tests {
50 use super::*;
51 use azure_identity::DefaultAzureCredential;
52 use azure_security_keyvault_secrets::SecretClient;
53
54 #[tokio::test]
55 async fn test_client_cache() {
56 let credential = DefaultAzureCredential::new().unwrap();
57
58 let cache = ClientCache::<SecretClient>::new();
59 cache
60 .get("https://vault1.vault.azure.net", |endpoint| {
61 SecretClient::new(endpoint, credential.clone(), None)
62 })
63 .await
64 .expect("add first client");
65 cache
66 .get("https://vault2.vault.azure.net", |endpoint| {
67 SecretClient::new(endpoint, credential.clone(), None)
68 })
69 .await
70 .expect("add first client");
71 cache
72 .get("https://vault1.vault.azure.net/", |endpoint| {
73 SecretClient::new(endpoint, credential.clone(), None)
74 })
75 .await
76 .expect("add first client again");
77
78 assert_eq!(cache.cache.lock().await.len(), 2);
79 }
80}