1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use super::credentials::{CredentialsProvider, CredentialsProxy};
use async_trait::async_trait;
use aws_sdk_kms::{Client, Config, Region};
use aws_types::Credentials;
use envelopers::{
    DataKey, KMSKeyProvider, Key, KeyDecryptionError, KeyGenerationError, KeyProvider, U16,
};
use std::borrow::Cow;

/// A KeyProvider that wraps a KMSKeyProvider to make sure that is always has fresh credentials
pub struct AutoRenewAwsKMSKeyProvider {
    credentials: CredentialsProxy<KMSKeyProvider, Credentials>,
}

impl AutoRenewAwsKMSKeyProvider {
    /// Create an auto-renewing KMSKeyProvider from a key_id, region and credentials provider
    pub fn new<P: CredentialsProvider<Credentials = Credentials> + 'static>(
        key_id: impl Into<String>,
        region: impl Into<Cow<'static, str>>,
        provider: P,
    ) -> Self {
        let key_id = key_id.into();
        let region = region.into();

        Self {
            credentials: CredentialsProxy::new(provider, move |creds| {
                let client = Client::from_conf(
                    Config::builder()
                        .credentials_provider(creds)
                        .region(Some(Region::new(region.clone())))
                        .build(),
                );

                KMSKeyProvider::new(client, key_id.clone())
            }),
        }
    }
}

#[async_trait(?Send)]
impl KeyProvider for AutoRenewAwsKMSKeyProvider {
    async fn generate_data_key(&self) -> Result<DataKey, KeyGenerationError> {
        self.credentials
            .get()
            .await
            .map_err(|e| KeyGenerationError::Other(format!("Failed to load credentials: {}", e)))?
            .generate_data_key()
            .await
    }

    async fn decrypt_data_key(
        &self,
        encrypted_key: &Vec<u8>,
    ) -> Result<Key<U16>, KeyDecryptionError> {
        self.credentials
            .get()
            .await
            .map_err(|e| KeyDecryptionError::Other(format!("Failed to load credentials: {}", e)))?
            .decrypt_data_key(encrypted_key)
            .await
    }
}