1use std::ops::Deref;
2use std::sync::Arc;
3
4#[cfg(feature = "auth")]
5pub use google_cloud_auth;
6use google_cloud_gax::conn::{ConnectionOptions, Environment, Error};
7
8use google_cloud_token::{NopeTokenSourceProvider, TokenSourceProvider};
9
10use crate::grpc::apiv1::conn_pool::{ConnectionManager, KMS, SCOPES};
11use crate::grpc::apiv1::kms_client::Client as KmsGrpcClient;
12
13#[derive(Debug)]
14pub struct ClientConfig {
15 pub endpoint: String,
16 pub token_source_provider: Box<dyn TokenSourceProvider>,
17 pub pool_size: Option<usize>,
18 pub connection_option: ConnectionOptions,
19}
20
21#[cfg(feature = "auth")]
22impl ClientConfig {
23 pub async fn with_auth(self) -> Result<Self, google_cloud_auth::error::Error> {
24 let ts = google_cloud_auth::token::DefaultTokenSourceProvider::new(Self::auth_config()).await?;
25 Ok(self.with_token_source(ts).await)
26 }
27
28 pub async fn with_credentials(
29 self,
30 credentials: google_cloud_auth::credentials::CredentialsFile,
31 ) -> Result<Self, google_cloud_auth::error::Error> {
32 let ts = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials(
33 Self::auth_config(),
34 Box::new(credentials),
35 )
36 .await?;
37 Ok(self.with_token_source(ts).await)
38 }
39
40 async fn with_token_source(mut self, ts: google_cloud_auth::token::DefaultTokenSourceProvider) -> Self {
41 self.token_source_provider = Box::new(ts);
42 self
43 }
44
45 fn auth_config() -> google_cloud_auth::project::Config<'static> {
46 google_cloud_auth::project::Config::default().with_scopes(&SCOPES)
47 }
48}
49
50impl Default for ClientConfig {
51 fn default() -> Self {
52 Self {
53 endpoint: KMS.to_string(),
54 token_source_provider: Box::new(NopeTokenSourceProvider {}),
55 pool_size: Some(1),
56 connection_option: ConnectionOptions::default(),
57 }
58 }
59}
60
61#[derive(Clone, Debug)]
62pub struct Client {
63 kms_client: KmsGrpcClient,
64}
65
66impl Client {
67 pub async fn new(config: ClientConfig) -> Result<Self, Error> {
68 let pool_size = config.pool_size.unwrap_or_default();
69 let cm = ConnectionManager::new(
70 pool_size,
71 config.endpoint.as_str(),
72 &Environment::GoogleCloud(config.token_source_provider),
73 &config.connection_option,
74 )
75 .await?;
76 Ok(Self {
77 kms_client: KmsGrpcClient::new(Arc::new(cm)),
78 })
79 }
80}
81
82impl Deref for Client {
83 type Target = KmsGrpcClient;
84
85 fn deref(&self) -> &Self::Target {
86 &self.kms_client
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use serial_test::serial;
93
94 use crate::grpc::kms::v1::{
95 AsymmetricSignRequest, CreateKeyRingRequest, DecryptRequest, EncryptRequest, GenerateRandomBytesRequest,
96 GetKeyRingRequest, GetPublicKeyRequest, ListKeyRingsRequest, MacSignRequest, MacVerifyRequest, ProtectionLevel,
97 };
98
99 use crate::client::{Client, ClientConfig};
100
101 async fn new_client() -> (Client, String) {
102 let cred = google_cloud_auth::credentials::CredentialsFile::new().await.unwrap();
103 let project = cred.project_id.clone().unwrap();
104 let config = ClientConfig::default().with_credentials(cred).await.unwrap();
105 (Client::new(config).await.unwrap(), project)
106 }
107
108 #[ctor::ctor]
109 fn init() {
110 let _ = tracing_subscriber::fmt().try_init();
111 }
112
113 #[tokio::test]
114 #[serial]
115 async fn test_key_ring() {
116 let (client, project) = new_client().await;
117 let key_ring_id = "gcpkmskr1714619260".to_string();
118
119 let create_request = CreateKeyRingRequest {
121 parent: format!("projects/{project}/locations/us-west1"),
122 key_ring_id: key_ring_id.clone(),
123 key_ring: None,
124 };
125 let key_ring = format!("{}/keyRings/{}", create_request.parent, create_request.key_ring_id);
134 let get_request = GetKeyRingRequest { name: key_ring };
136 let get_key_ring = client.get_key_ring(get_request.clone(), None).await.unwrap();
137 assert_eq!(get_key_ring.name, get_request.name);
138
139 let list_request = ListKeyRingsRequest {
141 parent: create_request.parent.to_string(),
142 page_size: 1,
143 page_token: "".to_string(),
144 filter: "".to_string(),
145 order_by: "".to_string(),
146 };
147 let list_result = client.list_key_rings(list_request, None).await.unwrap();
148 assert_eq!(1, list_result.key_rings.len());
149
150 let list_request = ListKeyRingsRequest {
151 parent: create_request.parent.to_string(),
152 page_size: 1,
153 page_token: list_result.next_page_token.to_string(),
154 filter: "".to_string(),
155 order_by: "".to_string(),
156 };
157 let list_result2 = client.list_key_rings(list_request, None).await.unwrap();
158 assert_eq!(1, list_result2.key_rings.len());
159
160 assert_ne!(list_result.key_rings[0].name, list_result2.key_rings[0].name);
161 }
162
163 #[tokio::test]
164 #[serial]
165 async fn test_generate_random_bytes() {
166 let (client, project) = new_client().await;
167
168 let create_request = GenerateRandomBytesRequest {
170 location: format!("projects/{project}/locations/us-west1"),
171 length_bytes: 128,
172 protection_level: ProtectionLevel::Hsm.into(),
173 };
174 let random_bytes = client.generate_random_bytes(create_request.clone(), None).await;
175 assert!(
176 random_bytes.is_ok(),
177 "Error when generating random bytes: {:?}",
178 random_bytes.unwrap_err()
179 );
180 let random_bytes = random_bytes.unwrap();
181 assert_eq!(
182 random_bytes.data.len(),
183 128,
184 "Returned data length was {:?} when it should have been 128",
185 random_bytes.data.len()
186 );
187 assert_ne!(
188 random_bytes.data, [0; 128],
189 "Data returned was all zeros: {:?}",
190 random_bytes.data
191 )
192 }
193
194 #[tokio::test]
195 #[serial]
196 async fn test_asymmetric_sign() {
197 let (client, project) = new_client().await;
198
199 let request = AsymmetricSignRequest {
200 name: format!("projects/{project}/locations/asia-northeast1/keyRings/gcr_test/cryptoKeys/eth-sign/cryptoKeyVersions/1"),
201 digest: None,
202 digest_crc32c: None,
203 data: vec![1,2,3,4,5],
204 data_crc32c: None,
205 };
206 let signature = client.asymmetric_sign(request.clone(), None).await.unwrap();
207 assert!(!signature.signature.is_empty());
208 }
209 #[tokio::test]
210 #[serial]
211 async fn test_get_pubkey() {
212 let (client, project) = new_client().await;
213 let request = GetPublicKeyRequest{
214 name: format!("projects/{project}/locations/asia-northeast1/keyRings/gcr_test/cryptoKeys/eth-sign/cryptoKeyVersions/1"),
215 };
216 let pubkey = client.get_public_key(request.clone(), None).await.unwrap();
217 assert!(!pubkey.pem.is_empty());
218 }
219
220 #[tokio::test]
221 #[serial]
222 async fn test_encrypt_decrypt() {
223 let (client, project) = new_client().await;
224
225 let key = format!("projects/{project}/locations/asia-northeast1/keyRings/gcr_test/cryptoKeys/gcr_test");
226 let data = [1, 2, 3, 4, 5];
227 let request = EncryptRequest {
228 name: key.clone(),
229 plaintext: data.to_vec(),
230 additional_authenticated_data: vec![],
231 plaintext_crc32c: None,
232 additional_authenticated_data_crc32c: None,
233 };
234 let encrypted = client.encrypt(request, None).await.unwrap();
235
236 let request = DecryptRequest {
237 name: key,
238 ciphertext: encrypted.ciphertext.clone(),
239 additional_authenticated_data: vec![],
240 ciphertext_crc32c: None,
241 additional_authenticated_data_crc32c: None,
242 };
243 let raw = client.decrypt(request.clone(), None).await.unwrap();
244 assert_eq!(data.to_vec(), raw.plaintext);
245 }
246
247 #[tokio::test]
248 #[serial]
249 async fn test_mac_sign_verify() {
250 let (client, project) = new_client().await;
251
252 let key = format!(
253 "projects/{project}/locations/asia-northeast1/keyRings/gcr_test/cryptoKeys/mac-test/cryptoKeyVersions/1"
254 );
255 let data = [1, 2, 3, 4, 5];
256 let request = MacSignRequest {
257 name: key.clone(),
258 data: data.to_vec(),
259 data_crc32c: None,
260 };
261 let signature = client.mac_sign(request, None).await.unwrap();
262
263 let request = MacVerifyRequest {
264 name: key,
265 data: data.to_vec(),
266 data_crc32c: None,
267 mac: signature.mac,
268 mac_crc32c: signature.mac_crc32c,
269 };
270 let raw = client.mac_verify(request, None).await.unwrap();
271 assert!(raw.success);
272 }
273}