#[cfg(any(feature = "receive", test))]
use super::tasks::error::Error as InternalError;
use aead::AeadCore;
use aes_gcm::{AeadInPlace, Aes256Gcm, Error as CryptoError};
use byteorder::{NetworkEndian, WriteBytesExt};
use chacha20poly1305::XChaCha20Poly1305;
use crypto_common::{InvalidLength, KeyInit};
#[cfg(feature = "receive")]
use discortp::rtcp::MutableRtcpPacket;
use discortp::MutablePacket;
#[cfg(any(feature = "receive", test))]
use discortp::{
rtp::{MutableRtpPacket, RtpExtensionPacket},
Packet,
};
use std::{num::Wrapping, str::FromStr};
use typenum::Unsigned;
use crate::error::ConnectionError;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Default, Hash)]
#[non_exhaustive]
pub enum CryptoMode {
#[default]
Aes256Gcm,
XChaCha20Poly1305,
}
impl From<CryptoState> for CryptoMode {
fn from(val: CryptoState) -> Self {
match val {
CryptoState::Aes256Gcm(_) => Self::Aes256Gcm,
CryptoState::XChaCha20Poly1305(_) => Self::XChaCha20Poly1305,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct UnrecognisedCryptoMode;
impl FromStr for CryptoMode {
type Err = UnrecognisedCryptoMode;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"aead_aes256_gcm_rtpsize" => Ok(Self::Aes256Gcm),
"aead_xchacha20_poly1305_rtpsize" => Ok(Self::XChaCha20Poly1305),
_ => Err(UnrecognisedCryptoMode),
}
}
}
impl CryptoMode {
#[must_use]
pub(crate) const fn algorithm(self) -> EncryptionAlgorithm {
match self {
CryptoMode::Aes256Gcm => EncryptionAlgorithm::Aes256Gcm,
CryptoMode::XChaCha20Poly1305 => EncryptionAlgorithm::XChaCha20Poly1305,
}
}
pub(crate) fn cipher_from_key(self, key: &[u8]) -> Result<Cipher, InvalidLength> {
match self.algorithm() {
EncryptionAlgorithm::Aes256Gcm => Aes256Gcm::new_from_slice(key)
.map(Box::new)
.map(Cipher::Aes256Gcm),
EncryptionAlgorithm::XChaCha20Poly1305 =>
XChaCha20Poly1305::new_from_slice(key).map(Cipher::XChaCha20Poly1305),
}
}
#[must_use]
pub(crate) fn priority(self) -> u64 {
match self {
CryptoMode::Aes256Gcm => 1,
CryptoMode::XChaCha20Poly1305 => 0,
}
}
pub(crate) fn negotiate<It, T>(
modes: It,
preferred: Option<Self>,
) -> Result<Self, ConnectionError>
where
T: AsRef<str>,
It: IntoIterator<Item = T>,
{
let mut best = None;
for el in modes {
let Ok(el) = CryptoMode::from_str(el.as_ref()) else {
continue;
};
let mut el_priority = el.priority();
if let Some(preferred) = preferred {
if el == preferred {
el_priority = u64::MAX;
}
}
let accept = match best {
None => true,
Some((_, score)) if el_priority > score => true,
_ => false,
};
if accept {
best = Some((el, el_priority));
}
}
best.map(|(v, _)| v)
.ok_or(ConnectionError::CryptoModeUnavailable)
}
#[must_use]
pub const fn to_request_str(self) -> &'static str {
match self {
Self::Aes256Gcm => "aead_aes256_gcm_rtpsize",
Self::XChaCha20Poly1305 => "aead_xchacha20_poly1305_rtpsize",
}
}
#[must_use]
pub const fn algorithm_nonce_size(self) -> usize {
use typenum::Unsigned as _;
match self {
Self::XChaCha20Poly1305 => <XChaCha20Poly1305 as AeadCore>::NonceSize::USIZE, Self::Aes256Gcm => <Aes256Gcm as AeadCore>::NonceSize::USIZE, }
}
#[must_use]
pub const fn nonce_size(self) -> usize {
match self {
Self::Aes256Gcm | Self::XChaCha20Poly1305 => 4,
}
}
#[must_use]
pub(crate) const fn payload_prefix_len(self) -> usize {
match self {
CryptoMode::Aes256Gcm | CryptoMode::XChaCha20Poly1305 => 0,
}
}
#[must_use]
pub(crate) const fn encryption_tag_len(self) -> usize {
self.algorithm().encryption_tag_len()
}
#[must_use]
pub const fn payload_suffix_len(self) -> usize {
self.nonce_size() + self.encryption_tag_len()
}
#[must_use]
pub const fn tag_suffix_len(self) -> usize {
self.encryption_tag_len()
}
#[must_use]
pub const fn payload_overhead(self) -> usize {
self.payload_prefix_len() + self.payload_suffix_len()
}
fn nonce_slice<'a>(
self,
_header: &'a [u8],
body: &'a mut [u8],
) -> Result<(&'a [u8], &'a mut [u8]), CryptoError> {
match self {
Self::Aes256Gcm | Self::XChaCha20Poly1305 => {
let len = body.len();
if len < self.payload_suffix_len() {
Err(CryptoError)
} else {
let (body_left, nonce_loc) = body.split_at_mut(len - self.nonce_size());
Ok((nonce_loc, body_left))
}
},
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum CryptoState {
Aes256Gcm(Wrapping<u32>),
XChaCha20Poly1305(Wrapping<u32>),
}
impl From<CryptoMode> for CryptoState {
fn from(val: CryptoMode) -> Self {
match val {
CryptoMode::Aes256Gcm => CryptoState::Aes256Gcm(Wrapping(rand::random::<u32>())),
CryptoMode::XChaCha20Poly1305 =>
CryptoState::XChaCha20Poly1305(Wrapping(rand::random::<u32>())),
}
}
}
impl CryptoState {
pub fn write_packet_nonce(
&mut self,
packet: &mut impl MutablePacket,
payload_end: usize,
) -> usize {
let mode = self.kind();
let endpoint = payload_end + mode.payload_suffix_len();
let startpoint = endpoint - mode.nonce_size();
match self {
Self::Aes256Gcm(ref mut i) | Self::XChaCha20Poly1305(ref mut i) => {
(&mut packet.payload_mut()[startpoint..endpoint])
.write_u32::<NetworkEndian>(i.0)
.expect(
"Nonce size is guaranteed to be sufficient to write u32 for lite tagging.",
);
*i += Wrapping(1);
},
}
endpoint
}
#[must_use]
pub fn kind(self) -> CryptoMode {
CryptoMode::from(self)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub(crate) enum EncryptionAlgorithm {
Aes256Gcm,
XChaCha20Poly1305,
}
impl EncryptionAlgorithm {
#[must_use]
pub(crate) const fn encryption_tag_len(self) -> usize {
match self {
Self::Aes256Gcm => <Aes256Gcm as AeadCore>::TagSize::USIZE, Self::XChaCha20Poly1305 => <XChaCha20Poly1305 as AeadCore>::TagSize::USIZE, }
}
}
impl From<&Cipher> for EncryptionAlgorithm {
fn from(value: &Cipher) -> Self {
match value {
Cipher::XChaCha20Poly1305(_) => EncryptionAlgorithm::XChaCha20Poly1305,
Cipher::Aes256Gcm(_) => EncryptionAlgorithm::Aes256Gcm,
}
}
}
#[derive(Clone)]
pub enum Cipher {
XChaCha20Poly1305(XChaCha20Poly1305),
Aes256Gcm(Box<Aes256Gcm>),
}
impl Cipher {
#[must_use]
pub(crate) fn mode(&self) -> CryptoMode {
match self {
Cipher::XChaCha20Poly1305(_) => CryptoMode::XChaCha20Poly1305,
Cipher::Aes256Gcm(_) => CryptoMode::Aes256Gcm,
}
}
#[must_use]
pub(crate) fn encryption_tag_len(&self) -> usize {
EncryptionAlgorithm::from(self).encryption_tag_len()
}
#[inline]
pub fn encrypt_pkt_in_place(
&self,
packet: &mut impl MutablePacket,
payload_len: usize,
) -> Result<(), CryptoError> {
let mode = self.mode();
let header_len = packet.packet().len() - packet.payload().len();
let (header, body) = packet.packet_mut().split_at_mut(header_len);
let (slice_to_use, body_remaining) = mode.nonce_slice(header, &mut body[..payload_len])?;
let tag_size = self.encryption_tag_len();
let (_, body_remaining) = body_remaining.split_at_mut(mode.payload_prefix_len());
let (body, post_payload) =
body_remaining.split_at_mut(body_remaining.len() - mode.tag_suffix_len());
match self {
Self::Aes256Gcm(aes_gcm) => {
let mut nonce = aes_gcm::Nonce::default();
nonce[..mode.nonce_size()].copy_from_slice(slice_to_use);
let tag = aes_gcm.encrypt_in_place_detached(&nonce, header, body)?;
post_payload[..tag_size].copy_from_slice(&tag[..]);
},
Self::XChaCha20Poly1305(cha_cha_poly1305) => {
let mut nonce = chacha20poly1305::XNonce::default();
nonce[..mode.nonce_size()].copy_from_slice(slice_to_use);
let tag = cha_cha_poly1305.encrypt_in_place_detached(&nonce, header, body)?;
post_payload[..tag_size].copy_from_slice(&tag[..]);
},
}
Ok(())
}
#[cfg(any(feature = "receive", test))]
pub(crate) fn decrypt_rtp_in_place(
&self,
packet: &mut MutableRtpPacket<'_>,
) -> Result<(usize, usize), InternalError> {
let has_extension = packet.get_extension() != 0;
let plain_bytes = if has_extension {
RtpExtensionPacket::minimum_packet_size()
} else {
0
};
let (_, end) = self.decrypt_pkt_in_place(packet, plain_bytes)?;
let payload_offset = if has_extension {
let payload = packet.payload();
let extension =
RtpExtensionPacket::new(payload).ok_or(InternalError::IllegalVoicePacket)?;
extension.packet().len() - extension.payload().len()
} else {
0
};
Ok((payload_offset, end))
}
#[cfg(feature = "receive")]
pub(crate) fn decrypt_rtcp_in_place(
&self,
packet: &mut MutableRtcpPacket<'_>,
) -> Result<(usize, usize), InternalError> {
self.decrypt_pkt_in_place(packet, 0)
}
#[inline]
#[cfg(any(feature = "receive", test))]
pub(crate) fn decrypt_pkt_in_place(
&self,
packet: &mut impl MutablePacket,
n_plaintext_body_bytes: usize,
) -> Result<(usize, usize), InternalError> {
let mode = self.mode();
let header_len = packet.packet().len() - packet.payload().len();
let plaintext_end = header_len + n_plaintext_body_bytes;
let (plaintext, ciphertext) = packet
.packet_mut()
.split_at_mut_checked(plaintext_end)
.ok_or(CryptoError)?;
let (slice_to_use, body_remaining) = mode.nonce_slice(plaintext, ciphertext)?;
let (pre_payload, body_remaining) = body_remaining
.split_at_mut_checked(mode.payload_prefix_len())
.ok_or(CryptoError)?;
let suffix_split_point = body_remaining
.len()
.checked_sub(mode.tag_suffix_len())
.ok_or(CryptoError)?;
let (body, post_payload) = body_remaining
.split_at_mut_checked(suffix_split_point)
.ok_or(CryptoError)?;
let tag_size = self.encryption_tag_len();
match self {
Self::Aes256Gcm(aes_gcm) => {
let mut nonce = aes_gcm::Nonce::default();
nonce[..mode.nonce_size()].copy_from_slice(slice_to_use);
let tag = aes_gcm::Tag::from_slice(&post_payload[..tag_size]);
aes_gcm.decrypt_in_place_detached(&nonce, plaintext, body, tag)?;
},
Self::XChaCha20Poly1305(cha_cha_poly1305) => {
let mut nonce = chacha20poly1305::XNonce::default();
nonce[..mode.nonce_size()].copy_from_slice(slice_to_use);
let tag = chacha20poly1305::Tag::from_slice(&post_payload[..tag_size]);
cha_cha_poly1305.decrypt_in_place_detached(&nonce, plaintext, body, tag)?;
},
}
Ok((
plaintext_end + pre_payload.len(),
post_payload.len() + slice_to_use.len(),
))
}
}
#[cfg(test)]
mod test {
use super::*;
use discortp::rtp::MutableRtpPacket;
#[test]
fn small_packet_decrypts_error() {
let mut buf = [0u8; MutableRtpPacket::minimum_packet_size()];
let modes = [CryptoMode::Aes256Gcm, CryptoMode::XChaCha20Poly1305];
let mut pkt = MutableRtpPacket::new(&mut buf[..]).unwrap();
for mode in modes {
let cipher = mode.cipher_from_key(&[1u8; 32]).unwrap();
assert!(cipher.decrypt_rtp_in_place(&mut pkt).is_err());
}
}
#[test]
fn symmetric_encrypt_decrypt_tag_after_data() {
const TRUE_PAYLOAD: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8];
for mode in [CryptoMode::Aes256Gcm, CryptoMode::XChaCha20Poly1305] {
let mut buf = vec![
0u8;
MutableRtpPacket::minimum_packet_size()
+ TRUE_PAYLOAD.len()
+ mode.nonce_size()
+ mode.encryption_tag_len()
];
buf.fill(0);
let cipher = mode.cipher_from_key(&[7u8; 32]).unwrap();
let mut pkt = MutableRtpPacket::new(&mut buf[..]).unwrap();
let mut crypto_state = CryptoState::from(mode);
let payload = pkt.payload_mut();
payload[mode.payload_prefix_len()..TRUE_PAYLOAD.len()].copy_from_slice(&TRUE_PAYLOAD);
let final_payload_size = crypto_state.write_packet_nonce(&mut pkt, TRUE_PAYLOAD.len());
let enc_succ = cipher.encrypt_pkt_in_place(&mut pkt, final_payload_size);
assert!(enc_succ.is_ok());
let final_pkt_len = MutableRtpPacket::minimum_packet_size() + final_payload_size;
let mut pkt = MutableRtpPacket::new(&mut buf[..final_pkt_len]).unwrap();
assert!(cipher.decrypt_rtp_in_place(&mut pkt).is_ok());
}
}
#[test]
fn negotiate_cryptomode() {
let test_set =
[CryptoMode::XChaCha20Poly1305, CryptoMode::Aes256Gcm].map(CryptoMode::to_request_str);
assert_eq!(
CryptoMode::negotiate(test_set, None).unwrap(),
CryptoMode::Aes256Gcm
);
let test_set_missing = [CryptoMode::XChaCha20Poly1305].map(CryptoMode::to_request_str);
assert_eq!(
CryptoMode::negotiate(test_set_missing, Some(CryptoMode::Aes256Gcm)).unwrap(),
CryptoMode::XChaCha20Poly1305
);
assert_eq!(
CryptoMode::negotiate(test_set, Some(CryptoMode::XChaCha20Poly1305)).unwrap(),
CryptoMode::XChaCha20Poly1305
);
let bad_modes = ["not_real", "des", "rc5"];
assert!(CryptoMode::negotiate(bad_modes, None).is_err());
assert!(CryptoMode::negotiate(bad_modes, Some(CryptoMode::Aes256Gcm)).is_err());
}
}