Skip to main content

akv_cli/
cache.rs

1// Copyright 2024 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4//! Cache clients based on the vault URL.
5
6use azure_security_keyvault_keys::KeyClient;
7use azure_security_keyvault_secrets::SecretClient;
8use std::{collections::HashMap, sync::Arc};
9use tokio::sync::Mutex;
10use url::Url;
11
12/// Caches Key Vault clients based on the vault URL.
13#[derive(Default)]
14pub struct ClientCache<T> {
15    // Mutex should be fast enough for our needs of a CLI.
16    cache: Arc<Mutex<HashMap<String, Arc<T>>>>,
17}
18
19impl<T: TypeName> ClientCache<T> {
20    /// Create a new `ClientCache`.
21    pub fn new() -> Self {
22        Self {
23            cache: Default::default(),
24        }
25    }
26
27    /// Gets, or creates and caches, a Key Vault client based on the `endpoint`.
28    pub async fn get<F>(&self, endpoint: impl AsRef<str>, f: F) -> crate::Result<Arc<T>>
29    where
30        F: FnOnce(&str) -> azure_core::Result<T>,
31    {
32        // Canonicalize the URL.
33        let endpoint = Url::parse(endpoint.as_ref())?.to_string();
34        let mut cache = self.cache.lock().await;
35        if let Some(c) = cache.get(&endpoint) {
36            tracing::debug!(target: "akv::cache", "found cached {client} for '{vault}'", client = T::type_name(), vault = &endpoint);
37            return Ok(c.clone());
38        };
39
40        let client = Arc::new(f(&endpoint)?);
41
42        tracing::debug!(target: "akv::cache", "caching new {client} for '{vault}'", client = T::type_name(), vault = &endpoint,);
43        cache.insert(endpoint, client.clone());
44        Ok(client)
45    }
46}
47
48impl<T> Clone for ClientCache<T> {
49    fn clone(&self) -> Self {
50        Self {
51            cache: self.cache.clone(),
52        }
53    }
54}
55
56/// Gets a friendly name of the implementing type.
57pub trait TypeName {
58    /// The friendly name of the implementing type.
59    fn type_name() -> &'static str;
60}
61
62impl TypeName for KeyClient {
63    fn type_name() -> &'static str {
64        "KeyClient"
65    }
66}
67
68impl TypeName for SecretClient {
69    fn type_name() -> &'static str {
70        "SecretClient"
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use azure_identity::AzureDeveloperCliCredential;
78    use azure_security_keyvault_secrets::SecretClient;
79
80    #[tokio::test]
81    async fn test_client_cache() {
82        let credential = AzureDeveloperCliCredential::new(None).unwrap();
83
84        let cache = ClientCache::<SecretClient>::new();
85        cache
86            .get("https://vault1.vault.azure.net", |endpoint| {
87                SecretClient::new(endpoint, credential.clone(), None)
88            })
89            .await
90            .expect("add first client");
91        cache
92            .get("https://vault2.vault.azure.net", |endpoint| {
93                SecretClient::new(endpoint, credential.clone(), None)
94            })
95            .await
96            .expect("add first client");
97        cache
98            .get("https://vault1.vault.azure.net/", |endpoint| {
99                SecretClient::new(endpoint, credential.clone(), None)
100            })
101            .await
102            .expect("add first client again");
103
104        assert_eq!(cache.cache.lock().await.len(), 2);
105    }
106}