1#[cfg(test)]
2mod context_test;
3#[cfg(test)]
4mod srtcp_test;
5#[cfg(test)]
6mod srtp_test;
7
8use std::collections::HashMap;
9
10use aes::Aes128;
11use aes::Aes256;
12use shared::replay_detector::*;
13
14use crate::cipher::cipher_aead_aes_gcm::*;
15use crate::cipher::cipher_aes_cm_hmac_sha1::*;
16use crate::cipher::*;
17use crate::option::*;
18use crate::protection_profile::*;
19use shared::error::{Error, Result};
20
21pub mod srtcp;
22pub mod srtp;
23
24const MAX_ROC: u32 = u32::MAX;
25const SEQ_NUM_MEDIAN: u16 = 1 << 15;
26const SEQ_NUM_MAX: u16 = u16::MAX;
27
28#[derive(Default)]
30pub(crate) struct SrtpSsrcState {
31 ssrc: u32,
32 index: u64,
33 rollover_has_processed: bool,
34 replay_detector: Option<Box<dyn ReplayDetector>>,
35}
36
37#[derive(Default)]
39pub(crate) struct SrtcpSsrcState {
40 srtcp_index: usize,
41 ssrc: u32,
42 replay_detector: Option<Box<dyn ReplayDetector>>,
43}
44
45impl SrtpSsrcState {
46 pub fn next_rollover_count(&self, sequence_number: u16) -> (u32, i32, bool) {
47 let local_roc = (self.index >> 16) as u32;
48 let local_seq = self.index as u16;
49
50 let mut guess_roc = local_roc;
51
52 let diff = if self.rollover_has_processed {
53 let seq = (sequence_number as i32).wrapping_sub(local_seq as i32);
54 if self.index > SEQ_NUM_MEDIAN as _ {
57 if local_seq < SEQ_NUM_MEDIAN {
58 if seq > SEQ_NUM_MEDIAN as i32 {
59 guess_roc = local_roc.wrapping_sub(1);
60 seq.wrapping_sub(SEQ_NUM_MAX as i32 + 1)
61 } else {
62 seq
63 }
64 } else if local_seq - SEQ_NUM_MEDIAN > sequence_number {
65 guess_roc = local_roc.wrapping_add(1);
66 seq.wrapping_add(SEQ_NUM_MAX as i32 + 1)
67 } else {
68 seq
69 }
70 } else {
71 seq
73 }
74 } else {
75 0i32
76 };
77
78 (guess_roc, diff, (guess_roc == 0 && local_roc == MAX_ROC))
79 }
80
81 pub fn update_rollover_count(&mut self, sequence_number: u16, diff: i32) {
83 if !self.rollover_has_processed {
84 self.index |= sequence_number as u64;
85 self.rollover_has_processed = true;
86 } else {
87 self.index = self.index.wrapping_add(diff as _);
88 }
89 }
90}
91
92pub struct Context {
96 cipher: Box<dyn Cipher>,
97
98 srtp_ssrc_states: HashMap<u32, SrtpSsrcState>,
99 srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>,
100
101 new_srtp_replay_detector: ContextOption,
102 new_srtcp_replay_detector: ContextOption,
103}
104
105impl Context {
106 pub fn new(
108 master_key: &[u8],
109 master_salt: &[u8],
110 profile: ProtectionProfile,
111 srtp_ctx_opt: Option<ContextOption>,
112 srtcp_ctx_opt: Option<ContextOption>,
113 ) -> Result<Context> {
114 let key_len = profile.key_len();
115 let salt_len = profile.salt_len();
116
117 if master_key.len() != key_len {
118 return Err(Error::SrtpMasterKeyLength(key_len, master_key.len()));
119 } else if master_salt.len() != salt_len {
120 return Err(Error::SrtpSaltLength(salt_len, master_salt.len()));
121 }
122
123 let cipher: Box<dyn Cipher> = match profile {
124 ProtectionProfile::Aes128CmHmacSha1_32
125 | ProtectionProfile::Aes128CmHmacSha1_80
126 | ProtectionProfile::Aes256CmHmacSha1_80
127 | ProtectionProfile::Aes256CmHmacSha1_32 => {
128 Box::new(CipherAesCmHmacSha1::new(profile, master_key, master_salt)?)
129 }
130
131 ProtectionProfile::AeadAes128Gcm => Box::new(CipherAeadAesGcm::<Aes128>::new(
132 profile,
133 master_key,
134 master_salt,
135 )?),
136
137 ProtectionProfile::AeadAes256Gcm => Box::new(CipherAeadAesGcm::<Aes256>::new(
138 profile,
139 master_key,
140 master_salt,
141 )?),
142 };
143
144 let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt {
145 ctx_opt
146 } else {
147 srtp_no_replay_protection()
148 };
149
150 let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt {
151 ctx_opt
152 } else {
153 srtcp_no_replay_protection()
154 };
155
156 Ok(Context {
157 cipher,
158 srtp_ssrc_states: HashMap::new(),
159 srtcp_ssrc_states: HashMap::new(),
160 new_srtp_replay_detector: srtp_ctx_opt,
161 new_srtcp_replay_detector: srtcp_ctx_opt,
162 })
163 }
164
165 fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtpSsrcState {
166 let s = SrtpSsrcState {
167 ssrc,
168 replay_detector: Some((self.new_srtp_replay_detector)()),
169 ..Default::default()
170 };
171
172 self.srtp_ssrc_states.entry(ssrc).or_insert(s)
173 }
174
175 fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtcpSsrcState {
176 let s = SrtcpSsrcState {
177 ssrc,
178 replay_detector: Some((self.new_srtcp_replay_detector)()),
179 ..Default::default()
180 };
181 self.srtcp_ssrc_states.entry(ssrc).or_insert(s)
182 }
183
184 fn get_roc(&self, ssrc: u32) -> Option<u32> {
186 self.srtp_ssrc_states
187 .get(&ssrc)
188 .map(|s| (s.index >> 16) as _)
189 }
190
191 fn set_roc(&mut self, ssrc: u32, roc: u32) {
193 let state = self.get_srtp_ssrc_state(ssrc);
194 state.index = (roc as u64) << 16;
195 state.rollover_has_processed = false;
196 }
197
198 fn get_index(&self, ssrc: u32) -> Option<usize> {
200 self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index)
201 }
202
203 fn set_index(&mut self, ssrc: u32, index: usize) {
205 self.get_srtcp_ssrc_state(ssrc).srtcp_index = index % (MAX_SRTCP_INDEX + 1);
206 }
207}