akv_cli/
cache.rs

1// Copyright 2024 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4use std::{collections::HashMap, sync::Arc};
5use tokio::sync::Mutex;
6use url::Url;
7
8#[derive(Default)]
9pub struct ClientCache<T> {
10    // Mutex should be fast enough for our needs of a CLI.
11    cache: Arc<Mutex<HashMap<String, Arc<T>>>>,
12}
13
14impl<T> ClientCache<T> {
15    pub fn new() -> Self {
16        Self {
17            cache: Default::default(),
18        }
19    }
20
21    pub async fn get<F>(&self, endpoint: impl AsRef<str>, f: F) -> crate::Result<Arc<T>>
22    where
23        F: FnOnce(&str) -> azure_core::Result<T>,
24    {
25        // Canonicalize the URL.
26        let endpoint = Url::parse(endpoint.as_ref())?.to_string();
27        let mut cache = self.cache.lock().await;
28        if let Some(c) = cache.get(&endpoint) {
29            tracing::debug!(target: "akv::cache", "found cached client for '{vault}'", vault = &endpoint);
30            return Ok(c.clone());
31        };
32
33        let client = Arc::new(f(&endpoint)?);
34
35        tracing::debug!(target: "akv::cache", "caching new client for '{vault}'", vault = &endpoint,);
36        cache.insert(endpoint, client.clone());
37        Ok(client)
38    }
39}
40
41impl<T> Clone for ClientCache<T> {
42    fn clone(&self) -> Self {
43        Self {
44            cache: self.cache.clone(),
45        }
46    }
47}
48#[cfg(test)]
49mod tests {
50    use super::*;
51    use azure_identity::DefaultAzureCredential;
52    use azure_security_keyvault_secrets::SecretClient;
53
54    #[tokio::test]
55    async fn test_client_cache() {
56        let credential = DefaultAzureCredential::new().unwrap();
57
58        let cache = ClientCache::<SecretClient>::new();
59        cache
60            .get("https://vault1.vault.azure.net", |endpoint| {
61                SecretClient::new(endpoint, credential.clone(), None)
62            })
63            .await
64            .expect("add first client");
65        cache
66            .get("https://vault2.vault.azure.net", |endpoint| {
67                SecretClient::new(endpoint, credential.clone(), None)
68            })
69            .await
70            .expect("add first client");
71        cache
72            .get("https://vault1.vault.azure.net/", |endpoint| {
73                SecretClient::new(endpoint, credential.clone(), None)
74            })
75            .await
76            .expect("add first client again");
77
78        assert_eq!(cache.cache.lock().await.len(), 2);
79    }
80}