use crate::keccak::{AlignedKeccakState, KECCAK_BLOCK_BITLEN_STR, KECCAK_BLOCK_SIZE, keccakf_u8};
use bitflags::bitflags;
use subtle::{self, ConstantTimeEq};
use zeroize::{Zeroize, ZeroizeOnDrop};
#[cfg(feature = "serialize_secret_state")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub const STROBE_VERSION: &[u8] = b"1.0.2";
const TEMPLATE_VERSION_STR: [u8; 29] = *b"Strobe-Keccak-sss/bbbb-vX.Y.Z";
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct OpFlags: u8 {
const I = 1<<0;
const A = 1<<1;
const C = 1<<2;
const T = 1<<3;
const M = 1<<4;
const K = 1<<5;
}
}
#[cfg(feature = "serialize_secret_state")]
impl Serialize for OpFlags {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
bitflags_serde_legacy::serialize(self, "OpFlags", serializer)
}
}
#[cfg(feature = "serialize_secret_state")]
impl<'de> Deserialize<'de> for OpFlags {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
bitflags_serde_legacy::deserialize("OpFlags", deserializer)
}
}
impl Zeroize for OpFlags {
fn zeroize(&mut self) {
self.0.0.zeroize();
}
}
#[derive(Clone, Copy)]
#[cfg_attr(feature = "serialize_secret_state", derive(Serialize, Deserialize))]
#[repr(usize)]
pub enum SecParam {
B128 = 128,
B256 = 256,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AuthError;
impl core::fmt::Display for AuthError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("MAC verification failed")
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
#[cfg_attr(feature = "serialize_secret_state", derive(Serialize, Deserialize))]
pub struct Strobe {
pub(crate) st: AlignedKeccakState,
#[zeroize(skip)]
sec: SecParam,
rate: usize,
pos: usize,
pos_begin: usize,
is_receiver: Option<bool>,
prev_flags: Option<OpFlags>,
}
macro_rules! def_op_mut {
($name:ident, $meta_name:ident, $flags:expr, $doc_str:expr) => {
#[doc = $doc_str]
pub fn $name(&mut self, data: &mut [u8], more: bool) {
let flags = $flags;
self.operate(flags, data, more);
}
#[doc = $doc_str]
pub fn $meta_name(&mut self, data: &mut [u8], more: bool) {
let flags = $flags | OpFlags::M;
self.operate(flags, data, more);
}
};
}
macro_rules! def_op_no_mut {
($name:ident, $meta_name:ident, $flags:expr, $doc_str:expr) => {
#[doc = $doc_str]
pub fn $name(&mut self, data: &[u8], more: bool) {
let flags = $flags;
self.operate_no_mutate(flags, data, more);
}
#[doc = $doc_str]
pub fn $meta_name(&mut self, data: &[u8], more: bool) {
let flags = $flags | OpFlags::M;
self.operate_no_mutate(flags, data, more);
}
};
}
impl Strobe {
pub fn new(proto: &[u8], sec: SecParam) -> Strobe {
let rate = KECCAK_BLOCK_SIZE * 8 - (sec as usize) / 4 - 2;
assert!(rate >= 1);
assert!(rate < 254);
let mut st_buf = [0u8; KECCAK_BLOCK_SIZE * 8];
st_buf[0..6].copy_from_slice(&[0x01, (rate as u8) + 2, 0x01, 0x00, 0x01, 0x60]);
st_buf[6..13].copy_from_slice(b"STROBEv");
st_buf[13..18].copy_from_slice(STROBE_VERSION);
let mut st = AlignedKeccakState(st_buf);
keccakf_u8(&mut st);
let mut strobe = Strobe {
st,
sec,
rate,
pos: 0,
pos_begin: 0,
is_receiver: None,
prev_flags: None,
};
strobe.meta_ad(proto, false);
strobe
}
pub fn version_str(&self) -> [u8; TEMPLATE_VERSION_STR.len()] {
let mut buf = TEMPLATE_VERSION_STR;
match self.sec {
SecParam::B128 => buf[14..17].copy_from_slice(b"128"),
SecParam::B256 => buf[14..17].copy_from_slice(b"256"),
}
buf[18..22].copy_from_slice(KECCAK_BLOCK_BITLEN_STR);
buf[24..29].copy_from_slice(STROBE_VERSION);
buf
}
fn validate_streaming(&mut self, flags: OpFlags, more: bool) {
if more {
assert_eq!(
self.prev_flags,
Some(flags),
"`more` can only be used when this operation is the same as the previous operation"
);
}
self.prev_flags = Some(flags);
}
fn run_f(&mut self) {
self.st.0[self.pos] ^= self.pos_begin as u8;
self.st.0[self.pos + 1] ^= 0x04;
self.st.0[self.rate + 1] ^= 0x80;
keccakf_u8(&mut self.st);
self.pos = 0;
self.pos_begin = 0;
}
fn absorb(&mut self, data: &[u8]) {
for b in data {
self.st.0[self.pos] ^= *b;
self.pos += 1;
if self.pos == self.rate {
self.run_f();
}
}
}
fn absorb_and_set(&mut self, data: &mut [u8]) {
for b in data {
let state_byte = self.st.0.get_mut(self.pos).unwrap();
*state_byte ^= *b;
*b = *state_byte;
self.pos += 1;
if self.pos == self.rate {
self.run_f();
}
}
}
fn copy_state(&mut self, data: &mut [u8]) {
for b in data {
*b = self.st.0[self.pos];
self.pos += 1;
if self.pos == self.rate {
self.run_f();
}
}
}
fn exchange(&mut self, data: &mut [u8]) {
for b in data {
let state_byte = self.st.0.get_mut(self.pos).unwrap();
*b ^= *state_byte;
*state_byte ^= *b;
self.pos += 1;
if self.pos == self.rate {
self.run_f();
}
}
}
fn overwrite(&mut self, data: &[u8]) {
for b in data {
self.st.0[self.pos] = *b;
self.pos += 1;
if self.pos == self.rate {
self.run_f();
}
}
}
fn squeeze(&mut self, data: &mut [u8]) {
for b in data {
let state_byte = self.st.0.get_mut(self.pos).unwrap();
*b = *state_byte;
*state_byte = 0;
self.pos += 1;
if self.pos == self.rate {
self.run_f();
}
}
}
fn zero_state(&mut self, mut bytes_to_zero: usize) {
static ZEROS: [u8; 8 * KECCAK_BLOCK_SIZE] = [0u8; 8 * KECCAK_BLOCK_SIZE];
while bytes_to_zero > 0 {
let slice_len = core::cmp::min(self.rate - self.pos, bytes_to_zero);
self.st.0[self.pos..(self.pos + slice_len)].copy_from_slice(&ZEROS[..slice_len]);
self.pos += slice_len;
bytes_to_zero -= slice_len;
if self.pos == self.rate {
self.run_f();
}
}
}
fn begin_op(&mut self, mut flags: OpFlags) {
if flags.contains(OpFlags::T) {
let is_op_receiving = flags.contains(OpFlags::I);
if self.is_receiver.is_none() {
self.is_receiver = Some(is_op_receiving);
}
flags.set(OpFlags::I, self.is_receiver.unwrap() != is_op_receiving);
}
let old_pos_begin = self.pos_begin;
self.pos_begin = self.pos + 1;
let to_mix = &mut [old_pos_begin as u8, flags.bits()];
self.absorb(&to_mix[..]);
let force_f = flags.contains(OpFlags::C) || flags.contains(OpFlags::K);
if force_f && self.pos != 0 {
self.run_f();
}
}
pub(crate) fn operate(&mut self, flags: OpFlags, data: &mut [u8], more: bool) {
assert!(!flags.contains(OpFlags::K), "Op flag K not implemented");
self.validate_streaming(flags, more);
if !more {
self.begin_op(flags);
}
let flags = flags & !OpFlags::M;
if flags.contains(OpFlags::C) && flags.contains(OpFlags::T) && !flags.contains(OpFlags::I) {
if flags == OpFlags::C | OpFlags::T {
self.copy_state(data)
} else {
self.absorb_and_set(data);
}
} else if flags == OpFlags::I | OpFlags::A | OpFlags::C {
self.squeeze(data);
} else if flags.contains(OpFlags::C) {
self.exchange(data);
} else {
panic!("operate should not be called for operations that do not require mutation");
}
}
pub(crate) fn operate_no_mutate(&mut self, flags: OpFlags, data: &[u8], more: bool) {
assert!(!flags.contains(OpFlags::K), "Op flag K not implemented");
self.validate_streaming(flags, more);
if !more {
self.begin_op(flags);
}
if flags.contains(OpFlags::C) && flags.contains(OpFlags::T) && !flags.contains(OpFlags::I) {
panic!("operate_no_mutate called on something that requires mutation");
} else if flags.contains(OpFlags::C) {
self.overwrite(data);
} else {
self.absorb(data);
};
}
fn generalized_recv_mac<const N: usize>(
&mut self,
mac: &[u8; N],
is_meta: bool,
) -> Result<(), AuthError> {
let mut mac_copy = *mac;
let flags = if is_meta {
OpFlags::I | OpFlags::C | OpFlags::T | OpFlags::M
} else {
OpFlags::I | OpFlags::C | OpFlags::T
};
self.operate(flags, &mut mac_copy, false);
let mut all_zero = subtle::Choice::from(1u8);
for b in mac_copy {
all_zero &= b.ct_eq(&0u8);
}
mac_copy.zeroize();
if !bool::from(all_zero) {
Err(AuthError)
} else {
Ok(())
}
}
pub fn recv_mac<const N: usize>(&mut self, mac: &[u8; N]) -> Result<(), AuthError> {
self.generalized_recv_mac(mac, false)
}
pub fn meta_recv_mac<const N: usize>(&mut self, mac: &[u8; N]) -> Result<(), AuthError> {
self.generalized_recv_mac(mac, true)
}
fn generalized_ratchet(&mut self, num_bytes_to_zero: usize, more: bool, is_meta: bool) {
let flags = if is_meta {
OpFlags::C | OpFlags::M
} else {
OpFlags::C
};
self.validate_streaming(flags, more);
if !more {
self.begin_op(flags);
}
self.zero_state(num_bytes_to_zero);
}
pub fn ratchet(&mut self, num_bytes_to_zero: usize, more: bool) {
self.generalized_ratchet(num_bytes_to_zero, more, false)
}
pub fn meta_ratchet(&mut self, num_bytes_to_zero: usize, more: bool) {
self.generalized_ratchet(num_bytes_to_zero, more, true)
}
def_op_mut!(
send_enc,
meta_send_enc,
OpFlags::A | OpFlags::C | OpFlags::T,
"Sends an encrypted message."
);
def_op_mut!(
recv_enc,
meta_recv_enc,
OpFlags::I | OpFlags::A | OpFlags::C | OpFlags::T,
"Receives an encrypted message."
);
def_op_mut!(
send_mac,
meta_send_mac,
OpFlags::C | OpFlags::T,
"Sends a MAC of the internal state. \
The output is independent of the initial contents of the input buffer."
);
def_op_mut!(
prf,
meta_prf,
OpFlags::I | OpFlags::A | OpFlags::C,
"Extracts pseudorandom data as a function of the internal state. \
The output is independent of the initial contents of the input buffer."
);
def_op_no_mut!(
send_clr,
meta_send_clr,
OpFlags::A | OpFlags::T,
"Sends a plaintext message."
);
def_op_no_mut!(
recv_clr,
meta_recv_clr,
OpFlags::I | OpFlags::A | OpFlags::T,
"Receives a plaintext message."
);
def_op_no_mut!(
ad,
meta_ad,
OpFlags::A,
"Mixes associated data into the internal state."
);
def_op_no_mut!(
key,
meta_key,
OpFlags::A | OpFlags::C,
"Sets a symmetric cipher key."
);
}
#[test]
fn version_str() {
let s128 = Strobe::new(b"version_str test", SecParam::B128);
assert_eq!(&s128.version_str(), b"Strobe-Keccak-128/1600-v1.0.2");
let s256 = Strobe::new(b"version_str test", SecParam::B256);
assert_eq!(&s256.version_str(), b"Strobe-Keccak-256/1600-v1.0.2");
}