use anyhow::{Context, bail};
use libsrtp::{MasterKey, ProtectionProfile, RecvSession, SendSession, SrtpError, StreamConfig};
use std::sync::{Arc, Mutex};
use test_utils::*;
fn one_stream_mki(
packet_num: usize, payload_size: usize, seq_num: u16, ssrc: u32, rtp_profile: ProtectionProfile,
rtcp_profile: ProtectionProfile,
unordered_decrypt: bool,
) -> anyhow::Result<()> {
let mut s = SendSession::new();
let mut r = RecvSession::new();
let (master_key1, master_salt1) = generate_keys(&rtp_profile);
let (master_key2, master_salt2) = generate_keys(&rtp_profile);
let (master_key3, master_salt3) = generate_keys(&rtp_profile);
let mkis = vec![
Some(vec![0x01, 0x23, 0x45]),
Some(vec![0x67, 0x89, 0xab]),
Some(vec![0xcd, 0xef, 0x01]),
];
let mut config = StreamConfig::new(
vec![
MasterKey::new(&master_key1, &master_salt1, &mkis[0]),
MasterKey::new(&master_key2, &master_salt2, &mkis[1]),
MasterKey::new(&master_key3, &master_salt3, &mkis[2]),
],
&rtp_profile,
&rtcp_profile,
);
s.add_stream(Some(ssrc), &config)
.map_err(anyhow::Error::from)
.with_context(|| ("failed to add send stream").to_string())?;
if unordered_decrypt && (config.get_replay_window_size() as usize) < packet_num * mkis.len() {
config.set_replay_window_size((packet_num * mkis.len() + 1) as u16);
}
r.add_stream(Some(ssrc), &config)
.map_err(anyhow::Error::from)
.with_context(|| ("failed to add recv stream").to_string())?;
let mut rtp_stream = Vec::<Vec<u8>>::new();
let mut rtcp_stream = Vec::<Vec<u8>>::new();
let mut srtp_stream = Vec::<Vec<u8>>::new();
let mut srtcp_stream = Vec::<Vec<u8>>::new();
let mut seq = seq_num;
for mki in &mkis {
for _i in 0..packet_num {
let rtp = create_rtp_packet(payload_size, ssrc, seq);
let rtcp = create_rtcp_packet(payload_size / 4, ssrc);
srtp_stream.push(
s.rtp_protect_mki(rtp.clone(), mki)
.map_err(anyhow::Error::from)
.with_context(|| ("rtp protect failed").to_string())?,
);
rtp_stream.push(rtp);
srtcp_stream.push(
s.rtcp_protect_mki(rtcp.clone(), mki)
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp protect failed").to_string())?,
);
rtcp_stream.push(rtcp);
seq = seq.wrapping_add(1); }
}
let mut range: Vec<usize> = (0..packet_num * mkis.len()).collect();
if unordered_decrypt {
range = rnd_range(0, packet_num * mkis.len());
}
let penultimate_index = range[range.len() - 2];
for i in range {
if r.rtp_unprotect(srtp_stream[i].clone())
.map_err(anyhow::Error::from)
.with_context(|| ("rtp unprotect failed").to_string())?
!= rtp_stream[i]
{
bail!("rtp decrypt didn't match plain");
}
if r.rtcp_unprotect(srtcp_stream[i].clone())
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp unprotect failed").to_string())?
!= rtcp_stream[i]
{
bail!("rtcp decrypt didn't match plain");
}
}
if r.rtp_unprotect(srtp_stream[penultimate_index].clone()) != Err(SrtpError::InvalidPacketIndex)
{
bail!("rtp fails to detect replay");
}
if r.rtcp_unprotect(srtcp_stream[penultimate_index].clone())
!= Err(SrtpError::InvalidPacketIndex)
{
bail!("rtcp fails to detect replay");
}
Ok(())
}
#[test]
fn simple_stream() -> anyhow::Result<()> {
let packet_num: usize = 42;
let payload_size: usize = 123;
let ssrc: u32 = 0xcafebabe;
let seq_num: u16 = 0x0123;
for (rtp_profile, rtcp_profile) in VALID_PROFILES {
one_stream_mki(
packet_num,
payload_size,
seq_num,
ssrc,
rtp_profile,
rtcp_profile,
false,
)
.with_context(|| {
format!(
"failed with profiles {:?}/{:?} packet num {packet_num}",
rtp_profile, rtcp_profile,
)
})?;
}
Ok(())
}
#[test]
fn empty_payload() -> anyhow::Result<()> {
let packet_num: usize = 42;
let payload_size: usize = 0;
let ssrc: u32 = 0xcafebabe;
let seq_num: u16 = 0x0123;
for (rtp_profile, rtcp_profile) in VALID_PROFILES {
one_stream_mki(
packet_num,
payload_size,
seq_num,
ssrc,
rtp_profile,
rtcp_profile,
false,
)
.with_context(|| {
format!(
"failed with profiles {:?}/{:?} packet num {packet_num}",
rtp_profile, rtcp_profile,
)
})?;
}
Ok(())
}
#[test]
fn simple_stream_unordered_decrypt() -> anyhow::Result<()> {
let packet_num: usize = 45;
let payload_size: usize = 123;
let ssrc: u32 = 0xcafebabe;
let seq_num: u16 = 0x0123;
for (rtp_profile, rtcp_profile) in VALID_PROFILES {
one_stream_mki(
packet_num,
payload_size,
seq_num,
ssrc,
rtp_profile,
rtcp_profile,
true,
)
.with_context(|| {
format!(
"failed with profiles {:?}/{:?} packet num {packet_num}",
rtp_profile, rtcp_profile,
)
})?;
}
Ok(())
}
fn one_stream_mki_update(
packet_num: usize, payload_size: usize, seq_num: u16, ssrc: u32, rtp_profile: ProtectionProfile,
rtcp_profile: ProtectionProfile,
) -> anyhow::Result<()> {
let mut s = SendSession::new();
let mut r = RecvSession::new();
let (master_key1, master_salt1) = generate_keys(&rtp_profile);
let (master_key2, master_salt2) = generate_keys(&rtp_profile);
let (master_key3, master_salt3) = generate_keys(&rtp_profile);
let mkis = [
Some(vec![0x01, 0x23, 0x45]),
Some(vec![0x67, 0x89, 0xab]),
Some(vec![0xcd, 0xef, 0x01]),
];
let mut r_config = StreamConfig::new(
vec![
MasterKey::new(&master_key1, &master_salt1, &mkis[0]),
MasterKey::new(&master_key2, &master_salt2, &mkis[1]),
MasterKey::new(&master_key3, &master_salt3, &mkis[2]),
],
&rtp_profile,
&rtcp_profile,
);
if (r_config.get_replay_window_size() as usize) < (packet_num * (mkis.len() + 1) + 1) {
r_config.set_replay_window_size((packet_num * (mkis.len() + 1) + 1) as u16);
}
r.add_stream(Some(ssrc), &r_config)
.map_err(anyhow::Error::from)
.with_context(|| ("failed to add recv stream").to_string())?;
let s_config = StreamConfig::new(
vec![MasterKey::new(&master_key1, &master_salt1, &mkis[0])],
&rtp_profile,
&rtcp_profile,
);
s.add_stream(Some(ssrc), &s_config)
.map_err(anyhow::Error::from)
.with_context(|| ("failed to add send stream").to_string())?;
let mut rtp_stream = Vec::<Vec<u8>>::new();
let mut rtcp_stream = Vec::<Vec<u8>>::new();
let mut srtp_stream = Vec::<Vec<u8>>::new();
let mut srtcp_stream = Vec::<Vec<u8>>::new();
let mut seq = seq_num;
for _i in 0..packet_num {
let rtp = create_rtp_packet(payload_size, ssrc, seq);
let rtcp = create_rtcp_packet(payload_size / 4, ssrc);
srtp_stream.push(
s.rtp_protect_mki(rtp.clone(), &mkis[0])
.map_err(anyhow::Error::from)
.with_context(|| ("rtp protect failed").to_string())?,
);
rtp_stream.push(rtp);
srtcp_stream.push(
s.rtcp_protect_mki(rtcp.clone(), &mkis[0])
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp protect failed").to_string())?,
);
rtcp_stream.push(rtcp);
seq = seq.wrapping_add(1); }
let rtp = create_rtp_packet(payload_size, ssrc, seq);
assert_eq!(s.rtp_protect_mki(rtp, &mkis[1]), Err(SrtpError::InvalidMki));
let s_config = StreamConfig::new(
vec![
MasterKey::new(&master_key2, &master_salt2, &mkis[1]),
MasterKey::new(&master_key3, &master_salt3, &mkis[2]),
],
&rtp_profile,
&rtcp_profile,
);
s.add_stream(Some(ssrc), &s_config)
.map_err(anyhow::Error::from)
.with_context(|| ("failed to add send stream").to_string())?;
for _i in 0..packet_num {
for mki in &mkis {
let rtp = create_rtp_packet(payload_size, ssrc, seq);
let rtcp = create_rtcp_packet(payload_size / 4, ssrc);
srtp_stream.push(
s.rtp_protect_mki(rtp.clone(), mki)
.map_err(anyhow::Error::from)
.with_context(|| ("rtp protect failed").to_string())?,
);
rtp_stream.push(rtp);
srtcp_stream.push(
s.rtcp_protect_mki(rtcp.clone(), mki)
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp protect failed").to_string())?,
);
rtcp_stream.push(rtcp);
seq = seq.wrapping_add(1); }
}
let range = rnd_range(0, packet_num * mkis.len());
let penultimate_index = range[range.len() - 2];
for i in range {
if r.rtp_unprotect(srtp_stream[i].clone())
.map_err(anyhow::Error::from)
.with_context(|| ("rtp unprotect failed").to_string())?
!= rtp_stream[i]
{
bail!("rtp decrypt didn't match plain");
}
if r.rtcp_unprotect(srtcp_stream[i].clone())
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp unprotect failed").to_string())?
!= rtcp_stream[i]
{
bail!("rtcp decrypt didn't match plain");
}
}
if r.rtp_unprotect(srtp_stream[penultimate_index].clone()) != Err(SrtpError::InvalidPacketIndex)
{
bail!("rtp fails to detect replay");
}
if r.rtcp_unprotect(srtcp_stream[penultimate_index].clone())
!= Err(SrtpError::InvalidPacketIndex)
{
bail!("rtcp fails to detect replay");
}
Ok(())
}
#[test]
fn simple_stream_with_update() -> anyhow::Result<()> {
let packet_num: usize = 45;
let payload_size: usize = 123;
let ssrc: u32 = 0xcafebabe;
let seq_num: u16 = 0x0123;
for (rtp_profile, rtcp_profile) in VALID_PROFILES {
one_stream_mki_update(
packet_num,
payload_size,
seq_num,
ssrc,
rtp_profile,
rtcp_profile,
)
.with_context(|| {
format!(
"failed with profiles {:?}/{:?} packet num {packet_num}",
rtp_profile, rtcp_profile,
)
})?;
}
Ok(())
}
fn multi_stream_mki_update_with_key_limit(
packet_num: usize, payload_size: usize, stream_number: usize, rtp_profile: ProtectionProfile,
rtcp_profile: ProtectionProfile,
) -> anyhow::Result<()> {
let mut s = SendSession::new();
let mut r = RecvSession::new();
let soft_limit_count = Arc::new(Mutex::new(0u32));
let hard_limit_count = Arc::new(Mutex::new(0u32));
let last_ssrc_alert = Arc::new(Mutex::new(0u32));
let last_mki_alert: Arc<Mutex<Option<Vec<u8>>>> = Arc::new(Mutex::new(None));
let soft_limit_clone = soft_limit_count.clone();
let hard_limit_clone = hard_limit_count.clone();
let last_ssrc_clone = last_ssrc_alert.clone();
let last_mki_clone = last_mki_alert.clone();
let handler = move |err: SrtpError| match err {
SrtpError::KeyLimit {
is_dead, ssrc, mki, ..
} => {
if is_dead {
*hard_limit_clone.lock().unwrap() += 1;
} else {
*soft_limit_clone.lock().unwrap() += 1;
}
*last_ssrc_clone.lock().unwrap() = ssrc;
*last_mki_clone.lock().unwrap() = mki;
}
_ => {
panic!("unexpected error received by key limit handler : {:?}", err);
}
};
s.set_key_limit_handler(handler).unwrap();
let (master_key1, master_salt1) = generate_keys(&rtp_profile);
let (master_key2, master_salt2) = generate_keys(&rtp_profile);
let (master_key3, master_salt3) = generate_keys(&rtp_profile);
let mkis = [
Some(vec![0x01, 0x23, 0x45]),
Some(vec![0x67, 0x89, 0xab]),
Some(vec![0xcd, 0xef, 0x01]),
];
let mut r_config = StreamConfig::new(
vec![
MasterKey::new(&master_key1, &master_salt1, &mkis[0]),
MasterKey::new(&master_key2, &master_salt2, &mkis[1]),
MasterKey::new(&master_key3, &master_salt3, &mkis[2]),
],
&rtp_profile,
&rtcp_profile,
);
if (r_config.get_replay_window_size() as usize) < (packet_num * mkis.len() + 1) {
r_config.set_replay_window_size((packet_num * mkis.len() + 1) as u16);
}
r.add_stream(None, &r_config) .map_err(anyhow::Error::from)
.with_context(|| ("failed to add recv stream").to_string())?;
let mut s_config = StreamConfig::new(
vec![MasterKey::new(&master_key1, &master_salt1, &mkis[0])],
&rtp_profile,
&rtcp_profile,
);
s_config
.set_keys_lifetime(packet_num as u64, 5, packet_num as u32, 2)
.expect("fail to set keys lifetime");
s.add_stream(None, &s_config)
.map_err(anyhow::Error::from)
.with_context(|| ("failed to add send stream").to_string())?;
let mut rtp_stream = Vec::<Vec<u8>>::new();
let mut rtcp_stream = Vec::<Vec<u8>>::new();
let mut srtp_stream = Vec::<Vec<u8>>::new();
let mut srtcp_stream = Vec::<Vec<u8>>::new();
let mut seq_nums = get_seq_nums(stream_number);
let ssrcs = get_ssrcs(stream_number);
for _i in 0..(packet_num - 1) {
for j in 0..stream_number {
let rtp = create_rtp_packet(payload_size, ssrcs[j], seq_nums[j]);
let rtcp = create_rtcp_packet(payload_size / 4, ssrcs[j]);
srtp_stream.push(
s.rtp_protect_mki(rtp.clone(), &mkis[0])
.map_err(anyhow::Error::from)
.with_context(|| ("rtp protect failed").to_string())?,
);
rtp_stream.push(rtp);
srtcp_stream.push(
s.rtcp_protect_mki(rtcp.clone(), &mkis[0])
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp protect failed").to_string())?,
);
rtcp_stream.push(rtcp);
seq_nums[j] = seq_nums[j].wrapping_add(1u16); }
}
assert_eq!(
*soft_limit_count.lock().unwrap(),
(5 * stream_number) as u32
);
assert_eq!(*last_ssrc_alert.lock().unwrap(), *ssrcs.last().unwrap());
assert_eq!(*last_mki_alert.lock().unwrap(), mkis[0]);
let mut s_config = StreamConfig::new(
vec![
MasterKey::new(&master_key2, &master_salt2, &mkis[1]),
MasterKey::new(&master_key3, &master_salt3, &mkis[2]),
],
&rtp_profile,
&rtcp_profile,
);
s_config
.set_keys_lifetime(packet_num as u64, 5, packet_num as u32, 2)
.expect("fail to set keys lifetime");
s.add_stream(None, &s_config) .map_err(anyhow::Error::from)
.with_context(|| ("failed to add send stream").to_string())?;
for _i in 0..packet_num - 1 {
for mki in mkis.iter().skip(1) {
for j in 0..stream_number {
let rtp = create_rtp_packet(payload_size, ssrcs[j], seq_nums[j]);
let rtcp = create_rtcp_packet(payload_size / 4, ssrcs[j]);
srtp_stream.push(
s.rtp_protect_mki(rtp.clone(), mki)
.map_err(anyhow::Error::from)
.with_context(|| ("rtp protect failed").to_string())?,
);
rtp_stream.push(rtp);
srtcp_stream.push(
s.rtcp_protect_mki(rtcp.clone(), mki)
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp protect failed").to_string())?,
);
rtcp_stream.push(rtcp);
seq_nums[j] = seq_nums[j].wrapping_add(1u16); }
}
}
assert_eq!(
*soft_limit_count.lock().unwrap(),
(5 * stream_number * mkis.len()) as u32
);
assert_eq!(*last_ssrc_alert.lock().unwrap(), *ssrcs.last().unwrap());
assert_eq!(*last_mki_alert.lock().unwrap(), mkis[2]);
let rtp = create_rtp_packet(payload_size, ssrcs[0], seq_nums[0]);
let rtcp = create_rtcp_packet(payload_size / 4, ssrcs[0]);
srtp_stream.push(
s.rtp_protect_mki(rtp.clone(), &mkis[0])
.map_err(anyhow::Error::from)
.with_context(|| ("rtp protect failed").to_string())?,
);
rtp_stream.push(rtp);
seq_nums[0] += 1;
assert_eq!(
*soft_limit_count.lock().unwrap(),
(5 * stream_number * mkis.len() + 1) as u32
);
let rtp = create_rtp_packet(payload_size, ssrcs[0], seq_nums[0]);
assert!(matches!(
s.rtp_protect_mki(rtp.clone(), &mkis[0]),
Err(SrtpError::KeyLimit { is_dead: true, is_rtp: true, mki: ref err_mki, ssrc})
if err_mki.as_slice() == mkis[0].as_slice() && ssrc == ssrcs[0]
));
assert!(matches!(
s.rtcp_protect_mki(rtcp.clone(), &mkis[0]),
Err(SrtpError::KeyLimit { is_dead: true, is_rtp: false, mki: ref err_mki, ssrc})
if err_mki.as_slice() == mkis[0].as_slice() && ssrc == ssrcs[0]
));
assert_eq!(*hard_limit_count.lock().unwrap(), 2);
let rtp = create_rtp_packet(payload_size, ssrcs[1], seq_nums[1]);
assert!(matches!(
s.rtp_protect_mki(rtp.clone(), &mkis[0]),
Err(SrtpError::KeyLimit { is_dead: true, is_rtp: true, mki: ref err_mki, ssrc})
if err_mki.as_slice() == mkis[0].as_slice() && ssrc == ssrcs[1]
));
let rtcp = create_rtcp_packet(payload_size / 4, ssrcs[1]);
srtcp_stream.push(
s.rtcp_protect_mki(rtcp.clone(), &mkis[1])
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp protect failed").to_string())?,
);
rtcp_stream.push(rtcp);
let rtcp = create_rtcp_packet(payload_size / 4, ssrcs[1]);
assert!(matches!(
s.rtcp_protect_mki(rtcp.clone(), &mkis[1]),
Err(SrtpError::KeyLimit { is_dead: true, is_rtp: false, mki: ref err_mki, ssrc})
if err_mki.as_slice() == mkis[1].as_slice() && ssrc == ssrcs[1]
));
let rtp = create_rtp_packet(payload_size, ssrcs[0], seq_nums[0]);
assert!(matches!(
s.rtp_protect_mki(rtp.clone(), &mkis[1]),
Err(SrtpError::KeyLimit { is_dead: true, is_rtp: true, mki: ref err_mki, ssrc})
if err_mki.as_slice() == mkis[1].as_slice() && ssrc == ssrcs[0]
));
let new_ssrcs = get_ssrcs(1);
let rtp = create_rtp_packet(payload_size, new_ssrcs[0], 0x1234);
assert!(matches!(
s.rtp_protect_mki(rtp.clone(), &mkis[1]),
Err(SrtpError::KeyLimit { is_dead: true, is_rtp: true, mki: ref err_mki, ssrc})
if err_mki.as_slice() == mkis[1].as_slice() && ssrc == new_ssrcs[0]
));
let range = rnd_range(0, srtp_stream.len());
for i in range {
if r.rtp_unprotect(srtp_stream[i].clone())
.map_err(anyhow::Error::from)
.with_context(|| ("rtp unprotect failed").to_string())?
!= rtp_stream[i]
{
bail!("rtp decrypt didn't match plain");
}
}
let range = rnd_range(0, srtcp_stream.len());
for i in range {
if r.rtcp_unprotect(srtcp_stream[i].clone())
.map_err(anyhow::Error::from)
.with_context(|| ("rtcp unprotect failed").to_string())?
!= rtcp_stream[i]
{
bail!("rtcp decrypt didn't match plain");
}
}
Ok(())
}
#[test]
fn multi_stream_mki_key_limit() -> anyhow::Result<()> {
let packet_num: usize = 6;
let payload_size: usize = 42;
let stream_number = 2;
for (rtp_profile, rtcp_profile) in VALID_PROFILES {
multi_stream_mki_update_with_key_limit(
packet_num,
payload_size,
stream_number,
rtp_profile,
rtcp_profile,
)
.with_context(|| {
format!(
"failed with profiles {:?}/{:?} packet num {packet_num}",
rtp_profile, rtcp_profile,
)
})?;
}
Ok(())
}