1use base64::{engine::general_purpose::STANDARD_NO_PAD as b64, Engine as _};
2use crc32c::crc32c;
3use gcloud_sdk::google::cloud::kms::v1::key_management_service_client::KeyManagementServiceClient;
4use gcloud_sdk::google::cloud::kms::v1::DecryptRequest;
5use gcloud_sdk::proto_ext::kms::EncryptRequest;
6use gcloud_sdk::*;
7use navajo::Envelope;
8use secret_vault_value::SecretValue;
9use std::fmt::{Debug, Display};
10use std::sync::Arc;
11use tokio::sync::OnceCell;
12use tonic::metadata::MetadataValue;
13use tonic::Status;
14
15#[derive(Clone)]
16pub struct Kms {
17 client: Arc<OnceCell<GoogleApi<KeyManagementServiceClient<GoogleAuthMiddleware>>>>,
18}
19
20impl Kms {
21 pub fn new() -> Self {
22 Self {
23 client: Default::default(),
24 }
25 }
26 pub fn key<N: ToString>(&self, name: N) -> CryptoKey {
27 CryptoKey {
28 name: name.to_string(),
29 client: self.client.clone(),
30 }
31 }
32}
33
34impl Default for Kms {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[derive(Clone)]
41pub struct CryptoKey {
42 name: String,
43 client: Arc<OnceCell<GoogleApi<KeyManagementServiceClient<GoogleAuthMiddleware>>>>,
44}
45impl Debug for CryptoKey {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("GcpKmsKey")
48 .field("name", &self.name)
49 .finish_non_exhaustive()
50 }
51}
52
53impl CryptoKey {
54 async fn try_get_client(
55 &self,
56 ) -> Result<KeyManagementServiceClient<GoogleAuthMiddleware>, KmsError> {
57 Ok(self.client.get_or_try_init(init_client).await?.get())
58 }
59}
60
61impl Envelope for CryptoKey {
62 type EncryptError = KmsError;
63 type DecryptError = KmsError;
64
65 fn encrypt_dek<A, P>(
66 &self,
67 aad: navajo::Aad<A>,
68 plaintext: P,
69 ) -> std::pin::Pin<
70 Box<dyn std::future::Future<Output = Result<Vec<u8>, Self::EncryptError>> + Send + '_>,
71 >
72 where
73 A: 'static + AsRef<[u8]> + Send + Sync,
74 P: 'static + AsRef<[u8]> + Send + Sync,
75 {
76 let plaintext = b64.encode(plaintext.as_ref());
77 let plaintext_crc32c = Some(crc32c(plaintext.as_bytes()) as i64);
78 let plaintext = SecretValue::new(plaintext.as_bytes().to_vec());
79 let additional_authenticated_data = aad.to_vec();
80 let mut request = tonic::Request::new(EncryptRequest {
81 name: self.name.clone(),
82 plaintext,
83 additional_authenticated_data,
84 plaintext_crc32c,
85 ..Default::default()
86 });
87
88 request.metadata_mut().insert(
89 "x-goog-request-params",
90 MetadataValue::<tonic::metadata::Ascii>::try_from(format!("name={}", self.name))
91 .unwrap(),
92 );
93 Box::pin(async move {
94 let response = self
95 .try_get_client()
96 .await?
97 .encrypt(request)
98 .await?
99 .into_inner();
100 Ok(response.ciphertext)
101 })
102 }
103
104 fn decrypt_dek<A, C>(
105 &self,
106 aad: navajo::Aad<A>,
107 ciphertext: C,
108 ) -> std::pin::Pin<
109 Box<dyn std::future::Future<Output = Result<Vec<u8>, Self::DecryptError>> + Send + '_>,
110 >
111 where
112 A: 'static + AsRef<[u8]> + Send + Sync,
113 C: 'static + AsRef<[u8]> + Send + Sync,
114 {
115 let ciphertext_crc32c = Some(crc32c(ciphertext.as_ref()) as i64);
117 let additional_authenticated_data = aad.to_vec();
118 let mut request = tonic::Request::new(DecryptRequest {
119 name: self.name.clone(),
120 ciphertext: ciphertext.as_ref().to_vec(),
121 additional_authenticated_data,
122 ciphertext_crc32c,
123 ..Default::default()
124 });
125
126 request.metadata_mut().insert(
127 "x-goog-request-params",
128 MetadataValue::<tonic::metadata::Ascii>::try_from(format!("name={}", self.name))
129 .unwrap(),
130 );
131 Box::pin(async move {
132 let response = self
133 .try_get_client()
134 .await?
135 .decrypt(request)
136 .await?
137 .into_inner();
138 let response = b64.decode(response.plaintext.as_sensitive_bytes())?;
139 Ok(response)
140 })
141 }
142}
143
144async fn init_client(
145) -> Result<GoogleApi<KeyManagementServiceClient<GoogleAuthMiddleware>>, gcloud_sdk::error::Error> {
146 GoogleApi::from_function(
147 KeyManagementServiceClient::new,
148 "https://cloudkms.googleapis.com",
149 None,
150 )
151 .await
152}
153
154pub mod sync {
155 use super::Envelope;
156 use std::sync::Arc;
157 #[derive(Clone)]
158 pub struct Kms {
159 kms: Arc<super::Kms>,
160 runtime: Arc<tokio::runtime::Runtime>,
161 }
162
163 impl Kms {
164 pub fn new() -> Self {
165 Self {
166 kms: Arc::new(super::Kms::new()),
167 runtime: Arc::new(tokio::runtime::Runtime::new().unwrap()),
168 }
169 }
170 pub fn key<N: ToString>(&self, name: N) -> CryptoKey {
171 CryptoKey {
172 key: Arc::new(self.kms.key(name)),
173 runtime: self.runtime.clone(),
174 }
175 }
176 }
177
178 impl Default for Kms {
179 fn default() -> Self {
180 Self::new()
181 }
182 }
183
184 #[derive(Clone, Debug)]
185 pub struct CryptoKey {
186 key: Arc<super::CryptoKey>,
187 runtime: Arc<tokio::runtime::Runtime>,
188 }
189
190 impl navajo::envelope::sync::Envelope for CryptoKey {
191 type EncryptError = super::KmsError;
192 type DecryptError = super::KmsError;
193
194 fn encrypt_dek<A, P>(
195 &self,
196 aad: navajo::Aad<A>,
197 plaintext: P,
198 ) -> Result<Vec<u8>, Self::EncryptError>
199 where
200 A: AsRef<[u8]>,
201 P: AsRef<[u8]>,
202 {
203 let aad = navajo::Aad(aad.as_ref().to_vec());
204 let cleartext = plaintext.as_ref().to_vec();
205 self.runtime
206 .block_on(async move { self.key.encrypt_dek(aad, cleartext).await })
207 }
208
209 fn decrypt_dek<A, C>(
210 &self,
211 aad: navajo::Aad<A>,
212 ciphertext: C,
213 ) -> Result<Vec<u8>, Self::DecryptError>
214 where
215 A: AsRef<[u8]>,
216 C: AsRef<[u8]>,
217 {
218 let aad = navajo::Aad(aad.as_ref().to_vec());
219 let ciphertext = ciphertext.as_ref().to_vec();
220 self.runtime
221 .block_on(async move { self.key.decrypt_dek(aad, ciphertext).await })
222 }
223 }
224}
225
226#[derive(Debug)]
229pub enum KmsError {
230 Tonic(Status),
231 Base64(base64::DecodeError),
232 Client(gcloud_sdk::error::Error),
233}
234
235impl std::error::Error for KmsError {}
236impl Display for KmsError {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 match self {
239 KmsError::Tonic(e) => write!(f, "navajo-gcp: {e}"),
240 KmsError::Base64(e) => write!(f, "navajo-gcp: {e}"),
241 KmsError::Client(e) => write!(f, "navajo-gcp: {e}"),
242 }
243 }
244}
245impl From<Status> for KmsError {
246 fn from(status: Status) -> Self {
247 Self::Tonic(status)
248 }
249}
250impl From<base64::DecodeError> for KmsError {
251 fn from(err: base64::DecodeError) -> Self {
252 Self::Base64(err)
253 }
254}
255impl From<gcloud_sdk::error::Error> for KmsError {
256 fn from(err: gcloud_sdk::error::Error) -> Self {
257 Self::Client(err)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[tokio::test]
266 #[ignore]
267 async fn test_encrypt() {
268 let gcp = Kms::new();
270 let key = gcp.key(std::env::var("GCP_KMS_KEY_URI").unwrap());
271 let aad = navajo::Aad("test");
272 let plaintext = "test";
273 let ciphertext = key.encrypt_dek(aad, plaintext).await.unwrap();
274 let plaintext = key.decrypt_dek(aad, ciphertext.clone()).await.unwrap();
275
276 println!("{ciphertext:?}");
277 println!("{}", String::from_utf8(plaintext).unwrap());
278 }
279}