rtc_srtp/context/
mod.rs

1#[cfg(test)]
2mod context_test;
3#[cfg(test)]
4mod srtcp_test;
5#[cfg(test)]
6mod srtp_test;
7
8use crate::{
9    cipher::cipher_aead_aes_gcm::*, cipher::cipher_aes_cm_hmac_sha1::*, cipher::*, option::*,
10    protection_profile::*,
11};
12use shared::{
13    error::{Error, Result},
14    replay_detector::*,
15};
16
17use std::collections::HashMap;
18
19pub mod srtcp;
20pub mod srtp;
21
22const MAX_ROC_DISORDER: u16 = 100;
23
24/// Encrypt/Decrypt state for a single SRTP SSRC
25#[derive(Default)]
26pub(crate) struct SrtpSsrcState {
27    ssrc: u32,
28    rollover_counter: u32,
29    rollover_has_processed: bool,
30    last_sequence_number: u16,
31    replay_detector: Option<Box<dyn ReplayDetector>>,
32}
33
34/// Encrypt/Decrypt state for a single SRTCP SSRC
35#[derive(Default)]
36pub(crate) struct SrtcpSsrcState {
37    srtcp_index: usize,
38    ssrc: u32,
39    replay_detector: Option<Box<dyn ReplayDetector>>,
40}
41
42impl SrtpSsrcState {
43    pub fn next_rollover_count(&self, sequence_number: u16) -> u32 {
44        let mut roc = self.rollover_counter;
45
46        if !self.rollover_has_processed {
47        } else if sequence_number == 0 {
48            // We exactly hit the rollover count
49
50            // Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER
51            // otherwise we already incremented for disorder
52            if self.last_sequence_number > MAX_ROC_DISORDER {
53                roc += 1;
54            }
55        } else if self.last_sequence_number < MAX_ROC_DISORDER
56            && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
57        {
58            // Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max
59            // So we fell behind, drop to account for jitter
60            roc -= 1;
61        } else if sequence_number < MAX_ROC_DISORDER
62            && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
63        {
64            // our current is within a MAX_ROCDISORDER of 0
65            // and our last sequence number was a high sequence number, increment to account for jitter
66            roc += 1;
67        }
68
69        roc
70    }
71
72    /// https://tools.ietf.org/html/rfc3550#appendix-A.1
73    pub fn update_rollover_count(&mut self, sequence_number: u16) {
74        if !self.rollover_has_processed {
75            self.rollover_has_processed = true;
76        } else if sequence_number == 0 {
77            // We exactly hit the rollover count
78
79            // Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER
80            // otherwise we already incremented for disorder
81            if self.last_sequence_number > MAX_ROC_DISORDER {
82                self.rollover_counter += 1;
83            }
84        } else if self.last_sequence_number < MAX_ROC_DISORDER
85            && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
86        {
87            // Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max
88            // So we fell behind, drop to account for jitter
89            self.rollover_counter -= 1;
90        } else if sequence_number < MAX_ROC_DISORDER
91            && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
92        {
93            // our current is within a MAX_ROCDISORDER of 0
94            // and our last sequence number was a high sequence number, increment to account for jitter
95            self.rollover_counter += 1;
96        }
97        self.last_sequence_number = sequence_number;
98    }
99}
100
101/// Context represents a SRTP cryptographic context
102/// Context can only be used for one-way operations
103/// it must either used ONLY for encryption or ONLY for decryption
104pub struct Context {
105    cipher: Box<dyn Cipher>,
106
107    srtp_ssrc_states: HashMap<u32, SrtpSsrcState>,
108    srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>,
109
110    new_srtp_replay_detector: ContextOption,
111    new_srtcp_replay_detector: ContextOption,
112}
113
114impl Context {
115    /// CreateContext creates a new SRTP Context
116    pub fn new(
117        master_key: &[u8],
118        master_salt: &[u8],
119        profile: ProtectionProfile,
120        srtp_ctx_opt: Option<ContextOption>,
121        srtcp_ctx_opt: Option<ContextOption>,
122    ) -> Result<Context> {
123        let key_len = profile.key_len();
124        let salt_len = profile.salt_len();
125
126        if master_key.len() != key_len {
127            return Err(Error::SrtpMasterKeyLength(key_len, master_key.len()));
128        } else if master_salt.len() != salt_len {
129            return Err(Error::SrtpSaltLength(salt_len, master_salt.len()));
130        }
131
132        let cipher: Box<dyn Cipher> = match profile {
133            ProtectionProfile::Aes128CmHmacSha1_80 => {
134                Box::new(CipherAesCmHmacSha1::new(master_key, master_salt)?)
135            }
136
137            ProtectionProfile::AeadAes128Gcm => {
138                Box::new(CipherAeadAesGcm::new(master_key, master_salt)?)
139            }
140        };
141
142        let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt {
143            ctx_opt
144        } else {
145            srtp_no_replay_protection()
146        };
147
148        let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt {
149            ctx_opt
150        } else {
151            srtcp_no_replay_protection()
152        };
153
154        Ok(Context {
155            cipher,
156            srtp_ssrc_states: HashMap::new(),
157            srtcp_ssrc_states: HashMap::new(),
158            new_srtp_replay_detector: srtp_ctx_opt,
159            new_srtcp_replay_detector: srtcp_ctx_opt,
160        })
161    }
162
163    fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtpSsrcState> {
164        let s = SrtpSsrcState {
165            ssrc,
166            replay_detector: Some((self.new_srtp_replay_detector)()),
167            ..Default::default()
168        };
169
170        self.srtp_ssrc_states.entry(ssrc).or_insert(s);
171        self.srtp_ssrc_states.get_mut(&ssrc)
172    }
173
174    fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtcpSsrcState> {
175        let s = SrtcpSsrcState {
176            ssrc,
177            replay_detector: Some((self.new_srtcp_replay_detector)()),
178            ..Default::default()
179        };
180        self.srtcp_ssrc_states.entry(ssrc).or_insert(s);
181        self.srtcp_ssrc_states.get_mut(&ssrc)
182    }
183
184    /// roc returns SRTP rollover counter value of specified SSRC.
185    fn get_roc(&self, ssrc: u32) -> Option<u32> {
186        self.srtp_ssrc_states.get(&ssrc).map(|s| s.rollover_counter)
187    }
188
189    /// set_roc sets SRTP rollover counter value of specified SSRC.
190    fn set_roc(&mut self, ssrc: u32, roc: u32) {
191        if let Some(s) = self.get_srtp_ssrc_state(ssrc) {
192            s.rollover_counter = roc;
193        }
194    }
195
196    /// index returns SRTCP index value of specified SSRC.
197    fn get_index(&self, ssrc: u32) -> Option<usize> {
198        self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index)
199    }
200
201    /// set_index sets SRTCP index value of specified SSRC.
202    fn set_index(&mut self, ssrc: u32, index: usize) {
203        if let Some(s) = self.get_srtcp_ssrc_state(ssrc) {
204            s.srtcp_index = index;
205        }
206    }
207}