cipherstash-client 0.34.1-alpha.1

The official CipherStash SDK
Documentation
use crate::zerokms::{
    self, DataKeyWithTag, Decryptable, EncryptPayload, EncryptedRecord, Error, GenerateKeyPayload,
    IndexKey, RecordDecryptError, ZeroKMSWithClientKey,
};
use stack_auth::AuthStrategy;
use std::{borrow::Cow, fmt::Debug, sync::Arc};
use uuid::Uuid;
use zeroize::{Zeroize, ZeroizeOnDrop};
use zerokms_protocol::{IdentifiedBy, UnverifiedContext};

use crate::credentials::ServiceToken;

use super::{
    compound_indexer::{Accumulator, ComposableIndex, ComposablePlaintext, CompoundIndex},
    EncryptionError, IndexTerm,
};

const DEFAULT_INDEX_TERM_SIZE: usize = 12;

/// A Scoped Cipher is one which has been initialized for a specific keyset.
/// It can be used *only* to encrypt and decrypt data for that keyset.
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct ScopedCipher<C> {
    #[zeroize(skip)]
    client: Arc<ZeroKMSWithClientKey<C>>,
    #[zeroize(skip)]
    keyset_id: Uuid,
    index_key: IndexKey,
}

impl<C> Debug for ScopedCipher<C> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ScopedCipher")
            .field("keyset_id", &self.keyset_id)
            .finish()
    }
}

impl<C> ScopedCipher<C> {
    /// Initialize a new ScopedCipher for the default keyset_id.
    pub async fn init_default(client: Arc<ZeroKMSWithClientKey<C>>) -> Result<Self, zerokms::Error>
    where
        C: Send + Sync + 'static,
        for<'b> &'b C: AuthStrategy,
    {
        Self::init(client, None).await
    }

    /// Initialize a new ScopedCipher for the given keyset_id.
    /// If the keyset_id is None, the ScopedCipher will be initialized with the default keyset_id for the client.
    pub async fn init(
        client: Arc<ZeroKMSWithClientKey<C>>,
        keyset_id: Option<IdentifiedBy>,
    ) -> Result<Self, zerokms::Error>
    where
        C: Send + Sync + 'static,
        for<'b> &'b C: AuthStrategy,
    {
        client
            .load_keyset_index_key(keyset_id)
            .await
            .map(|(keyset_id, index_key)| Self {
                client,
                keyset_id,
                index_key,
            })
    }

    /// This value is used for term index keys and "encrypted" partition / sort keys
    pub fn mac<const N: usize>(&self, value: &str, prefix: Option<&str>) -> [u8; N] {
        // TODO: Make sure this uses an unambiguous encoding scheme
        // see https://linear.app/cipherstash/issue/BUG-95/wip-ensure-unambiguous-coding-for-mac-prefixes
        let mut hasher = blake3::Hasher::new_keyed(self.index_key.key());
        if let Some(prefix) = prefix {
            hasher.update(prefix.as_bytes());
        }
        hasher.update(value.as_bytes());
        let mut result = [0; N];
        hasher.finalize_xof().fill(&mut result);
        hasher.zeroize();
        result
    }

    pub fn compound_index<I>(
        &self,
        index: I,
        plaintext: ComposablePlaintext,
        info: String,
    ) -> Result<IndexTerm, EncryptionError>
    where
        I: ComposableIndex + Send,
    {
        let index = CompoundIndex::new(index);
        let accumulator = Accumulator::from_salt(info);

        let term = index
            .compose_index(&self.index_key, plaintext, accumulator)?
            .truncate(DEFAULT_INDEX_TERM_SIZE)?;

        Ok(term.into())
    }

    pub fn compound_query<I>(
        &self,
        index: I,
        input: ComposablePlaintext,
        info: String,
    ) -> Result<IndexTerm, EncryptionError>
    where
        I: ComposableIndex + Send,
    {
        let index = CompoundIndex::new(index);
        let accumulator = Accumulator::from_salt(info);

        let term = index
            .compose_query(&self.index_key, input, accumulator)?
            .exactly_one()?
            .truncate(DEFAULT_INDEX_TERM_SIZE)?;

        Ok(term.try_into()?)
    }

    /// Encrypt a stream of [`EncryptPayload`] and return them as an [`EncryptedRecord`].
    /// This function wraps the [`ZeroKMSWithClientKey::encrypt`] function but with the keyset_id set.
    pub async fn encrypt(
        &self,
        payloads: impl IntoIterator<Item = EncryptPayload<'_>>,
    ) -> Result<Vec<EncryptedRecord>, Error>
    where
        C: Send + Sync + 'static,
        for<'b> &'b C: AuthStrategy,
    {
        self.client.encrypt(payloads, Some(self.keyset_id)).await
    }

    /// Decrypt a stream of encrypted values (of type `D` where `D: Decryptable`) and return the raw decrypted binary
    /// blob.
    pub async fn decrypt<'a, D>(
        &self,
        payloads: impl IntoIterator<Item = D>,
        opts: &DecryptOptions<'a>,
    ) -> Result<Vec<Vec<u8>>, Error>
    where
        D: Decryptable,
        C: Send + Sync + 'static,
        for<'b> &'b C: AuthStrategy,
    {
        self.client
            .decrypt(
                payloads,
                opts.keyset_id.or(Some(self.keyset_id)),
                opts.service_token.clone(),
                opts.unverified_context.as_deref(),
            )
            .await
    }

    pub async fn decrypt_fallible<'a, D>(
        &self,
        payloads: impl IntoIterator<Item = D>,
        opts: &DecryptOptions<'a>,
    ) -> Result<Vec<Result<Vec<u8>, RecordDecryptError>>, Error>
    where
        D: Decryptable,
        C: Send + Sync + 'static,
        for<'b> &'b C: AuthStrategy,
    {
        let unverified_context = opts.unverified_context.as_deref().map(Cow::Borrowed);
        let service_token = opts.service_token.as_deref().map(Cow::Borrowed);
        self.client
            .decrypt_fallible(payloads, service_token, unverified_context)
            .await
    }

    // TODO: Dan to make this pub(crate) again
    pub fn index_key(&self) -> &IndexKey {
        &self.index_key
    }

    pub fn keyset_id(&self) -> Uuid {
        self.keyset_id
    }

    /// Generate data keys for a stream of [`GenerateKeyPayload`] and return them as a [`DataKeyWithTag`].
    /// Scoped to the keyset_id of this [`ScopedCipher`].
    #[allow(dead_code)]
    pub(crate) async fn generate_data_keys<'a>(
        &self,
        payloads: impl IntoIterator<Item = GenerateKeyPayload<'_>>,
        service_token: Option<Cow<'a, ServiceToken>>,
        unverified_context: Option<Cow<'a, UnverifiedContext>>,
    ) -> Result<Vec<DataKeyWithTag>, Error>
    where
        C: Send + Sync + 'static,
        for<'b> &'b C: AuthStrategy,
    {
        self.client
            .generate_data_keys(
                payloads,
                Some(self.keyset_id),
                service_token,
                unverified_context,
            )
            .await
    }
}

#[derive(Debug, Default)]
pub struct DecryptOptions<'a> {
    pub keyset_id: Option<Uuid>,
    pub service_token: Option<Cow<'a, ServiceToken>>,
    pub unverified_context: Option<Cow<'a, UnverifiedContext>>,
}