akv_cli/
cache.rs

1// Copyright 2024 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4use crate::ErrorKind;
5use azure_security_keyvault_secrets::SecretClient;
6use std::{collections::HashMap, sync::Arc};
7use tokio::sync::Mutex;
8
9#[derive(Clone, Default)]
10pub struct ClientCache {
11    // Mutex should be fast enough for our needs of a CLI.
12    cache: Arc<Mutex<HashMap<String, Arc<SecretClient>>>>,
13}
14
15impl ClientCache {
16    pub fn new() -> Self {
17        Default::default()
18    }
19
20    pub async fn get(&mut self, client: Arc<SecretClient>) -> crate::Result<Arc<SecretClient>> {
21        let endpoint = client
22            .endpoint()
23            .host_str()
24            .ok_or_else(|| {
25                crate::Error::with_message(ErrorKind::InvalidData, "no host for SecretClient")
26            })?
27            .to_string();
28
29        let mut cache = self.cache.lock().await;
30        if let Some(c) = cache.get(&endpoint) {
31            tracing::debug!(
32                "found cached client for '{vault}'",
33                vault = c.endpoint().as_str()
34            );
35            return Ok(c.clone());
36        };
37
38        tracing::debug!("caching new client for '{vault}'", vault = &endpoint,);
39        cache.insert(endpoint, client.clone());
40        Ok(client)
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use azure_identity::DefaultAzureCredential;
48
49    #[tokio::test]
50    async fn test_client_cache() {
51        let credential = DefaultAzureCredential::new().unwrap();
52
53        let mut cache = ClientCache::new();
54        cache
55            .get(Arc::new(
56                SecretClient::new("https://vault1.vault.azure.net", credential.clone(), None)
57                    .unwrap(),
58            ))
59            .await
60            .expect("add first client");
61        cache
62            .get(Arc::new(
63                SecretClient::new("https://vault2.vault.azure.net", credential.clone(), None)
64                    .unwrap(),
65            ))
66            .await
67            .expect("add first client");
68        cache
69            .get(Arc::new(
70                SecretClient::new("https://vault1.vault.azure.net/", credential.clone(), None)
71                    .unwrap(),
72            ))
73            .await
74            .expect("add first client again");
75
76        assert_eq!(cache.cache.lock().await.len(), 2);
77    }
78}