use crate::xor;
use super::{Mac, MacNotEqual};
use std::sync::atomic::{AtomicU64, Ordering};
use std::fmt;
use zeroize::Zeroize;
use chacha20::{hchacha, XChaCha20};
use chacha20::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use chacha20::cipher::typenum::U10;
use poly1305::Poly1305;
use universal_hash::{KeyInit, UniversalHash};
use generic_array::GenericArray;
const BLOCK_SIZE: u64 = 64;
pub struct Key {
shared_secret: [u8; 32],
initial_nonce: [u8; 24],
count: u64
}
impl Key {
pub(crate) fn new(
shared_secret: [u8; 32],
initial_nonce: [u8; 24]
) -> Self {
let shared_secret = hchacha::<U10>(
shared_secret.as_ref().into(),
&GenericArray::default()
).into();
Self {
shared_secret,
initial_nonce,
count: 0
}
}
pub fn encrypt(&mut self, msg: &mut [u8]) -> Mac {
self.new_cipher().encrypt(msg)
}
pub fn decrypt(
&mut self,
msg: &mut [u8],
recv_mac: &Mac
) -> Result<(), MacNotEqual> {
self.new_cipher().decrypt(msg, recv_mac)
}
fn new_cipher(&mut self) -> Cipher {
self.count += 1;
Cipher::new(&self.shared_secret, &self.initial_nonce, self.count)
}
pub fn into_sync(self) -> SyncKey {
SyncKey::new(self.shared_secret, self.initial_nonce, self.count)
}
pub fn dublicate(&self) -> Self {
Self {
shared_secret: self.shared_secret,
initial_nonce: self.initial_nonce,
count: self.count
}
}
}
impl fmt::Debug for Key {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Key")
}
}
impl Drop for Key {
fn drop(&mut self) {
self.shared_secret.zeroize();
self.initial_nonce.zeroize();
}
}
pub struct SyncKey {
shared_secret: [u8; 32],
initial_nonce: [u8; 24],
count: AtomicU64
}
impl SyncKey {
fn new(
shared_secret: [u8; 32],
initial_nonce: [u8; 24],
count: u64
) -> Self {
Self {
shared_secret,
initial_nonce,
count: AtomicU64::new(count + 1)
}
}
pub fn encrypt(&self, msg: &mut [u8]) -> Mac {
self.new_cipher().encrypt(msg)
}
pub fn decrypt(
&self,
msg: &mut [u8],
recv_mac: &Mac
) -> Result<(), MacNotEqual> {
self.new_cipher().decrypt(msg, recv_mac)
}
fn new_cipher(&self) -> Cipher {
Cipher::new(
&self.shared_secret,
&self.initial_nonce,
self.count.fetch_add(1, Ordering::Relaxed)
)
}
}
impl fmt::Debug for SyncKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("SyncKey")
}
}
impl Drop for SyncKey {
fn drop(&mut self) {
self.shared_secret.zeroize();
self.initial_nonce.zeroize();
}
}
trait ToMac {
fn to_mac(self, msg_len: usize) -> Mac;
}
impl ToMac for Poly1305 {
fn to_mac(self, msg_len: usize) -> Mac {
let bytes = (msg_len as u64).to_be_bytes();
Mac::new(self.compute_unpadded(&bytes))
}
}
fn xor_nonce_with_u64(nonce: &mut [u8; 24], count: u64) {
let bytes = count.to_be_bytes();
xor(&mut nonce[..8], &bytes);
xor(&mut nonce[8..16], &bytes);
xor(&mut nonce[16..], &bytes);
}
struct Cipher {
cipher: XChaCha20,
poly: Poly1305
}
impl Cipher {
fn new(
shared_secret: &[u8; 32],
initial_nonce: &[u8; 24],
count: u64
) -> Self {
let mut iv = *initial_nonce;
xor_nonce_with_u64(&mut iv, count);
let mut cipher = XChaCha20::new(
shared_secret.into(),
iv.as_ref().into()
);
let mut mac_key = [0u8; 32];
cipher.apply_keystream(&mut mac_key);
let poly = Poly1305::new(mac_key.as_ref().into());
mac_key.zeroize();
cipher.seek(BLOCK_SIZE);
Self { cipher, poly }
}
fn encrypt(mut self, msg: &mut [u8]) -> Mac {
self.cipher.apply_keystream(msg);
self.poly.update_padded(msg);
self.poly.to_mac(msg.len())
}
fn decrypt(
mut self,
msg: &mut [u8],
recv_mac: &Mac
) -> Result<(), MacNotEqual> {
self.poly.update_padded(msg);
let mac = self.poly.to_mac(msg.len());
if recv_mac == &mac {
self.cipher.apply_keystream(msg);
Ok(())
} else {
Err(MacNotEqual)
}
}
}