use crate::zerokms::{
self, DataKeyWithTag, Decryptable, EncryptPayload, EncryptedRecord, Error, GenerateKeyPayload,
IndexKey, RecordDecryptError, ZeroKMSWithClientKey,
};
use stack_auth::AuthStrategy;
use std::{borrow::Cow, fmt::Debug, sync::Arc};
use uuid::Uuid;
use zeroize::{Zeroize, ZeroizeOnDrop};
use zerokms_protocol::{IdentifiedBy, UnverifiedContext};
use crate::credentials::ServiceToken;
use super::{
compound_indexer::{Accumulator, ComposableIndex, ComposablePlaintext, CompoundIndex},
EncryptionError, IndexTerm,
};
const DEFAULT_INDEX_TERM_SIZE: usize = 12;
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct ScopedCipher<C> {
#[zeroize(skip)]
client: Arc<ZeroKMSWithClientKey<C>>,
#[zeroize(skip)]
keyset_id: Uuid,
index_key: IndexKey,
}
impl<C> Debug for ScopedCipher<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScopedCipher")
.field("keyset_id", &self.keyset_id)
.finish()
}
}
impl<C> ScopedCipher<C> {
pub async fn init_default(client: Arc<ZeroKMSWithClientKey<C>>) -> Result<Self, zerokms::Error>
where
C: Send + Sync + 'static,
for<'b> &'b C: AuthStrategy,
{
Self::init(client, None).await
}
pub async fn init(
client: Arc<ZeroKMSWithClientKey<C>>,
keyset_id: Option<IdentifiedBy>,
) -> Result<Self, zerokms::Error>
where
C: Send + Sync + 'static,
for<'b> &'b C: AuthStrategy,
{
client
.load_keyset_index_key(keyset_id)
.await
.map(|(keyset_id, index_key)| Self {
client,
keyset_id,
index_key,
})
}
pub fn mac<const N: usize>(&self, value: &str, prefix: Option<&str>) -> [u8; N] {
let mut hasher = blake3::Hasher::new_keyed(self.index_key.key());
if let Some(prefix) = prefix {
hasher.update(prefix.as_bytes());
}
hasher.update(value.as_bytes());
let mut result = [0; N];
hasher.finalize_xof().fill(&mut result);
hasher.zeroize();
result
}
pub fn compound_index<I>(
&self,
index: I,
plaintext: ComposablePlaintext,
info: String,
) -> Result<IndexTerm, EncryptionError>
where
I: ComposableIndex + Send,
{
let index = CompoundIndex::new(index);
let accumulator = Accumulator::from_salt(info);
let term = index
.compose_index(&self.index_key, plaintext, accumulator)?
.truncate(DEFAULT_INDEX_TERM_SIZE)?;
Ok(term.into())
}
pub fn compound_query<I>(
&self,
index: I,
input: ComposablePlaintext,
info: String,
) -> Result<IndexTerm, EncryptionError>
where
I: ComposableIndex + Send,
{
let index = CompoundIndex::new(index);
let accumulator = Accumulator::from_salt(info);
let term = index
.compose_query(&self.index_key, input, accumulator)?
.exactly_one()?
.truncate(DEFAULT_INDEX_TERM_SIZE)?;
Ok(term.try_into()?)
}
pub async fn encrypt(
&self,
payloads: impl IntoIterator<Item = EncryptPayload<'_>>,
) -> Result<Vec<EncryptedRecord>, Error>
where
C: Send + Sync + 'static,
for<'b> &'b C: AuthStrategy,
{
self.client.encrypt(payloads, Some(self.keyset_id)).await
}
pub async fn decrypt<'a, D>(
&self,
payloads: impl IntoIterator<Item = D>,
opts: &DecryptOptions<'a>,
) -> Result<Vec<Vec<u8>>, Error>
where
D: Decryptable,
C: Send + Sync + 'static,
for<'b> &'b C: AuthStrategy,
{
self.client
.decrypt(
payloads,
opts.keyset_id.or(Some(self.keyset_id)),
opts.service_token.clone(),
opts.unverified_context.as_deref(),
)
.await
}
pub async fn decrypt_fallible<'a, D>(
&self,
payloads: impl IntoIterator<Item = D>,
opts: &DecryptOptions<'a>,
) -> Result<Vec<Result<Vec<u8>, RecordDecryptError>>, Error>
where
D: Decryptable,
C: Send + Sync + 'static,
for<'b> &'b C: AuthStrategy,
{
let unverified_context = opts.unverified_context.as_deref().map(Cow::Borrowed);
let service_token = opts.service_token.as_deref().map(Cow::Borrowed);
self.client
.decrypt_fallible(payloads, service_token, unverified_context)
.await
}
pub fn index_key(&self) -> &IndexKey {
&self.index_key
}
pub fn keyset_id(&self) -> Uuid {
self.keyset_id
}
#[allow(dead_code)]
pub(crate) async fn generate_data_keys<'a>(
&self,
payloads: impl IntoIterator<Item = GenerateKeyPayload<'_>>,
service_token: Option<Cow<'a, ServiceToken>>,
unverified_context: Option<Cow<'a, UnverifiedContext>>,
) -> Result<Vec<DataKeyWithTag>, Error>
where
C: Send + Sync + 'static,
for<'b> &'b C: AuthStrategy,
{
self.client
.generate_data_keys(
payloads,
Some(self.keyset_id),
service_token,
unverified_context,
)
.await
}
}
#[derive(Debug, Default)]
pub struct DecryptOptions<'a> {
pub keyset_id: Option<Uuid>,
pub service_token: Option<Cow<'a, ServiceToken>>,
pub unverified_context: Option<Cow<'a, UnverifiedContext>>,
}