1use crate::ErrorKind;
5use azure_security_keyvault_secrets::SecretClient;
6use std::{collections::HashMap, sync::Arc};
7use tokio::sync::Mutex;
8
9#[derive(Clone, Default)]
10pub struct ClientCache {
11 cache: Arc<Mutex<HashMap<String, Arc<SecretClient>>>>,
13}
14
15impl ClientCache {
16 pub fn new() -> Self {
17 Default::default()
18 }
19
20 pub async fn get(&mut self, client: Arc<SecretClient>) -> crate::Result<Arc<SecretClient>> {
21 let endpoint = client
22 .endpoint()
23 .host_str()
24 .ok_or_else(|| {
25 crate::Error::with_message(ErrorKind::InvalidData, "no host for SecretClient")
26 })?
27 .to_string();
28
29 let mut cache = self.cache.lock().await;
30 if let Some(c) = cache.get(&endpoint) {
31 tracing::debug!(
32 "found cached client for '{vault}'",
33 vault = c.endpoint().as_str()
34 );
35 return Ok(c.clone());
36 };
37
38 tracing::debug!("caching new client for '{vault}'", vault = &endpoint,);
39 cache.insert(endpoint, client.clone());
40 Ok(client)
41 }
42}
43
44#[cfg(test)]
45mod tests {
46 use super::*;
47 use azure_identity::DefaultAzureCredential;
48
49 #[tokio::test]
50 async fn test_client_cache() {
51 let credential = DefaultAzureCredential::new().unwrap();
52
53 let mut cache = ClientCache::new();
54 cache
55 .get(Arc::new(
56 SecretClient::new("https://vault1.vault.azure.net", credential.clone(), None)
57 .unwrap(),
58 ))
59 .await
60 .expect("add first client");
61 cache
62 .get(Arc::new(
63 SecretClient::new("https://vault2.vault.azure.net", credential.clone(), None)
64 .unwrap(),
65 ))
66 .await
67 .expect("add first client");
68 cache
69 .get(Arc::new(
70 SecretClient::new("https://vault1.vault.azure.net/", credential.clone(), None)
71 .unwrap(),
72 ))
73 .await
74 .expect("add first client again");
75
76 assert_eq!(cache.cache.lock().await.len(), 2);
77 }
78}