1#![allow(clippy::cast_possible_truncation)]
42#![allow(clippy::result_large_err)]
44#![allow(clippy::needless_pass_by_value)]
46
47use crate::config::{SslMode, TlsConfig};
48use crate::protocol::{PacketWriter, capabilities};
49use sqlmodel_core::Error;
50use sqlmodel_core::error::{ConnectionError, ConnectionErrorKind};
51
52#[cfg(feature = "tls")]
53use std::io::{Read, Write};
54#[cfg(feature = "tls")]
55use std::sync::Arc;
56
57pub fn build_ssl_request_packet(
72 client_caps: u32,
73 max_packet_size: u32,
74 character_set: u8,
75 sequence_id: u8,
76) -> Vec<u8> {
77 let mut writer = PacketWriter::with_capacity(32);
78
79 let caps_with_ssl = client_caps | capabilities::CLIENT_SSL;
81 writer.write_u32_le(caps_with_ssl);
82
83 writer.write_u32_le(max_packet_size);
85
86 writer.write_u8(character_set);
88
89 writer.write_zeros(23);
91
92 writer.build_packet(sequence_id)
93}
94
95pub const fn server_supports_ssl(server_caps: u32) -> bool {
105 server_caps & capabilities::CLIENT_SSL != 0
106}
107
108pub fn validate_ssl_mode(ssl_mode: SslMode, server_caps: u32) -> Result<bool, Error> {
116 let server_supports = server_supports_ssl(server_caps);
117
118 match ssl_mode {
119 SslMode::Disable => Ok(false),
120 SslMode::Preferred => Ok(server_supports),
121 SslMode::Required | SslMode::VerifyCa | SslMode::VerifyIdentity => {
122 if server_supports {
123 Ok(true)
124 } else {
125 Err(tls_error("SSL required but server does not support it"))
126 }
127 }
128 }
129}
130
131pub fn validate_tls_config(ssl_mode: SslMode, tls_config: &TlsConfig) -> Result<(), Error> {
142 match ssl_mode {
143 SslMode::Disable | SslMode::Preferred | SslMode::Required => {
144 Ok(())
146 }
147 SslMode::VerifyCa | SslMode::VerifyIdentity => {
148 if tls_config.ca_cert_path.is_none() && !tls_config.danger_skip_verify {
150 return Err(tls_error(
151 "CA certificate required for VerifyCa/VerifyIdentity mode. \
152 Set ca_cert_path or danger_skip_verify.",
153 ));
154 }
155
156 if tls_config.client_cert_path.is_some() && tls_config.client_key_path.is_none() {
158 return Err(tls_error(
159 "Client certificate provided without client key. \
160 Both must be set for mutual TLS.",
161 ));
162 }
163
164 Ok(())
165 }
166 }
167}
168
169fn tls_error(message: impl Into<String>) -> Error {
171 Error::Connection(ConnectionError {
172 kind: ConnectionErrorKind::Ssl,
173 message: message.into(),
174 source: None,
175 })
176}
177
178#[cfg(feature = "tls")]
196pub struct TlsStream<S: Read + Write> {
197 conn: rustls::ClientConnection,
199 stream: S,
201}
202
203#[cfg(feature = "tls")]
204impl<S: Read + Write> std::fmt::Debug for TlsStream<S> {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("TlsStream")
207 .field("protocol_version", &self.conn.protocol_version())
208 .field("is_handshaking", &self.conn.is_handshaking())
209 .finish_non_exhaustive()
210 }
211}
212
213#[cfg(feature = "tls")]
214impl<S: Read + Write> TlsStream<S> {
215 pub fn new(
228 mut stream: S,
229 tls_config: &TlsConfig,
230 server_name: &str,
231 ssl_mode: SslMode,
232 ) -> Result<Self, Error> {
233 let config = build_client_config(tls_config, ssl_mode)?;
235
236 let sni_name = tls_config.server_name.as_deref().unwrap_or(server_name);
238
239 let server_name = sni_name
240 .to_string()
241 .try_into()
242 .map_err(|e| tls_error(format!("Invalid server name '{}': {}", sni_name, e)))?;
243
244 let mut conn = rustls::ClientConnection::new(Arc::new(config), server_name)
246 .map_err(|e| tls_error(format!("Failed to create TLS connection: {}", e)))?;
247
248 while conn.is_handshaking() {
251 while conn.wants_write() {
253 conn.write_tls(&mut stream)
254 .map_err(|e| tls_error(format!("TLS handshake write error: {}", e)))?;
255 }
256
257 if conn.wants_read() {
259 conn.read_tls(&mut stream)
260 .map_err(|e| tls_error(format!("TLS handshake read error: {}", e)))?;
261
262 conn.process_new_packets()
264 .map_err(|e| tls_error(format!("TLS handshake error: {}", e)))?;
265 }
266 }
267
268 Ok(TlsStream { conn, stream })
269 }
270
271 pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
273 self.conn.protocol_version()
274 }
275
276 pub fn negotiated_cipher_suite(&self) -> Option<rustls::SupportedCipherSuite> {
278 self.conn.negotiated_cipher_suite()
279 }
280
281 pub fn is_tls13(&self) -> bool {
283 self.conn.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
284 }
285}
286
287#[cfg(feature = "tls")]
288impl<S: Read + Write> Read for TlsStream<S> {
289 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
290 loop {
292 match self.conn.reader().read(buf) {
294 Ok(n) if n > 0 => return Ok(n),
295 Ok(_) => {}
296 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
297 Err(e) => return Err(e),
298 }
299
300 if self.conn.wants_read() {
302 let n = self.conn.read_tls(&mut self.stream)?;
303 if n == 0 {
304 return Ok(0); }
306
307 self.conn
309 .process_new_packets()
310 .map_err(|e| std::io::Error::other(format!("TLS error: {}", e)))?;
311 } else {
312 return Ok(0);
313 }
314 }
315 }
316}
317
318#[cfg(feature = "tls")]
319impl<S: Read + Write> Write for TlsStream<S> {
320 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
321 let n = self.conn.writer().write(buf)?;
323
324 while self.conn.wants_write() {
326 self.conn.write_tls(&mut self.stream)?;
327 }
328
329 Ok(n)
330 }
331
332 fn flush(&mut self) -> std::io::Result<()> {
333 self.conn.writer().flush()?;
334 while self.conn.wants_write() {
335 self.conn.write_tls(&mut self.stream)?;
336 }
337 self.stream.flush()
338 }
339}
340
341#[cfg(feature = "tls")]
343pub(crate) fn build_client_config(
344 tls_config: &TlsConfig,
345 ssl_mode: SslMode,
346) -> Result<rustls::ClientConfig, Error> {
347 let provider = Arc::new(rustls::crypto::ring::default_provider());
349
350 match ssl_mode {
351 SslMode::Disable => {
352 Err(tls_error("TlsStream created with SslMode::Disable"))
354 }
355
356 SslMode::Preferred | SslMode::Required => {
357 if tls_config.danger_skip_verify {
360 build_no_verify_config(&provider)
361 } else {
362 build_webpki_config(&provider, tls_config)
364 }
365 }
366
367 SslMode::VerifyCa | SslMode::VerifyIdentity => {
368 if tls_config.danger_skip_verify {
369 build_no_verify_config(&provider)
371 } else if let Some(ca_path) = &tls_config.ca_cert_path {
372 build_custom_ca_config(&provider, tls_config, ca_path)
374 } else {
375 build_webpki_config(&provider, tls_config)
377 }
378 }
379 }
380}
381
382#[cfg(feature = "tls")]
384fn build_no_verify_config(
385 provider: &Arc<rustls::crypto::CryptoProvider>,
386) -> Result<rustls::ClientConfig, Error> {
387 use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
388 use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
389 use rustls::{DigitallySignedStruct, Error as RustlsError, SignatureScheme};
390
391 #[derive(Debug)]
393 struct NoVerifier;
394
395 impl ServerCertVerifier for NoVerifier {
396 fn verify_server_cert(
397 &self,
398 _end_entity: &CertificateDer<'_>,
399 _intermediates: &[CertificateDer<'_>],
400 _server_name: &ServerName<'_>,
401 _ocsp_response: &[u8],
402 _now: UnixTime,
403 ) -> Result<ServerCertVerified, RustlsError> {
404 Ok(ServerCertVerified::assertion())
405 }
406
407 fn verify_tls12_signature(
408 &self,
409 _message: &[u8],
410 _cert: &CertificateDer<'_>,
411 _dss: &DigitallySignedStruct,
412 ) -> Result<HandshakeSignatureValid, RustlsError> {
413 Ok(HandshakeSignatureValid::assertion())
414 }
415
416 fn verify_tls13_signature(
417 &self,
418 _message: &[u8],
419 _cert: &CertificateDer<'_>,
420 _dss: &DigitallySignedStruct,
421 ) -> Result<HandshakeSignatureValid, RustlsError> {
422 Ok(HandshakeSignatureValid::assertion())
423 }
424
425 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
426 vec![
427 SignatureScheme::RSA_PKCS1_SHA256,
428 SignatureScheme::RSA_PKCS1_SHA384,
429 SignatureScheme::RSA_PKCS1_SHA512,
430 SignatureScheme::ECDSA_NISTP256_SHA256,
431 SignatureScheme::ECDSA_NISTP384_SHA384,
432 SignatureScheme::ECDSA_NISTP521_SHA512,
433 SignatureScheme::RSA_PSS_SHA256,
434 SignatureScheme::RSA_PSS_SHA384,
435 SignatureScheme::RSA_PSS_SHA512,
436 SignatureScheme::ED25519,
437 ]
438 }
439 }
440
441 let config = rustls::ClientConfig::builder_with_provider(provider.clone())
442 .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
443 .map_err(|e| tls_error(format!("Failed to set TLS versions: {}", e)))?
444 .dangerous()
445 .with_custom_certificate_verifier(Arc::new(NoVerifier))
446 .with_no_client_auth();
447
448 Ok(config)
449}
450
451#[cfg(feature = "tls")]
453fn build_webpki_config(
454 provider: &Arc<rustls::crypto::CryptoProvider>,
455 tls_config: &TlsConfig,
456) -> Result<rustls::ClientConfig, Error> {
457 use rustls::RootCertStore;
458
459 let mut root_store = RootCertStore::empty();
460 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
461
462 let builder = rustls::ClientConfig::builder_with_provider(provider.clone())
463 .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
464 .map_err(|e| tls_error(format!("Failed to set TLS versions: {}", e)))?
465 .with_root_certificates(root_store);
466
467 let config = add_client_auth(builder, tls_config)?;
469
470 Ok(config)
471}
472
473#[cfg(feature = "tls")]
475fn build_custom_ca_config(
476 provider: &Arc<rustls::crypto::CryptoProvider>,
477 tls_config: &TlsConfig,
478 ca_path: &std::path::Path,
479) -> Result<rustls::ClientConfig, Error> {
480 use rustls::RootCertStore;
481 use std::fs::File;
482 use std::io::BufReader;
483
484 let ca_file = File::open(ca_path).map_err(|e| {
486 tls_error(format!(
487 "Failed to open CA certificate '{}': {}",
488 ca_path.display(),
489 e
490 ))
491 })?;
492 let mut reader = BufReader::new(ca_file);
493
494 let certs = rustls_pemfile::certs(&mut reader)
495 .collect::<Result<Vec<_>, _>>()
496 .map_err(|e| tls_error(format!("Failed to parse CA certificate: {}", e)))?;
497
498 if certs.is_empty() {
499 return Err(tls_error(format!(
500 "No certificates found in CA file '{}'",
501 ca_path.display()
502 )));
503 }
504
505 let mut root_store = RootCertStore::empty();
506 for cert in certs {
507 root_store
508 .add(cert)
509 .map_err(|e| tls_error(format!("Failed to add CA certificate: {}", e)))?;
510 }
511
512 let builder = rustls::ClientConfig::builder_with_provider(provider.clone())
513 .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
514 .map_err(|e| tls_error(format!("Failed to set TLS versions: {}", e)))?
515 .with_root_certificates(root_store);
516
517 let config = add_client_auth(builder, tls_config)?;
519
520 Ok(config)
521}
522
523#[cfg(feature = "tls")]
525fn add_client_auth(
526 builder: rustls::ConfigBuilder<rustls::ClientConfig, rustls::client::WantsClientCert>,
527 tls_config: &TlsConfig,
528) -> Result<rustls::ClientConfig, Error> {
529 use std::fs::File;
530 use std::io::BufReader;
531
532 if let (Some(cert_path), Some(key_path)) =
533 (&tls_config.client_cert_path, &tls_config.client_key_path)
534 {
535 let cert_file = File::open(cert_path).map_err(|e| {
537 tls_error(format!(
538 "Failed to open client cert '{}': {}",
539 cert_path.display(),
540 e
541 ))
542 })?;
543 let mut cert_reader = BufReader::new(cert_file);
544
545 let certs = rustls_pemfile::certs(&mut cert_reader)
546 .collect::<Result<Vec<_>, _>>()
547 .map_err(|e| tls_error(format!("Failed to parse client certificate: {}", e)))?;
548
549 if certs.is_empty() {
550 return Err(tls_error(format!(
551 "No certificates found in client cert file '{}'",
552 cert_path.display()
553 )));
554 }
555
556 let key_file = File::open(key_path).map_err(|e| {
558 tls_error(format!(
559 "Failed to open client key '{}': {}",
560 key_path.display(),
561 e
562 ))
563 })?;
564 let mut key_reader = BufReader::new(key_file);
565
566 let key = rustls_pemfile::private_key(&mut key_reader)
567 .map_err(|e| tls_error(format!("Failed to parse client key: {}", e)))?
568 .ok_or_else(|| {
569 tls_error(format!("No private key found in '{}'", key_path.display()))
570 })?;
571
572 builder
573 .with_client_auth_cert(certs, key)
574 .map_err(|e| tls_error(format!("Failed to configure client auth: {}", e)))
575 } else {
576 Ok(builder.with_no_client_auth())
577 }
578}
579
580#[cfg(not(feature = "tls"))]
586#[derive(Debug)]
587pub struct TlsStream<S> {
588 #[allow(dead_code)]
590 inner: S,
591}
592
593#[cfg(not(feature = "tls"))]
594impl<S> TlsStream<S> {
595 #[allow(unused_variables)]
602 pub fn new(
603 stream: S,
604 tls_config: &TlsConfig,
605 server_name: &str,
606 ssl_mode: SslMode,
607 ) -> Result<Self, Error> {
608 Err(tls_error(
609 "TLS support requires the 'tls' feature. \
610 Add `sqlmodel-mysql = { features = [\"tls\"] }` to your Cargo.toml.",
611 ))
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::protocol::charset;
619
620 #[test]
621 fn test_build_ssl_request_packet() {
622 let packet = build_ssl_request_packet(
623 capabilities::DEFAULT_CLIENT_FLAGS,
624 16 * 1024 * 1024, charset::UTF8MB4_0900_AI_CI,
626 1,
627 );
628
629 assert_eq!(packet.len(), 36);
631
632 assert_eq!(packet[0], 32); assert_eq!(packet[1], 0); assert_eq!(packet[2], 0); assert_eq!(packet[3], 1); let caps = u32::from_le_bytes([packet[4], packet[5], packet[6], packet[7]]);
640 assert!(caps & capabilities::CLIENT_SSL != 0);
641 }
642
643 #[test]
644 fn test_server_supports_ssl() {
645 assert!(server_supports_ssl(capabilities::CLIENT_SSL));
646 assert!(server_supports_ssl(
647 capabilities::CLIENT_SSL | capabilities::CLIENT_PROTOCOL_41
648 ));
649 assert!(!server_supports_ssl(0));
650 assert!(!server_supports_ssl(capabilities::CLIENT_PROTOCOL_41));
651 }
652
653 #[test]
654 fn test_validate_ssl_mode_disable() {
655 assert!(!validate_ssl_mode(SslMode::Disable, 0).unwrap());
656 assert!(!validate_ssl_mode(SslMode::Disable, capabilities::CLIENT_SSL).unwrap());
657 }
658
659 #[test]
660 fn test_validate_ssl_mode_preferred() {
661 assert!(!validate_ssl_mode(SslMode::Preferred, 0).unwrap());
663 assert!(validate_ssl_mode(SslMode::Preferred, capabilities::CLIENT_SSL).unwrap());
665 }
666
667 #[test]
668 fn test_validate_ssl_mode_required() {
669 assert!(validate_ssl_mode(SslMode::Required, 0).is_err());
671 assert!(validate_ssl_mode(SslMode::Required, capabilities::CLIENT_SSL).unwrap());
673 }
674
675 #[test]
676 fn test_validate_ssl_mode_verify() {
677 assert!(validate_ssl_mode(SslMode::VerifyCa, 0).is_err());
679 assert!(validate_ssl_mode(SslMode::VerifyIdentity, 0).is_err());
680
681 assert!(validate_ssl_mode(SslMode::VerifyCa, capabilities::CLIENT_SSL).unwrap());
683 assert!(validate_ssl_mode(SslMode::VerifyIdentity, capabilities::CLIENT_SSL).unwrap());
684 }
685
686 #[test]
687 fn test_validate_tls_config_basic_modes() {
688 let config = TlsConfig::new();
689
690 assert!(validate_tls_config(SslMode::Disable, &config).is_ok());
692 assert!(validate_tls_config(SslMode::Preferred, &config).is_ok());
693 assert!(validate_tls_config(SslMode::Required, &config).is_ok());
694 }
695
696 #[test]
697 fn test_validate_tls_config_verify_modes() {
698 let config = TlsConfig::new();
700 assert!(validate_tls_config(SslMode::VerifyCa, &config).is_err());
701 assert!(validate_tls_config(SslMode::VerifyIdentity, &config).is_err());
702
703 let config = TlsConfig::new().ca_cert("/path/to/ca.pem");
705 assert!(validate_tls_config(SslMode::VerifyCa, &config).is_ok());
706 assert!(validate_tls_config(SslMode::VerifyIdentity, &config).is_ok());
707
708 let config = TlsConfig::new().skip_verify(true);
710 assert!(validate_tls_config(SslMode::VerifyCa, &config).is_ok());
711 }
712
713 #[test]
714 fn test_validate_tls_config_client_cert() {
715 let config = TlsConfig::new()
717 .ca_cert("/path/to/ca.pem")
718 .client_cert("/path/to/client.pem");
719 assert!(validate_tls_config(SslMode::VerifyCa, &config).is_err());
720
721 let config = TlsConfig::new()
723 .ca_cert("/path/to/ca.pem")
724 .client_cert("/path/to/client.pem")
725 .client_key("/path/to/client-key.pem");
726 assert!(validate_tls_config(SslMode::VerifyCa, &config).is_ok());
727 }
728}