cipherstash-client 0.34.1-alpha.1

The official CipherStash SDK
Documentation
use super::{
    accumulator::{Accumulator, AccumulatorError},
    composable_plaintext::ComposablePlaintext,
    ComposableIndex,
};
use crate::{
    encryption::{
        text::{process_edge_ngram_raw, TokenFilter},
        EncryptionError, Plaintext,
    },
    zerokms::IndexKey,
};
use hmac::{Hmac, Mac};
use sha2::Sha256;

type HmacSha256 = Hmac<Sha256>;

pub struct PrefixIndexer<'k> {
    index_key: &'k IndexKey,
    token_filters: Vec<TokenFilter>,
    min_length: usize,
    max_length: usize,
}

impl<'k> PrefixIndexer<'k> {
    pub fn new(
        index_key: &'k IndexKey,
        token_filters: Vec<TokenFilter>,
        min_length: usize,
        max_length: usize,
    ) -> Self {
        Self {
            index_key,
            token_filters,
            min_length,
            max_length,
        }
    }

    pub fn create_hmac(&self) -> Result<HmacSha256, EncryptionError> {
        Ok(HmacSha256::new_from_slice(self.index_key.key())?)
    }

    pub fn encrypt_with_salt<S>(
        &self,
        plaintext: &Plaintext,
        salt: S,
    ) -> Result<Accumulator, EncryptionError>
    where
        S: AsRef<[u8]>,
    {
        match plaintext {
            Plaintext::Utf8Str(Some(value)) => {
                let tokens =
                    process_edge_ngram_raw(value.to_string(), self.min_length, self.max_length);

                let terms = self
                    .token_filters
                    .iter()
                    .fold(tokens, |tokens, filter| filter.process(tokens));

                let out = terms
                    .into_iter()
                    .map(|term| {
                        let mut mac = self.create_hmac()?;
                        mac.update(salt.as_ref());
                        mac.update(term.as_bytes());
                        Ok::<Vec<u8>, EncryptionError>(mac.finalize().into_bytes().to_vec())
                    })
                    .collect::<Result<Vec<_>, _>>()?;

                Ok(Accumulator::Terms(out))
            }

            // Empty case
            Plaintext::Utf8Str(None) => Ok(Accumulator::empty()),

            // Error case
            _ => Err(EncryptionError::IndexingError(format!(
                "{plaintext:?} is not supported by prefix indexes"
            ))),
        }
    }

    pub fn query_with_salt<S>(
        &self,
        plaintext: &Plaintext,
        salt: S,
    ) -> Result<Accumulator, EncryptionError>
    where
        S: AsRef<[u8]>,
    {
        match plaintext {
            Plaintext::Utf8Str(Some(value)) => {
                let tokens = vec![value.to_string()];

                let terms = self
                    .token_filters
                    .iter()
                    .fold(tokens, |tokens, filter| filter.process(tokens));

                let term = terms.first().ok_or_else(|| {
                    EncryptionError::IndexingError("Encryption terms were empty".to_string())
                })?;

                let mut mac = self.create_hmac()?;
                mac.update(salt.as_ref());
                mac.update(term.as_bytes());

                Ok(Accumulator::Term(mac.finalize().into_bytes().to_vec()))
            }

            // Empty case
            Plaintext::Utf8Str(None) => Ok(Accumulator::empty()),

            // Error case
            _ => Err(EncryptionError::IndexingError(format!(
                "{plaintext:?} is not supported by prefix indexes"
            ))),
        }
    }
}

#[derive(Debug)]
pub struct PrefixIndex {
    token_filters: Vec<TokenFilter>,
    min_length: usize,
    max_length: usize,
}

impl PrefixIndex {
    pub fn new(token_filters: Vec<TokenFilter>) -> Self {
        Self::new_with_opts(token_filters, 3, 10)
    }

