pub mod query;
mod table_entry;
pub use self::{
query::QueryBuilder,
table_entry::{TableAttribute, TableEntry, TryFromTableAttr},
};
use crate::{
crypto::*,
errors::*,
traits::{Decryptable, PrimaryKey, PrimaryKeyError, PrimaryKeyParts, Searchable},
Identifiable, IndexType,
};
use aws_sdk_dynamodb::types::{AttributeValue, Delete, Put, TransactWriteItem};
use cipherstash_client::{
config::{console_config::ConsoleConfig, zero_kms_config::ZeroKMSConfig},
credentials::{
auto_refresh::AutoRefresh,
service_credentials::{ServiceCredentials, ServiceToken},
Credentials,
},
encryption::{Encryption, Plaintext},
zero_kms::ZeroKMS,
};
use log::info;
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
ops::Deref,
};
pub struct Headless;
pub struct Dynamo {
pub(crate) db: aws_sdk_dynamodb::Client,
pub(crate) table_name: String,
}
impl Deref for Dynamo {
type Target = aws_sdk_dynamodb::Client;
fn deref(&self) -> &Self::Target {
&self.db
}
}
pub struct EncryptedTable<D = Dynamo> {
db: D,
cipher: Box<Encryption<AutoRefresh<ServiceCredentials>>>,
}
impl<D> EncryptedTable<D> {
pub fn cipher(&self) -> &Encryption<impl Credentials<Token = ServiceToken>> {
self.cipher.as_ref()
}
}
impl EncryptedTable<Headless> {
pub async fn init_headless() -> Result<Self, InitError> {
info!("Initializing...");
let console_config = ConsoleConfig::builder().with_env().build()?;
let zero_kms_config = ZeroKMSConfig::builder()
.decryption_log(true)
.with_env()
.console_config(&console_config)
.build_with_client_key()?;
let zero_kms_client = ZeroKMS::new_with_client_key(
&zero_kms_config.base_url(),
AutoRefresh::new(zero_kms_config.credentials()),
zero_kms_config.decryption_log_path().as_deref(),
zero_kms_config.client_key(),
);
info!("Fetching dataset config...");
let dataset_config = zero_kms_client.load_dataset_config().await?;
let cipher = Box::new(Encryption::new(
dataset_config.index_root_key,
zero_kms_client,
));
info!("Ready!");
Ok(Self {
db: Headless,
cipher,
})
}
}
pub struct DynamoRecordPatch {
pub put_records: Vec<HashMap<String, AttributeValue>>,
pub delete_records: Vec<PrimaryKeyParts>,
}
pub struct PreparedRecord {
protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
protected_attributes: Cow<'static, [Cow<'static, str>]>,
sealer: Sealer,
}
pub struct PreparedDelete {
primary_key: PreparedPrimaryKey,
protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
}
impl PreparedDelete {
pub fn new<S: Searchable>(k: impl Into<S::PrimaryKey>) -> Self {
Self::new_from_parts::<S>(
k.into()
.into_parts(&S::type_name(), S::sort_key_prefix().as_deref()),
)
}
pub fn new_from_parts<S: Searchable>(k: PrimaryKeyParts) -> Self {
let primary_key = PreparedPrimaryKey::new_from_parts::<S>(k);
let protected_indexes = S::protected_indexes();
Self {
primary_key,
protected_indexes,
}
}
pub fn prepared_primary_key(&self) -> PreparedPrimaryKey {
self.primary_key.clone()
}
}
impl PreparedRecord {
pub(crate) fn new(
protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
protected_attributes: Cow<'static, [Cow<'static, str>]>,
sealer: Sealer,
) -> Self {
Self {
protected_indexes,
protected_attributes,
sealer,
}
}
pub fn protected(&self) -> impl Iterator<Item = (&str, &Plaintext)> {
self.sealer
.unsealed
.protected()
.iter()
.map(|(key, (plaintext, _descriptor))| (key.as_str(), plaintext))
}
pub fn unprotected(&self) -> impl Iterator<Item = (&str, &TableAttribute)> {
self.sealer
.unsealed
.unprotected()
.iter()
.map(|(key, attr)| (key.as_str(), attr))
}
pub fn prepare_record<R>(record: R) -> Result<Self, SealError>
where
R: Searchable + Identifiable,
{
let type_name = R::type_name();
let PrimaryKeyParts { pk, sk } = record
.get_primary_key()
.into_parts(&type_name, R::sort_key_prefix().as_deref());
let protected_indexes = R::protected_indexes();
let protected_attributes = R::protected_attributes();
let unsealed_indexes = protected_indexes
.iter()
.map(|(index_name, index_type)| {
record
.attribute_for_index(index_name, *index_type)
.and_then(|attr| {
R::index_by_name(index_name, *index_type)
.map(|index| (attr, index, index_name.clone(), *index_type))
})
.ok_or(SealError::MissingAttribute(index_name.to_string()))
})
.collect::<Result<Vec<_>, _>>()?;
let unsealed = record.into_unsealed();
let sealer = Sealer {
pk,
sk,
is_sk_encrypted: R::is_sk_encrypted(),
is_pk_encrypted: R::is_pk_encrypted(),
type_name,
unsealed_indexes,
unsealed,
};
Ok(PreparedRecord::new(
protected_indexes,
protected_attributes,
sealer,
))
}
pub fn primary_key_parts(&self) -> PrimaryKeyParts {
PrimaryKeyParts {
pk: self.sealer.pk.clone(),
sk: self.sealer.sk.clone(),
}
}
pub fn type_name(&self) -> &str {
&self.sealer.type_name
}
}
impl DynamoRecordPatch {
pub fn into_transact_write_items(
self,
table_name: &str,
) -> Result<Vec<TransactWriteItem>, BuildError> {
let mut items = Vec::with_capacity(self.put_records.len() + self.delete_records.len());
for insert in self.put_records.into_iter() {
items.push(
TransactWriteItem::builder()
.put(
Put::builder()
.table_name(table_name)
.set_item(Some(insert))
.build()?,
)
.build(),
);
}
for PrimaryKeyParts { pk, sk } in self.delete_records.into_iter() {
items.push(
TransactWriteItem::builder()
.delete(
Delete::builder()
.table_name(table_name)
.key("pk", AttributeValue::S(pk))
.key("sk", AttributeValue::S(sk))
.build()?,
)
.build(),
);
}
Ok(items)
}
}
impl<D> EncryptedTable<D> {
pub fn query<S>(&self) -> QueryBuilder<S, &Self>
where
S: Searchable,
{
QueryBuilder::with_backend(self)
}
pub async fn unseal_all(
&self,
items: impl IntoIterator<Item = HashMap<String, AttributeValue>>,
spec: UnsealSpec<'_>,
) -> Result<Vec<Unsealed>, DecryptError> {
let table_entries = SealedTableEntry::vec_from(items)?;
let results = SealedTableEntry::unseal_all(table_entries, spec, &self.cipher).await?;
Ok(results)
}
pub async fn unseal(
&self,
item: HashMap<String, AttributeValue>,
spec: UnsealSpec<'_>,
) -> Result<Unsealed, DecryptError> {
let table_entry = SealedTableEntry::try_from(item)?;
let result = table_entry.unseal(spec, &self.cipher).await?;
Ok(result)
}
pub async fn decrypt_all<T: Decryptable>(
&self,
items: impl IntoIterator<Item = HashMap<String, AttributeValue>>,
) -> Result<Vec<T>, DecryptError> {
let items = self
.unseal_all(items, UnsealSpec::new_for_decryptable::<T>())
.await?;
Ok(items
.into_iter()
.map(|x| x.into_value::<T>())
.collect::<Result<Vec<_>, _>>()?)
}
pub async fn decrypt<T: Decryptable>(
&self,
item: HashMap<String, AttributeValue>,
) -> Result<T, DecryptError> {
let item = self
.unseal(item, UnsealSpec::new_for_decryptable::<T>())
.await?;
Ok(item.into_value()?)
}
pub async fn create_delete_patch(
&self,
delete: PreparedDelete,
) -> Result<DynamoRecordPatch, DeleteError> {
let PrimaryKeyParts { pk, sk } = self.encrypt_primary_key_parts(delete.primary_key)?;
let delete_records = all_index_keys(&sk, delete.protected_indexes)
.into_iter()
.map(|x| Ok::<_, DeleteError>(b64_encode(hmac(&x, Some(pk.as_str()), &self.cipher)?)))
.chain([Ok(sk)])
.map(|sk| {
let sk = sk?;
Ok::<_, DeleteError>(PrimaryKeyParts { pk: pk.clone(), sk })
})
.collect::<Result<Vec<_>, _>>()?;
Ok(DynamoRecordPatch {
put_records: vec![],
delete_records,
})
}
pub async fn create_put_patch(
&self,
record: PreparedRecord,
index_predicate: impl FnMut(&str, &TableAttribute) -> bool,
) -> Result<DynamoRecordPatch, PutError> {
let mut seen_sk = HashSet::new();
let PreparedRecord {
protected_attributes,
protected_indexes,
sealer,
} = record;
let sealed = sealer.seal(protected_attributes, &self.cipher, 12).await?;
let mut put_records = Vec::with_capacity(sealed.len());
let mut delete_records = vec![];
let PrimaryKeyParts { pk, sk } = sealed.primary_key();
let (root, index_entries) = sealed.into_table_entries(index_predicate);
seen_sk.insert(root.inner().sk.clone());
put_records.push(root.try_into()?);
for entry in index_entries.into_iter() {
seen_sk.insert(entry.inner().sk.clone());
put_records.push(entry.try_into()?);
}
for index_sk in all_index_keys(&sk, protected_indexes) {
let index_sk = b64_encode(hmac(&index_sk, Some(pk.as_str()), &self.cipher)?);
if seen_sk.contains(&index_sk) {
continue;
}
delete_records.push(PrimaryKeyParts {
pk: pk.clone(),
sk: index_sk,
});
}
Ok(DynamoRecordPatch {
put_records,
delete_records,
})
}
pub fn encrypt_primary_key_parts(
&self,
prepared_primary_key: PreparedPrimaryKey,
) -> Result<PrimaryKeyParts, PrimaryKeyError> {
let PrimaryKeyParts { mut pk, mut sk } = prepared_primary_key.primary_key_parts;
if prepared_primary_key.is_pk_encrypted {
pk = b64_encode(hmac(&pk, None, &self.cipher)?);
}
if prepared_primary_key.is_sk_encrypted {
sk = b64_encode(hmac(&sk, Some(pk.as_str()), &self.cipher)?);
}
Ok(PrimaryKeyParts { pk, sk })
}
}
impl EncryptedTable<Dynamo> {
pub async fn init(
db: aws_sdk_dynamodb::Client,
table_name: impl Into<String>,
) -> Result<Self, InitError> {
let table = EncryptedTable::init_headless().await?;
Ok(Self {
db: Dynamo {
table_name: table_name.into(),
db,
},
cipher: table.cipher,
})
}
pub async fn get<T>(&self, k: impl Into<T::PrimaryKey>) -> Result<Option<T>, GetError>
where
T: Decryptable + Identifiable,
{
let PrimaryKeyParts { pk, sk } =
self.encrypt_primary_key_parts(PreparedPrimaryKey::new::<T>(k))?;
let result = self
.db
.get_item()
.table_name(&self.db.table_name)
.key("pk", AttributeValue::S(pk))
.key("sk", AttributeValue::S(sk))
.send()
.await
.map_err(|e| GetError::Aws(format!("{e:?}")))?;
if let Some(item) = result.item {
Ok(Some(self.decrypt(item).await?))
} else {
Ok(None)
}
}
pub async fn delete<E: Searchable + Identifiable>(
&self,
k: impl Into<E::PrimaryKey>,
) -> Result<(), DeleteError> {
let transact_items = self
.create_delete_patch(PreparedDelete::new::<E>(k))
.await?
.into_transact_write_items(&self.db.table_name)?;
for items in transact_items.chunks(100) {
self.db
.transact_write_items()
.set_transact_items(Some(items.to_vec()))
.send()
.await
.map_err(|e| DeleteError::Aws(format!("{e:?}")))?;
}
Ok(())
}
pub async fn put<T>(&self, record: T) -> Result<(), PutError>
where
T: Searchable + Identifiable,
{
let record = PreparedRecord::prepare_record(record)?;
let transact_items = self
.create_put_patch(
record,
|_, _| true,
)
.await?
.into_transact_write_items(&self.db.table_name)?;
for items in transact_items.chunks(100) {
self.db
.transact_write_items()
.set_transact_items(Some(items.to_vec()))
.send()
.await
.map_err(|e| PutError::Aws(format!("{e:?}")))?;
}
Ok(())
}
}