webrtc_srtp/context/
mod.rs1#[cfg(test)]
2mod context_test;
3#[cfg(test)]
4mod srtcp_test;
5#[cfg(test)]
6mod srtp_test;
7
8use std::collections::HashMap;
9
10use aes::Aes128;
11use aes::Aes256;
12use util::replay_detector::*;
13
14use crate::cipher::cipher_aead_aes_gcm::*;
15use crate::cipher::cipher_aes_cm_hmac_sha1::*;
16use crate::cipher::*;
17use crate::error::{Error, Result};
18use crate::option::*;
19use crate::protection_profile::*;
20
21pub mod srtcp;
22pub mod srtp;
23
24const MAX_ROC: u32 = u32::MAX;
25const SEQ_NUM_MEDIAN: u16 = 1 << 15;
26const SEQ_NUM_MAX: u16 = u16::MAX;
27
28#[derive(Default)]
30pub(crate) struct SrtpSsrcState {
31 ssrc: u32,
32 index: u64,
33 rollover_has_processed: bool,
34 replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
35}
36
37#[derive(Default)]
39pub(crate) struct SrtcpSsrcState {
40 srtcp_index: usize,
41 ssrc: u32,
42 replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
43}
44
45impl SrtpSsrcState {
46 pub fn next_rollover_count(&self, sequence_number: u16) -> (u32, i32, bool) {
47 let local_roc = (self.index >> 16) as u32;
48 let local_seq = self.index as u16;
49
50 let mut guess_roc = local_roc;
51
52 let diff = if self.rollover_has_processed {
53 let seq = (sequence_number as i32).wrapping_sub(local_seq as i32);
54 if self.index > SEQ_NUM_MEDIAN as _ {
57 if local_seq < SEQ_NUM_MEDIAN {
58 if seq > SEQ_NUM_MEDIAN as i32 {
59 guess_roc = local_roc.wrapping_sub(1);
60 seq.wrapping_sub(SEQ_NUM_MAX as i32 + 1)
61 } else {
62 seq
63 }
64 } else if local_seq - SEQ_NUM_MEDIAN > sequence_number {
65 guess_roc = local_roc.wrapping_add(1);
66 seq.wrapping_add(SEQ_NUM_MAX as i32 + 1)
67 } else {
68 seq
69 }
70 } else {
71 seq
73 }
74 } else {
75 0i32
76 };
77
78 (guess_roc, diff, (guess_roc == 0 && local_roc == MAX_ROC))
79 }
80
81 pub fn update_rollover_count(&mut self, sequence_number: u16, diff: i32) {
83 if !self.rollover_has_processed {
84 self.index |= sequence_number as u64;
85 self.rollover_has_processed = true;
86 } else {
87 self.index = self.index.wrapping_add(diff as _);
88 }
89 }
90}
91
92pub struct Context {
96 cipher: Box<dyn Cipher + Send>,
97
98 srtp_ssrc_states: HashMap<u32, SrtpSsrcState>,
99 srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>,
100
101 new_srtp_replay_detector: ContextOption,
102 new_srtcp_replay_detector: ContextOption,
103}
104
105impl Context {
106 pub fn new(
108 master_key: &[u8],
109 master_salt: &[u8],
110 profile: ProtectionProfile,
111 srtp_ctx_opt: Option<ContextOption>,
112 srtcp_ctx_opt: Option<ContextOption>,
113 ) -> Result<Context> {
114 let key_len = profile.key_len();
115 let salt_len = profile.salt_len();
116
117 if master_key.len() != key_len {
118 return Err(Error::SrtpMasterKeyLength(key_len, master_key.len()));
119 } else if master_salt.len() != salt_len {
120 return Err(Error::SrtpSaltLength(salt_len, master_salt.len()));
121 }
122
123 let cipher: Box<dyn Cipher + Send> = match profile {
124 ProtectionProfile::Aes128CmHmacSha1_32 | ProtectionProfile::Aes128CmHmacSha1_80 => {
125 Box::new(CipherAesCmHmacSha1::new(profile, master_key, master_salt)?)
126 }
127
128 ProtectionProfile::AeadAes128Gcm => Box::new(CipherAeadAesGcm::<Aes128>::new(
129 profile,
130 master_key,
131 master_salt,
132 )?),
133
134 ProtectionProfile::AeadAes256Gcm => Box::new(CipherAeadAesGcm::<Aes256>::new(
135 profile,
136 master_key,
137 master_salt,
138 )?),
139 };
140
141 let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt {
142 ctx_opt
143 } else {
144 srtp_no_replay_protection()
145 };
146
147 let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt {
148 ctx_opt
149 } else {
150 srtcp_no_replay_protection()
151 };
152
153 Ok(Context {
154 cipher,
155 srtp_ssrc_states: HashMap::new(),
156 srtcp_ssrc_states: HashMap::new(),
157 new_srtp_replay_detector: srtp_ctx_opt,
158 new_srtcp_replay_detector: srtcp_ctx_opt,
159 })
160 }
161
162 fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtpSsrcState {
163 let s = SrtpSsrcState {
164 ssrc,
165 replay_detector: Some((self.new_srtp_replay_detector)()),
166 ..Default::default()
167 };
168
169 self.srtp_ssrc_states.entry(ssrc).or_insert(s)
170 }
171
172 fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtcpSsrcState {
173 let s = SrtcpSsrcState {
174 ssrc,
175 replay_detector: Some((self.new_srtcp_replay_detector)()),
176 ..Default::default()
177 };
178 self.srtcp_ssrc_states.entry(ssrc).or_insert(s)
179 }
180
181 fn get_roc(&self, ssrc: u32) -> Option<u32> {
183 self.srtp_ssrc_states
184 .get(&ssrc)
185 .map(|s| (s.index >> 16) as _)
186 }
187
188 fn set_roc(&mut self, ssrc: u32, roc: u32) {
190 let state = self.get_srtp_ssrc_state(ssrc);
191 state.index = (roc as u64) << 16;
192 state.rollover_has_processed = false;
193 }
194
195 fn get_index(&self, ssrc: u32) -> Option<usize> {
197 self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index)
198 }
199
200 fn set_index(&mut self, ssrc: u32, index: usize) {
202 self.get_srtcp_ssrc_state(ssrc).srtcp_index = index % (MAX_SRTCP_INDEX + 1);
203 }
204}