akv_cli/
cache.rs

1// Copyright 2024 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4use azure_security_keyvault_keys::KeyClient;
5use azure_security_keyvault_secrets::SecretClient;
6use std::{collections::HashMap, sync::Arc};
7use tokio::sync::Mutex;
8use url::Url;
9
10#[derive(Default)]
11pub struct ClientCache<T> {
12    // Mutex should be fast enough for our needs of a CLI.
13    cache: Arc<Mutex<HashMap<String, Arc<T>>>>,
14}
15
16impl<T: TypeName> ClientCache<T> {
17    pub fn new() -> Self {
18        Self {
19            cache: Default::default(),
20        }
21    }
22
23    pub async fn get<F>(&self, endpoint: impl AsRef<str>, f: F) -> crate::Result<Arc<T>>
24    where
25        F: FnOnce(&str) -> azure_core::Result<T>,
26    {
27        // Canonicalize the URL.
28        let endpoint = Url::parse(endpoint.as_ref())?.to_string();
29        let mut cache = self.cache.lock().await;
30        if let Some(c) = cache.get(&endpoint) {
31            tracing::debug!(target: "akv::cache", "found cached {client} for '{vault}'", client = T::type_name(), vault = &endpoint);
32            return Ok(c.clone());
33        };
34
35        let client = Arc::new(f(&endpoint)?);
36
37        tracing::debug!(target: "akv::cache", "caching new {client} for '{vault}'", client = T::type_name(), vault = &endpoint,);
38        cache.insert(endpoint, client.clone());
39        Ok(client)
40    }
41}
42
43impl<T> Clone for ClientCache<T> {
44    fn clone(&self) -> Self {
45        Self {
46            cache: self.cache.clone(),
47        }
48    }
49}
50
51pub trait TypeName {
52    fn type_name() -> &'static str;
53}
54
55impl TypeName for KeyClient {
56    fn type_name() -> &'static str {
57        "KeyClient"
58    }
59}
60
61impl TypeName for SecretClient {
62    fn type_name() -> &'static str {
63        "SecretClient"
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use azure_identity::AzureDeveloperCliCredential;
71    use azure_security_keyvault_secrets::SecretClient;
72
73    #[tokio::test]
74    async fn test_client_cache() {
75        let credential = AzureDeveloperCliCredential::new(None).unwrap();
76
77        let cache = ClientCache::<SecretClient>::new();
78        cache
79            .get("https://vault1.vault.azure.net", |endpoint| {
80                SecretClient::new(endpoint, credential.clone(), None)
81            })
82            .await
83            .expect("add first client");
84        cache
85            .get("https://vault2.vault.azure.net", |endpoint| {
86                SecretClient::new(endpoint, credential.clone(), None)
87            })
88            .await
89            .expect("add first client");
90        cache
91            .get("https://vault1.vault.azure.net/", |endpoint| {
92                SecretClient::new(endpoint, credential.clone(), None)
93            })
94            .await
95            .expect("add first client again");
96
97        assert_eq!(cache.cache.lock().await.len(), 2);
98    }
99}