Skip to main content

sqlmodel_mysql/
config.rs

1//! MySQL connection configuration.
2//!
3//! Provides connection parameters for establishing MySQL connections
4//! including authentication, SSL, and connection options.
5
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::Duration;
9
10/// TLS/SSL configuration for MySQL connections.
11///
12/// This struct holds the certificate and key paths for TLS connections.
13/// The actual TLS implementation requires the `tls` feature to be enabled.
14#[derive(Debug, Clone, Default)]
15pub struct TlsConfig {
16    /// Path to CA certificate file (PEM format) for server verification.
17    /// Required for `SslMode::VerifyCa` and `SslMode::VerifyIdentity`.
18    pub ca_cert_path: Option<PathBuf>,
19
20    /// Path to client certificate file (PEM format) for mutual TLS.
21    /// Optional - only needed if server requires client certificate.
22    pub client_cert_path: Option<PathBuf>,
23
24    /// Path to client private key file (PEM format) for mutual TLS.
25    /// Required if `client_cert_path` is set.
26    pub client_key_path: Option<PathBuf>,
27
28    /// Skip server certificate verification.
29    ///
30    /// # Security Warning
31    /// Setting this to `true` disables certificate verification, making the
32    /// connection vulnerable to man-in-the-middle attacks. Only use for
33    /// development/testing with self-signed certificates.
34    pub danger_skip_verify: bool,
35
36    /// Server name for SNI (Server Name Indication).
37    /// If not set, defaults to the connection hostname.
38    pub server_name: Option<String>,
39}
40
41impl TlsConfig {
42    /// Create a new TLS configuration with default values.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Set the CA certificate path.
48    pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
49        self.ca_cert_path = Some(path.into());
50        self
51    }
52
53    /// Set the client certificate path.
54    pub fn client_cert(mut self, path: impl Into<PathBuf>) -> Self {
55        self.client_cert_path = Some(path.into());
56        self
57    }
58
59    /// Set the client key path.
60    pub fn client_key(mut self, path: impl Into<PathBuf>) -> Self {
61        self.client_key_path = Some(path.into());
62        self
63    }
64
65    /// Skip server certificate verification (dangerous!).
66    pub fn skip_verify(mut self, skip: bool) -> Self {
67        self.danger_skip_verify = skip;
68        self
69    }
70
71    /// Set the server name for SNI.
72    pub fn server_name(mut self, name: impl Into<String>) -> Self {
73        self.server_name = Some(name.into());
74        self
75    }
76
77    /// Check if mutual TLS (client certificate) is configured.
78    pub fn has_client_cert(&self) -> bool {
79        self.client_cert_path.is_some() && self.client_key_path.is_some()
80    }
81}
82
83/// SSL mode for MySQL connections.
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
85pub enum SslMode {
86    /// Do not use SSL
87    #[default]
88    Disable,
89    /// Prefer SSL if available, fall back to non-SSL
90    Preferred,
91    /// Require SSL connection
92    Required,
93    /// Require SSL and verify server certificate
94    VerifyCa,
95    /// Require SSL and verify server certificate matches hostname
96    VerifyIdentity,
97}
98
99impl SslMode {
100    /// Check if SSL should be attempted.
101    pub const fn should_try_ssl(self) -> bool {
102        !matches!(self, SslMode::Disable)
103    }
104
105    /// Check if SSL is required.
106    pub const fn is_required(self) -> bool {
107        matches!(
108            self,
109            SslMode::Required | SslMode::VerifyCa | SslMode::VerifyIdentity
110        )
111    }
112}
113
114/// MySQL connection configuration.
115#[derive(Debug, Clone)]
116pub struct MySqlConfig {
117    /// Hostname or IP address
118    pub host: String,
119    /// Port number (default: 3306)
120    pub port: u16,
121    /// Username for authentication
122    pub user: String,
123    /// Password for authentication
124    pub password: Option<String>,
125    /// Database name to connect to (optional at connect time)
126    pub database: Option<String>,
127    /// Character set (default: utf8mb4)
128    pub charset: u8,
129    /// Connection timeout
130    pub connect_timeout: Duration,
131    /// SSL mode
132    pub ssl_mode: SslMode,
133    /// TLS configuration (certificates, keys, etc.)
134    pub tls_config: TlsConfig,
135    /// Enable compression (CLIENT_COMPRESS capability)
136    pub compression: bool,
137    /// Additional connection attributes
138    pub attributes: HashMap<String, String>,
139    /// Local infile handling (disabled by default for security)
140    pub local_infile: bool,
141    /// Max allowed packet size (default: 64MB)
142    pub max_packet_size: u32,
143}
144
145impl Default for MySqlConfig {
146    fn default() -> Self {
147        Self {
148            host: "localhost".to_string(),
149            port: 3306,
150            user: String::new(),
151            password: None,
152            database: None,
153            charset: crate::protocol::charset::UTF8MB4_0900_AI_CI,
154            connect_timeout: Duration::from_secs(30),
155            ssl_mode: SslMode::default(),
156            tls_config: TlsConfig::default(),
157            compression: false,
158            attributes: HashMap::new(),
159            local_infile: false,
160            max_packet_size: 64 * 1024 * 1024, // 64MB
161        }
162    }
163}
164
165impl MySqlConfig {
166    /// Create a new configuration with default values.
167    pub fn new() -> Self {
168        Self::default()
169    }
170
171    /// Set the hostname.
172    pub fn host(mut self, host: impl Into<String>) -> Self {
173        self.host = host.into();
174        self
175    }
176
177    /// Set the port.
178    pub fn port(mut self, port: u16) -> Self {
179        self.port = port;
180        self
181    }
182
183    /// Set the username.
184    pub fn user(mut self, user: impl Into<String>) -> Self {
185        self.user = user.into();
186        self
187    }
188
189    /// Set the password.
190    pub fn password(mut self, password: impl Into<String>) -> Self {
191        self.password = Some(password.into());
192        self
193    }
194
195    /// Set the database.
196    pub fn database(mut self, database: impl Into<String>) -> Self {
197        self.database = Some(database.into());
198        self
199    }
200
201    /// Set the character set.
202    pub fn charset(mut self, charset: u8) -> Self {
203        self.charset = charset;
204        self
205    }
206
207    /// Set the connection timeout.
208    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
209        self.connect_timeout = timeout;
210        self
211    }
212
213    /// Set the SSL mode.
214    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
215        self.ssl_mode = mode;
216        self
217    }
218
219    /// Set the TLS configuration.
220    pub fn tls_config(mut self, config: TlsConfig) -> Self {
221        self.tls_config = config;
222        self
223    }
224
225    /// Set the CA certificate path for TLS.
226    ///
227    /// This is a convenience method equivalent to:
228    /// ```ignore
229    /// config.tls_config(TlsConfig::new().ca_cert(path))
230    /// ```
231    pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
232        self.tls_config.ca_cert_path = Some(path.into());
233        self
234    }
235
236    /// Set client certificate and key paths for mutual TLS.
237    ///
238    /// Both cert and key must be provided for client authentication.
239    pub fn client_cert(
240        mut self,
241        cert_path: impl Into<PathBuf>,
242        key_path: impl Into<PathBuf>,
243    ) -> Self {
244        self.tls_config.client_cert_path = Some(cert_path.into());
245        self.tls_config.client_key_path = Some(key_path.into());
246        self
247    }
248
249    /// Enable or disable compression.
250    pub fn compression(mut self, enabled: bool) -> Self {
251        self.compression = enabled;
252        self
253    }
254
255    /// Set a connection attribute.
256    pub fn attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
257        self.attributes.insert(key.into(), value.into());
258        self
259    }
260
261    /// Enable or disable local infile handling.
262    ///
263    /// # Security Warning
264    /// Enabling local infile can be a security risk. Only enable if you
265    /// trust the server and understand the implications.
266    pub fn local_infile(mut self, enabled: bool) -> Self {
267        self.local_infile = enabled;
268        self
269    }
270
271    /// Set the max allowed packet size.
272    pub fn max_packet_size(mut self, size: u32) -> Self {
273        self.max_packet_size = size;
274        self
275    }
276
277    /// Get the socket address string for connection.
278    pub fn socket_addr(&self) -> String {
279        format!("{}:{}", self.host, self.port)
280    }
281
282    /// Build capability flags based on configuration.
283    pub fn capability_flags(&self) -> u32 {
284        use crate::protocol::capabilities::{
285            CLIENT_COMPRESS, CLIENT_CONNECT_ATTRS, CLIENT_CONNECT_WITH_DB, CLIENT_LOCAL_FILES,
286            CLIENT_SSL, DEFAULT_CLIENT_FLAGS,
287        };
288
289        let mut flags = DEFAULT_CLIENT_FLAGS;
290
291        if self.database.is_some() {
292            flags |= CLIENT_CONNECT_WITH_DB;
293        }
294
295        if self.ssl_mode.should_try_ssl() {
296            flags |= CLIENT_SSL;
297        }
298
299        if self.compression {
300            flags |= CLIENT_COMPRESS;
301        }
302
303        if self.local_infile {
304            flags |= CLIENT_LOCAL_FILES;
305        }
306
307        if !self.attributes.is_empty() {
308            flags |= CLIENT_CONNECT_ATTRS;
309        }
310
311        flags
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_config_builder() {
321        let config = MySqlConfig::new()
322            .host("db.example.com")
323            .port(3307)
324            .user("myuser")
325            .password("secret")
326            .database("testdb")
327            .connect_timeout(Duration::from_secs(10))
328            .ssl_mode(SslMode::Required)
329            .compression(true)
330            .attribute("program_name", "myapp");
331
332        assert_eq!(config.host, "db.example.com");
333        assert_eq!(config.port, 3307);
334        assert_eq!(config.user, "myuser");
335        assert_eq!(config.password, Some("secret".to_string()));
336        assert_eq!(config.database, Some("testdb".to_string()));
337        assert_eq!(config.connect_timeout, Duration::from_secs(10));
338        assert_eq!(config.ssl_mode, SslMode::Required);
339        assert!(config.compression);
340        assert_eq!(
341            config.attributes.get("program_name"),
342            Some(&"myapp".to_string())
343        );
344    }
345
346    #[test]
347    fn test_socket_addr() {
348        let config = MySqlConfig::new().host("db.example.com").port(3307);
349        assert_eq!(config.socket_addr(), "db.example.com:3307");
350    }
351
352    #[test]
353    fn test_ssl_mode_properties() {
354        assert!(!SslMode::Disable.should_try_ssl());
355        assert!(!SslMode::Disable.is_required());
356
357        assert!(SslMode::Preferred.should_try_ssl());
358        assert!(!SslMode::Preferred.is_required());
359
360        assert!(SslMode::Required.should_try_ssl());
361        assert!(SslMode::Required.is_required());
362
363        assert!(SslMode::VerifyCa.should_try_ssl());
364        assert!(SslMode::VerifyCa.is_required());
365
366        assert!(SslMode::VerifyIdentity.should_try_ssl());
367        assert!(SslMode::VerifyIdentity.is_required());
368    }
369
370    #[test]
371    fn test_capability_flags() {
372        use crate::protocol::capabilities::*;
373
374        let config = MySqlConfig::new().database("test").compression(true);
375        let flags = config.capability_flags();
376
377        assert!(flags & CLIENT_CONNECT_WITH_DB != 0);
378        assert!(flags & CLIENT_COMPRESS != 0);
379        assert!(flags & CLIENT_PROTOCOL_41 != 0);
380        assert!(flags & CLIENT_SECURE_CONNECTION != 0);
381    }
382
383    #[test]
384    fn test_default_config() {
385        let config = MySqlConfig::default();
386
387        assert_eq!(config.host, "localhost");
388        assert_eq!(config.port, 3306);
389        assert_eq!(config.ssl_mode, SslMode::Disable);
390        assert!(!config.compression);
391        assert!(!config.local_infile);
392    }
393
394    #[test]
395    fn test_tls_config_builder() {
396        let tls = TlsConfig::new()
397            .ca_cert("/path/to/ca.pem")
398            .client_cert("/path/to/client.pem")
399            .client_key("/path/to/client-key.pem")
400            .server_name("db.example.com");
401
402        assert_eq!(tls.ca_cert_path, Some(PathBuf::from("/path/to/ca.pem")));
403        assert_eq!(
404            tls.client_cert_path,
405            Some(PathBuf::from("/path/to/client.pem"))
406        );
407        assert_eq!(
408            tls.client_key_path,
409            Some(PathBuf::from("/path/to/client-key.pem"))
410        );
411        assert_eq!(tls.server_name, Some("db.example.com".to_string()));
412        assert!(!tls.danger_skip_verify);
413        assert!(tls.has_client_cert());
414    }
415
416    #[test]
417    fn test_tls_config_skip_verify() {
418        let tls = TlsConfig::new().skip_verify(true);
419        assert!(tls.danger_skip_verify);
420    }
421
422    #[test]
423    fn test_mysql_config_with_tls() {
424        let config = MySqlConfig::new()
425            .host("db.example.com")
426            .ssl_mode(SslMode::VerifyCa)
427            .ca_cert("/etc/ssl/certs/ca.pem")
428            .client_cert(
429                "/home/user/.mysql/client-cert.pem",
430                "/home/user/.mysql/client-key.pem",
431            );
432
433        assert_eq!(config.ssl_mode, SslMode::VerifyCa);
434        assert_eq!(
435            config.tls_config.ca_cert_path,
436            Some(PathBuf::from("/etc/ssl/certs/ca.pem"))
437        );
438        assert!(config.tls_config.has_client_cert());
439    }
440
441    #[test]
442    fn test_tls_config_no_client_cert() {
443        let tls = TlsConfig::new().ca_cert("/path/to/ca.pem");
444        assert!(!tls.has_client_cert());
445
446        // Only cert, no key
447        let tls = TlsConfig::new()
448            .ca_cert("/path/to/ca.pem")
449            .client_cert("/path/to/client.pem");
450        assert!(!tls.has_client_cert());
451    }
452}