use crate::cipher_suites::AesGcm;
use rustls::crypto::cipher::{
make_tls12_aad, AeadKey, InboundOpaqueMessage, InboundPlainMessage, Iv, KeyBlockShape,
MessageDecrypter, MessageEncrypter, Nonce, OutboundOpaqueMessage, OutboundPlainMessage,
PrefixedPayload, Tls12AeadAlgorithm, UnsupportedOperationError,
};
use rustls::{ConnectionTrafficSecrets, Error};
use symcrypt::cipher::BlockCipherType;
use symcrypt::gcm::GcmExpandedKey;
const GCM_FULL_NONCE_LENGTH: usize = 12;
const GCM_EXPLICIT_NONCE_LENGTH: usize = 8;
const GCM_IMPLICIT_NONCE_LENGTH: usize = 4;
const GCM_TAG_LENGTH: usize = 16;
#[cfg(feature = "chacha")]
use symcrypt::chacha::{chacha20_poly1305_decrypt_in_place, chacha20_poly1305_encrypt_in_place};
#[cfg(feature = "chacha")]
const CHACHA_TAG_LENGTH: usize = 16;
#[cfg(feature = "chacha")]
const CHACHA_NONCE_LENGTH: usize = 12;
#[cfg(feature = "chacha")]
const CHACHA_KEY_LENGTH: usize = 32;
#[cfg(feature = "chacha")]
pub struct Tls12ChaCha;
#[cfg(feature = "chacha")]
pub struct Tls12ChaCha20Poly1305 {
key: [u8; CHACHA_KEY_LENGTH],
iv: Iv,
}
#[cfg(feature = "chacha")]
impl Tls12AeadAlgorithm for Tls12ChaCha {
fn encrypter(&self, key: AeadKey, iv: &[u8], _: &[u8]) -> Box<dyn MessageEncrypter> {
assert_eq!(key.as_ref().len(), CHACHA_KEY_LENGTH);
let mut chacha_key = [0u8; CHACHA_KEY_LENGTH];
chacha_key[..CHACHA_KEY_LENGTH].copy_from_slice(key.as_ref());
Box::new(Tls12ChaCha20Poly1305 {
key: chacha_key,
iv: Iv::copy(iv),
})
}
fn decrypter(&self, key: AeadKey, iv: &[u8]) -> Box<dyn MessageDecrypter> {
assert_eq!(key.as_ref().len(), CHACHA_KEY_LENGTH);
let mut chacha_key = [0u8; CHACHA_KEY_LENGTH];
chacha_key[..CHACHA_KEY_LENGTH].copy_from_slice(key.as_ref());
Box::new(Tls12ChaCha20Poly1305 {
key: chacha_key,
iv: Iv::copy(iv),
})
}
fn key_block_shape(&self) -> KeyBlockShape {
KeyBlockShape {
enc_key_len: CHACHA_KEY_LENGTH, fixed_iv_len: CHACHA_NONCE_LENGTH,
explicit_nonce_len: 0,
}
}
fn extract_keys(
&self,
key: AeadKey,
iv: &[u8],
_explicit: &[u8],
) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError> {
assert_eq!(CHACHA_NONCE_LENGTH, iv.len()); Ok(ConnectionTrafficSecrets::Chacha20Poly1305 {
key,
iv: Iv::new(iv[..].try_into().unwrap()),
})
}
}
#[cfg(feature = "chacha")]
impl MessageEncrypter for Tls12ChaCha20Poly1305 {
fn encrypt(
&mut self,
msg: OutboundPlainMessage,
seq: u64,
) -> Result<OutboundOpaqueMessage, Error> {
let total_len = self.encrypted_payload_len(msg.payload.len());
let mut payload = PrefixedPayload::with_capacity(total_len);
payload.extend_from_chunks(&msg.payload);
let mut tag = [0u8; CHACHA_TAG_LENGTH];
let nonce = Nonce::new(&self.iv, seq);
let auth_data = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len());
match chacha20_poly1305_encrypt_in_place(
&self.key,
&nonce.0,
&auth_data,
&mut payload.as_mut()[..msg.payload.len()],
&mut tag,
) {
Ok(_) => {
payload.extend_from_slice(&tag); Ok(OutboundOpaqueMessage::new(msg.typ, msg.version, payload))
}
Err(symcrypt_error) => {
let custom_error_message = format!(
"SymCryptError: {}",
symcrypt_error );
Err(Error::General(custom_error_message))
}
}
}
fn encrypted_payload_len(&self, payload_len: usize) -> usize {
payload_len + CHACHA_TAG_LENGTH
}
}
#[cfg(feature = "chacha")]
impl MessageDecrypter for Tls12ChaCha20Poly1305 {
fn decrypt<'a>(
&mut self,
mut msg: InboundOpaqueMessage<'a>,
seq: u64,
) -> Result<InboundPlainMessage<'a>, Error> {
let payload = &mut msg.payload; let payload_len = payload.len(); if payload_len < CHACHA_TAG_LENGTH {
return Err(Error::DecryptError);
}
let message_len = payload_len - CHACHA_TAG_LENGTH;
let nonce = Nonce::new(&self.iv, seq);
let auth_data = make_tls12_aad(seq, msg.typ, msg.version, message_len);
let mut tag = [0u8; CHACHA_TAG_LENGTH];
tag.copy_from_slice(&payload[message_len..]);
match chacha20_poly1305_decrypt_in_place(
&self.key,
&nonce.0,
&auth_data,
&mut payload[..message_len],
&tag,
) {
Ok(_) => {
payload.truncate(message_len);
Ok(msg.into_plain_message())
}
Err(symcrypt_error) => {
let custom_error_message = format!(
"SymCryptError: {}",
symcrypt_error );
Err(Error::General(custom_error_message))
}
}
}
}
pub struct Tls12Gcm {
pub(crate) algo_type: AesGcm,
}
pub struct Gcm12Decrypt {
key: GcmExpandedKey,
iv: [u8; GCM_IMPLICIT_NONCE_LENGTH],
}
pub struct Gcm12Encrypt {
key: GcmExpandedKey,
full_iv: [u8; GCM_FULL_NONCE_LENGTH],
}
impl Tls12AeadAlgorithm for Tls12Gcm {
fn encrypter(&self, key: AeadKey, iv: &[u8], extra: &[u8]) -> Box<dyn MessageEncrypter> {
assert_eq!(iv.len(), GCM_IMPLICIT_NONCE_LENGTH);
assert_eq!(extra.len(), GCM_EXPLICIT_NONCE_LENGTH);
let mut full_iv = [0u8; GCM_FULL_NONCE_LENGTH];
full_iv[..GCM_IMPLICIT_NONCE_LENGTH].copy_from_slice(iv);
full_iv[GCM_IMPLICIT_NONCE_LENGTH..].copy_from_slice(extra);
Box::new(Gcm12Encrypt {
key: GcmExpandedKey::new(key.as_ref(), BlockCipherType::AesBlock).unwrap(),
full_iv,
})
}
fn decrypter(&self, key: AeadKey, iv: &[u8]) -> Box<dyn MessageDecrypter> {
assert_eq!(iv.len(), GCM_IMPLICIT_NONCE_LENGTH);
let mut implicit_iv = [0u8; GCM_IMPLICIT_NONCE_LENGTH];
implicit_iv.copy_from_slice(iv);
Box::new(Gcm12Decrypt {
key: GcmExpandedKey::new(key.as_ref(), BlockCipherType::AesBlock).unwrap(),
iv: implicit_iv,
})
}
fn key_block_shape(&self) -> KeyBlockShape {
KeyBlockShape {
enc_key_len: self.algo_type.key_size(), fixed_iv_len: GCM_IMPLICIT_NONCE_LENGTH,
explicit_nonce_len: GCM_EXPLICIT_NONCE_LENGTH,
}
}
fn extract_keys(
&self,
key: AeadKey,
iv: &[u8],
explicit: &[u8],
) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError> {
let mut gcm_iv = [0; GCM_FULL_NONCE_LENGTH];
gcm_iv[..GCM_IMPLICIT_NONCE_LENGTH].copy_from_slice(iv);
gcm_iv[GCM_IMPLICIT_NONCE_LENGTH..].copy_from_slice(explicit);
match self.algo_type.key_size() {
16 => Ok(ConnectionTrafficSecrets::Aes128Gcm {
key,
iv: Iv::new(gcm_iv),
}),
32 => Ok(ConnectionTrafficSecrets::Aes256Gcm {
key,
iv: Iv::new(gcm_iv),
}),
_ => Err(UnsupportedOperationError),
}
}
}
impl MessageEncrypter for Gcm12Encrypt {
fn encrypt(
&mut self,
msg: OutboundPlainMessage,
seq: u64,
) -> Result<OutboundOpaqueMessage, Error> {
let total_len = self.encrypted_payload_len(msg.payload.len());
let mut payload = PrefixedPayload::with_capacity(total_len);
let nonce = Nonce::new(&Iv::copy(&self.full_iv), seq);
payload.extend_from_slice(&nonce.0[GCM_IMPLICIT_NONCE_LENGTH..]);
payload.extend_from_chunks(&msg.payload);
let mut tag = [0u8; GCM_TAG_LENGTH];
let auth_data = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len());
self.key.encrypt_in_place(
&nonce.0,
&auth_data,
&mut payload.as_mut()
[GCM_EXPLICIT_NONCE_LENGTH..(msg.payload.len() + GCM_EXPLICIT_NONCE_LENGTH)],
&mut tag,
);
payload.extend_from_slice(&tag);
Ok(OutboundOpaqueMessage::new(msg.typ, msg.version, payload))
}
fn encrypted_payload_len(&self, payload_len: usize) -> usize {
payload_len + GCM_EXPLICIT_NONCE_LENGTH + GCM_TAG_LENGTH
}
}
impl MessageDecrypter for Gcm12Decrypt {
fn decrypt<'a>(
&mut self,
mut msg: InboundOpaqueMessage<'a>,
seq: u64,
) -> Result<InboundPlainMessage<'a>, Error> {
let payload = &mut msg.payload; let payload_len = payload.len(); if payload_len < GCM_TAG_LENGTH + GCM_EXPLICIT_NONCE_LENGTH {
return Err(Error::DecryptError);
}
let mut nonce = [0u8; GCM_FULL_NONCE_LENGTH];
nonce[..GCM_IMPLICIT_NONCE_LENGTH].copy_from_slice(&self.iv);
nonce[GCM_IMPLICIT_NONCE_LENGTH..].copy_from_slice(&payload[..GCM_EXPLICIT_NONCE_LENGTH]);
let mut tag = [0u8; GCM_TAG_LENGTH];
tag.copy_from_slice(&payload[payload_len - GCM_TAG_LENGTH..]);
let auth_data = make_tls12_aad(
seq,
msg.typ,
msg.version,
payload_len - GCM_TAG_LENGTH - GCM_EXPLICIT_NONCE_LENGTH,
);
match self.key.decrypt_in_place(
&nonce,
&auth_data,
&mut payload[GCM_EXPLICIT_NONCE_LENGTH..payload_len - GCM_TAG_LENGTH],
&tag,
) {
Ok(()) => {
payload.copy_within(GCM_EXPLICIT_NONCE_LENGTH..(payload_len - GCM_TAG_LENGTH), 0);
payload.truncate(payload_len - (GCM_EXPLICIT_NONCE_LENGTH + GCM_TAG_LENGTH));
Ok(msg.into_plain_message())
}
Err(symcrypt_error) => {
let custom_error_message = format!(
"SymCryptError: {}",
symcrypt_error );
Err(Error::General(custom_error_message))
}
}
}
}