navajo_gcp/
kms.rs

1use base64::{engine::general_purpose::STANDARD_NO_PAD as b64, Engine as _};
2use crc32c::crc32c;
3use gcloud_sdk::google::cloud::kms::v1::key_management_service_client::KeyManagementServiceClient;
4use gcloud_sdk::google::cloud::kms::v1::DecryptRequest;
5use gcloud_sdk::proto_ext::kms::EncryptRequest;
6use gcloud_sdk::*;
7use navajo::Envelope;
8use secret_vault_value::SecretValue;
9use std::fmt::{Debug, Display};
10use std::sync::Arc;
11use tokio::sync::OnceCell;
12use tonic::metadata::MetadataValue;
13use tonic::Status;
14
15#[derive(Clone)]
16pub struct Kms {
17    client: Arc<OnceCell<GoogleApi<KeyManagementServiceClient<GoogleAuthMiddleware>>>>,
18}
19
20impl Kms {
21    pub fn new() -> Self {
22        Self {
23            client: Default::default(),
24        }
25    }
26    pub fn key<N: ToString>(&self, name: N) -> CryptoKey {
27        CryptoKey {
28            name: name.to_string(),
29            client: self.client.clone(),
30        }
31    }
32}
33
34impl Default for Kms {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[derive(Clone)]
41pub struct CryptoKey {
42    name: String,
43    client: Arc<OnceCell<GoogleApi<KeyManagementServiceClient<GoogleAuthMiddleware>>>>,
44}
45impl Debug for CryptoKey {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("GcpKmsKey")
48            .field("name", &self.name)
49            .finish_non_exhaustive()
50    }
51}
52
53impl CryptoKey {
54    async fn try_get_client(
55        &self,
56    ) -> Result<KeyManagementServiceClient<GoogleAuthMiddleware>, KmsError> {
57        Ok(self.client.get_or_try_init(init_client).await?.get())
58    }
59}
60
61impl Envelope for CryptoKey {
62    type EncryptError = KmsError;
63    type DecryptError = KmsError;
64
65    fn encrypt_dek<A, P>(
66        &self,
67        aad: navajo::Aad<A>,
68        plaintext: P,
69    ) -> std::pin::Pin<
70        Box<dyn std::future::Future<Output = Result<Vec<u8>, Self::EncryptError>> + Send + '_>,
71    >
72    where
73        A: 'static + AsRef<[u8]> + Send + Sync,
74        P: 'static + AsRef<[u8]> + Send + Sync,
75    {
76        let plaintext = b64.encode(plaintext.as_ref());
77        let plaintext_crc32c = Some(crc32c(plaintext.as_bytes()) as i64);
78        let plaintext = SecretValue::new(plaintext.as_bytes().to_vec());
79        let additional_authenticated_data = aad.to_vec();
80        let mut request = tonic::Request::new(EncryptRequest {
81            name: self.name.clone(),
82            plaintext,
83            additional_authenticated_data,
84            plaintext_crc32c,
85            ..Default::default()
86        });
87
88        request.metadata_mut().insert(
89            "x-goog-request-params",
90            MetadataValue::<tonic::metadata::Ascii>::try_from(format!("name={}", self.name))
91                .unwrap(),
92        );
93        Box::pin(async move {
94            let response = self
95                .try_get_client()
96                .await?
97                .encrypt(request)
98                .await?
99                .into_inner();
100            Ok(response.ciphertext)
101        })
102    }
103
104    fn decrypt_dek<A, C>(
105        &self,
106        aad: navajo::Aad<A>,
107        ciphertext: C,
108    ) -> std::pin::Pin<
109        Box<dyn std::future::Future<Output = Result<Vec<u8>, Self::DecryptError>> + Send + '_>,
110    >
111    where
112        A: 'static + AsRef<[u8]> + Send + Sync,
113        C: 'static + AsRef<[u8]> + Send + Sync,
114    {
115        // let ciphertext = b64.encode(ciphertext.as_ref());
116        let ciphertext_crc32c = Some(crc32c(ciphertext.as_ref()) as i64);
117        let additional_authenticated_data = aad.to_vec();
118        let mut request = tonic::Request::new(DecryptRequest {
119            name: self.name.clone(),
120            ciphertext: ciphertext.as_ref().to_vec(),
121            additional_authenticated_data,
122            ciphertext_crc32c,
123            ..Default::default()
124        });
125
126        request.metadata_mut().insert(
127            "x-goog-request-params",
128            MetadataValue::<tonic::metadata::Ascii>::try_from(format!("name={}", self.name))
129                .unwrap(),
130        );
131        Box::pin(async move {
132            let response = self
133                .try_get_client()
134                .await?
135                .decrypt(request)
136                .await?
137                .into_inner();
138            let response = b64.decode(response.plaintext.as_sensitive_bytes())?;
139            Ok(response)
140        })
141    }
142}
143
144async fn init_client(
145) -> Result<GoogleApi<KeyManagementServiceClient<GoogleAuthMiddleware>>, gcloud_sdk::error::Error> {
146    GoogleApi::from_function(
147        KeyManagementServiceClient::new,
148        "https://cloudkms.googleapis.com",
149        None,
150    )
151    .await
152}
153
154pub mod sync {
155    use super::Envelope;
156    use std::sync::Arc;
157    #[derive(Clone)]
158    pub struct Kms {
159        kms: Arc<super::Kms>,
160        runtime: Arc<tokio::runtime::Runtime>,
161    }
162
163    impl Kms {
164        pub fn new() -> Self {
165            Self {
166                kms: Arc::new(super::Kms::new()),
167                runtime: Arc::new(tokio::runtime::Runtime::new().unwrap()),
168            }
169        }
170        pub fn key<N: ToString>(&self, name: N) -> CryptoKey {
171            CryptoKey {
172                key: Arc::new(self.kms.key(name)),
173                runtime: self.runtime.clone(),
174            }
175        }
176    }
177
178    impl Default for Kms {
179        fn default() -> Self {
180            Self::new()
181        }
182    }
183
184    #[derive(Clone, Debug)]
185    pub struct CryptoKey {
186        key: Arc<super::CryptoKey>,
187        runtime: Arc<tokio::runtime::Runtime>,
188    }
189
190    impl navajo::envelope::sync::Envelope for CryptoKey {
191        type EncryptError = super::KmsError;
192        type DecryptError = super::KmsError;
193
194        fn encrypt_dek<A, P>(
195            &self,
196            aad: navajo::Aad<A>,
197            plaintext: P,
198        ) -> Result<Vec<u8>, Self::EncryptError>
199        where
200            A: AsRef<[u8]>,
201            P: AsRef<[u8]>,
202        {
203            let aad = navajo::Aad(aad.as_ref().to_vec());
204            let cleartext = plaintext.as_ref().to_vec();
205            self.runtime
206                .block_on(async move { self.key.encrypt_dek(aad, cleartext).await })
207        }
208
209        fn decrypt_dek<A, C>(
210            &self,
211            aad: navajo::Aad<A>,
212            ciphertext: C,
213        ) -> Result<Vec<u8>, Self::DecryptError>
214        where
215            A: AsRef<[u8]>,
216            C: AsRef<[u8]>,
217        {
218            let aad = navajo::Aad(aad.as_ref().to_vec());
219            let ciphertext = ciphertext.as_ref().to_vec();
220            self.runtime
221                .block_on(async move { self.key.decrypt_dek(aad, ciphertext).await })
222        }
223    }
224}
225
226// TODO: Improve error handling, classify KmsError
227
228#[derive(Debug)]
229pub enum KmsError {
230    Tonic(Status),
231    Base64(base64::DecodeError),
232    Client(gcloud_sdk::error::Error),
233}
234
235impl std::error::Error for KmsError {}
236impl Display for KmsError {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self {
239            KmsError::Tonic(e) => write!(f, "navajo-gcp: {e}"),
240            KmsError::Base64(e) => write!(f, "navajo-gcp: {e}"),
241            KmsError::Client(e) => write!(f, "navajo-gcp: {e}"),
242        }
243    }
244}
245impl From<Status> for KmsError {
246    fn from(status: Status) -> Self {
247        Self::Tonic(status)
248    }
249}
250impl From<base64::DecodeError> for KmsError {
251    fn from(err: base64::DecodeError) -> Self {
252        Self::Base64(err)
253    }
254}
255impl From<gcloud_sdk::error::Error> for KmsError {
256    fn from(err: gcloud_sdk::error::Error) -> Self {
257        Self::Client(err)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[tokio::test]
266    #[ignore]
267    async fn test_encrypt() {
268        // todo: need to figure out how to safely run this in CI
269        let gcp = Kms::new();
270        let key = gcp.key(std::env::var("GCP_KMS_KEY_URI").unwrap());
271        let aad = navajo::Aad("test");
272        let plaintext = "test";
273        let ciphertext = key.encrypt_dek(aad, plaintext).await.unwrap();
274        let plaintext = key.decrypt_dek(aad, ciphertext.clone()).await.unwrap();
275
276        println!("{ciphertext:?}");
277        println!("{}", String::from_utf8(plaintext).unwrap());
278    }
279}