1use crate::cipher_suite::*;
2use crate::conn::{DEFAULT_REPLAY_PROTECTION_WINDOW, INITIAL_TICKER_INTERVAL};
3use crate::crypto::*;
4use crate::extension::extension_use_srtp::SrtpProtectionProfile;
5use crate::signature_hash_algorithm::{
6 SignatureHashAlgorithm, SignatureScheme, parse_signature_schemes,
7};
8use log::warn;
9use shared::error::*;
10use std::collections::HashMap;
11use std::fmt;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15
16use rustls::client::danger::ServerCertVerifier;
17use rustls::pki_types::CertificateDer;
18use rustls::server::danger::ClientCertVerifier;
19
20#[derive(Clone)]
23pub struct ConfigBuilder {
24 certificates: Vec<Certificate>,
25 cipher_suites: Vec<CipherSuiteId>,
26 signature_schemes: Vec<SignatureScheme>,
27 srtp_protection_profiles: Vec<SrtpProtectionProfile>,
28 client_auth: ClientAuthType,
29 extended_master_secret: ExtendedMasterSecretType,
30 flight_interval: Duration,
31 psk: Option<PskCallback>,
32 psk_identity_hint: Option<Vec<u8>>,
33 insecure_skip_verify: bool,
34 insecure_hashes: bool,
35 insecure_verification: bool,
36 verify_peer_certificate: Option<VerifyPeerCertificateFn>,
37 roots_cas: rustls::RootCertStore,
38 client_cas: rustls::RootCertStore,
39 server_name: String,
40 mtu: usize,
41 replay_protection_window: usize,
42}
43
44impl Default for ConfigBuilder {
45 fn default() -> Self {
46 Self {
47 certificates: vec![],
48 cipher_suites: vec![],
49 signature_schemes: vec![],
50 srtp_protection_profiles: vec![],
51 client_auth: ClientAuthType::default(),
52 extended_master_secret: ExtendedMasterSecretType::default(),
53 flight_interval: Duration::default(),
54 psk: None,
55 psk_identity_hint: None,
56 insecure_skip_verify: false,
57 insecure_hashes: false,
58 insecure_verification: false,
59 verify_peer_certificate: None,
60 roots_cas: rustls::RootCertStore::empty(),
61 client_cas: rustls::RootCertStore::empty(),
62 server_name: String::default(),
63 mtu: 0,
64 replay_protection_window: 0,
65 }
66 }
67}
68
69impl ConfigBuilder {
70 pub fn with_certificates(mut self, certificates: Vec<Certificate>) -> Self {
74 self.certificates = certificates;
75 self
76 }
77
78 pub fn with_cipher_suites(mut self, cipher_suites: Vec<CipherSuiteId>) -> Self {
81 self.cipher_suites = cipher_suites;
82 self
83 }
84
85 pub fn with_signature_schemes(mut self, signature_schemes: Vec<SignatureScheme>) -> Self {
87 self.signature_schemes = signature_schemes;
88 self
89 }
90
91 pub fn with_srtp_protection_profiles(
95 mut self,
96 srtp_protection_profiles: Vec<SrtpProtectionProfile>,
97 ) -> Self {
98 self.srtp_protection_profiles = srtp_protection_profiles;
99 self
100 }
101
102 pub fn with_client_auth(mut self, client_auth: ClientAuthType) -> Self {
105 self.client_auth = client_auth;
106 self
107 }
108
109 pub fn with_extended_master_secret(
112 mut self,
113 extended_master_secret: ExtendedMasterSecretType,
114 ) -> Self {
115 self.extended_master_secret = extended_master_secret;
116 self
117 }
118
119 pub fn with_flight_interval(mut self, flight_interval: Duration) -> Self {
122 self.flight_interval = flight_interval;
123 self
124 }
125
126 pub fn with_psk(mut self, psk: Option<PskCallback>) -> Self {
129 self.psk = psk;
130 self
131 }
132
133 pub fn with_psk_identity_hint(mut self, psk_identity_hint: Option<Vec<u8>>) -> Self {
135 self.psk_identity_hint = psk_identity_hint;
136 self
137 }
138
139 pub fn with_insecure_skip_verify(mut self, insecure_skip_verify: bool) -> Self {
146 self.insecure_skip_verify = insecure_skip_verify;
147 self
148 }
149
150 pub fn with_insecure_hashes(mut self, insecure_hashes: bool) -> Self {
153 self.insecure_hashes = insecure_hashes;
154 self
155 }
156
157 pub fn with_insecure_verification(mut self, insecure_verification: bool) -> Self {
160 self.insecure_verification = insecure_verification;
161 self
162 }
163
164 pub fn with_verify_peer_certificate(
176 mut self,
177 verify_peer_certificate: Option<VerifyPeerCertificateFn>,
178 ) -> Self {
179 self.verify_peer_certificate = verify_peer_certificate;
180 self
181 }
182
183 pub fn with_roots_cas(mut self, roots_cas: rustls::RootCertStore) -> Self {
188 self.roots_cas = roots_cas;
189 self
190 }
191
192 pub fn with_client_cas(mut self, client_cas: rustls::RootCertStore) -> Self {
197 self.client_cas = client_cas;
198 self
199 }
200
201 pub fn with_server_name(mut self, server_name: String) -> Self {
204 self.server_name = server_name;
205 self
206 }
207
208 pub fn with_mtu(mut self, mtu: usize) -> Self {
211 self.mtu = mtu;
212 self
213 }
214
215 pub fn with_replay_protection_window(mut self, replay_protection_window: usize) -> Self {
220 self.replay_protection_window = replay_protection_window;
221 self
222 }
223}
224
225pub(crate) const DEFAULT_MTU: usize = 1200; pub(crate) type PskCallback = Arc<dyn (Fn(&[u8]) -> Result<Vec<u8>>) + Send + Sync>;
230
231#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)]
234pub enum ClientAuthType {
235 #[default]
236 NoClientCert = 0,
237 RequestClientCert = 1,
238 RequireAnyClientCert = 2,
239 VerifyClientCertIfGiven = 3,
240 RequireAndVerifyClientCert = 4,
241}
242
243#[derive(Debug, Default, PartialEq, Eq, Copy, Clone)]
246pub enum ExtendedMasterSecretType {
247 #[default]
248 Request = 0,
249 Require = 1,
250 Disable = 2,
251}
252
253impl ConfigBuilder {
254 fn validate(&self, is_client: bool) -> Result<()> {
255 if is_client && self.psk.is_some() && self.psk_identity_hint.is_none() {
256 return Err(Error::ErrPskAndIdentityMustBeSetForClient);
257 }
258
259 if !is_client && self.psk.is_none() && self.certificates.is_empty() {
260 return Err(Error::ErrServerMustHaveCertificate);
261 }
262
263 if !self.certificates.is_empty() && self.psk.is_some() {
264 return Err(Error::ErrPskAndCertificate);
265 }
266
267 if self.psk_identity_hint.is_some() && self.psk.is_none() {
268 return Err(Error::ErrIdentityNoPsk);
269 }
270
271 for cert in &self.certificates {
272 match cert.private_key.kind {
273 CryptoPrivateKeyKind::Ed25519(_) => {}
274 CryptoPrivateKeyKind::Ecdsa256(_) => {}
275 _ => return Err(Error::ErrInvalidPrivateKey),
276 }
277 }
278
279 parse_cipher_suites(&self.cipher_suites, self.psk.is_none(), self.psk.is_some())?;
280
281 Ok(())
282 }
283
284 pub fn build(
286 mut self,
287 is_client: bool,
288 remote_addr: Option<SocketAddr>,
289 ) -> Result<HandshakeConfig> {
290 self.validate(is_client)?;
291
292 let local_cipher_suites: Vec<CipherSuiteId> =
293 parse_cipher_suites(&self.cipher_suites, self.psk.is_none(), self.psk.is_some())?
294 .iter()
295 .map(|cs| cs.id())
296 .collect();
297
298 let sigs: Vec<u16> = self.signature_schemes.iter().map(|x| *x as u16).collect();
299 let local_signature_schemes = parse_signature_schemes(&sigs, self.insecure_hashes)?;
300
301 let retransmit_interval = if self.flight_interval != Duration::from_secs(0) {
302 self.flight_interval
303 } else {
304 INITIAL_TICKER_INTERVAL
305 };
306
307 let maximum_transmission_unit = if self.mtu == 0 { DEFAULT_MTU } else { self.mtu };
308
309 let replay_protection_window = if self.replay_protection_window == 0 {
310 DEFAULT_REPLAY_PROTECTION_WINDOW
311 } else {
312 self.replay_protection_window
313 };
314
315 let mut server_name = self.server_name.clone();
316
317 if is_client && server_name.is_empty() {
319 if let Some(remote_addr) = remote_addr {
320 server_name = remote_addr.ip().to_string();
321 } else {
322 warn!(
323 "conn.remote_addr is empty, please set explicitly server_name in Config! Use default \"localhost\" as server_name now"
324 );
325 "localhost".clone_into(&mut server_name);
326 }
327 }
328
329 Ok(HandshakeConfig {
330 local_psk_callback: self.psk.take(),
331 local_psk_identity_hint: self.psk_identity_hint.take(),
332 local_cipher_suites,
333 local_signature_schemes,
334 extended_master_secret: self.extended_master_secret,
335 local_srtp_protection_profiles: self.srtp_protection_profiles,
336 server_name,
337 client_auth: self.client_auth,
338 local_certificates: self.certificates,
339 insecure_skip_verify: self.insecure_skip_verify,
340 insecure_verification: self.insecure_verification,
341 verify_peer_certificate: self.verify_peer_certificate.take(),
342 roots_cas: self.roots_cas,
343 server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new(
344 gen_self_signed_root_cert(),
345 ))
346 .build()
347 .unwrap(),
348 client_cert_verifier: None,
349 retransmit_interval,
350 initial_epoch: 0,
351 maximum_transmission_unit,
352 replay_protection_window,
353 ..Default::default()
354 })
355 }
356}
357
358pub type VerifyPeerCertificateFn =
359 Arc<dyn (Fn(&[Vec<u8>], &[CertificateDer<'static>]) -> Result<()>) + Send + Sync>;
360
361pub fn gen_self_signed_root_cert() -> rustls::RootCertStore {
362 let mut certs = rustls::RootCertStore::empty();
363 certs
364 .add(
365 rcgen::generate_simple_self_signed(vec![])
366 .unwrap()
367 .cert
368 .der()
369 .to_owned(),
370 )
371 .unwrap();
372 certs
373}
374
375#[derive(Clone)]
376pub struct HandshakeConfig {
377 pub(crate) local_psk_callback: Option<PskCallback>,
378 pub(crate) local_psk_identity_hint: Option<Vec<u8>>,
379 pub(crate) local_cipher_suites: Vec<CipherSuiteId>, pub(crate) local_signature_schemes: Vec<SignatureHashAlgorithm>, pub(crate) extended_master_secret: ExtendedMasterSecretType, pub(crate) local_srtp_protection_profiles: Vec<SrtpProtectionProfile>, pub(crate) server_name: String,
384 pub(crate) client_auth: ClientAuthType, pub(crate) local_certificates: Vec<Certificate>,
386 pub(crate) name_to_certificate: HashMap<String, Certificate>,
387 pub(crate) insecure_skip_verify: bool,
388 pub(crate) insecure_verification: bool,
389 pub(crate) verify_peer_certificate: Option<VerifyPeerCertificateFn>,
390 pub(crate) roots_cas: rustls::RootCertStore,
391 pub(crate) server_cert_verifier: Arc<dyn ServerCertVerifier>,
392 pub(crate) client_cert_verifier: Option<Arc<dyn ClientCertVerifier>>,
393 pub(crate) retransmit_interval: std::time::Duration,
394 pub(crate) initial_epoch: u16,
395 pub(crate) maximum_transmission_unit: usize,
396 pub(crate) maximum_retransmit_number: usize,
397 pub(crate) replay_protection_window: usize,
398}
399
400impl fmt::Debug for HandshakeConfig {
401 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
402 fmt.debug_struct("HandshakeConfig<T>")
403 .field("local_psk_identity_hint", &self.local_psk_identity_hint)
404 .field("local_cipher_suites", &self.local_cipher_suites)
405 .field("local_signature_schemes", &self.local_signature_schemes)
406 .field("extended_master_secret", &self.extended_master_secret)
407 .field(
408 "local_srtp_protection_profiles",
409 &self.local_srtp_protection_profiles,
410 )
411 .field("server_name", &self.server_name)
412 .field("client_auth", &self.client_auth)
413 .field("local_certificates", &self.local_certificates)
414 .field("name_to_certificate", &self.name_to_certificate)
415 .field("insecure_skip_verify", &self.insecure_skip_verify)
416 .field("insecure_verification", &self.insecure_verification)
417 .field("roots_cas", &self.roots_cas)
418 .field("retransmit_interval", &self.retransmit_interval)
419 .field("initial_epoch", &self.initial_epoch)
420 .field("maximum_transmission_unit", &self.maximum_transmission_unit)
421 .field("maximum_retransmit_number", &self.maximum_retransmit_number)
422 .field("replay_protection_window", &self.replay_protection_window)
423 .finish()
424 }
425}
426
427impl Default for HandshakeConfig {
428 fn default() -> Self {
429 HandshakeConfig {
430 local_psk_callback: None,
431 local_psk_identity_hint: None,
432 local_cipher_suites: vec![],
433 local_signature_schemes: vec![],
434 extended_master_secret: ExtendedMasterSecretType::Disable,
435 local_srtp_protection_profiles: vec![],
436 server_name: String::new(),
437 client_auth: ClientAuthType::NoClientCert,
438 local_certificates: vec![],
439 name_to_certificate: HashMap::new(),
440 insecure_skip_verify: false,
441 insecure_verification: false,
442 verify_peer_certificate: None,
443 roots_cas: rustls::RootCertStore::empty(),
444 server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new(
445 gen_self_signed_root_cert(),
446 ))
447 .build()
448 .unwrap(),
449 client_cert_verifier: None,
450 retransmit_interval: std::time::Duration::from_secs(0),
451 initial_epoch: 0,
452 maximum_transmission_unit: DEFAULT_MTU,
453 maximum_retransmit_number: 7,
454 replay_protection_window: DEFAULT_REPLAY_PROTECTION_WINDOW,
455 }
456 }
457}
458
459impl HandshakeConfig {
460 pub(crate) fn get_certificate(&self, server_name: &str) -> Result<Certificate> {
461 if self.local_certificates.is_empty() {
462 return Err(Error::ErrNoCertificates);
463 }
464
465 if self.local_certificates.len() == 1 {
466 return Ok(self.local_certificates[0].clone());
468 }
469
470 if server_name.is_empty() {
471 return Ok(self.local_certificates[0].clone());
472 }
473
474 let lower = server_name.to_lowercase();
475 let name = lower.trim_end_matches('.');
476
477 if let Some(cert) = self.name_to_certificate.get(name) {
478 return Ok(cert.clone());
479 }
480
481 let mut labels: Vec<&str> = name.split_terminator('.').collect();
484 for i in 0..labels.len() {
485 labels[i] = "*";
486 let candidate = labels.join(".");
487 if let Some(cert) = self.name_to_certificate.get(&candidate) {
488 return Ok(cert.clone());
489 }
490 }
491
492 Ok(self.local_certificates[0].clone())
494 }
495}