1use azure_security_keyvault_keys::KeyClient;
7use azure_security_keyvault_secrets::SecretClient;
8use std::{collections::HashMap, sync::Arc};
9use tokio::sync::Mutex;
10use url::Url;
11
12#[derive(Default)]
14pub struct ClientCache<T> {
15 cache: Arc<Mutex<HashMap<String, Arc<T>>>>,
17}
18
19impl<T: TypeName> ClientCache<T> {
20 pub fn new() -> Self {
22 Self {
23 cache: Default::default(),
24 }
25 }
26
27 pub async fn get<F>(&self, endpoint: impl AsRef<str>, f: F) -> crate::Result<Arc<T>>
29 where
30 F: FnOnce(&str) -> azure_core::Result<T>,
31 {
32 let endpoint = Url::parse(endpoint.as_ref())?.to_string();
34 let mut cache = self.cache.lock().await;
35 if let Some(c) = cache.get(&endpoint) {
36 tracing::debug!(target: "akv::cache", "found cached {client} for '{vault}'", client = T::type_name(), vault = &endpoint);
37 return Ok(c.clone());
38 };
39
40 let client = Arc::new(f(&endpoint)?);
41
42 tracing::debug!(target: "akv::cache", "caching new {client} for '{vault}'", client = T::type_name(), vault = &endpoint,);
43 cache.insert(endpoint, client.clone());
44 Ok(client)
45 }
46}
47
48impl<T> Clone for ClientCache<T> {
49 fn clone(&self) -> Self {
50 Self {
51 cache: self.cache.clone(),
52 }
53 }
54}
55
56pub trait TypeName {
58 fn type_name() -> &'static str;
60}
61
62impl TypeName for KeyClient {
63 fn type_name() -> &'static str {
64 "KeyClient"
65 }
66}
67
68impl TypeName for SecretClient {
69 fn type_name() -> &'static str {
70 "SecretClient"
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use azure_identity::AzureDeveloperCliCredential;
78 use azure_security_keyvault_secrets::SecretClient;
79
80 #[tokio::test]
81 async fn test_client_cache() {
82 let credential = AzureDeveloperCliCredential::new(None).unwrap();
83
84 let cache = ClientCache::<SecretClient>::new();
85 cache
86 .get("https://vault1.vault.azure.net", |endpoint| {
87 SecretClient::new(endpoint, credential.clone(), None)
88 })
89 .await
90 .expect("add first client");
91 cache
92 .get("https://vault2.vault.azure.net", |endpoint| {
93 SecretClient::new(endpoint, credential.clone(), None)
94 })
95 .await
96 .expect("add first client");
97 cache
98 .get("https://vault1.vault.azure.net/", |endpoint| {
99 SecretClient::new(endpoint, credential.clone(), None)
100 })
101 .await
102 .expect("add first client again");
103
104 assert_eq!(cache.cache.lock().await.len(), 2);
105 }
106}