1use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Default)]
15pub struct TlsConfig {
16 pub ca_cert_path: Option<PathBuf>,
19
20 pub client_cert_path: Option<PathBuf>,
23
24 pub client_key_path: Option<PathBuf>,
27
28 pub danger_skip_verify: bool,
35
36 pub server_name: Option<String>,
39}
40
41impl TlsConfig {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
49 self.ca_cert_path = Some(path.into());
50 self
51 }
52
53 pub fn client_cert(mut self, path: impl Into<PathBuf>) -> Self {
55 self.client_cert_path = Some(path.into());
56 self
57 }
58
59 pub fn client_key(mut self, path: impl Into<PathBuf>) -> Self {
61 self.client_key_path = Some(path.into());
62 self
63 }
64
65 pub fn skip_verify(mut self, skip: bool) -> Self {
67 self.danger_skip_verify = skip;
68 self
69 }
70
71 pub fn server_name(mut self, name: impl Into<String>) -> Self {
73 self.server_name = Some(name.into());
74 self
75 }
76
77 pub fn has_client_cert(&self) -> bool {
79 self.client_cert_path.is_some() && self.client_key_path.is_some()
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
85pub enum SslMode {
86 #[default]
88 Disable,
89 Preferred,
91 Required,
93 VerifyCa,
95 VerifyIdentity,
97}
98
99impl SslMode {
100 pub const fn should_try_ssl(self) -> bool {
102 !matches!(self, SslMode::Disable)
103 }
104
105 pub const fn is_required(self) -> bool {
107 matches!(
108 self,
109 SslMode::Required | SslMode::VerifyCa | SslMode::VerifyIdentity
110 )
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct MySqlConfig {
117 pub host: String,
119 pub port: u16,
121 pub user: String,
123 pub password: Option<String>,
125 pub database: Option<String>,
127 pub charset: u8,
129 pub connect_timeout: Duration,
131 pub ssl_mode: SslMode,
133 pub tls_config: TlsConfig,
135 pub compression: bool,
137 pub attributes: HashMap<String, String>,
139 pub local_infile: bool,
141 pub max_packet_size: u32,
143}
144
145impl Default for MySqlConfig {
146 fn default() -> Self {
147 Self {
148 host: "localhost".to_string(),
149 port: 3306,
150 user: String::new(),
151 password: None,
152 database: None,
153 charset: crate::protocol::charset::UTF8MB4_0900_AI_CI,
154 connect_timeout: Duration::from_secs(30),
155 ssl_mode: SslMode::default(),
156 tls_config: TlsConfig::default(),
157 compression: false,
158 attributes: HashMap::new(),
159 local_infile: false,
160 max_packet_size: 64 * 1024 * 1024, }
162 }
163}
164
165impl MySqlConfig {
166 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn host(mut self, host: impl Into<String>) -> Self {
173 self.host = host.into();
174 self
175 }
176
177 pub fn port(mut self, port: u16) -> Self {
179 self.port = port;
180 self
181 }
182
183 pub fn user(mut self, user: impl Into<String>) -> Self {
185 self.user = user.into();
186 self
187 }
188
189 pub fn password(mut self, password: impl Into<String>) -> Self {
191 self.password.replace(password.into());
193 self
194 }
195
196 pub(crate) fn password_str(&self) -> &str {
201 self.password.as_deref().unwrap_or_default()
202 }
203
204 pub(crate) fn password_owned(&self) -> String {
206 self.password.clone().unwrap_or_default()
207 }
208
209 pub fn database(mut self, database: impl Into<String>) -> Self {
211 self.database = Some(database.into());
212 self
213 }
214
215 pub fn charset(mut self, charset: u8) -> Self {
217 self.charset = charset;
218 self
219 }
220
221 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
223 self.connect_timeout = timeout;
224 self
225 }
226
227 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
229 self.ssl_mode = mode;
230 self
231 }
232
233 pub fn tls_config(mut self, config: TlsConfig) -> Self {
235 self.tls_config = config;
236 self
237 }
238
239 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
246 self.tls_config.ca_cert_path = Some(path.into());
247 self
248 }
249
250 pub fn client_cert(
254 mut self,
255 cert_path: impl Into<PathBuf>,
256 key_path: impl Into<PathBuf>,
257 ) -> Self {
258 self.tls_config.client_cert_path = Some(cert_path.into());
259 self.tls_config.client_key_path = Some(key_path.into());
260 self
261 }
262
263 pub fn compression(mut self, enabled: bool) -> Self {
265 self.compression = enabled;
266 self
267 }
268
269 pub fn attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
271 self.attributes.insert(key.into(), value.into());
272 self
273 }
274
275 pub fn local_infile(mut self, enabled: bool) -> Self {
281 self.local_infile = enabled;
282 self
283 }
284
285 pub fn max_packet_size(mut self, size: u32) -> Self {
287 self.max_packet_size = size;
288 self
289 }
290
291 pub fn socket_addr(&self) -> String {
293 format!("{}:{}", self.host, self.port)
294 }
295
296 pub fn capability_flags(&self) -> u32 {
298 use crate::protocol::capabilities::{
299 CLIENT_COMPRESS, CLIENT_CONNECT_ATTRS, CLIENT_CONNECT_WITH_DB, CLIENT_LOCAL_FILES,
300 CLIENT_SSL, DEFAULT_CLIENT_FLAGS,
301 };
302
303 let mut flags = DEFAULT_CLIENT_FLAGS;
304
305 if self.database.is_some() {
306 flags |= CLIENT_CONNECT_WITH_DB;
307 }
308
309 if self.ssl_mode.should_try_ssl() {
310 flags |= CLIENT_SSL;
311 }
312
313 if self.compression {
314 flags |= CLIENT_COMPRESS;
315 }
316
317 if self.local_infile {
318 flags |= CLIENT_LOCAL_FILES;
319 }
320
321 if !self.attributes.is_empty() {
322 flags |= CLIENT_CONNECT_ATTRS;
323 }
324
325 flags
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_config_builder() {
335 let config = MySqlConfig::new()
336 .host("db.example.com")
337 .port(3307)
338 .user("myuser")
339 .password("test")
340 .database("testdb")
341 .connect_timeout(Duration::from_secs(10))
342 .ssl_mode(SslMode::Required)
343 .compression(true)
344 .attribute("program_name", "myapp");
345
346 assert_eq!(config.host, "db.example.com");
347 assert_eq!(config.port, 3307);
348 assert_eq!(config.user, "myuser");
349 assert_eq!(config.password, Some("test".to_string()));
350 assert_eq!(config.database, Some("testdb".to_string()));
351 assert_eq!(config.connect_timeout, Duration::from_secs(10));
352 assert_eq!(config.ssl_mode, SslMode::Required);
353 assert!(config.compression);
354 assert_eq!(
355 config.attributes.get("program_name"),
356 Some(&"myapp".to_string())
357 );
358 }
359
360 #[test]
361 fn test_socket_addr() {
362 let config = MySqlConfig::new().host("db.example.com").port(3307);
363 assert_eq!(config.socket_addr(), "db.example.com:3307");
364 }
365
366 #[test]
367 fn test_ssl_mode_properties() {
368 assert!(!SslMode::Disable.should_try_ssl());
369 assert!(!SslMode::Disable.is_required());
370
371 assert!(SslMode::Preferred.should_try_ssl());
372 assert!(!SslMode::Preferred.is_required());
373
374 assert!(SslMode::Required.should_try_ssl());
375 assert!(SslMode::Required.is_required());
376
377 assert!(SslMode::VerifyCa.should_try_ssl());
378 assert!(SslMode::VerifyCa.is_required());
379
380 assert!(SslMode::VerifyIdentity.should_try_ssl());
381 assert!(SslMode::VerifyIdentity.is_required());
382 }
383
384 #[test]
385 fn test_capability_flags() {
386 use crate::protocol::capabilities::*;
387
388 let config = MySqlConfig::new().database("test").compression(true);
389 let flags = config.capability_flags();
390
391 assert!(flags & CLIENT_CONNECT_WITH_DB != 0);
392 assert!(flags & CLIENT_COMPRESS != 0);
393 assert!(flags & CLIENT_PROTOCOL_41 != 0);
394 assert!(flags & CLIENT_SECURE_CONNECTION != 0);
395 }
396
397 #[test]
398 fn test_default_config() {
399 let config = MySqlConfig::default();
400
401 assert_eq!(config.host, "localhost");
402 assert_eq!(config.port, 3306);
403 assert_eq!(config.ssl_mode, SslMode::Disable);
404 assert!(!config.compression);
405 assert!(!config.local_infile);
406 }
407
408 #[test]
409 fn test_tls_config_builder() {
410 let tls = TlsConfig::new()
411 .ca_cert("/path/to/ca.pem")
412 .client_cert("/path/to/client.pem")
413 .client_key("/path/to/client-key.pem")
414 .server_name("db.example.com");
415
416 assert_eq!(tls.ca_cert_path, Some(PathBuf::from("/path/to/ca.pem")));
417 assert_eq!(
418 tls.client_cert_path,
419 Some(PathBuf::from("/path/to/client.pem"))
420 );
421 assert_eq!(
422 tls.client_key_path,
423 Some(PathBuf::from("/path/to/client-key.pem"))
424 );
425 assert_eq!(tls.server_name, Some("db.example.com".to_string()));
426 assert!(!tls.danger_skip_verify);
427 assert!(tls.has_client_cert());
428 }
429
430 #[test]
431 fn test_tls_config_skip_verify() {
432 let tls = TlsConfig::new().skip_verify(true);
433 assert!(tls.danger_skip_verify);
434 }
435
436 #[test]
437 fn test_mysql_config_with_tls() {
438 let config = MySqlConfig::new()
439 .host("db.example.com")
440 .ssl_mode(SslMode::VerifyCa)
441 .ca_cert("/etc/ssl/certs/ca.pem")
442 .client_cert(
443 "/home/user/.mysql/client-cert.pem",
444 "/home/user/.mysql/client-key.pem",
445 );
446
447 assert_eq!(config.ssl_mode, SslMode::VerifyCa);
448 assert_eq!(
449 config.tls_config.ca_cert_path,
450 Some(PathBuf::from("/etc/ssl/certs/ca.pem"))
451 );
452 assert!(config.tls_config.has_client_cert());
453 }
454
455 #[test]
456 fn test_tls_config_no_client_cert() {
457 let tls = TlsConfig::new().ca_cert("/path/to/ca.pem");
458 assert!(!tls.has_client_cert());
459
460 let tls = TlsConfig::new()
462 .ca_cert("/path/to/ca.pem")
463 .client_cert("/path/to/client.pem");
464 assert!(!tls.has_client_cert());
465 }
466}