    pub fn new_with_opts(
        token_filters: Vec<TokenFilter>,
        min_length: usize,
        max_length: usize,
    ) -> Self {
        Self {
            token_filters,
            min_length,
            max_length,
        }
    }
}

impl Default for PrefixIndex {
    /// Creates a new `PrefixIndex` with a single `TokenFilter::Downcase`, `min_length` of 3, and `max_length` of 10.
    /// This makes the index case-insensitive by default.
    fn default() -> Self {
        Self::new(vec![TokenFilter::Downcase])
    }
}

impl ComposableIndex for PrefixIndex {
    fn compose_index(
        &self,
        key: &IndexKey,
        plaintext: ComposablePlaintext,
        accumulator: Accumulator,
    ) -> Result<Accumulator, EncryptionError> {
        let indexer = PrefixIndexer::new(
            key,
            self.token_filters.to_vec(),
            self.min_length,
            self.max_length,
        );
        let plaintext: Plaintext = plaintext.try_into()?;

        match accumulator {
            Accumulator::Term(term) => indexer.encrypt_with_salt(&plaintext, term),
            Accumulator::Terms(terms) => {
                terms
                    .into_iter()
                    .try_fold(Accumulator::empty(), |acc, term| {
                        indexer
                            .encrypt_with_salt(&plaintext, term)
                            .map(|out| acc.combine(out))
                    })
            }
        }
    }

