use crate::constants::{NONCE_LEN, TAG_LEN};
use rand::{rngs::OsRng, thread_rng, RngCore};
pub use strobe_rs::AuthError;
use strobe_rs::{SecParam, Strobe};
#[derive(Clone, Debug)]
pub struct AuthPlaintext {
mac: Vec<u8>,
pt: Vec<u8>,
}
impl AuthPlaintext {
pub fn into_bytes(self) -> Vec<u8> {
let mut concatted = self.mac;
concatted.extend(self.pt);
concatted
}
pub fn from_bytes(mut bytes: Vec<u8>) -> Option<AuthPlaintext> {
if bytes.len() < TAG_LEN {
None
} else {
let pt = bytes.split_off(TAG_LEN);
let mac = bytes;
Some(AuthPlaintext { mac, pt })
}
}
}
#[derive(Clone, Debug)]
pub struct AuthCiphertext {
mac: Vec<u8>,
nonce: Vec<u8>,
ct: Vec<u8>,
}
impl AuthCiphertext {
pub fn into_bytes(self) -> Vec<u8> {
let mut concatted = self.mac;
concatted.extend(self.nonce);
concatted.extend(self.ct);
concatted
}
pub fn from_bytes(mut bytes: Vec<u8>) -> Option<AuthCiphertext> {
if bytes.len() < TAG_LEN + NONCE_LEN {
None
} else {
let mut rest = bytes.split_off(TAG_LEN);
let mac = bytes;
let rest2 = rest.split_off(NONCE_LEN);
let nonce = rest;
let ct = rest2;
Some(AuthCiphertext { mac, nonce, ct })
}
}
}
#[derive(Clone)]
pub struct DiscoHash {
strobe_ctx: Strobe,
initialized: bool,
output_len: usize,
}
impl DiscoHash {
pub fn new(output_len: usize) -> DiscoHash {
assert!(output_len >= 32);
DiscoHash {
strobe_ctx: Strobe::new(b"DiscoHash", SecParam::B128),
initialized: false,
output_len,
}
}
pub fn write(&mut self, input_data: &[u8]) {
self.strobe_ctx.ad(input_data, self.initialized);
self.initialized = true
}
pub fn sum(mut self) -> Vec<u8> {
let mut buf = vec![0u8; self.output_len];
self.strobe_ctx.prf(&mut buf, false);
buf
}
#[inline]
pub fn random(output_len: usize) -> Vec<u8> {
let mut buf = vec![0u8; output_len];
OsRng.fill_bytes(&mut buf);
buf
}
}
pub fn hash(input_data: &[u8], output_len: usize) -> Vec<u8> {
let mut h = DiscoHash::new(output_len);
h.write(input_data);
h.sum()
}
pub fn derive_keys(input_key: &[u8], output_len: usize) -> Vec<u8> {
assert!(input_key.len() >= 16);
let mut s = Strobe::new(b"DiscoKDF", SecParam::B128);
s.ad(input_key, false);
let mut buf = vec![0u8; output_len];
s.prf(&mut buf, false);
buf
}
pub fn protect_integrity(key: &[u8], plaintext: Vec<u8>) -> AuthPlaintext {
assert!(key.len() >= 16);
let mut s = Strobe::new(b"DiscoMAC", SecParam::B128);
s.ad(key, false);
s.ad(&plaintext, false);
let mut mac = vec![0u8; TAG_LEN];
s.send_mac(&mut mac, false);
AuthPlaintext { pt: plaintext, mac }
}
pub fn verify_integrity(key: &[u8], input: AuthPlaintext) -> Result<Vec<u8>, AuthError> {
assert!(key.len() >= 16);
let AuthPlaintext { pt, mut mac } = input;
let mut s = Strobe::new(b"DiscoMAC", SecParam::B128);
s.ad(key, false);
s.ad(&pt, false);
match s.recv_mac(&mut mac, false) {
Ok(_) => Ok(pt),
Err(ae) => Err(ae),
}
}
pub fn encrypt(key: &[u8], mut plaintext: Vec<u8>) -> AuthCiphertext {
assert!(key.len() >= 16);
let mut s = Strobe::new(b"DiscoAEAD", SecParam::B128);
s.ad(key, false);
let mut rng = thread_rng();
let mut nonce = vec![0u8; NONCE_LEN];
rng.fill_bytes(nonce.as_mut_slice());
s.ad(&nonce, false);
s.send_enc(&mut plaintext, false);
let mut mac = vec![0u8; TAG_LEN];
s.send_mac(&mut mac, false);
AuthCiphertext {
mac,
nonce,
ct: plaintext,
}
}
pub fn decrypt(key: &[u8], ciphertext: AuthCiphertext) -> Result<Vec<u8>, AuthError> {
assert!(key.len() >= 16);
let AuthCiphertext {
mut mac,
nonce,
mut ct,
} = ciphertext;
let mut s = Strobe::new(b"DiscoAEAD", SecParam::B128);
s.ad(&key, false);
s.ad(&nonce, false);
s.recv_enc(&mut ct, false);
match s.recv_mac(&mut mac, false) {
Ok(_) => Ok(ct),
Err(ae) => Err(ae),
}
}