use alloc::vec::Vec;
use crate::{Call, Error, Safe, Sponge};
use zeroize::Zeroize;
pub trait Encryption<T, const W: usize> {
fn subtract(&mut self, minuend: &T, subtrahend: &T) -> T;
fn is_equal(&mut self, lhs: &T, rhs: &T) -> bool;
}
fn prepare_sponge<E, T, const W: usize>(
safe: E,
domain_sep: u64,
message_len: usize,
shared_secret: &[T; 2],
nonce: &T,
) -> Result<Sponge<E, T, W>, Error>
where
E: Safe<T, W> + Encryption<T, W>,
T: Default + Copy + Zeroize,
{
let mut sponge = Sponge::start(safe, io_pattern(message_len), domain_sep)?;
sponge.absorb(2, shared_secret)?;
sponge.absorb(1, [*nonce])?;
sponge.squeeze(message_len)?;
Ok(sponge)
}
pub fn encrypt<E, T, const W: usize>(
safe: E,
domain_sep: impl Into<u64>,
message: impl AsRef<[T]>,
shared_secret: &[T; 2],
nonce: &T,
) -> Result<Vec<T>, Error>
where
E: Safe<T, W> + Encryption<T, W>,
T: Default + Copy + Zeroize,
{
let message = message.as_ref();
let message_len = message.len();
let mut sponge = prepare_sponge(
safe,
domain_sep.into(),
message_len,
shared_secret,
nonce,
)?;
sponge.absorb(message_len, message)?;
sponge.squeeze(1)?;
let mut cipher = Vec::from(&sponge.output[..]);
for i in 0..message_len {
cipher[i] = sponge.safe.add(&cipher[i], &message[i]);
}
if cipher.len() != message_len + 1 {
return Err(Error::EncryptionFailed);
}
match sponge.finish() {
Ok(mut output) => {
output.zeroize();
Ok(cipher)
}
Err(e) => {
cipher.zeroize();
Err(e)
}
}
}
pub fn decrypt<E, T, const W: usize>(
safe: E,
domain_sep: impl Into<u64>,
cipher: impl AsRef<[T]>,
shared_secret: &[T; 2],
nonce: &T,
) -> Result<Vec<T>, Error>
where
E: Safe<T, W> + Encryption<T, W>,
T: Default + Copy + Zeroize,
{
let cipher = cipher.as_ref();
let message_len = cipher.len() - 1;
let mut sponge = prepare_sponge(
safe,
domain_sep.into(),
message_len,
shared_secret,
nonce,
)?;
let mut message = Vec::from(&sponge.output[..]);
for i in 0..message_len {
message[i] = sponge.safe.subtract(&cipher[i], &message[i]);
}
sponge.absorb(message_len, &message)?;
sponge.squeeze(1)?;
let s = sponge.output[message_len];
if !sponge.safe.is_equal(&s, &cipher[message_len]) {
message.zeroize();
sponge.zeroize();
return Err(Error::DecryptionFailed);
};
match sponge.finish() {
Ok(mut output) => {
output.zeroize();
Ok(message)
}
Err(e) => {
message.zeroize();
Err(e)
}
}
}
const fn io_pattern(message_len: usize) -> [Call; 5] {
[
Call::Absorb(2),
Call::Absorb(1),
Call::Squeeze(message_len),
Call::Absorb(message_len),
Call::Squeeze(1),
]
}