use crate::encryption::IndexTerm;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AccumulatorError {
#[error("Invalid term length")]
InvalidTermLength,
#[error("Empty accumulator")]
EmptyAccumulator,
#[error("Multiple terms found")]
MultipleTermsFound,
}
pub const MAX_TERM_LENGTH: usize = 32;
#[derive(Debug)]
pub enum Accumulator {
Term(Vec<u8>),
Terms(Vec<Vec<u8>>),
}
pub struct ExactlyOneAccumulator(Accumulator);
impl ExactlyOneAccumulator {
pub fn term(self) -> Result<Vec<u8>, AccumulatorError> {
match self.0 {
Accumulator::Term(term) => Ok(term),
Accumulator::Terms(terms) => {
if terms.is_empty() {
Err(AccumulatorError::EmptyAccumulator)
} else {
Err(AccumulatorError::MultipleTermsFound)
}
}
}
}
pub fn into_inner(self) -> Accumulator {
self.0
}
pub fn truncate(self, term_length: usize) -> Result<Self, AccumulatorError> {
Ok(Self(self.0.truncate(term_length)?))
}
}
impl Accumulator {
pub fn from_salt<S: AsRef<[u8]>>(salt: S) -> Self {
Self::Term(salt.as_ref().to_vec())
}
pub fn truncate(mut self, term_length: usize) -> Result<Self, AccumulatorError> {
if term_length > MAX_TERM_LENGTH {
Err(AccumulatorError::InvalidTermLength)?
}
match &mut self {
Accumulator::Term(term) => {
term.truncate(term_length);
}
Accumulator::Terms(terms) => {
for term in terms.iter_mut() {
term.truncate(term_length);
}
}
}
Ok(self)
}
pub fn terms(self) -> Vec<Vec<u8>> {
match self {
Accumulator::Term(term) => vec![term],
Accumulator::Terms(terms) => terms,
}
}
pub fn empty() -> Self {
Self::Terms(vec![])
}
pub fn exactly_one(self) -> Result<ExactlyOneAccumulator, AccumulatorError> {
match self {
Accumulator::Term(_) => Ok(ExactlyOneAccumulator(self)),
Accumulator::Terms(terms) => {
if terms.is_empty() {
Err(AccumulatorError::EmptyAccumulator)
} else {
Err(AccumulatorError::MultipleTermsFound)
}
}
}
}
pub fn combine(self, other: Self) -> Self {
match self {
Accumulator::Term(term) => Accumulator::Terms([vec![term], other.terms()].concat()),
Accumulator::Terms(terms) => Accumulator::Terms([terms, other.terms()].concat()),
}
}
}
impl From<Accumulator> for IndexTerm {
fn from(acc: Accumulator) -> Self {
match acc {
Accumulator::Term(term) => IndexTerm::Binary(term),
Accumulator::Terms(terms) => IndexTerm::BinaryVec(terms),
}
}
}
impl TryFrom<ExactlyOneAccumulator> for IndexTerm {
type Error = AccumulatorError;
fn try_from(acc: ExactlyOneAccumulator) -> Result<Self, Self::Error> {
Ok(IndexTerm::Binary(acc.term()?))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_combine_two_accumulators() {
let left = Accumulator::Term(vec![1, 2, 3]);
let right = Accumulator::Terms(vec![vec![4, 5, 6], vec![7, 8, 9]]);
assert_eq!(
left.combine(right).terms(),
vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]
);
let left = Accumulator::Terms(vec![vec![1, 2, 3], vec![4, 5, 6]]);
let right = Accumulator::Term(vec![7, 8, 9]);
assert_eq!(
left.combine(right).terms(),
vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]
);
}
#[test]
fn test_combine_two_accumulators_term() {
let left = Accumulator::Term(vec![1, 2, 3]);
let right = Accumulator::Term(vec![4, 5, 6]);
assert_eq!(
left.combine(right).terms(),
vec![vec![1, 2, 3], vec![4, 5, 6]]
);
}
#[test]
fn test_combine_two_accumulators_terms() {
let left = Accumulator::Terms(vec![vec![1, 2, 3], vec![4, 5, 6]]);
let right = Accumulator::Terms(vec![vec![7, 8, 9], vec![10, 11, 12]]);
assert_eq!(
left.combine(right).terms(),
vec![
vec![1, 2, 3],
vec![4, 5, 6],
vec![7, 8, 9],
vec![10, 11, 12]
]
);
}
#[test]
fn test_truncate() {
let acc = Accumulator::Terms(vec![
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
]);
let acc = acc.truncate(5).expect("should truncate");
assert_eq!(acc.terms(), vec![vec![1, 2, 3, 4, 5], vec![1, 2, 3, 4, 5]])
}
#[test]
fn test_truncate_different() {
let acc = Accumulator::Terms(vec![
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
vec![1, 2, 3, 4],
vec![1, 2],
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
]);
let acc = acc.truncate(5).expect("should truncate");
assert_eq!(
acc.terms(),
vec![
vec![1, 2, 3, 4, 5],
vec![1, 2, 3, 4],
vec![1, 2],
vec![1, 2, 3, 4, 5]
]
)
}
}