1use azure_security_keyvault_keys::KeyClient;
5use azure_security_keyvault_secrets::SecretClient;
6use std::{collections::HashMap, sync::Arc};
7use tokio::sync::Mutex;
8use url::Url;
9
10#[derive(Default)]
11pub struct ClientCache<T> {
12 cache: Arc<Mutex<HashMap<String, Arc<T>>>>,
14}
15
16impl<T: TypeName> ClientCache<T> {
17 pub fn new() -> Self {
18 Self {
19 cache: Default::default(),
20 }
21 }
22
23 pub async fn get<F>(&self, endpoint: impl AsRef<str>, f: F) -> crate::Result<Arc<T>>
24 where
25 F: FnOnce(&str) -> azure_core::Result<T>,
26 {
27 let endpoint = Url::parse(endpoint.as_ref())?.to_string();
29 let mut cache = self.cache.lock().await;
30 if let Some(c) = cache.get(&endpoint) {
31 tracing::debug!(target: "akv::cache", "found cached {client} for '{vault}'", client = T::type_name(), vault = &endpoint);
32 return Ok(c.clone());
33 };
34
35 let client = Arc::new(f(&endpoint)?);
36
37 tracing::debug!(target: "akv::cache", "caching new {client} for '{vault}'", client = T::type_name(), vault = &endpoint,);
38 cache.insert(endpoint, client.clone());
39 Ok(client)
40 }
41}
42
43impl<T> Clone for ClientCache<T> {
44 fn clone(&self) -> Self {
45 Self {
46 cache: self.cache.clone(),
47 }
48 }
49}
50
51pub trait TypeName {
52 fn type_name() -> &'static str;
53}
54
55impl TypeName for KeyClient {
56 fn type_name() -> &'static str {
57 "KeyClient"
58 }
59}
60
61impl TypeName for SecretClient {
62 fn type_name() -> &'static str {
63 "SecretClient"
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use azure_identity::DefaultAzureCredential;
71 use azure_security_keyvault_secrets::SecretClient;
72
73 #[tokio::test]
74 async fn test_client_cache() {
75 let credential = DefaultAzureCredential::new().unwrap();
76
77 let cache = ClientCache::<SecretClient>::new();
78 cache
79 .get("https://vault1.vault.azure.net", |endpoint| {
80 SecretClient::new(endpoint, credential.clone(), None)
81 })
82 .await
83 .expect("add first client");
84 cache
85 .get("https://vault2.vault.azure.net", |endpoint| {
86 SecretClient::new(endpoint, credential.clone(), None)
87 })
88 .await
89 .expect("add first client");
90 cache
91 .get("https://vault1.vault.azure.net/", |endpoint| {
92 SecretClient::new(endpoint, credential.clone(), None)
93 })
94 .await
95 .expect("add first client again");
96
97 assert_eq!(cache.cache.lock().await.len(), 2);
98 }
99}