use crate::cipher_suites::AesGcm;
use rustls::crypto::cipher::{
make_tls13_aad, AeadKey, InboundOpaqueMessage, InboundPlainMessage, Iv, MessageDecrypter,
MessageEncrypter, Nonce, OutboundOpaqueMessage, OutboundPlainMessage, PrefixedPayload,
Tls13AeadAlgorithm, UnsupportedOperationError,
};
use rustls::ConnectionTrafficSecrets;
use symcrypt::cipher::BlockCipherType;
use symcrypt::gcm::GcmExpandedKey;
const GCM_TAG_LENGTH: usize = 16;
#[cfg(feature = "chacha")]
use symcrypt::chacha::{chacha20_poly1305_decrypt_in_place, chacha20_poly1305_encrypt_in_place};
#[cfg(feature = "chacha")]
const CHACHA_TAG_LENGTH: usize = 16;
#[cfg(feature = "chacha")]
const CHACHA_KEY_LENGTH: usize = 32;
#[cfg(feature = "chacha")]
pub struct Tls13ChaCha;
#[cfg(feature = "chacha")]
pub struct Tls13ChaCha20Poly1305 {
key: [u8; CHACHA_KEY_LENGTH],
iv: Iv,
}
#[cfg(feature = "chacha")]
impl Tls13AeadAlgorithm for Tls13ChaCha {
fn encrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageEncrypter> {
assert_eq!(key.as_ref().len(), CHACHA_KEY_LENGTH); let mut chacha_key = [0u8; CHACHA_KEY_LENGTH];
chacha_key[..CHACHA_KEY_LENGTH].copy_from_slice(key.as_ref());
Box::new(Tls13ChaCha20Poly1305 {
key: chacha_key,
iv,
})
}
fn decrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageDecrypter> {
assert_eq!(key.as_ref().len(), CHACHA_KEY_LENGTH); let mut chacha_key = [0u8; CHACHA_KEY_LENGTH];
chacha_key[..CHACHA_KEY_LENGTH].copy_from_slice(key.as_ref());
Box::new(Tls13ChaCha20Poly1305 {
key: chacha_key,
iv,
})
}
fn key_len(&self) -> usize {
CHACHA_KEY_LENGTH }
fn extract_keys(
&self,
key: AeadKey,
iv: Iv,
) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError> {
Ok(ConnectionTrafficSecrets::Chacha20Poly1305 { key, iv })
}
}
#[cfg(feature = "chacha")]
impl MessageEncrypter for Tls13ChaCha20Poly1305 {
fn encrypt(
&mut self,
msg: OutboundPlainMessage,
seq: u64,
) -> Result<OutboundOpaqueMessage, rustls::Error> {
let total_len = self.encrypted_payload_len(msg.payload.len());
let mut payload = PrefixedPayload::with_capacity(total_len);
payload.extend_from_chunks(&msg.payload);
payload.extend_from_slice(&msg.typ.to_array());
let nonce = Nonce::new(&self.iv, seq);
let auth_data = make_tls13_aad(total_len);
let mut tag = [0u8; CHACHA_TAG_LENGTH];
match chacha20_poly1305_encrypt_in_place(
&self.key,
&nonce.0,
&auth_data,
&mut payload.as_mut()[..msg.payload.len() + 1],
&mut tag,
) {
Ok(_) => {
payload.extend_from_slice(&tag); Ok(OutboundOpaqueMessage::new(
rustls::ContentType::ApplicationData,
rustls::ProtocolVersion::TLSv1_2,
payload,
))
}
Err(symcrypt_error) => {
let custom_error_message = format!(
"SymCryptError: {}",
symcrypt_error );
Err(rustls::Error::General(custom_error_message))
}
}
}
fn encrypted_payload_len(&self, payload_len: usize) -> usize {
payload_len + 1 + CHACHA_TAG_LENGTH
}
}
#[cfg(feature = "chacha")]
impl MessageDecrypter for Tls13ChaCha20Poly1305 {
fn decrypt<'a>(
&mut self,
mut msg: InboundOpaqueMessage<'a>,
seq: u64,
) -> Result<InboundPlainMessage<'a>, rustls::Error> {
let payload = &mut msg.payload; let payload_len = payload.len();
if payload_len < CHACHA_TAG_LENGTH {
return Err(rustls::Error::DecryptError);
}
let message_length = payload_len - CHACHA_TAG_LENGTH;
let nonce = Nonce::new(&self.iv, seq);
let auth_data = make_tls13_aad(payload_len); let mut tag = [0u8; GCM_TAG_LENGTH];
tag.copy_from_slice(&payload[message_length..]);
match chacha20_poly1305_decrypt_in_place(
&self.key,
&nonce.0,
&auth_data,
&mut payload[..message_length],
&tag,
) {
Ok(_) => {
payload.truncate(message_length);
msg.into_tls13_unpadded_message() }
Err(symcrypt_error) => {
let custom_error_message = format!(
"SymCryptError: {}",
symcrypt_error );
Err(rustls::Error::General(custom_error_message))
}
}
}
}
pub struct Tls13Gcm {
pub(crate) algo_type: AesGcm,
}
pub struct Tls13GcmState {
key: GcmExpandedKey,
iv: Iv,
}
impl Tls13AeadAlgorithm for Tls13Gcm {
fn encrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageEncrypter> {
Box::new(Tls13GcmState {
key: GcmExpandedKey::new(key.as_ref(), BlockCipherType::AesBlock).unwrap(),
iv,
})
}
fn decrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageDecrypter> {
Box::new(Tls13GcmState {
key: GcmExpandedKey::new(key.as_ref(), BlockCipherType::AesBlock).unwrap(),
iv,
})
}
fn key_len(&self) -> usize {
self.algo_type.key_size()
}
fn extract_keys(
&self,
key: AeadKey,
iv: Iv,
) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError> {
match self.key_len() {
16 => Ok(ConnectionTrafficSecrets::Aes128Gcm { key, iv }),
32 => Ok(ConnectionTrafficSecrets::Aes256Gcm { key, iv }),
_ => Err(UnsupportedOperationError),
}
}
}
impl MessageEncrypter for Tls13GcmState {
fn encrypt(
&mut self,
msg: OutboundPlainMessage,
seq: u64,
) -> Result<OutboundOpaqueMessage, rustls::Error> {
let total_len = self.encrypted_payload_len(msg.payload.len());
let mut payload = PrefixedPayload::with_capacity(total_len);
payload.extend_from_chunks(&msg.payload);
payload.extend_from_slice(&msg.typ.to_array());
let nonce = Nonce::new(&self.iv, seq);
let auth_data = make_tls13_aad(total_len);
let mut tag = [0u8; GCM_TAG_LENGTH];
self.key.encrypt_in_place(
&nonce.0,
&auth_data,
&mut payload.as_mut()[..msg.payload.len() + 1],
&mut tag,
);
payload.extend_from_slice(&tag);
Ok(OutboundOpaqueMessage::new(
rustls::ContentType::ApplicationData,
rustls::ProtocolVersion::TLSv1_2,
payload,
))
}
fn encrypted_payload_len(&self, payload_len: usize) -> usize {
payload_len + 1 + GCM_TAG_LENGTH
}
}
impl MessageDecrypter for Tls13GcmState {
fn decrypt<'a>(
&mut self,
mut msg: InboundOpaqueMessage<'a>,
seq: u64,
) -> Result<InboundPlainMessage<'a>, rustls::Error> {
let payload = &mut msg.payload; let payload_len = payload.len(); if payload_len < GCM_TAG_LENGTH {
return Err(rustls::Error::DecryptError);
}
let message_length = payload_len - GCM_TAG_LENGTH;
let nonce = Nonce::new(&self.iv, seq);
let auth_data = make_tls13_aad(payload_len); let mut tag = [0u8; GCM_TAG_LENGTH];
tag.copy_from_slice(&payload[message_length..]);
match self
.key
.decrypt_in_place(&nonce.0, &auth_data, &mut payload[..message_length], &tag)
{
Ok(()) => {
payload.truncate(message_length);
msg.into_tls13_unpadded_message() }
Err(symcrypt_error) => {
let custom_error_message = format!(
"SymCryptError: {}",
symcrypt_error );
Err(rustls::Error::General(custom_error_message))
}
}
}
}