use crate::traits::{Cipher, U8Array};
#[cfg(feature = "use_alloc")]
use alloc::vec::Vec;
pub struct CipherState<C: Cipher> {
key: C::Key,
n: u64,
}
impl<C> Clone for CipherState<C>
where
C: Cipher,
{
fn clone(&self) -> Self {
Self {
key: self.key.clone(),
n: self.n,
}
}
}
impl<C> CipherState<C>
where
C: Cipher,
{
pub fn name() -> &'static str {
C::name()
}
pub fn new(key: &[u8], n: u64) -> Self {
CipherState {
key: C::Key::from_slice(key),
n,
}
}
pub fn rekey(&mut self) {
self.key = C::rekey(&self.key);
}
pub fn encrypt_ad(&mut self, authtext: &[u8], plaintext: &[u8], out: &mut [u8]) {
C::encrypt(&self.key, self.n, authtext, plaintext, out);
#[cfg(feature = "use_std")]
if option_env!("NOISE_RUST_TEST_IN_PLACE").is_some() {
let mut inout = plaintext.to_vec();
inout.extend_from_slice(&[0; 16]);
let l = C::encrypt_in_place(&self.key, self.n, authtext, &mut inout, plaintext.len());
assert_eq!(inout, out);
assert_eq!(l, out.len());
}
self.n = self.n.checked_add(1).unwrap();
}
pub fn encrypt_ad_in_place(
&mut self,
authtext: &[u8],
in_out: &mut [u8],
plaintext_len: usize,
) -> usize {
let size = C::encrypt_in_place(&self.key, self.n, authtext, in_out, plaintext_len);
self.n = self.n.checked_add(1).unwrap();
size
}
pub fn decrypt_ad(
&mut self,
authtext: &[u8],
ciphertext: &[u8],
out: &mut [u8],
) -> Result<(), ()> {
let r = C::decrypt(&self.key, self.n, authtext, ciphertext, out);
#[cfg(feature = "use_std")]
if option_env!("NOISE_RUST_TEST_IN_PLACE").is_some() {
let mut inout = ciphertext.to_vec();
let r2 = C::decrypt_in_place(&self.key, self.n, authtext, &mut inout, ciphertext.len());
assert_eq!(r.map(|_| out.len()), r2);
if r.is_ok() {
assert_eq!(&inout[..out.len()], out);
}
}
r?;
self.n = self.n.checked_add(1).unwrap();
Ok(())
}
pub fn decrypt_ad_in_place(
&mut self,
authtext: &[u8],
in_out: &mut [u8],
ciphertext_len: usize,
) -> Result<usize, ()> {
let size = C::decrypt_in_place(&self.key, self.n, authtext, in_out, ciphertext_len)?;
self.n = self.n.checked_add(1).unwrap();
Ok(size)
}
pub fn encrypt(&mut self, plaintext: &[u8], out: &mut [u8]) {
self.encrypt_ad(&[0u8; 0], plaintext, out)
}
pub fn encrypt_in_place(&mut self, in_out: &mut [u8], plaintext_len: usize) -> usize {
self.encrypt_ad_in_place(&[0u8; 0], in_out, plaintext_len)
}
#[cfg(any(feature = "use_std", feature = "use_alloc"))]
pub fn encrypt_vec(&mut self, plaintext: &[u8]) -> Vec<u8> {
let mut out = vec![0u8; plaintext.len() + 16];
self.encrypt(plaintext, &mut out);
out
}
pub fn decrypt(&mut self, ciphertext: &[u8], out: &mut [u8]) -> Result<(), ()> {
self.decrypt_ad(&[0u8; 0], ciphertext, out)
}
pub fn decrypt_in_place(
&mut self,
in_out: &mut [u8],
ciphertext_len: usize,
) -> Result<usize, ()> {
self.decrypt_ad_in_place(&[0u8; 0], in_out, ciphertext_len)
}
#[cfg(any(feature = "use_std", feature = "use_alloc"))]
pub fn decrypt_vec(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, ()> {
if ciphertext.len() < 16 {
return Err(());
}
let mut out = vec![0u8; ciphertext.len() - 16];
self.decrypt(ciphertext, &mut out)?;
Ok(out)
}
pub fn get_next_n(&self) -> u64 {
self.n
}
pub fn extract(self) -> (C::Key, u64) {
(self.key, self.n)
}
}