cipherstash-client 0.34.1-alpha.1

The official CipherStash SDK
Documentation
use crate::encryption::IndexTerm;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum AccumulatorError {
    #[error("Invalid term length")]
    InvalidTermLength,

    #[error("Empty accumulator")]
    EmptyAccumulator,

    #[error("Multiple terms found")]
    MultipleTermsFound,
}

pub const MAX_TERM_LENGTH: usize = 32;

/// The [`Accumulator`] type is used to represent one or many binary "terms" that can be collected
/// from a compound index.
#[derive(Debug)]
pub enum Accumulator {
    Term(Vec<u8>),
    Terms(Vec<Vec<u8>>),
}

/// [`Accumulator`] that will only return if it has exactly one index term
pub struct ExactlyOneAccumulator(Accumulator);

impl ExactlyOneAccumulator {
    /// Get the binary term from the [`Accumulator`]
    ///
    /// If there are no terms or more than one this method will return an error
    pub fn term(self) -> Result<Vec<u8>, AccumulatorError> {
        match self.0 {
            Accumulator::Term(term) => Ok(term),
            Accumulator::Terms(terms) => {
                if terms.is_empty() {
                    Err(AccumulatorError::EmptyAccumulator)
                } else {
                    Err(AccumulatorError::MultipleTermsFound)
                }
            }
        }
    }

    pub fn into_inner(self) -> Accumulator {
        self.0
    }

    /// Truncate the terms in the accumulator to a specified length
    ///
    /// The length must be less than or equal to [`MAX_TERM_LENGTH`] otherwise this method will return an error
    pub fn truncate(self, term_length: usize) -> Result<Self, AccumulatorError> {
        Ok(Self(self.0.truncate(term_length)?))
    }
}

impl Accumulator {
    pub fn from_salt<S: AsRef<[u8]>>(salt: S) -> Self {
        Self::Term(salt.as_ref().to_vec())
    }

    /// Truncate the terms in the accumulator to a specified length
    ///
    /// The length must be less than or equal to [`MAX_TERM_LENGTH`] otherwise this method will return an error
    pub fn truncate(mut self, term_length: usize) -> Result<Self, AccumulatorError> {
        if term_length > MAX_TERM_LENGTH {
            Err(AccumulatorError::InvalidTermLength)?
        }

        match &mut self {
            Accumulator::Term(term) => {
                term.truncate(term_length);
            }

            Accumulator::Terms(terms) => {
                for term in terms.iter_mut() {
                    term.truncate(term_length);
                }
            }
        }

        Ok(self)
    }

    /// Return a vector of the binary terms in the [`Accumulator`]
    pub fn terms(self) -> Vec<Vec<u8>> {
        match self {
            Accumulator::Term(term) => vec![term],
            Accumulator::Terms(terms) => terms,
        }
    }

    pub fn empty() -> Self {
        Self::Terms(vec![])
    }

    /// Return the accumulator term, only if there is exactly one term in the accumulator
    pub fn exactly_one(self) -> Result<ExactlyOneAccumulator, AccumulatorError> {
        match self {
            Accumulator::Term(_) => Ok(ExactlyOneAccumulator(self)),
            Accumulator::Terms(terms) => {
                if terms.is_empty() {
                    Err(AccumulatorError::EmptyAccumulator)
                } else {
                    Err(AccumulatorError::MultipleTermsFound)
                }
            }
        }
    }

    /// Combine this accumulator with another
    pub fn combine(self, other: Self) -> Self {
        match self {
            Accumulator::Term(term) => Accumulator::Terms([vec![term], other.terms()].concat()),
            Accumulator::Terms(terms) => Accumulator::Terms([terms, other.terms()].concat()),
        }
    }
}

impl From<Accumulator> for IndexTerm {
    fn from(acc: Accumulator) -> Self {
        match acc {
            Accumulator::Term(term) => IndexTerm::Binary(term),
            Accumulator::Terms(terms) => IndexTerm::BinaryVec(terms),
        }
    }
}

impl TryFrom<ExactlyOneAccumulator> for IndexTerm {
    type Error = AccumulatorError;

    fn try_from(acc: ExactlyOneAccumulator) -> Result<Self, Self::Error> {
        Ok(IndexTerm::Binary(acc.term()?))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_combine_two_accumulators() {
        let left = Accumulator::Term(vec![1, 2, 3]);
        let right = Accumulator::Terms(vec![vec![4, 5, 6], vec![7, 8, 9]]);

        assert_eq!(
            left.combine(right).terms(),
            vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]
        );

        let left = Accumulator::Terms(vec![vec![1, 2, 3], vec![4, 5, 6]]);
        let right = Accumulator::Term(vec![7, 8, 9]);

        assert_eq!(
            left.combine(right).terms(),
            vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]
        );
    }

    #[test]
    fn test_combine_two_accumulators_term() {
        let left = Accumulator::Term(vec![1, 2, 3]);
        let right = Accumulator::Term(vec![4, 5, 6]);

        assert_eq!(
            left.combine(right).terms(),
            vec![vec![1, 2, 3], vec![4, 5, 6]]
        );
    }

    #[test]
    fn test_combine_two_accumulators_terms() {
        let left = Accumulator::Terms(vec![vec![1, 2, 3], vec![4, 5, 6]]);
        let right = Accumulator::Terms(vec![vec![7, 8, 9], vec![10, 11, 12]]);

        assert_eq!(
            left.combine(right).terms(),
            vec![
                vec![1, 2, 3],
                vec![4, 5, 6],
                vec![7, 8, 9],
                vec![10, 11, 12]
            ]
        );
    }

    #[test]
    fn test_truncate() {
        let acc = Accumulator::Terms(vec![
            vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        ]);

        let acc = acc.truncate(5).expect("should truncate");

        assert_eq!(acc.terms(), vec![vec![1, 2, 3, 4, 5], vec![1, 2, 3, 4, 5]])
    }

    #[test]
    fn test_truncate_different() {
        let acc = Accumulator::Terms(vec![
            vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            vec![1, 2, 3, 4],
            vec![1, 2],
            vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        ]);

        let acc = acc.truncate(5).expect("should truncate");

        assert_eq!(
            acc.terms(),
            vec![
                vec![1, 2, 3, 4, 5],
                vec![1, 2, 3, 4],
                vec![1, 2],
                vec![1, 2, 3, 4, 5]
            ]
        )
    }
}