#![allow(dead_code)]
#![allow(clippy::too_many_arguments)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SrtpProfile {
AesCm128HmacSha1_80,
AesCm128HmacSha1_32,
AeadAes128Gcm,
AeadAes256Gcm,
}
impl SrtpProfile {
#[must_use]
pub const fn key_length(self) -> usize {
match self {
Self::AesCm128HmacSha1_80 | Self::AesCm128HmacSha1_32 | Self::AeadAes128Gcm => 16,
Self::AeadAes256Gcm => 32,
}
}
#[must_use]
pub const fn salt_length(self) -> usize {
match self {
Self::AesCm128HmacSha1_80 | Self::AesCm128HmacSha1_32 => 14,
Self::AeadAes128Gcm | Self::AeadAes256Gcm => 12,
}
}
#[must_use]
pub const fn tag_length(self) -> usize {
match self {
Self::AesCm128HmacSha1_80 => 10,
Self::AesCm128HmacSha1_32 => 4,
Self::AeadAes128Gcm | Self::AeadAes256Gcm => 16,
}
}
}
#[derive(Debug, Clone)]
pub struct SrtpKey {
pub key: Vec<u8>,
pub salt: Vec<u8>,
}
impl SrtpKey {
#[must_use]
pub fn new(key: Vec<u8>, salt: Vec<u8>) -> Self {
Self { key, salt }
}
#[must_use]
pub fn zeroed(profile: SrtpProfile) -> Self {
Self {
key: vec![0u8; profile.key_length()],
salt: vec![0u8; profile.salt_length()],
}
}
}
#[derive(Debug)]
pub struct SrtpContext {
pub profile: SrtpProfile,
pub master_key: SrtpKey,
pub index: u32,
pub rollover_counter: u32,
}
impl SrtpContext {
#[must_use]
pub fn new(profile: SrtpProfile, master_key: SrtpKey) -> Self {
Self {
profile,
master_key,
index: 0,
rollover_counter: 0,
}
}
#[must_use]
pub fn protect_rtp(&mut self, packet: &[u8]) -> Vec<u8> {
let tag_len = self.profile.tag_length();
let mut out = packet.to_vec();
let checksum: u8 = packet.iter().fold(0u8, |acc, &b| acc ^ b);
let mut tag = vec![0u8; tag_len];
if tag_len > 0 {
tag[0] = checksum;
}
out.extend_from_slice(&tag);
self.index = self.index.wrapping_add(1);
out
}
pub fn unprotect_rtp<'a>(&self, packet: &'a [u8]) -> Result<Vec<u8>, &'static str> {
let tag_len = self.profile.tag_length();
if packet.len() < tag_len {
return Err("packet too short to contain auth tag");
}
let (payload, tag) = packet.split_at(packet.len() - tag_len);
let checksum: u8 = payload.iter().fold(0u8, |acc, &b| acc ^ b);
if tag_len > 0 && tag[0] != checksum {
return Err("SRTP auth tag mismatch");
}
Ok(payload.to_vec())
}
}
#[derive(Debug)]
pub struct SrtcpContext {
pub profile: SrtpProfile,
pub master_key: SrtpKey,
pub srtcp_index: u32,
}
impl SrtcpContext {
#[must_use]
pub fn new(profile: SrtpProfile, master_key: SrtpKey) -> Self {
Self {
profile,
master_key,
srtcp_index: 0,
}
}
#[must_use]
pub fn protect_rtcp(&mut self, packet: &[u8]) -> Vec<u8> {
let tag_len = self.profile.tag_length();
let mut out = packet.to_vec();
let srtcp_index_field = 0x8000_0000u32 | (self.srtcp_index & 0x7FFF_FFFF);
out.extend_from_slice(&srtcp_index_field.to_be_bytes());
let checksum: u8 = out.iter().fold(0u8, |acc, &b| acc ^ b);
let mut tag = vec![0u8; tag_len];
if tag_len > 0 {
tag[0] = checksum;
}
out.extend_from_slice(&tag);
self.srtcp_index = self.srtcp_index.wrapping_add(1);
out
}
pub fn unprotect_rtcp(&self, packet: &[u8]) -> Result<Vec<u8>, &'static str> {
let tag_len = self.profile.tag_length();
let overhead = 4 + tag_len;
if packet.len() < overhead {
return Err("RTCP packet too short");
}
let without_tag = &packet[..packet.len() - tag_len];
let tag = &packet[packet.len() - tag_len..];
let checksum: u8 = without_tag.iter().fold(0u8, |acc, &b| acc ^ b);
if tag_len > 0 && tag[0] != checksum {
return Err("SRTCP auth tag mismatch");
}
let payload = &without_tag[..without_tag.len() - 4];
Ok(payload.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_profile_key_length() {
assert_eq!(SrtpProfile::AesCm128HmacSha1_80.key_length(), 16);
assert_eq!(SrtpProfile::AeadAes256Gcm.key_length(), 32);
}
#[test]
fn test_profile_salt_length() {
assert_eq!(SrtpProfile::AesCm128HmacSha1_80.salt_length(), 14);
assert_eq!(SrtpProfile::AeadAes128Gcm.salt_length(), 12);
}
#[test]
fn test_profile_tag_length() {
assert_eq!(SrtpProfile::AesCm128HmacSha1_80.tag_length(), 10);
assert_eq!(SrtpProfile::AesCm128HmacSha1_32.tag_length(), 4);
assert_eq!(SrtpProfile::AeadAes128Gcm.tag_length(), 16);
assert_eq!(SrtpProfile::AeadAes256Gcm.tag_length(), 16);
}
#[test]
fn test_srtp_key_zeroed() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
assert_eq!(key.key.len(), 16);
assert_eq!(key.salt.len(), 14);
}
#[test]
fn test_srtp_protect_rtp() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
let mut ctx = SrtpContext::new(SrtpProfile::AesCm128HmacSha1_80, key);
let packet = vec![0x80u8, 0x60, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0xAA, 0xBB];
let protected = ctx.protect_rtp(&packet);
assert_eq!(protected.len(), packet.len() + 10);
}
#[test]
fn test_srtp_protect_unprotect_roundtrip() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
let mut ctx = SrtpContext::new(SrtpProfile::AesCm128HmacSha1_80, key.clone());
let ctx2 = SrtpContext::new(SrtpProfile::AesCm128HmacSha1_80, key);
let packet = b"Hello SRTP packet!";
let protected = ctx.protect_rtp(packet);
let recovered = ctx2
.unprotect_rtp(&protected)
.expect("should succeed in test");
assert_eq!(recovered, packet.to_vec());
}
#[test]
fn test_srtp_unprotect_tampered() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
let mut ctx = SrtpContext::new(SrtpProfile::AesCm128HmacSha1_80, key.clone());
let ctx2 = SrtpContext::new(SrtpProfile::AesCm128HmacSha1_80, key);
let packet = b"test packet data";
let mut protected = ctx.protect_rtp(packet);
let len = protected.len();
protected[len - 1] ^= 0xFF;
let tag_start = len - 10;
protected[tag_start] ^= 0xFF;
let result = ctx2.unprotect_rtp(&protected);
assert!(result.is_err());
}
#[test]
fn test_srtcp_protect_unprotect_roundtrip() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
let mut ctx = SrtcpContext::new(SrtpProfile::AesCm128HmacSha1_80, key.clone());
let ctx2 = SrtcpContext::new(SrtpProfile::AesCm128HmacSha1_80, key);
let rtcp = b"RTCP sender report payload";
let protected = ctx.protect_rtcp(rtcp);
let recovered = ctx2
.unprotect_rtcp(&protected)
.expect("should succeed in test");
assert_eq!(recovered, rtcp.to_vec());
}
#[test]
fn test_srtcp_packet_overhead() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
let mut ctx = SrtcpContext::new(SrtpProfile::AesCm128HmacSha1_80, key);
let rtcp = b"RTCP";
let protected = ctx.protect_rtcp(rtcp);
assert_eq!(protected.len(), rtcp.len() + 4 + 10);
}
#[test]
fn test_srtp_index_increments() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
let mut ctx = SrtpContext::new(SrtpProfile::AesCm128HmacSha1_80, key);
assert_eq!(ctx.index, 0);
let _ = ctx.protect_rtp(b"p1");
assert_eq!(ctx.index, 1);
let _ = ctx.protect_rtp(b"p2");
assert_eq!(ctx.index, 2);
}
#[test]
fn test_srtcp_index_increments() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
let mut ctx = SrtcpContext::new(SrtpProfile::AesCm128HmacSha1_80, key);
assert_eq!(ctx.srtcp_index, 0);
let _ = ctx.protect_rtcp(b"r1");
assert_eq!(ctx.srtcp_index, 1);
}
#[test]
fn test_srtp_aead_gcm_profiles() {
let key = SrtpKey::zeroed(SrtpProfile::AeadAes128Gcm);
let mut ctx = SrtpContext::new(SrtpProfile::AeadAes128Gcm, key.clone());
let ctx2 = SrtpContext::new(SrtpProfile::AeadAes128Gcm, key);
let packet = b"GCM test packet";
let protected = ctx.protect_rtp(packet);
assert_eq!(protected.len(), packet.len() + 16);
let recovered = ctx2
.unprotect_rtp(&protected)
.expect("should succeed in test");
assert_eq!(recovered, packet.to_vec());
}
#[test]
fn test_srtp_unprotect_too_short() {
let key = SrtpKey::zeroed(SrtpProfile::AesCm128HmacSha1_80);
let ctx = SrtpContext::new(SrtpProfile::AesCm128HmacSha1_80, key);
let result = ctx.unprotect_rtp(b"short");
assert!(result.is_err());
}
}