Skip to main content

rtc_dtls/
state.rs

1use super::cipher_suite::*;
2use super::conn::*;
3use super::curve::named_curve::*;
4use super::extension::extension_use_srtp::SrtpProtectionProfile;
5use super::handshake::handshake_random::*;
6use super::prf::*;
7use rkyv::{Archive, Deserialize, Serialize};
8use shared::crypto::KeyingMaterialExporter;
9use shared::error::*;
10use std::io::{BufWriter, Cursor};
11
12// State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
13pub struct State {
14    pub(crate) local_epoch: u16,
15    pub(crate) remote_epoch: u16,
16    pub(crate) local_sequence_number: Vec<u64>, // uint48
17    pub(crate) local_random: HandshakeRandom,
18    pub(crate) remote_random: HandshakeRandom,
19    pub(crate) master_secret: Vec<u8>,
20    pub(crate) cipher_suite: Option<Box<dyn CipherSuite>>, // nil if a cipher_suite hasn't been chosen
21
22    pub(crate) srtp_protection_profile: SrtpProtectionProfile, // Negotiated srtp_protection_profile
23    pub peer_certificates: Vec<Vec<u8>>,
24    pub identity_hint: Vec<u8>,
25
26    pub(crate) is_client: bool,
27
28    pub(crate) pre_master_secret: Vec<u8>,
29    pub(crate) extended_master_secret: bool,
30
31    pub(crate) named_curve: NamedCurve,
32    pub(crate) local_keypair: Option<NamedCurveKeypair>,
33    pub(crate) cookie: Vec<u8>,
34    pub(crate) handshake_send_sequence: isize,
35    pub(crate) handshake_recv_sequence: isize,
36    pub(crate) server_name: String,
37    pub(crate) remote_requested_certificate: bool, // Did we get a CertificateRequest
38    pub(crate) local_certificates_verify: Vec<u8>, // cache CertificateVerify
39    pub(crate) local_verify_data: Vec<u8>,         // cached VerifyData
40    pub(crate) local_key_signature: Vec<u8>,       // cached keySignature
41    pub(crate) peer_certificates_verified: bool,
42    //pub(crate) replay_detector: Vec<Box<dyn ReplayDetector>>,
43}
44
45#[derive(Archive, Serialize, Deserialize, PartialEq, Debug)]
46struct SerializedState {
47    local_epoch: u16,
48    remote_epoch: u16,
49    local_random: [u8; HANDSHAKE_RANDOM_LENGTH],
50    remote_random: [u8; HANDSHAKE_RANDOM_LENGTH],
51    cipher_suite_id: u16,
52    master_secret: Vec<u8>,
53    sequence_number: u64,
54    srtp_protection_profile: u16,
55    peer_certificates: Vec<Vec<u8>>,
56    identity_hint: Vec<u8>,
57    is_client: bool,
58}
59
60impl Default for State {
61    fn default() -> Self {
62        State {
63            local_epoch: 0,
64            remote_epoch: 0,
65            local_sequence_number: vec![],
66            local_random: HandshakeRandom::default(),
67            remote_random: HandshakeRandom::default(),
68            master_secret: vec![],
69            cipher_suite: None, // nil if a cipher_suite hasn't been chosen
70
71            srtp_protection_profile: SrtpProtectionProfile::Unsupported, // Negotiated srtp_protection_profile
72            peer_certificates: vec![],
73            identity_hint: vec![],
74
75            is_client: false,
76
77            pre_master_secret: vec![],
78            extended_master_secret: false,
79
80            named_curve: NamedCurve::Unsupported,
81            local_keypair: None,
82            cookie: vec![],
83            handshake_send_sequence: 0,
84            handshake_recv_sequence: 0,
85            server_name: "".to_string(),
86            remote_requested_certificate: false, // Did we get a CertificateRequest
87            local_certificates_verify: vec![],   // cache CertificateVerify
88            local_verify_data: vec![],           // cached VerifyData
89            local_key_signature: vec![],         // cached keySignature
90            peer_certificates_verified: false,
91            //replay_detector: vec![],
92        }
93    }
94}
95
96impl State {
97    fn serialize(&self) -> Result<SerializedState> {
98        let mut local_rand = vec![];
99        {
100            let mut writer = BufWriter::<&mut Vec<u8>>::new(local_rand.as_mut());
101            self.local_random.marshal(&mut writer)?;
102        }
103        let mut remote_rand = vec![];
104        {
105            let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_rand.as_mut());
106            self.remote_random.marshal(&mut writer)?;
107        }
108
109        let mut local_random = [0u8; HANDSHAKE_RANDOM_LENGTH];
110        let mut remote_random = [0u8; HANDSHAKE_RANDOM_LENGTH];
111
112        local_random.copy_from_slice(&local_rand);
113        remote_random.copy_from_slice(&remote_rand);
114
115        let local_epoch = self.local_epoch;
116        let remote_epoch = self.remote_epoch;
117        let sequence_number = self.local_sequence_number[local_epoch as usize];
118        let cipher_suite_id = {
119            match &self.cipher_suite {
120                Some(cipher_suite) => cipher_suite.id() as u16,
121                None => return Err(Error::ErrCipherSuiteUnset),
122            }
123        };
124
125        Ok(SerializedState {
126            local_epoch,
127            remote_epoch,
128            local_random,
129            remote_random,
130            cipher_suite_id,
131            master_secret: self.master_secret.clone(),
132            sequence_number,
133            srtp_protection_profile: self.srtp_protection_profile as u16,
134            peer_certificates: self.peer_certificates.clone(),
135            identity_hint: self.identity_hint.clone(),
136            is_client: self.is_client,
137        })
138    }
139
140    fn deserialize(&mut self, serialized: &SerializedState) -> Result<()> {
141        // Set epoch values
142        self.local_epoch = serialized.local_epoch;
143        self.remote_epoch = serialized.remote_epoch;
144        {
145            while self.local_sequence_number.len() <= serialized.local_epoch as usize {
146                self.local_sequence_number.push(0);
147            }
148            self.local_sequence_number[serialized.local_epoch as usize] =
149                serialized.sequence_number;
150        }
151
152        // Set random values
153        let mut reader = Cursor::new(&serialized.local_random);
154        self.local_random = HandshakeRandom::unmarshal(&mut reader)?;
155
156        let mut reader = Cursor::new(&serialized.remote_random);
157        self.remote_random = HandshakeRandom::unmarshal(&mut reader)?;
158
159        self.is_client = serialized.is_client;
160
161        // Set master secret
162        self.master_secret.clone_from(&serialized.master_secret);
163
164        // Set cipher suite
165        self.cipher_suite = Some(cipher_suite_for_id(serialized.cipher_suite_id.into())?);
166
167        self.srtp_protection_profile = serialized.srtp_protection_profile.into();
168
169        // Set remote certificate
170        self.peer_certificates
171            .clone_from(&serialized.peer_certificates);
172        self.identity_hint.clone_from(&serialized.identity_hint);
173
174        Ok(())
175    }
176
177    pub fn init_cipher_suite(&mut self) -> Result<()> {
178        if let Some(cipher_suite) = &mut self.cipher_suite {
179            if cipher_suite.is_initialized() {
180                return Ok(());
181            }
182
183            let mut local_random = vec![];
184            {
185                let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut());
186                self.local_random.marshal(&mut writer)?;
187            }
188            let mut remote_random = vec![];
189            {
190                let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut());
191                self.remote_random.marshal(&mut writer)?;
192            }
193
194            if self.is_client {
195                cipher_suite.init(&self.master_secret, &local_random, &remote_random, true)
196            } else {
197                cipher_suite.init(&self.master_secret, &remote_random, &local_random, false)
198            }
199        } else {
200            Err(Error::ErrCipherSuiteUnset)
201        }
202    }
203
204    // marshal_binary is a binary.BinaryMarshaler.marshal_binary implementation
205    pub fn marshal_binary(&self) -> Result<Vec<u8>> {
206        let serialized = self.serialize()?;
207
208        match rkyv::to_bytes::<rkyv::rancor::Error>(&serialized).map(Vec::from) {
209            Ok(enc) => Ok(enc),
210            Err(err) => Err(Error::Other(err.to_string())),
211        }
212    }
213
214    // unmarshal_binary is a binary.BinaryUnmarshaler.unmarshal_binary implementation
215    pub fn unmarshal_binary(&mut self, data: &[u8]) -> Result<()> {
216        let serialized: SerializedState =
217            match rkyv::access::<ArchivedSerializedState, rkyv::rancor::Error>(data)
218                .and_then(rkyv::deserialize)
219            {
220                Ok(dec) => dec,
221                Err(err) => return Err(Error::Other(err.to_string())),
222            };
223        self.deserialize(&serialized)?;
224        self.init_cipher_suite()?;
225
226        Ok(())
227    }
228
229    pub fn srtp_protection_profile(&self) -> SrtpProtectionProfile {
230        self.srtp_protection_profile
231    }
232
233    pub fn is_client(&self) -> bool {
234        self.is_client
235    }
236
237    pub fn cipher_suite(&self) -> Option<&dyn CipherSuite> {
238        self.cipher_suite.as_deref()
239    }
240}
241
242impl KeyingMaterialExporter for State {
243    /// export_keying_material returns length bytes of exported key material in a new
244    /// slice as defined in RFC 5705.
245    /// This allows protocols to use DTLS for key establishment, but
246    /// then use some of the keying material for their own purposes
247    fn export_keying_material(
248        &self,
249        label: &str,
250        context: &[u8],
251        length: usize,
252    ) -> shared::error::Result<Vec<u8>> {
253        if self.local_epoch == 0 {
254            return Err(Error::HandshakeInProgress);
255        } else if !context.is_empty() {
256            return Err(Error::ContextUnsupported);
257        } else if INVALID_KEYING_LABELS.contains(&label) {
258            return Err(Error::ReservedExportKeyingMaterial);
259        }
260
261        let mut local_random = vec![];
262        {
263            let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut());
264            self.local_random.marshal(&mut writer)?;
265        }
266        let mut remote_random = vec![];
267        {
268            let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut());
269            self.remote_random.marshal(&mut writer)?;
270        }
271
272        let mut seed = label.as_bytes().to_vec();
273        if self.is_client {
274            seed.extend_from_slice(&local_random);
275            seed.extend_from_slice(&remote_random);
276        } else {
277            seed.extend_from_slice(&remote_random);
278            seed.extend_from_slice(&local_random);
279        }
280
281        if let Some(cipher_suite) = &self.cipher_suite {
282            match prf_p_hash(&self.master_secret, &seed, length, cipher_suite.hash_func()) {
283                Ok(v) => Ok(v),
284                Err(err) => Err(Error::Hash(err.to_string())),
285            }
286        } else {
287            Err(Error::CipherSuiteUnset)
288        }
289    }
290}