    fn compose_query(
        &self,
        key: &IndexKey,
        plaintext: ComposablePlaintext,
        accumulator: Accumulator,
    ) -> Result<Accumulator, EncryptionError> {
        let indexer = PrefixIndexer::new(
            key,
            self.token_filters.to_vec(),
            self.min_length,
            self.max_length,
        );
        let plaintext: Plaintext = plaintext.try_into()?;

        match accumulator {
            Accumulator::Term(term) => indexer.query_with_salt(&plaintext, term),
            _ => Err(AccumulatorError::MultipleTermsFound)?,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use chrono::{DateTime, NaiveDate};
    use rust_decimal::Decimal;

    fn prefixer(key: &IndexKey, min: usize, max: usize) -> PrefixIndexer<'_> {
        PrefixIndexer::new(key, vec![TokenFilter::Downcase], min, max)
    }

    #[test]
    fn test_encrypt_term() -> Result<(), Box<dyn std::error::Error>> {
        let index_key = IndexKey::from([1; 32]);
        let result =
            prefixer(&index_key, 2, 4).encrypt_with_salt(&Plaintext::new("Hello World"), [])?;
        assert_eq!(result.terms().len(), 3);

        Ok(())
    }

    #[test]
    fn test_encrypt_term_default() -> Result<(), Box<dyn std::error::Error>> {
        let index_key = IndexKey::from([1; 32]);
        let result_default = PrefixIndex::default().compose_index(
            &index_key,
            "Hello World".into(),
            Accumulator::from_salt("salt"),
        )?;

        let result = PrefixIndex::new(vec![TokenFilter::Downcase]).compose_index(
            &index_key,
            "Hello World".into(),
            Accumulator::from_salt("salt"),
        )?;

        let result_no_filters = PrefixIndex::new(vec![]).compose_index(
            &index_key,
            "Hello World".into(),
            Accumulator::from_salt("salt"),
        )?;

        let terms = result_default.terms();
        assert_eq!(terms.to_vec(), result.terms());
        assert_ne!(terms.to_vec(), result_no_filters.terms());

        Ok(())
    }

    #[test]
    fn test_encrypt_with_salt() -> Result<(), Box<dyn std::error::Error>> {
        let index_key = IndexKey::from([1; 32]);
        let result_no_salt =
            prefixer(&index_key, 2, 4).encrypt_with_salt(&Plaintext::new("Hello World"), [])?;
        let result_salt = prefixer(&index_key, 2, 4)
            .encrypt_with_salt(&Plaintext::new("Hello World"), "somesalt")?;

        // No term should be the same when a salt is used
        result_no_salt
            .terms()
            .into_iter()
            .zip(result_salt.terms())
            .for_each(|(no_salt, salt)| {
                assert_ne!(no_salt, salt);
            });

        Ok(())
    }

    #[test]
    fn test_query_single_word() -> Result<(), Box<dyn std::error::Error>> {
        let index_key = IndexKey::from([1; 32]);
        let result = prefixer(&index_key, 2, 4).query_with_salt(&Plaintext::new("Hello"), [])?;
        assert_eq!(result.exactly_one()?.term()?.len(), 32);

        Ok(())
    }

    #[test]
    fn test_common_prefix_single_word() -> Result<(), Box<dyn std::error::Error>> {
        let index_key = IndexKey::from([1; 32]);
        let prefixer = prefixer(&index_key, 3, 6);
        let data_1 = prefixer.encrypt_with_salt(&Plaintext::new("supervise"), [])?;
        let data_2 = prefixer.encrypt_with_salt(&Plaintext::new("superhero"), [])?;
        let query = prefixer.query_with_salt(&Plaintext::new("super"), [])?;
        let q = query.exactly_one()?.term()?;

        assert!(data_1.terms().contains(&q));
        assert!(data_2.terms().contains(&q));

        Ok(())
    }

    #[test]
    fn test_encrypt_none() -> Result<(), Box<dyn std::error::Error>> {
        let index_key = IndexKey::from([1; 32]);
        let indexer = prefixer(&index_key, 2, 4);

        assert!(indexer
            .encrypt_with_salt(&Plaintext::Utf8Str(None), [])?
            .terms()
            .is_empty());

        assert!(indexer
            .encrypt_with_salt(&Plaintext::Utf8Str(None), "somesalt")?
            .terms()
            .is_empty());

        assert!(indexer
            .query_with_salt(&Plaintext::Utf8Str(None), [])?
            .terms()
            .is_empty());

        assert!(indexer
            .query_with_salt(&Plaintext::Utf8Str(None), "somesalt")?
            .terms()
            .is_empty());

        Ok(())
    }

    #[test]
    fn test_unsupported_plaintext() -> Result<(), Box<dyn std::error::Error>> {
        fn test_unsupported(indexer: &PrefixIndexer, plaintext: Plaintext) {
            let result = indexer.encrypt_with_salt(&plaintext, []);
            assert!(
                matches!(result, Err(EncryptionError::IndexingError(_))),
                "Expected IndexingError for {plaintext:?}, got {result:?}",
            );
            assert!(
                result.is_err_and(|e| e.to_string().contains("is not supported by prefix indexes"))
            );

            let result = indexer.query_with_salt(&plaintext, []);
            assert!(
                matches!(result, Err(EncryptionError::IndexingError(_))),
                "Expected IndexingError for {plaintext:?}, got {result:?}",
            );
            assert!(
                result.is_err_and(|e| e.to_string().contains("is not supported by prefix indexes"))
            );
        }

        let index_key = IndexKey::from([1; 32]);
        let indexer = prefixer(&index_key, 2, 4);
        test_unsupported(&indexer, Plaintext::new(100i32));
        test_unsupported(&indexer, Plaintext::new(100i16));
        test_unsupported(&indexer, Plaintext::new(100u64));
        test_unsupported(&indexer, Plaintext::new(100i64));
        test_unsupported(&indexer, Plaintext::new(100.77f64));
        test_unsupported(&indexer, Plaintext::new(Decimal::new(202, 2)));
        test_unsupported(&indexer, Plaintext::new(true));
        test_unsupported(
            &indexer,
            Plaintext::new(DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00")?.to_utc()),
        );
        test_unsupported(
            &indexer,
            Plaintext::new(NaiveDate::from_ymd_opt(2002, 9, 6)),
        );
        test_unsupported(&indexer, Plaintext::new(serde_json::to_value("some json")?));

        Ok(())
    }
}