Skip to main content

rtc_dtls/
config.rs

1use crate::cipher_suite::*;
2use crate::conn::{DEFAULT_REPLAY_PROTECTION_WINDOW, INITIAL_TICKER_INTERVAL};
3use crate::crypto::*;
4use crate::extension::extension_use_srtp::SrtpProtectionProfile;
5use crate::signature_hash_algorithm::{
6    SignatureHashAlgorithm, SignatureScheme, parse_signature_schemes,
7};
8use log::warn;
9use shared::error::*;
10use std::collections::HashMap;
11use std::fmt;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15
16use rustls::client::danger::ServerCertVerifier;
17use rustls::pki_types::CertificateDer;
18use rustls::server::danger::ClientCertVerifier;
19
20/// Config is used to configure a DTLS client or server.
21/// After a Config is passed to a DTLS function it must not be modified.
22#[derive(Clone)]
23pub struct ConfigBuilder {
24    certificates: Vec<Certificate>,
25    cipher_suites: Vec<CipherSuiteId>,
26    signature_schemes: Vec<SignatureScheme>,
27    srtp_protection_profiles: Vec<SrtpProtectionProfile>,
28    client_auth: ClientAuthType,
29    extended_master_secret: ExtendedMasterSecretType,
30    flight_interval: Duration,
31    psk: Option<PskCallback>,
32    psk_identity_hint: Option<Vec<u8>>,
33    insecure_skip_verify: bool,
34    insecure_hashes: bool,
35    insecure_verification: bool,
36    verify_peer_certificate: Option<VerifyPeerCertificateFn>,
37    roots_cas: rustls::RootCertStore,
38    client_cas: rustls::RootCertStore,
39    server_name: String,
40    mtu: usize,
41    replay_protection_window: usize,
42}
43
44impl Default for ConfigBuilder {
45    fn default() -> Self {
46        Self {
47            certificates: vec![],
48            cipher_suites: vec![],
49            signature_schemes: vec![],
50            srtp_protection_profiles: vec![],
51            client_auth: ClientAuthType::default(),
52            extended_master_secret: ExtendedMasterSecretType::default(),
53            flight_interval: Duration::default(),
54            psk: None,
55            psk_identity_hint: None,
56            insecure_skip_verify: false,
57            insecure_hashes: false,
58            insecure_verification: false,
59            verify_peer_certificate: None,
60            roots_cas: rustls::RootCertStore::empty(),
61            client_cas: rustls::RootCertStore::empty(),
62            server_name: String::default(),
63            mtu: 0,
64            replay_protection_window: 0,
65        }
66    }
67}
68
69impl ConfigBuilder {
70    /// certificates contains certificate chain to present to the other side of the connection.
71    /// Server MUST set this if psk is non-nil
72    /// client SHOULD sets this so CertificateRequests can be handled if psk is non-nil
73    pub fn with_certificates(mut self, certificates: Vec<Certificate>) -> Self {
74        self.certificates = certificates;
75        self
76    }
77
78    /// cipher_suites is a list of supported cipher suites.
79    /// If cipher_suites is nil, a default list is used
80    pub fn with_cipher_suites(mut self, cipher_suites: Vec<CipherSuiteId>) -> Self {
81        self.cipher_suites = cipher_suites;
82        self
83    }
84
85    /// signature_schemes contains the signature and hash schemes that the peer requests to verify.
86    pub fn with_signature_schemes(mut self, signature_schemes: Vec<SignatureScheme>) -> Self {
87        self.signature_schemes = signature_schemes;
88        self
89    }
90
91    /// srtp_protection_profiles are the supported protection profiles
92    /// Clients will send this via use_srtp and assert that the server properly responds
93    /// Servers will assert that clients send one of these profiles and will respond as needed
94    pub fn with_srtp_protection_profiles(
95        mut self,
96        srtp_protection_profiles: Vec<SrtpProtectionProfile>,
97    ) -> Self {
98        self.srtp_protection_profiles = srtp_protection_profiles;
99        self
100    }
101
102    /// client_auth determines the server's policy for
103    /// TLS Client Authentication. The default is NoClientCert.
104    pub fn with_client_auth(mut self, client_auth: ClientAuthType) -> Self {
105        self.client_auth = client_auth;
106        self
107    }
108
109    /// extended_master_secret determines if the "Extended Master Secret" extension
110    /// should be disabled, requested, or required (default requested).
111    pub fn with_extended_master_secret(
112        mut self,
113        extended_master_secret: ExtendedMasterSecretType,
114    ) -> Self {
115        self.extended_master_secret = extended_master_secret;
116        self
117    }
118
119    /// flight_interval controls how often we send outbound handshake messages
120    /// defaults to time.Second
121    pub fn with_flight_interval(mut self, flight_interval: Duration) -> Self {
122        self.flight_interval = flight_interval;
123        self
124    }
125
126    /// psk sets the pre-shared key used by this DTLS connection
127    /// If psk is non-nil only psk cipher_suites will be used
128    pub fn with_psk(mut self, psk: Option<PskCallback>) -> Self {
129        self.psk = psk;
130        self
131    }
132
133    /// psk_identity_hint sets the pre-shared key hint
134    pub fn with_psk_identity_hint(mut self, psk_identity_hint: Option<Vec<u8>>) -> Self {
135        self.psk_identity_hint = psk_identity_hint;
136        self
137    }
138
139    /// insecure_skip_verify controls whether a client verifies the
140    /// server's certificate chain and host name.
141    /// If insecure_skip_verify is true, TLS accepts any certificate
142    /// presented by the server and any host name in that certificate.
143    /// In this mode, TLS is susceptible to man-in-the-middle attacks.
144    /// This should be used only for testing.
145    pub fn with_insecure_skip_verify(mut self, insecure_skip_verify: bool) -> Self {
146        self.insecure_skip_verify = insecure_skip_verify;
147        self
148    }
149
150    /// insecure_hashes allows the use of hashing algorithms that are known
151    /// to be vulnerable.
152    pub fn with_insecure_hashes(mut self, insecure_hashes: bool) -> Self {
153        self.insecure_hashes = insecure_hashes;
154        self
155    }
156
157    /// insecure_verification allows the use of verification algorithms that are
158    /// known to be vulnerable or deprecated
159    pub fn with_insecure_verification(mut self, insecure_verification: bool) -> Self {
160        self.insecure_verification = insecure_verification;
161        self
162    }
163
164    /// VerifyPeerCertificate, if not nil, is called after normal
165    /// certificate verification by either a client or server. It
166    /// receives the certificate provided by the peer and also a flag
167    /// that tells if normal verification has succeeded. If it returns a
168    /// non-nil error, the handshake is aborted and that error results.
169    ///
170    /// If normal verification fails then the handshake will abort before
171    /// considering this callback. If normal verification is disabled by
172    /// setting insecure_skip_verify, or (for a server) when client_auth is
173    /// RequestClientCert or RequireAnyClientCert, then this callback will
174    /// be considered but the verifiedChains will always be nil.
175    pub fn with_verify_peer_certificate(
176        mut self,
177        verify_peer_certificate: Option<VerifyPeerCertificateFn>,
178    ) -> Self {
179        self.verify_peer_certificate = verify_peer_certificate;
180        self
181    }
182
183    /// roots_cas defines the set of root certificate authorities
184    /// that one peer uses when verifying the other peer's certificates.
185    /// If RootCAs is nil, TLS uses the host's root CA set.
186    /// Used by Client to verify server's certificate
187    pub fn with_roots_cas(mut self, roots_cas: rustls::RootCertStore) -> Self {
188        self.roots_cas = roots_cas;
189        self
190    }
191
192    /// client_cas defines the set of root certificate authorities
193    /// that servers use if required to verify a client certificate
194    /// by the policy in client_auth.
195    /// Used by Server to verify client's certificate
196    pub fn with_client_cas(mut self, client_cas: rustls::RootCertStore) -> Self {
197        self.client_cas = client_cas;
198        self
199    }
200
201    /// server_name is used to verify the hostname on the returned
202    /// certificates unless insecure_skip_verify is given.
203    pub fn with_server_name(mut self, server_name: String) -> Self {
204        self.server_name = server_name;
205        self
206    }
207
208    /// mtu is the length at which handshake messages will be fragmented to
209    /// fit within the maximum transmission unit (default is 1200 bytes)
210    pub fn with_mtu(mut self, mtu: usize) -> Self {
211        self.mtu = mtu;
212        self
213    }
214
215    /// replay_protection_window is the size of the replay attack protection window.
216    /// Duplication of the sequence number is checked in this window size.
217    /// Packet with sequence number older than this value compared to the latest
218    /// accepted packet will be discarded. (default is 64)
219    pub fn with_replay_protection_window(mut self, replay_protection_window: usize) -> Self {
220        self.replay_protection_window = replay_protection_window;
221        self
222    }
223}
224
225pub(crate) const DEFAULT_MTU: usize = 1200; // bytes
226
227/// PSKCallback is called once we have the remote's psk_identity_hint.
228/// If the remote provided none it will be nil
229pub(crate) type PskCallback = Arc<dyn (Fn(&[u8]) -> Result<Vec<u8>>) + Send + Sync>;
230
231/// ClientAuthType declares the policy the server will follow for
232/// TLS Client Authentication.
233#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)]
234pub enum ClientAuthType {
235    #[default]
236    NoClientCert = 0,
237    RequestClientCert = 1,
238    RequireAnyClientCert = 2,
239    VerifyClientCertIfGiven = 3,
240    RequireAndVerifyClientCert = 4,
241}
242
243// ExtendedMasterSecretType declares the policy the client and server
244// will follow for the Extended Master Secret extension
245#[derive(Debug, Default, PartialEq, Eq, Copy, Clone)]
246pub enum ExtendedMasterSecretType {
247    #[default]
248    Request = 0,
249    Require = 1,
250    Disable = 2,
251}
252
253impl ConfigBuilder {
254    fn validate(&self, is_client: bool) -> Result<()> {
255        if is_client && self.psk.is_some() && self.psk_identity_hint.is_none() {
256            return Err(Error::ErrPskAndIdentityMustBeSetForClient);
257        }
258
259        if !is_client && self.psk.is_none() && self.certificates.is_empty() {
260            return Err(Error::ErrServerMustHaveCertificate);
261        }
262
263        if !self.certificates.is_empty() && self.psk.is_some() {
264            return Err(Error::ErrPskAndCertificate);
265        }
266
267        if self.psk_identity_hint.is_some() && self.psk.is_none() {
268            return Err(Error::ErrIdentityNoPsk);
269        }
270
271        for cert in &self.certificates {
272            match cert.private_key.kind {
273                CryptoPrivateKeyKind::Ed25519(_) => {}
274                CryptoPrivateKeyKind::Ecdsa256(_) => {}
275                _ => return Err(Error::ErrInvalidPrivateKey),
276            }
277        }
278
279        parse_cipher_suites(&self.cipher_suites, self.psk.is_none(), self.psk.is_some())?;
280
281        Ok(())
282    }
283
284    /// build handshake config
285    pub fn build(
286        mut self,
287        is_client: bool,
288        remote_addr: Option<SocketAddr>,
289    ) -> Result<HandshakeConfig> {
290        self.validate(is_client)?;
291
292        let local_cipher_suites: Vec<CipherSuiteId> =
293            parse_cipher_suites(&self.cipher_suites, self.psk.is_none(), self.psk.is_some())?
294                .iter()
295                .map(|cs| cs.id())
296                .collect();
297
298        let sigs: Vec<u16> = self.signature_schemes.iter().map(|x| *x as u16).collect();
299        let local_signature_schemes = parse_signature_schemes(&sigs, self.insecure_hashes)?;
300
301        let retransmit_interval = if self.flight_interval != Duration::from_secs(0) {
302            self.flight_interval
303        } else {
304            INITIAL_TICKER_INTERVAL
305        };
306
307        let maximum_transmission_unit = if self.mtu == 0 { DEFAULT_MTU } else { self.mtu };
308
309        let replay_protection_window = if self.replay_protection_window == 0 {
310            DEFAULT_REPLAY_PROTECTION_WINDOW
311        } else {
312            self.replay_protection_window
313        };
314
315        let mut server_name = self.server_name.clone();
316
317        // Use host from conn address when server_name is not provided
318        if is_client && server_name.is_empty() {
319            if let Some(remote_addr) = remote_addr {
320                server_name = remote_addr.ip().to_string();
321            } else {
322                warn!(
323                    "conn.remote_addr is empty, please set explicitly server_name in Config! Use default \"localhost\" as server_name now"
324                );
325                "localhost".clone_into(&mut server_name);
326            }
327        }
328
329        Ok(HandshakeConfig {
330            local_psk_callback: self.psk.take(),
331            local_psk_identity_hint: self.psk_identity_hint.take(),
332            local_cipher_suites,
333            local_signature_schemes,
334            extended_master_secret: self.extended_master_secret,
335            local_srtp_protection_profiles: self.srtp_protection_profiles,
336            server_name,
337            client_auth: self.client_auth,
338            local_certificates: self.certificates,
339            insecure_skip_verify: self.insecure_skip_verify,
340            insecure_verification: self.insecure_verification,
341            verify_peer_certificate: self.verify_peer_certificate.take(),
342            roots_cas: self.roots_cas,
343            server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new(
344                gen_self_signed_root_cert(),
345            ))
346            .build()
347            .unwrap(),
348            client_cert_verifier: None,
349            retransmit_interval,
350            initial_epoch: 0,
351            maximum_transmission_unit,
352            replay_protection_window,
353            ..Default::default()
354        })
355    }
356}
357
358pub type VerifyPeerCertificateFn =
359    Arc<dyn (Fn(&[Vec<u8>], &[CertificateDer<'static>]) -> Result<()>) + Send + Sync>;
360
361pub fn gen_self_signed_root_cert() -> rustls::RootCertStore {
362    let mut certs = rustls::RootCertStore::empty();
363    certs
364        .add(
365            rcgen::generate_simple_self_signed(vec![])
366                .unwrap()
367                .cert
368                .der()
369                .to_owned(),
370        )
371        .unwrap();
372    certs
373}
374
375#[derive(Clone)]
376pub struct HandshakeConfig {
377    pub(crate) local_psk_callback: Option<PskCallback>,
378    pub(crate) local_psk_identity_hint: Option<Vec<u8>>,
379    pub(crate) local_cipher_suites: Vec<CipherSuiteId>, // Available CipherSuites
380    pub(crate) local_signature_schemes: Vec<SignatureHashAlgorithm>, // Available signature schemes
381    pub(crate) extended_master_secret: ExtendedMasterSecretType, // Policy for the Extended Master Support extension
382    pub(crate) local_srtp_protection_profiles: Vec<SrtpProtectionProfile>, // Available SRTPProtectionProfiles, if empty no SRTP support
383    pub(crate) server_name: String,
384    pub(crate) client_auth: ClientAuthType, // If we are a client should we request a client certificate
385    pub(crate) local_certificates: Vec<Certificate>,
386    pub(crate) name_to_certificate: HashMap<String, Certificate>,
387    pub(crate) insecure_skip_verify: bool,
388    pub(crate) insecure_verification: bool,
389    pub(crate) verify_peer_certificate: Option<VerifyPeerCertificateFn>,
390    pub(crate) roots_cas: rustls::RootCertStore,
391    pub(crate) server_cert_verifier: Arc<dyn ServerCertVerifier>,
392    pub(crate) client_cert_verifier: Option<Arc<dyn ClientCertVerifier>>,
393    pub(crate) retransmit_interval: std::time::Duration,
394    pub(crate) initial_epoch: u16,
395    pub(crate) maximum_transmission_unit: usize,
396    pub(crate) maximum_retransmit_number: usize,
397    pub(crate) replay_protection_window: usize,
398}
399
400impl fmt::Debug for HandshakeConfig {
401    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
402        fmt.debug_struct("HandshakeConfig<T>")
403            .field("local_psk_identity_hint", &self.local_psk_identity_hint)
404            .field("local_cipher_suites", &self.local_cipher_suites)
405            .field("local_signature_schemes", &self.local_signature_schemes)
406            .field("extended_master_secret", &self.extended_master_secret)
407            .field(
408                "local_srtp_protection_profiles",
409                &self.local_srtp_protection_profiles,
410            )
411            .field("server_name", &self.server_name)
412            .field("client_auth", &self.client_auth)
413            .field("local_certificates", &self.local_certificates)
414            .field("name_to_certificate", &self.name_to_certificate)
415            .field("insecure_skip_verify", &self.insecure_skip_verify)
416            .field("insecure_verification", &self.insecure_verification)
417            .field("roots_cas", &self.roots_cas)
418            .field("retransmit_interval", &self.retransmit_interval)
419            .field("initial_epoch", &self.initial_epoch)
420            .field("maximum_transmission_unit", &self.maximum_transmission_unit)
421            .field("maximum_retransmit_number", &self.maximum_retransmit_number)
422            .field("replay_protection_window", &self.replay_protection_window)
423            .finish()
424    }
425}
426
427impl Default for HandshakeConfig {
428    fn default() -> Self {
429        HandshakeConfig {
430            local_psk_callback: None,
431            local_psk_identity_hint: None,
432            local_cipher_suites: vec![],
433            local_signature_schemes: vec![],
434            extended_master_secret: ExtendedMasterSecretType::Disable,
435            local_srtp_protection_profiles: vec![],
436            server_name: String::new(),
437            client_auth: ClientAuthType::NoClientCert,
438            local_certificates: vec![],
439            name_to_certificate: HashMap::new(),
440            insecure_skip_verify: false,
441            insecure_verification: false,
442            verify_peer_certificate: None,
443            roots_cas: rustls::RootCertStore::empty(),
444            server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new(
445                gen_self_signed_root_cert(),
446            ))
447            .build()
448            .unwrap(),
449            client_cert_verifier: None,
450            retransmit_interval: std::time::Duration::from_secs(0),
451            initial_epoch: 0,
452            maximum_transmission_unit: DEFAULT_MTU,
453            maximum_retransmit_number: 7,
454            replay_protection_window: DEFAULT_REPLAY_PROTECTION_WINDOW,
455        }
456    }
457}
458
459impl HandshakeConfig {
460    pub(crate) fn get_certificate(&self, server_name: &str) -> Result<Certificate> {
461        if self.local_certificates.is_empty() {
462            return Err(Error::ErrNoCertificates);
463        }
464
465        if self.local_certificates.len() == 1 {
466            // There's only one choice, so no point doing any work.
467            return Ok(self.local_certificates[0].clone());
468        }
469
470        if server_name.is_empty() {
471            return Ok(self.local_certificates[0].clone());
472        }
473
474        let lower = server_name.to_lowercase();
475        let name = lower.trim_end_matches('.');
476
477        if let Some(cert) = self.name_to_certificate.get(name) {
478            return Ok(cert.clone());
479        }
480
481        // try replacing labels in the name with wildcards until we get a
482        // match.
483        let mut labels: Vec<&str> = name.split_terminator('.').collect();
484        for i in 0..labels.len() {
485            labels[i] = "*";
486            let candidate = labels.join(".");
487            if let Some(cert) = self.name_to_certificate.get(&candidate) {
488                return Ok(cert.clone());
489            }
490        }
491
492        // If nothing matches, return the first certificate.
493        Ok(self.local_certificates[0].clone())
494    }
495}