use std::any::Any;
use std::convert::TryInto;
use std::ffi::c_void;
use std::fmt;
use std::mem::MaybeUninit;
use std::os::raw::c_int;
use std::ptr;
use foreign_types::{ForeignType, ForeignTypeRef};
use srtp2_sys as sys;
use crate::crypto_policy::CryptoPolicy;
use crate::error::{Error, Result};
use crate::vec_like::VecLike;
foreign_types::foreign_type! {
pub unsafe type Session: Send + Sync {
type CType = sys::srtp_ctx_t;
fn drop = sys::srtp_dealloc;
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct StreamPolicy<'a> {
pub rtp: CryptoPolicy,
pub rtcp: CryptoPolicy,
pub key: &'a [u8],
pub window_size: u64,
pub allow_repeat_tx: bool,
pub encrypt_extension_headers: &'a [i32],
}
type EventHandler = Option<
Box<dyn FnMut(&mut SessionRef, u32, Option<&mut (dyn Any + Send + 'static)>) + Send + 'static>,
>;
#[derive(Default)]
pub(crate) struct UserDataWrapper {
pub(crate) user_data: Option<Box<dyn Any + Send + 'static>>,
on_ssrc_collision: EventHandler,
on_key_hard_limit: EventHandler,
on_key_soft_limit: EventHandler,
}
impl Session {
pub fn new() -> Result<Self> {
crate::ensure_init();
let mut session: MaybeUninit<sys::srtp_t> = MaybeUninit::uninit();
unsafe {
Error::check(sys::srtp_create(session.as_mut_ptr(), ptr::null_mut()))?;
Ok(Session::from_ptr(session.assume_init()))
}
}
pub fn with_inbound_template(policy: StreamPolicy<'_>) -> Result<Self> {
crate::ensure_init();
let mut session: MaybeUninit<sys::srtp_t> = MaybeUninit::uninit();
let mut policy = policy.sys_policy()?;
policy.ssrc.type_ = sys::srtp_ssrc_type_t_ssrc_any_inbound;
unsafe {
Error::check(sys::srtp_create(session.as_mut_ptr(), &policy))?;
Ok(Session::from_ptr(session.assume_init()))
}
}
pub fn with_outbound_template(policy: StreamPolicy<'_>) -> Result<Self> {
crate::ensure_init();
let mut session: MaybeUninit<sys::srtp_t> = MaybeUninit::uninit();
let mut policy = policy.sys_policy()?;
policy.ssrc.type_ = sys::srtp_ssrc_type_t_ssrc_any_outbound;
unsafe {
Error::check(sys::srtp_create(session.as_mut_ptr(), &policy))?;
Ok(Session::from_ptr(session.assume_init()))
}
}
}
impl SessionRef {
unsafe fn overwrite<T: VecLike>(
&mut self,
buf: &mut T,
reserve: bool,
func: unsafe extern "C" fn(sys::srtp_t, *mut c_void, *mut c_int) -> sys::srtp_err_status_t,
) -> Result<()> {
if reserve {
if let Err(err) = buf.reserve(sys::SRTP_MAX_TRAILER_LEN as usize) {
error!("`buf.reserve()` failed: {}", err);
return Err(Error::BAD_PARAM);
}
}
let bytes = buf.as_mut_bytes();
let orig_length = bytes.len();
let head_ptr = bytes.as_mut_ptr() as *mut c_void;
let mut length: c_int = match orig_length.try_into() {
Ok(len) => len,
Err(err) => {
error!("Cannot convert the length of the `key` into c_int: {}", err);
return Err(Error::BAD_PARAM);
}
};
let res = Error::check(func(self.as_ptr(), head_ptr, &mut length));
if let Err(err) = res {
buf.set_len(0);
return Err(err);
}
#[cfg(debug_assertions)]
if reserve {
assert!(length as usize <= orig_length + sys::SRTP_MAX_TRAILER_LEN as usize)
} else {
assert!(length as usize <= orig_length)
}
buf.set_len(length as usize);
Ok(())
}
pub fn protect<T: VecLike>(&mut self, buf: &mut T) -> Result<()> {
unsafe { self.overwrite(buf, true, sys::srtp_protect) }
}
pub fn protect_rtcp<T: VecLike>(&mut self, buf: &mut T) -> Result<()> {
unsafe { self.overwrite(buf, true, sys::srtp_protect_rtcp) }
}
pub fn unprotect<T: VecLike>(&mut self, buf: &mut T) -> Result<()> {
unsafe { self.overwrite(buf, false, sys::srtp_unprotect) }
}
pub fn unprotect_rtcp<T: VecLike>(&mut self, buf: &mut T) -> Result<()> {
unsafe { self.overwrite(buf, false, sys::srtp_unprotect_rtcp) }
}
pub fn add_stream(&mut self, ssrc: u32, policy: StreamPolicy<'_>) -> Result<()> {
let policy = policy.sys_policy_ssrc(ssrc)?;
unsafe {
Error::check(sys::srtp_add_stream(self.as_ptr(), &policy))?;
Ok(())
}
}
pub fn remove_stream(&mut self, ssrc: u32) -> Result<()> {
unsafe {
Error::check(sys::srtp_remove_stream(self.as_ptr(), ssrc))?;
Ok(())
}
}
pub fn update_stream(&mut self, ssrc: u32, policy: StreamPolicy<'_>) -> Result<()> {
let policy = policy.sys_policy_ssrc(ssrc)?;
unsafe {
Error::check(sys::srtp_update_stream(self.as_ptr(), &policy))?;
Ok(())
}
}
pub fn update_inbound_template(&mut self, policy: StreamPolicy<'_>) -> Result<()> {
let mut policy = policy.sys_policy()?;
policy.ssrc.type_ = sys::srtp_ssrc_type_t_ssrc_any_inbound;
unsafe {
Error::check(sys::srtp_update_stream(self.as_ptr(), &policy))?;
Ok(())
}
}
pub fn update_outbound_template(&mut self, policy: StreamPolicy<'_>) -> Result<()> {
let mut policy = policy.sys_policy()?;
policy.ssrc.type_ = sys::srtp_ssrc_type_t_ssrc_any_outbound;
unsafe {
Error::check(sys::srtp_update_stream(self.as_ptr(), &policy))?;
Ok(())
}
}
pub fn get_stream_roc(&mut self, ssrc: u32) -> Result<u32> {
unsafe {
let mut roc = 0;
Error::check(sys::srtp_get_stream_roc(self.as_ptr(), ssrc, &mut roc))?;
Ok(roc)
}
}
pub fn set_stream_roc(&mut self, ssrc: u32, roc: u32) -> Result<()> {
unsafe {
Error::check(sys::srtp_set_stream_roc(self.as_ptr(), ssrc, roc))?;
Ok(())
}
}
pub(crate) fn user_data_wrapper(&mut self) -> &mut UserDataWrapper {
unsafe {
match (sys::srtp_get_user_data(self.as_ptr()) as *mut UserDataWrapper).as_mut() {
Some(wrapper) => wrapper,
None => {
let wrapper = Box::into_raw(Box::new(UserDataWrapper::default()));
sys::srtp_set_user_data(self.as_ptr(), wrapper as *mut c_void);
&mut *wrapper
}
}
}
}
pub fn set_user_data<T>(&mut self, data: T)
where
T: Any + Send + 'static,
{
self.user_data_wrapper().user_data = Some(Box::new(data))
}
pub fn user_data(&mut self) -> Option<&mut (dyn Any + Send + 'static)> {
self.user_data_wrapper().user_data.as_deref_mut()
}
pub fn take_user_data(&mut self) -> Option<Box<dyn Any + Send + 'static>> {
self.user_data_wrapper().user_data.take()
}
pub fn on_ssrc_collision<F>(&mut self, f: F)
where
F: FnMut(&mut SessionRef, u32, Option<&mut (dyn Any + Send + 'static)>) + Send + 'static,
{
self.user_data_wrapper().on_ssrc_collision = Some(Box::new(f))
}
pub fn on_key_hard_limit<F>(&mut self, f: F)
where
F: FnMut(&mut SessionRef, u32, Option<&mut (dyn Any + Send + 'static)>) + Send + 'static,
{
self.user_data_wrapper().on_key_hard_limit = Some(Box::new(f))
}
pub fn on_key_soft_limit<F>(&mut self, f: F)
where
F: FnMut(&mut SessionRef, u32, Option<&mut (dyn Any + Send + 'static)>) + Send + 'static,
{
self.user_data_wrapper().on_key_soft_limit = Some(Box::new(f))
}
}
impl StreamPolicy<'_> {
fn sys_policy(&self) -> Result<sys::srtp_policy_t> {
let mut policy: sys::srtp_policy_t = unsafe { MaybeUninit::zeroed().assume_init() };
policy.ssrc.type_ = sys::srtp_ssrc_type_t_ssrc_undefined;
policy.rtp = self.rtp.into_raw();
policy.rtcp = self.rtcp.into_raw();
let key_length =
std::cmp::max(policy.rtp.cipher_key_len, policy.rtcp.cipher_key_len) as usize;
if self.key.len() < key_length {
error!(
"StreamPolicy key is too short, required: {}, provided: {}",
key_length,
self.key.len(),
);
return Err(Error::BAD_PARAM);
}
policy.key = self.key.as_ptr() as *mut u8;
policy.window_size = self.window_size;
policy.allow_repeat_tx = if self.allow_repeat_tx { 1 } else { 0 };
policy.enc_xtn_hdr = if self.encrypt_extension_headers.is_empty() {
ptr::null_mut()
} else {
self.encrypt_extension_headers.as_ptr() as *mut i32
};
policy.enc_xtn_hdr_count = match self.encrypt_extension_headers.len().try_into() {
Ok(len) => len,
Err(err) => {
error!(
"Cannot convert the length of the `enc_xtn_hdr_count` into c_int: {}",
err
);
return Err(Error::BAD_PARAM);
}
};
Ok(policy)
}
fn sys_policy_ssrc(&self, ssrc: u32) -> Result<sys::srtp_policy_t> {
let mut policy = self.sys_policy()?;
policy.ssrc.type_ = sys::srtp_ssrc_type_t_ssrc_specific;
policy.ssrc.value = ssrc;
Ok(policy)
}
}
impl UserDataWrapper {
pub(crate) fn event_handler(&mut self, kind: sys::srtp_event_t) -> &mut EventHandler {
match kind {
sys::srtp_event_t_event_ssrc_collision => &mut self.on_ssrc_collision,
sys::srtp_event_t_event_key_hard_limit => &mut self.on_key_hard_limit,
sys::srtp_event_t_event_key_soft_limit => &mut self.on_key_soft_limit,
_ => unreachable!(),
}
}
}
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("Session { .. }")
}
}
impl fmt::Debug for SessionRef {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("SessionRef { .. }")
}
}