1#[cfg(test)]
2mod context_test;
3#[cfg(test)]
4mod srtcp_test;
5#[cfg(test)]
6mod srtp_test;
7
8use crate::{
9 cipher::cipher_aead_aes_gcm::*, cipher::cipher_aes_cm_hmac_sha1::*, cipher::*, option::*,
10 protection_profile::*,
11};
12use shared::{
13 error::{Error, Result},
14 replay_detector::*,
15};
16
17use std::collections::HashMap;
18
19pub mod srtcp;
20pub mod srtp;
21
22const MAX_ROC_DISORDER: u16 = 100;
23
24#[derive(Default)]
26pub(crate) struct SrtpSsrcState {
27 ssrc: u32,
28 rollover_counter: u32,
29 rollover_has_processed: bool,
30 last_sequence_number: u16,
31 replay_detector: Option<Box<dyn ReplayDetector>>,
32}
33
34#[derive(Default)]
36pub(crate) struct SrtcpSsrcState {
37 srtcp_index: usize,
38 ssrc: u32,
39 replay_detector: Option<Box<dyn ReplayDetector>>,
40}
41
42impl SrtpSsrcState {
43 pub fn next_rollover_count(&self, sequence_number: u16) -> u32 {
44 let mut roc = self.rollover_counter;
45
46 if !self.rollover_has_processed {
47 } else if sequence_number == 0 {
48 if self.last_sequence_number > MAX_ROC_DISORDER {
53 roc += 1;
54 }
55 } else if self.last_sequence_number < MAX_ROC_DISORDER
56 && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
57 {
58 roc -= 1;
61 } else if sequence_number < MAX_ROC_DISORDER
62 && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
63 {
64 roc += 1;
67 }
68
69 roc
70 }
71
72 pub fn update_rollover_count(&mut self, sequence_number: u16) {
74 if !self.rollover_has_processed {
75 self.rollover_has_processed = true;
76 } else if sequence_number == 0 {
77 if self.last_sequence_number > MAX_ROC_DISORDER {
82 self.rollover_counter += 1;
83 }
84 } else if self.last_sequence_number < MAX_ROC_DISORDER
85 && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
86 {
87 self.rollover_counter -= 1;
90 } else if sequence_number < MAX_ROC_DISORDER
91 && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
92 {
93 self.rollover_counter += 1;
96 }
97 self.last_sequence_number = sequence_number;
98 }
99}
100
101pub struct Context {
105 cipher: Box<dyn Cipher>,
106
107 srtp_ssrc_states: HashMap<u32, SrtpSsrcState>,
108 srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>,
109
110 new_srtp_replay_detector: ContextOption,
111 new_srtcp_replay_detector: ContextOption,
112}
113
114impl Context {
115 pub fn new(
117 master_key: &[u8],
118 master_salt: &[u8],
119 profile: ProtectionProfile,
120 srtp_ctx_opt: Option<ContextOption>,
121 srtcp_ctx_opt: Option<ContextOption>,
122 ) -> Result<Context> {
123 let key_len = profile.key_len();
124 let salt_len = profile.salt_len();
125
126 if master_key.len() != key_len {
127 return Err(Error::SrtpMasterKeyLength(key_len, master_key.len()));
128 } else if master_salt.len() != salt_len {
129 return Err(Error::SrtpSaltLength(salt_len, master_salt.len()));
130 }
131
132 let cipher: Box<dyn Cipher> = match profile {
133 ProtectionProfile::Aes128CmHmacSha1_80 => {
134 Box::new(CipherAesCmHmacSha1::new(master_key, master_salt)?)
135 }
136
137 ProtectionProfile::AeadAes128Gcm => {
138 Box::new(CipherAeadAesGcm::new(master_key, master_salt)?)
139 }
140 };
141
142 let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt {
143 ctx_opt
144 } else {
145 srtp_no_replay_protection()
146 };
147
148 let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt {
149 ctx_opt
150 } else {
151 srtcp_no_replay_protection()
152 };
153
154 Ok(Context {
155 cipher,
156 srtp_ssrc_states: HashMap::new(),
157 srtcp_ssrc_states: HashMap::new(),
158 new_srtp_replay_detector: srtp_ctx_opt,
159 new_srtcp_replay_detector: srtcp_ctx_opt,
160 })
161 }
162
163 fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtpSsrcState> {
164 let s = SrtpSsrcState {
165 ssrc,
166 replay_detector: Some((self.new_srtp_replay_detector)()),
167 ..Default::default()
168 };
169
170 self.srtp_ssrc_states.entry(ssrc).or_insert(s);
171 self.srtp_ssrc_states.get_mut(&ssrc)
172 }
173
174 fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtcpSsrcState> {
175 let s = SrtcpSsrcState {
176 ssrc,
177 replay_detector: Some((self.new_srtcp_replay_detector)()),
178 ..Default::default()
179 };
180 self.srtcp_ssrc_states.entry(ssrc).or_insert(s);
181 self.srtcp_ssrc_states.get_mut(&ssrc)
182 }
183
184 fn get_roc(&self, ssrc: u32) -> Option<u32> {
186 self.srtp_ssrc_states.get(&ssrc).map(|s| s.rollover_counter)
187 }
188
189 fn set_roc(&mut self, ssrc: u32, roc: u32) {
191 if let Some(s) = self.get_srtp_ssrc_state(ssrc) {
192 s.rollover_counter = roc;
193 }
194 }
195
196 fn get_index(&self, ssrc: u32) -> Option<usize> {
198 self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index)
199 }
200
201 fn set_index(&mut self, ssrc: u32, index: usize) {
203 if let Some(s) = self.get_srtcp_ssrc_state(ssrc) {
204 s.srtcp_index = index;
205 }
206 }
207}