use alloc::boxed::Box;
use alloc::string::ToString;
use core::fmt;
use zeroize::Zeroize;
use crate::enums::{ContentType, ProtocolVersion};
use crate::error::Error;
use crate::msgs::codec;
pub use crate::msgs::message::{
BorrowedPayload, InboundOpaqueMessage, InboundPlainMessage, OutboundChunks,
OutboundOpaqueMessage, OutboundPlainMessage, PlainMessage, PrefixedPayload,
};
use crate::suites::ConnectionTrafficSecrets;
pub trait Tls13AeadAlgorithm: Send + Sync {
fn encrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageEncrypter>;
fn decrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageDecrypter>;
fn key_len(&self) -> usize;
fn extract_keys(
&self,
key: AeadKey,
iv: Iv,
) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError>;
fn fips(&self) -> bool {
false
}
}
pub trait Tls12AeadAlgorithm: Send + Sync + 'static {
fn encrypter(&self, key: AeadKey, iv: &[u8], extra: &[u8]) -> Box<dyn MessageEncrypter>;
fn decrypter(&self, key: AeadKey, iv: &[u8]) -> Box<dyn MessageDecrypter>;
fn key_block_shape(&self) -> KeyBlockShape;
fn extract_keys(
&self,
key: AeadKey,
iv: &[u8],
explicit: &[u8],
) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError>;
fn fips(&self) -> bool {
false
}
}
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub struct UnsupportedOperationError;
impl From<UnsupportedOperationError> for Error {
fn from(value: UnsupportedOperationError) -> Self {
Self::General(value.to_string())
}
}
impl fmt::Display for UnsupportedOperationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "operation not supported")
}
}
#[cfg(feature = "std")]
impl std::error::Error for UnsupportedOperationError {}
pub struct KeyBlockShape {
pub enc_key_len: usize,
pub fixed_iv_len: usize,
pub explicit_nonce_len: usize,
}
pub trait MessageDecrypter: Send + Sync {
fn decrypt<'a>(
&mut self,
msg: InboundOpaqueMessage<'a>,
seq: u64,
) -> Result<InboundPlainMessage<'a>, Error>;
}
pub trait MessageEncrypter: Send + Sync {
fn encrypt(
&mut self,
msg: OutboundPlainMessage<'_>,
seq: u64,
) -> Result<OutboundOpaqueMessage, Error>;
fn encrypted_payload_len(&self, payload_len: usize) -> usize;
}
impl dyn MessageEncrypter {
pub(crate) fn invalid() -> Box<dyn MessageEncrypter> {
Box::new(InvalidMessageEncrypter {})
}
}
impl dyn MessageDecrypter {
pub(crate) fn invalid() -> Box<dyn MessageDecrypter> {
Box::new(InvalidMessageDecrypter {})
}
}
#[derive(Default)]
pub struct Iv([u8; NONCE_LEN]);
impl Iv {
#[cfg(feature = "tls12")]
pub fn new(value: [u8; NONCE_LEN]) -> Self {
Self(value)
}
#[cfg(feature = "tls12")]
pub fn copy(value: &[u8]) -> Self {
debug_assert_eq!(value.len(), NONCE_LEN);
let mut iv = Self::new(Default::default());
iv.0.copy_from_slice(value);
iv
}
}
impl From<[u8; NONCE_LEN]> for Iv {
fn from(bytes: [u8; NONCE_LEN]) -> Self {
Self(bytes)
}
}
impl AsRef<[u8]> for Iv {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
pub struct Nonce(pub [u8; NONCE_LEN]);
impl Nonce {
#[inline]
pub fn new(iv: &Iv, seq: u64) -> Self {
let mut seq_bytes = [0u8; NONCE_LEN];
codec::put_u64(seq, &mut seq_bytes[4..]);
Self::new_from_seq(iv, seq_bytes)
}
pub fn for_path(path_id: u32, iv: &Iv, pn: u64) -> Self {
let mut seq_bytes = [0u8; NONCE_LEN];
seq_bytes[0..4].copy_from_slice(&path_id.to_be_bytes());
codec::put_u64(pn, &mut seq_bytes[4..]);
Self::new_from_seq(iv, seq_bytes)
}
#[inline]
fn new_from_seq(iv: &Iv, mut seq: [u8; NONCE_LEN]) -> Self {
seq.iter_mut()
.zip(iv.0.iter())
.for_each(|(s, iv)| {
*s ^= *iv;
});
Self(seq)
}
}
pub const NONCE_LEN: usize = 12;
#[inline]
pub fn make_tls13_aad(payload_len: usize) -> [u8; 5] {
let version = ProtocolVersion::TLSv1_2.to_array();
[
ContentType::ApplicationData.into(),
version[0],
version[1],
(payload_len >> 8) as u8,
(payload_len & 0xff) as u8,
]
}
#[inline]
pub fn make_tls12_aad(
seq: u64,
typ: ContentType,
vers: ProtocolVersion,
len: usize,
) -> [u8; TLS12_AAD_SIZE] {
let mut out = [0; TLS12_AAD_SIZE];
codec::put_u64(seq, &mut out[0..]);
out[8] = typ.into();
codec::put_u16(vers.into(), &mut out[9..]);
codec::put_u16(len as u16, &mut out[11..]);
out
}
const TLS12_AAD_SIZE: usize = 8 + 1 + 2 + 2;
pub struct AeadKey {
buf: [u8; Self::MAX_LEN],
used: usize,
}
impl AeadKey {
#[cfg(feature = "tls12")]
pub(crate) fn new(buf: &[u8]) -> Self {
debug_assert!(buf.len() <= Self::MAX_LEN);
let mut key = Self::from([0u8; Self::MAX_LEN]);
key.buf[..buf.len()].copy_from_slice(buf);
key.used = buf.len();
key
}
pub(crate) fn with_length(self, len: usize) -> Self {
assert!(len <= self.used);
Self {
buf: self.buf,
used: len,
}
}
pub(crate) const MAX_LEN: usize = 32;
}
impl Drop for AeadKey {
fn drop(&mut self) {
self.buf.zeroize();
}
}
impl AsRef<[u8]> for AeadKey {
fn as_ref(&self) -> &[u8] {
&self.buf[..self.used]
}
}
impl From<[u8; Self::MAX_LEN]> for AeadKey {
fn from(bytes: [u8; Self::MAX_LEN]) -> Self {
Self {
buf: bytes,
used: Self::MAX_LEN,
}
}
}
struct InvalidMessageEncrypter {}
impl MessageEncrypter for InvalidMessageEncrypter {
fn encrypt(
&mut self,
_m: OutboundPlainMessage<'_>,
_seq: u64,
) -> Result<OutboundOpaqueMessage, Error> {
Err(Error::EncryptError)
}
fn encrypted_payload_len(&self, payload_len: usize) -> usize {
payload_len
}
}
struct InvalidMessageDecrypter {}
impl MessageDecrypter for InvalidMessageDecrypter {
fn decrypt<'a>(
&mut self,
_m: InboundOpaqueMessage<'a>,
_seq: u64,
) -> Result<InboundPlainMessage<'a>, Error> {
Err(Error::DecryptError)
}
}