use super::entry::Entry;
use crate::{
constant_time::verify_slices_are_equal,
error::{MacVerificationError, TruncationError},
};
use alloc::{sync::Arc, vec::Vec};
#[derive(Clone, Debug)]
pub struct Tag {
entries: Arc<Vec<Entry>>,
primary_idx: usize,
pub(super) primary_tag: Arc<Vec<u8>>,
primary_tag_header_len: usize,
truncate_to: Option<usize>,
omit_header: bool,
}
impl AsRef<Tag> for Tag {
fn as_ref(&self) -> &Tag {
self
}
}
impl Tag {
pub fn remove_truncation(&self) -> Self {
Self {
omit_header: self.omit_header,
entries: self.entries.clone(),
truncate_to: None,
primary_idx: self.primary_idx,
primary_tag_header_len: self.primary_tag_header_len,
primary_tag: self.primary_tag.clone(),
}
}
pub fn truncate_to(&self, len: usize) -> Result<Self, TruncationError> {
if len == 0 {
return Ok(self.remove_truncation());
}
if len < 10 {
return Err(TruncationError::MinLengthNotMet);
}
if !self.omit_header && len < 14 {
return Err(TruncationError::MinLengthNotMet);
}
if len > self.primary_tag.len() {
return Err(TruncationError::LengthExceeded);
}
Ok(Self {
omit_header: self.omit_header,
truncate_to: Some(len),
entries: self.entries.clone(),
primary_idx: self.primary_idx,
primary_tag: self.primary_tag.clone(),
primary_tag_header_len: self.primary_tag_header_len,
})
}
pub fn omit_header(&self) -> Result<Self, TruncationError> {
if let Some(truncation) = self.truncate_to {
if truncation < 8 {
return Err(TruncationError::MinLengthNotMet);
}
}
Ok(Self {
omit_header: true,
entries: self.entries.clone(),
truncate_to: self.truncate_to,
primary_idx: self.primary_idx,
primary_tag: self.primary_tag.clone(),
primary_tag_header_len: self.primary_tag_header_len,
})
}
pub fn include_header(&self) -> Self {
Self {
omit_header: false,
entries: self.entries.clone(),
truncate_to: self.truncate_to,
primary_idx: self.primary_idx,
primary_tag: self.primary_tag.clone(),
primary_tag_header_len: self.primary_tag_header_len,
}
}
pub fn as_bytes(&self) -> &[u8] {
let slice = if self.omit_header {
&self.primary_tag[self.primary_tag_header_len..]
} else {
&self.primary_tag
};
if let Some(truncate_to) = self.truncate_to {
&slice[..truncate_to]
} else {
slice
}
}
pub(super) fn new(entries_iter: impl Iterator<Item = Entry>) -> Self {
let mut primary: Option<usize> = None;
let mut entries = Vec::with_capacity(entries_iter.size_hint().1.unwrap_or(1));
for (i, entry) in entries_iter.enumerate() {
if entry.is_primary() {
primary = Some(i);
}
entries.push(entry);
}
let primary_idx = primary.unwrap_or(entries.len() - 1);
let primary = &entries[primary_idx];
let primary_tag = Arc::new([primary.header(), primary.output_bytes()].concat());
let primary_tag_header_len = primary.header().len();
Self {
entries: Arc::new(entries),
primary_idx,
primary_tag,
primary_tag_header_len,
truncate_to: None,
omit_header: false,
}
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.as_bytes().len()
}
fn eq_slice(&self, other: &[u8]) -> Result<(), MacVerificationError> {
if other.len() == self.primary_tag.len()
&& verify_slices_are_equal(self.primary_tag.as_ref(), other).is_ok()
{
return Ok(());
}
for entry in self.entries.iter() {
if entry.verify(other, self.truncate_to).is_ok() {
return Ok(());
}
}
Err(MacVerificationError)
}
fn eq_tag(&self, other: &Tag) -> Result<(), MacVerificationError> {
#[cfg(feature = "rayon")]
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
if self.entries.len() > 1 {
if other.entries.len() > 1 {
#[cfg(feature = "rayon")]
let result = self
.entries
.par_iter()
.find_any(|entry| {
other
.entries
.par_iter()
.find_any(|other_entry| {
entry
.verify(other_entry.output_bytes(), self.truncate_to)
.is_ok()
})
.is_some()
})
.map(|_| ())
.ok_or(MacVerificationError);
#[cfg(not(feature = "rayon"))]
let result = self
.entries
.iter()
.find(|entry| {
other.entries.iter().any(|other_entry| {
entry
.verify(other_entry.output_bytes(), self.truncate_to)
.is_ok()
})
})
.map(|_| ())
.ok_or(MacVerificationError);
result
} else {
#[cfg(feature = "rayon")]
let result = self
.entries
.par_iter()
.find_any(|entry| {
entry
.verify(other.primary_tag.as_ref(), self.truncate_to)
.is_ok()
})
.map(|_| ())
.ok_or(MacVerificationError);
#[cfg(not(feature = "rayon"))]
let result = self
.entries
.iter()
.find(|entry| {
entry
.verify(other.primary_tag.as_ref(), self.truncate_to)
.is_ok()
})
.map(|_| ())
.ok_or(MacVerificationError);
result
}
} else if other.entries.len() > 1 {
other.eq_tag(self)
} else {
other.eq_slice(self.primary_tag.as_ref())
}
}
}
impl AsRef<[u8]> for Tag {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl PartialEq<[u8]> for Tag {
fn eq(&self, other: &[u8]) -> bool {
self.eq_slice(other).is_ok()
}
}
impl PartialEq<&[u8]> for Tag {
fn eq(&self, other: &&[u8]) -> bool {
self.eq_slice(other).is_ok()
}
}
impl PartialEq<Vec<u8>> for Tag {
fn eq(&self, other: &Vec<u8>) -> bool {
self.eq_slice(other.as_slice()).is_ok()
}
}
impl PartialEq<Tag> for Tag {
fn eq(&self, other: &Self) -> bool {
self.eq_tag(other).is_ok()
}
}
impl PartialEq<&Tag> for Tag {
fn eq(&self, other: &&Self) -> bool {
self.eq_tag(other).is_ok()
}
}
impl PartialEq<Tag> for &Tag {
fn eq(&self, other: &Tag) -> bool {
other.eq_tag(self).is_ok()
}
}
impl PartialEq<Tag> for [u8] {
fn eq(&self, other: &Tag) -> bool {
other.eq_slice(self).is_ok()
}
}
impl Eq for Tag {}
#[cfg(test)]
mod tests {
#[cfg(feature = "blake3")]
#[test]
fn test_verify_bake3() {
use crate::mac::output::Output;
use crate::SystemRng;
use super::*;
let rng = SystemRng::new();
let mut hash_arr = [0; 32];
let id = crate::keyring::gen_id(&rng);
let id_bytes = id.to_be_bytes();
rng.fill(&mut hash_arr).unwrap();
let hash = blake3::Hash::from(hash_arr);
let output = Output::Blake3(crate::mac::output::Blake3Output::from(hash));
let entry = Entry::new(true, id_bytes.to_vec(), output.clone());
let tag = Tag::new(core::iter::once(entry));
let other = tag.clone();
assert_eq!(tag, other);
let output_1 = output;
let entry_1 = Entry::new(false, id_bytes.to_vec(), output_1);
let id_2 = crate::keyring::gen_id(&rng);
let id_bytes_2 = id_2.to_be_bytes();
let mut hash_arr_2 = [0; 32];
rng.fill(&mut hash_arr_2).unwrap();
let hash_2 = blake3::Hash::from(hash_arr_2);
let output_2 = Output::Blake3(crate::mac::output::Blake3Output::from(hash_2));
let entry_2 = Entry::new(true, id_bytes_2.to_vec(), output_2.clone());
let tag_2 = Tag::new([entry_1.clone(), entry_2].iter().cloned());
assert_eq!(tag, tag_2);
let entry_2 = Entry::new(false, id_bytes_2.to_vec(), output_2);
let id_3 = crate::keyring::gen_id(&rng);
let id_bytes_3 = id_3.to_be_bytes();
let mut hash_arr_3 = [0; 32];
rng.fill(&mut hash_arr_3).unwrap();
let hash_3 = blake3::Hash::from(hash_arr_3);
let output_3 = Output::Blake3(crate::mac::output::Blake3Output::from(hash_3));
let entry_3 = Entry::new(true, id_bytes_3.to_vec(), output_3);
let tag_3 = Tag::new([entry_1, entry_2, entry_3].iter().cloned());
assert_eq!(tag_2, tag_3);
}
}