use aws_lc_rs::aead::chacha20_poly1305_openssh::OpeningKey;
use aws_lc_rs::aead::chacha20_poly1305_openssh::PACKET_LENGTH_LEN;
use aws_lc_rs::aead::chacha20_poly1305_openssh::SealingKey;
use aws_lc_rs::aead::chacha20_poly1305_openssh::TAG_LEN;
use crate::crypto::cipher::Cipher;
use crate::types::Error;
use crate::types::Result;
use crate::wire::Packet;
pub struct ChaCha20Poly1305 {
key_material: [u8; 64],
}
impl ChaCha20Poly1305 {
#[must_use]
pub const fn new(key_material: [u8; 64]) -> Self {
Self { key_material }
}
}
impl Cipher for ChaCha20Poly1305 {
const AEAD_LENGTH: Option<usize> = Some(TAG_LEN);
fn encrypt_packet<'buf, B>(
&self,
packet: &'buf mut Packet<&'buf mut B>,
sequence_number: u32,
) -> Result<()>
where
B: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
{
let key = SealingKey::new(&self.key_material);
let packet_len = packet.packet_length()?;
let (data, mac) = packet.packet_mut()?.split_at_mut(4 + packet_len as usize);
let mac = mac.as_mut_array().ok_or(Error::Crypto)?;
key.seal_in_place(sequence_number, data, mac);
Ok(())
}
fn decrypt_packet_length<B>(&self, packet: &Packet<B>, sequence_number: u32) -> Result<u32>
where
B: AsRef<[u8]>,
{
let key = OpeningKey::new(&self.key_material);
let packet_len_encrypted = packet.packet_length_bytes()?;
let packet_len_decrypted = key.decrypt_packet_length(sequence_number, packet_len_encrypted);
Ok(u32::from_be_bytes(packet_len_decrypted))
}
fn decrypt_packet<'buf, B>(
&self,
packet: &'buf mut Packet<&'buf mut B>,
sequence_number: u32,
) -> Result<()>
where
B: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
{
let packet_len = self.decrypt_packet_length(packet, sequence_number)?;
let (data, mac) = packet.packet_mut()?.split_at_mut(4 + packet_len as usize);
let packet_len_bytes = packet_len.to_be_bytes();
let mut tag = [0u8; TAG_LEN];
tag.copy_from_slice(mac);
let key = OpeningKey::new(&self.key_material);
key.open_in_place(sequence_number, data, &tag)
.map_err(|_unspecified| Error::Crypto)?;
data[..PACKET_LENGTH_LEN].copy_from_slice(&packet_len_bytes);
Ok(())
}
}
#[cfg(test)]
mod tests {
use aws_lc_rs::aead::chacha20_poly1305_openssh::KEY_LEN;
use rstest::rstest;
use super::*;
const KEY_MATERIAL: [u8; KEY_LEN] = {
let mut bytes = [0u8; KEY_LEN];
bytes[KEY_LEN - 1] = 1;
bytes
};
#[rstest]
#[case(
0,
"00000008061500010203040500000000000000000000000000000000",
"4540f0529912e7bf57523c7f66022017cfefd3278ac13f40f8523faf"
)]
fn encrypt_works(
#[case] sequence_number: u32,
#[case] packet_clear: &str,
#[case] packet_cipher_should: &str,
) {
let mut data = hex::decode(packet_clear).unwrap();
let data_len_should = data.len();
{
let mut packet = Packet::new(&mut data, TAG_LEN as u8);
let cipher = ChaCha20Poly1305::new(KEY_MATERIAL);
cipher.encrypt_packet(&mut packet, sequence_number).unwrap();
}
assert_eq!(data.len(), data_len_should);
let data = hex::encode(&data);
assert_eq!(data, packet_cipher_should);
}
#[rstest]
#[case(
"4540f0529912e7bf57523c7f66022017cfefd3278ac13f40f8523faf",
0,
"00000008061500010203040566022017cfefd3278ac13f40f8523faf",
8,
6,
&[0x15],
&[0x00, 0x01, 0x02 , 0x03 , 0x04 , 0x05]
)]
fn decrypt_works(
#[case] input_data: &str,
#[case] input_sequence_number: u32,
#[case] should_data: &str,
#[case] should_packet_length: u32,
#[case] should_padding_length: u8,
#[case] should_payload: &[u8],
#[case] should_padding: &[u8],
) {
let mut data = hex::decode(input_data).unwrap();
{
let mut packet = Packet::new(&mut data, TAG_LEN as u8);
let mut key_material = [0u8; KEY_LEN];
key_material[63] = 1;
let cipher = ChaCha20Poly1305::new(key_material);
cipher
.decrypt_packet(&mut packet, input_sequence_number)
.unwrap();
}
let got_dec_data = hex::encode(&data);
assert_eq!(got_dec_data, should_data);
let packet = Packet::new(&data, TAG_LEN as u8);
assert_eq!(packet.packet_length().unwrap(), should_packet_length);
assert_eq!(packet.padding_length().unwrap(), should_padding_length);
assert_eq!(packet.payload().unwrap(), should_payload);
assert_eq!(packet.padding().unwrap(), should_padding);
}
}