1mod attribute_name;
2pub mod query;
3mod table_attribute;
4mod table_attributes;
5mod table_entry;
6pub use self::{
7 attribute_name::AttributeName,
8 query::QueryBuilder,
9 table_attribute::{TableAttribute, TryFromTableAttr},
10 table_attributes::TableAttributes,
11 table_entry::TableEntry,
12};
13use crate::{
14 crypto::*,
15 errors::*,
16 traits::{Decryptable, PrimaryKey, PrimaryKeyError, PrimaryKeyParts, Searchable},
17 Identifiable, IndexType,
18};
19use aws_sdk_dynamodb::types::{AttributeValue, Delete, Put, TransactWriteItem};
20use cipherstash_client::{
21 config::{console_config::ConsoleConfig, cts_config::CtsConfig},
22 credentials::{
23 auto_refresh::AutoRefresh,
24 service_credentials::{ServiceCredentials, ServiceToken},
25 Credentials,
26 },
27 encryption::Encryption,
28 zerokms::ZeroKMS,
29 ZeroKMSConfig,
30};
31use log::info;
32use std::{
33 borrow::Cow,
34 collections::{HashMap, HashSet},
35 ops::Deref,
36};
37
38const DEFAULT_TERM_SIZE: usize = 12;
40
41pub struct Headless;
42
43pub struct Dynamo {
44 pub(crate) db: aws_sdk_dynamodb::Client,
45 pub(crate) table_name: String,
46}
47
48impl Deref for Dynamo {
49 type Target = aws_sdk_dynamodb::Client;
50
51 fn deref(&self) -> &Self::Target {
52 &self.db
53 }
54}
55
56pub struct EncryptedTable<D = Dynamo> {
57 db: D,
58 cipher: Box<Encryption<AutoRefresh<ServiceCredentials>>>,
59}
60
61impl<D> EncryptedTable<D> {
62 pub fn cipher(&self) -> &Encryption<impl Credentials<Token = ServiceToken>> {
63 self.cipher.as_ref()
64 }
65}
66
67impl EncryptedTable<Headless> {
68 pub async fn init_headless() -> Result<Self, InitError> {
69 info!("Initializing...");
70
71 let console_config = ConsoleConfig::builder().with_env().build()?;
72
73 let cts_config = CtsConfig::builder().with_env().build()?;
74
75 let zerokms_config = ZeroKMSConfig::builder()
76 .decryption_log(true)
77 .with_env()
78 .console_config(&console_config)
79 .cts_config(&cts_config)
80 .build_with_client_key()?;
81
82 let zerokms_client = ZeroKMS::new_with_client_key(
83 &zerokms_config.base_url(),
84 AutoRefresh::new(zerokms_config.credentials()),
85 zerokms_config.decryption_log_path().as_deref(),
86 zerokms_config.client_key(),
87 );
88
89 info!("Fetching dataset config...");
90 let dataset_config = zerokms_client.load_dataset_config().await?;
91
92 let cipher = Box::new(Encryption::new(
93 dataset_config.index_root_key,
94 zerokms_client,
95 ));
96
97 info!("Ready!");
98
99 Ok(Self {
100 db: Headless,
101 cipher,
102 })
103 }
104}
105
106pub struct DynamoRecordPatch {
114 pub put_records: Vec<HashMap<String, AttributeValue>>,
115 pub delete_records: Vec<PrimaryKeyParts>,
116}
117
118pub struct PreparedRecord {
119 protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
120 protected_attributes: Cow<'static, [Cow<'static, str>]>,
121 sealer: Sealer,
122}
123
124pub struct PreparedDelete {
125 primary_key: PreparedPrimaryKey,
126 protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
127}
128
129impl PreparedDelete {
130 pub fn new<S: Searchable>(k: impl Into<S::PrimaryKey>) -> Self {
131 Self::new_from_parts::<S>(
132 k.into()
133 .into_parts(&S::type_name(), S::sort_key_prefix().as_deref()),
134 )
135 }
136
137 pub fn new_from_parts<S: Searchable>(k: PrimaryKeyParts) -> Self {
138 let primary_key = PreparedPrimaryKey::new_from_parts::<S>(k);
139 let protected_indexes = S::protected_indexes();
140
141 Self {
142 primary_key,
143 protected_indexes,
144 }
145 }
146
147 pub fn prepared_primary_key(&self) -> PreparedPrimaryKey {
148 self.primary_key.clone()
149 }
150
151 pub fn protected_indexes(&self) -> &[(Cow<'static, str>, IndexType)] {
152 &self.protected_indexes
153 }
154}
155
156impl PreparedRecord {
157 pub(crate) fn new(
158 protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
159 protected_attributes: Cow<'static, [Cow<'static, str>]>,
160 sealer: Sealer,
161 ) -> Self {
162 Self {
163 protected_indexes,
164 protected_attributes,
165 sealer,
166 }
167 }
168
169 pub fn prepare_record<R>(record: R) -> Result<Self, SealError>
170 where
171 R: Searchable + Identifiable,
172 {
173 let type_name = R::type_name();
174
175 let PrimaryKeyParts { pk, sk } = record
176 .get_primary_key()
177 .into_parts(&type_name, R::sort_key_prefix().as_deref());
178
179 let protected_indexes = R::protected_indexes();
180 let protected_attributes = R::protected_attributes();
181
182 let unsealed_indexes = protected_indexes
184 .iter()
185 .map(|(index_name, index_type)| {
186 record
187 .attribute_for_index(index_name, *index_type)
188 .and_then(|attr| {
189 R::index_by_name(index_name, *index_type)
190 .map(|index| (attr, index, index_name.clone(), *index_type))
191 })
192 .ok_or(SealError::MissingAttribute(index_name.to_string()))
193 })
194 .collect::<Result<Vec<_>, _>>()?;
195
196 let unsealed = record.into_unsealed();
197
198 let sealer = Sealer {
199 pk,
200 sk,
201
202 is_sk_encrypted: R::is_sk_encrypted(),
203 is_pk_encrypted: R::is_pk_encrypted(),
204
205 type_name,
206
207 unsealed_indexes,
208
209 unsealed,
210 };
211
212 Ok(PreparedRecord::new(
213 protected_indexes,
214 protected_attributes,
215 sealer,
216 ))
217 }
218
219 pub fn primary_key_parts(&self) -> PrimaryKeyParts {
220 PrimaryKeyParts {
221 pk: self.sealer.pk.clone(),
222 sk: self.sealer.sk.clone(),
223 }
224 }
225
226 pub fn type_name(&self) -> &str {
227 &self.sealer.type_name
228 }
229
230 pub fn protected_indexes(&self) -> &[(Cow<'static, str>, IndexType)] {
231 &self.protected_indexes
232 }
233}
234
235impl DynamoRecordPatch {
236 pub fn into_transact_write_items(
241 self,
242 table_name: &str,
243 ) -> Result<Vec<TransactWriteItem>, BuildError> {
244 let mut items = Vec::with_capacity(self.put_records.len() + self.delete_records.len());
245
246 for insert in self.put_records.into_iter() {
247 items.push(
248 TransactWriteItem::builder()
249 .put(
250 Put::builder()
251 .table_name(table_name)
252 .set_item(Some(insert))
253 .build()?,
254 )
255 .build(),
256 );
257 }
258
259 for PrimaryKeyParts { pk, sk } in self.delete_records.into_iter() {
260 items.push(
261 TransactWriteItem::builder()
262 .delete(
263 Delete::builder()
264 .table_name(table_name)
265 .key("pk", AttributeValue::S(pk))
266 .key("sk", AttributeValue::S(sk))
267 .build()?,
268 )
269 .build(),
270 );
271 }
272
273 Ok(items)
274 }
275}
276
277impl<D> EncryptedTable<D> {
278 pub fn query<S>(&self) -> QueryBuilder<S, &Self>
279 where
280 S: Searchable,
281 {
282 QueryBuilder::with_backend(self)
283 }
284
285 pub async fn unseal_all(
286 &self,
287 items: impl IntoIterator<Item = HashMap<String, AttributeValue>>,
288 spec: UnsealSpec<'_>,
289 ) -> Result<Vec<Unsealed>, DecryptError> {
290 let table_entries = SealedTableEntry::vec_from(items)?;
291 let results = SealedTableEntry::unseal_all(table_entries, spec, &self.cipher).await?;
292 Ok(results)
293 }
294
295 pub async fn unseal(
296 &self,
297 item: HashMap<String, AttributeValue>,
298 spec: UnsealSpec<'_>,
299 ) -> Result<Unsealed, DecryptError> {
300 let table_entry = SealedTableEntry::try_from(item)?;
301 let result = table_entry.unseal(spec, &self.cipher).await?;
302 Ok(result)
303 }
304
305 pub async fn decrypt_all<T>(
306 &self,
307 items: impl IntoIterator<Item = HashMap<String, AttributeValue>>,
308 ) -> Result<Vec<T>, DecryptError>
309 where
310 T: Decryptable + Identifiable,
311 {
312 let items = self
313 .unseal_all(items, UnsealSpec::new_for_decryptable::<T>())
314 .await?;
315
316 Ok(items
317 .into_iter()
318 .map(|x| x.into_value::<T>())
319 .collect::<Result<Vec<_>, _>>()?)
320 }
321
322 pub async fn decrypt<T>(&self, item: HashMap<String, AttributeValue>) -> Result<T, DecryptError>
323 where
324 T: Decryptable + Identifiable,
325 {
326 let uspec = UnsealSpec::new_for_decryptable::<T>();
327 let item = self.unseal(item, uspec).await?;
328
329 Ok(item.into_value()?)
330 }
331
332 pub async fn create_delete_patch(
333 &self,
334 delete: PreparedDelete,
335 ) -> Result<DynamoRecordPatch, DeleteError> {
336 let PrimaryKeyParts { pk, sk } = self.encrypt_primary_key_parts(delete.primary_key)?;
337
338 let delete_records = all_index_keys(&sk, delete.protected_indexes)
339 .into_iter()
340 .map(|x| Ok::<_, DeleteError>(b64_encode(hmac(&x, Some(pk.as_str()), &self.cipher)?)))
341 .chain([Ok(sk)])
342 .map(|sk| {
343 let sk = sk?;
344 Ok::<_, DeleteError>(PrimaryKeyParts { pk: pk.clone(), sk })
345 })
346 .collect::<Result<Vec<_>, _>>()?;
347
348 Ok(DynamoRecordPatch {
349 put_records: vec![],
350 delete_records,
351 })
352 }
353
354 pub async fn create_put_patch(
363 &self,
364 record: PreparedRecord,
365 index_predicate: impl FnMut(&AttributeName, &TableAttribute) -> bool,
367 ) -> Result<DynamoRecordPatch, PutError> {
368 let mut seen_sk = HashSet::new();
369
370 let PreparedRecord {
371 protected_attributes,
372 protected_indexes,
373 sealer,
374 } = record;
375
376 let sealed = sealer
377 .seal(protected_attributes, &self.cipher, DEFAULT_TERM_SIZE)
378 .await?;
379
380 let mut put_records = Vec::with_capacity(sealed.len());
381
382 let mut delete_records = vec![];
385
386 let PrimaryKeyParts { pk, sk } = sealed.primary_key();
387
388 let (root, index_entries) = sealed.into_table_entries(index_predicate);
389
390 seen_sk.insert(root.inner().sk.clone());
391 put_records.push(root.try_into()?);
392
393 for entry in index_entries.into_iter() {
394 seen_sk.insert(entry.inner().sk.clone());
395 put_records.push(entry.try_into()?);
396 }
397
398 for index_sk in all_index_keys(&sk, protected_indexes) {
399 let index_sk = b64_encode(hmac(&index_sk, Some(pk.as_str()), &self.cipher)?);
400
401 if seen_sk.contains(&index_sk) {
403 continue;
404 }
405
406 delete_records.push(PrimaryKeyParts {
407 pk: pk.clone(),
408 sk: index_sk,
409 });
410 }
411
412 Ok(DynamoRecordPatch {
413 put_records,
414 delete_records,
415 })
416 }
417
418 pub fn encrypt_primary_key_parts(
421 &self,
422 prepared_primary_key: PreparedPrimaryKey,
423 ) -> Result<PrimaryKeyParts, PrimaryKeyError> {
424 let PrimaryKeyParts { mut pk, mut sk } = prepared_primary_key.primary_key_parts;
425
426 if prepared_primary_key.is_pk_encrypted {
427 pk = b64_encode(hmac(&pk, None, &self.cipher)?);
428 }
429
430 if prepared_primary_key.is_sk_encrypted {
431 sk = b64_encode(hmac(&sk, Some(pk.as_str()), &self.cipher)?);
432 }
433
434 Ok(PrimaryKeyParts { pk, sk })
435 }
436}
437
438impl EncryptedTable<Dynamo> {
439 pub async fn init(
440 db: aws_sdk_dynamodb::Client,
441 table_name: impl Into<String>,
442 ) -> Result<Self, InitError> {
443 let table = EncryptedTable::init_headless().await?;
444
445 Ok(Self {
446 db: Dynamo {
447 table_name: table_name.into(),
448 db,
449 },
450 cipher: table.cipher,
451 })
452 }
453
454 pub async fn get<T>(&self, k: impl Into<T::PrimaryKey>) -> Result<Option<T>, GetError>
455 where
456 T: Decryptable + Identifiable,
457 {
458 let PrimaryKeyParts { pk, sk } =
459 self.encrypt_primary_key_parts(PreparedPrimaryKey::new::<T>(k))?;
460
461 let result = self
462 .db
463 .get_item()
464 .table_name(&self.db.table_name)
465 .key("pk", AttributeValue::S(pk))
466 .key("sk", AttributeValue::S(sk))
467 .send()
468 .await
469 .map_err(|e| GetError::Aws(format!("{e:?}")))?;
470
471 if let Some(item) = result.item {
472 Ok(Some(self.decrypt(item).await?))
473 } else {
474 Ok(None)
475 }
476 }
477
478 pub async fn delete<E: Searchable + Identifiable>(
479 &self,
480 k: impl Into<E::PrimaryKey>,
481 ) -> Result<(), DeleteError> {
482 let transact_items = self
483 .create_delete_patch(PreparedDelete::new::<E>(k))
484 .await?
485 .into_transact_write_items(&self.db.table_name)?;
486
487 for items in transact_items.chunks(100) {
489 self.db
490 .transact_write_items()
491 .set_transact_items(Some(items.to_vec()))
492 .send()
493 .await
494 .map_err(|e| DeleteError::Aws(format!("{e:?}")))?;
495 }
496
497 Ok(())
498 }
499
500 pub async fn put<T>(&self, record: T) -> Result<(), PutError>
501 where
502 T: Searchable + Identifiable,
503 {
504 let record = PreparedRecord::prepare_record(record)?;
505
506 let transact_items = self
507 .create_put_patch(
508 record,
509 |_, _| true,
511 )
512 .await?
513 .into_transact_write_items(&self.db.table_name)?;
514
515 for items in transact_items.chunks(100) {
517 self.db
518 .transact_write_items()
519 .set_transact_items(Some(items.to_vec()))
520 .send()
521 .await?;
522 }
523
524 Ok(())
525 }
526}