1#[cfg(feature = "tls")]
2use std::convert;
3
4use std::{
5 borrow::Cow,
6 fmt,
7 str::FromStr,
8 sync::{Arc, Mutex},
9 time::Duration,
10 fmt::Debug,
11};
12
13use crate::errors::{Error, Result, UrlError};
14use url::Url;
15
16const DEFAULT_MIN_CONNS: usize = 10;
17
18const DEFAULT_MAX_CONNS: usize = 20;
19
20#[derive(Debug)]
21#[allow(clippy::large_enum_variant)]
22enum State {
23 Raw(Options),
24 Url(String),
25}
26
27#[derive(Clone)]
28pub struct OptionsSource {
29 state: Arc<Mutex<State>>,
30}
31
32impl fmt::Debug for OptionsSource {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 let guard = self.state.lock().unwrap();
35 match *guard {
36 State::Url(ref url) => write!(f, "Url({})", url),
37 State::Raw(ref options) => write!(f, "{:?}", options),
38 }
39 }
40}
41
42impl OptionsSource {
43 pub(crate) fn get(&self) -> Result<Cow<Options>> {
44 let mut state = self.state.lock().unwrap();
45 loop {
46 let new_state;
47 match &*state {
48 State::Raw(ref options) => {
49 let ptr = options as *const Options;
50 return unsafe { Ok(Cow::Borrowed(ptr.as_ref().unwrap())) };
51 }
52 State::Url(url) => {
53 let options = from_url(&url)?;
54 new_state = State::Raw(options);
55 }
56 }
57 *state = new_state;
58 }
59 }
60}
61
62impl Default for OptionsSource {
63 fn default() -> Self {
64 Self {
65 state: Arc::new(Mutex::new(State::Raw(Options::default()))),
66 }
67 }
68}
69
70pub trait IntoOptions {
71 fn into_options_src(self) -> OptionsSource;
72}
73
74impl IntoOptions for Options {
75 fn into_options_src(self) -> OptionsSource {
76 OptionsSource {
77 state: Arc::new(Mutex::new(State::Raw(self))),
78 }
79 }
80}
81
82impl IntoOptions for &str {
83 fn into_options_src(self) -> OptionsSource {
84 OptionsSource {
85 state: Arc::new(Mutex::new(State::Url(self.into()))),
86 }
87 }
88}
89
90impl IntoOptions for String {
91 fn into_options_src(self) -> OptionsSource {
92 OptionsSource {
93 state: Arc::new(Mutex::new(State::Url(self))),
94 }
95 }
96}
97
98#[cfg(feature = "tls")]
100#[derive(Clone)]
101pub struct Certificate(Arc<native_tls::Certificate>);
102
103#[cfg(feature = "tls")]
104impl Certificate {
105 pub fn from_der(der: &[u8]) -> Result<Certificate> {
107 let inner = match native_tls::Certificate::from_der(der) {
108 Ok(certificate) => certificate,
109 Err(err) => return Err(Error::Other(err.to_string().into())),
110 };
111 Ok(Certificate(Arc::new(inner)))
112 }
113
114 pub fn from_pem(der: &[u8]) -> Result<Certificate> {
116 let inner = match native_tls::Certificate::from_pem(der) {
117 Ok(certificate) => certificate,
118 Err(err) => return Err(Error::Other(err.to_string().into())),
119 };
120 Ok(Certificate(Arc::new(inner)))
121 }
122}
123
124#[cfg(feature = "tls")]
125impl fmt::Debug for Certificate {
126 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
127 write!(f, "[Certificate]")
128 }
129}
130
131#[cfg(feature = "tls")]
132impl PartialEq for Certificate {
133 fn eq(&self, _other: &Self) -> bool {
134 true
135 }
136}
137
138#[cfg(feature = "tls")]
139impl convert::From<Certificate> for native_tls::Certificate {
140 fn from(value: Certificate) -> Self {
141 value.0.as_ref().clone()
142 }
143}
144
145#[derive(Clone, PartialEq)]
147pub struct Options {
148 pub(crate) addr: Url,
150
151 pub(crate) database: String,
153 pub(crate) username: String,
155 pub(crate) password: String,
157
158 pub(crate) compression: bool,
160
161 pub(crate) pool_min: usize,
163 pub(crate) pool_max: usize,
165
166 pub(crate) nodelay: bool,
168 pub(crate) keepalive: Option<Duration>,
170
171 pub(crate) ping_before_query: bool,
173 pub(crate) send_retries: usize,
175 pub(crate) retry_timeout: Duration,
177 pub(crate) ping_timeout: Duration,
179
180 pub(crate) connection_timeout: Duration,
182
183 pub(crate) query_timeout: Option<Duration>,
185
186 pub(crate) query_block_timeout: Option<Duration>,
188
189 pub(crate) insert_timeout: Option<Duration>,
191
192 pub(crate) execute_timeout: Option<Duration>,
194
195 #[cfg(feature = "tls")]
197 pub(crate) secure: bool,
198
199 #[cfg(feature = "tls")]
201 pub(crate) skip_verify: bool,
202
203 #[cfg(feature = "tls")]
205 pub(crate) certificate: Option<Certificate>,
206
207 pub(crate) readonly: Option<u8>,
209
210 pub(crate) alt_hosts: Vec<Url>,
212}
213
214impl fmt::Debug for Options {
215 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216 f.debug_struct("Options")
217 .field("addr", &self.addr)
218 .field("database", &self.database)
219 .field("compression", &self.compression)
220 .field("pool_min", &self.pool_min)
221 .field("pool_max", &self.pool_max)
222 .field("nodelay", &self.nodelay)
223 .field("keepalive", &self.keepalive)
224 .field("ping_before_query", &self.ping_before_query)
225 .field("send_retries", &self.send_retries)
226 .field("retry_timeout", &self.retry_timeout)
227 .field("ping_timeout", &self.ping_timeout)
228 .field("connection_timeout", &self.connection_timeout)
229 .field("query_timeout", &self.query_timeout)
230 .field("query_block_timeout", &self.query_block_timeout)
231 .field("insert_timeout", &self.insert_timeout)
232 .field("execute_timeout", &self.execute_timeout)
233 .field("readonly", &self.readonly)
234 .field("alt_hosts", &self.alt_hosts)
235 .finish()
236 }
237}
238
239impl Default for Options {
240 fn default() -> Self {
241 Self {
242 addr: Url::parse("tcp://default@127.0.0.1:9000").unwrap(),
243 database: "default".into(),
244 username: "default".into(),
245 password: "".into(),
246 compression: false,
247 pool_min: DEFAULT_MIN_CONNS,
248 pool_max: DEFAULT_MAX_CONNS,
249 nodelay: true,
250 keepalive: None,
251 ping_before_query: true,
252 send_retries: 3,
253 retry_timeout: Duration::from_secs(5),
254 ping_timeout: Duration::from_millis(500),
255 connection_timeout: Duration::from_millis(500),
256 query_timeout: Some(Duration::from_secs(180)),
257 query_block_timeout: Some(Duration::from_secs(180)),
258 insert_timeout: Some(Duration::from_secs(180)),
259 execute_timeout: Some(Duration::from_secs(180)),
260 #[cfg(feature = "tls")]
261 secure: false,
262 #[cfg(feature = "tls")]
263 skip_verify: false,
264 #[cfg(feature = "tls")]
265 certificate: None,
266 readonly: None,
267 alt_hosts: Vec::new(),
268 }
269 }
270}
271
272macro_rules! property {
273 ( $k:ident: $t:ty ) => {
274 pub fn $k(self, $k: $t) -> Self {
275 Self {
276 $k: $k.into(),
277 ..self
278 }
279 }
280 };
281 ( $(#[$attr:meta])* => $k:ident: $t:ty ) => {
282 $(#[$attr])*
283 pub fn $k(self, $k: $t) -> Self {
284 Self {
285 $k: $k.into(),
286 ..self
287 }
288 }
289 }
290}
291
292impl Options {
293 pub fn new<A>(addr: A) -> Self
295 where
296 A: Into<Url>,
297 {
298 Self {
299 addr: addr.into(),
300 ..Self::default()
301 }
302 }
303
304 property! {
305 => database: &str
307 }
308
309 property! {
310 => username: &str
312 }
313
314 property! {
315 => password: &str
317 }
318
319 pub fn with_compression(self) -> Self {
321 Self {
322 compression: true,
323 ..self
324 }
325 }
326
327 property! {
328 => pool_min: usize
330 }
331
332 property! {
333 => pool_max: usize
335 }
336
337 property! {
338 => nodelay: bool
340 }
341
342 property! {
343 => keepalive: Option<Duration>
345 }
346
347 property! {
348 => ping_before_query: bool
350 }
351
352 property! {
353 => send_retries: usize
355 }
356
357 property! {
358 => retry_timeout: Duration
360 }
361
362 property! {
363 => ping_timeout: Duration
365 }
366
367 property! {
368 => connection_timeout: Duration
370 }
371
372 property! {
373 => query_timeout: Duration
375 }
376
377 property! {
378 => query_block_timeout: Duration
380 }
381
382 property! {
383 => insert_timeout: Option<Duration>
385 }
386
387 property! {
388 => execute_timeout: Option<Duration>
390 }
391
392 #[cfg(feature = "tls")]
393 property! {
394 => secure: bool
396 }
397
398 #[cfg(feature = "tls")]
399 property! {
400 => skip_verify: bool
402 }
403
404 #[cfg(feature = "tls")]
405 property! {
406 => certificate: Option<Certificate>
408 }
409
410 property! {
411 => readonly: Option<u8>
413 }
414
415 property! {
416 => alt_hosts: Vec<Url>
418 }
419}
420
421impl FromStr for Options {
422 type Err = Error;
423
424 fn from_str(url: &str) -> Result<Self> {
425 from_url(url)
426 }
427}
428
429fn from_url(url_str: &str) -> Result<Options> {
430 let url = Url::parse(url_str)?;
431
432 if url.scheme() != "tcp" {
433 return Err(UrlError::UnsupportedScheme {
434 scheme: url.scheme().to_string(),
435 }
436 .into());
437 }
438
439 if url.cannot_be_a_base() || !url.has_host() {
440 return Err(UrlError::Invalid.into());
441 }
442
443 let mut options = Options::default();
444
445 if let Some(username) = get_username_from_url(&url)? {
446 options.username = username.into();
447 }
448
449 if let Some(password) = get_password_from_url(&url)? {
450 options.password = password.into()
451 }
452
453 let mut addr = url.clone();
454 addr.set_path("");
455 addr.set_query(None);
456
457 let port = url.port().or(Some(9000));
458 addr.set_port(port).map_err(|_| UrlError::Invalid)?;
459 options.addr = addr;
460
461 if let Some(database) = get_database_from_url(&url)? {
462 options.database = database.into();
463 }
464
465 set_params(&mut options, url.query_pairs())?;
466
467 Ok(options)
468}
469
470fn set_params<'a, I>(options: &mut Options, iter: I) -> std::result::Result<(), UrlError>
471where
472 I: Iterator<Item = (Cow<'a, str>, Cow<'a, str>)>,
473{
474 for (key, value) in iter {
475 match key.as_ref() {
476 "pool_min" => options.pool_min = parse_param(key, value, usize::from_str)?,
477 "pool_max" => options.pool_max = parse_param(key, value, usize::from_str)?,
478 "nodelay" => options.nodelay = parse_param(key, value, bool::from_str)?,
479 "keepalive" => options.keepalive = parse_param(key, value, parse_opt_duration)?,
480 "ping_before_query" => {
481 options.ping_before_query = parse_param(key, value, bool::from_str)?
482 }
483 "send_retries" => options.send_retries = parse_param(key, value, usize::from_str)?,
484 "retry_timeout" => options.retry_timeout = parse_param(key, value, parse_duration)?,
485 "ping_timeout" => options.ping_timeout = parse_param(key, value, parse_duration)?,
486 "connection_timeout" => {
487 options.connection_timeout = parse_param(key, value, parse_duration)?
488 }
489 "query_timeout" => options.query_timeout = parse_param(key, value, parse_opt_duration)?,
490 "query_block_timeout" => {
491 options.query_block_timeout = parse_param(key, value, parse_opt_duration)?
492 }
493 "insert_timeout" => {
494 options.insert_timeout = parse_param(key, value, parse_opt_duration)?
495 }
496 "execute_timeout" => {
497 options.execute_timeout = parse_param(key, value, parse_opt_duration)?
498 }
499 "compression" => options.compression = parse_param(key, value, parse_compression)?,
500 #[cfg(feature = "tls")]
501 "secure" => options.secure = parse_param(key, value, bool::from_str)?,
502 #[cfg(feature = "tls")]
503 "skip_verify" => options.skip_verify = parse_param(key, value, bool::from_str)?,
504 "readonly" => options.readonly = parse_param(key, value, parse_opt_u8)?,
505 "alt_hosts" => options.alt_hosts = parse_param(key, value, parse_hosts)?,
506 _ => return Err(UrlError::UnknownParameter { param: key.into() }),
507 };
508 }
509
510 Ok(())
511}
512
513fn parse_param<'a, F, T, E>(
514 param: Cow<'a, str>,
515 value: Cow<'a, str>,
516 parse: F,
517) -> std::result::Result<T, UrlError>
518where
519 F: Fn(&str) -> std::result::Result<T, E>,
520{
521 match parse(value.as_ref()) {
522 Ok(value) => Ok(value),
523 Err(_) => Err(UrlError::InvalidParamValue {
524 param: param.into(),
525 value: value.into(),
526 }),
527 }
528}
529
530fn get_username_from_url(url: &Url) -> Result<Option<&str>> {
531 let user = url.username();
532 if user.is_empty() {
533 return Ok(None);
534 }
535 Ok(Some(user))
536}
537
538fn get_password_from_url(url: &Url) -> Result<Option<&str>> {
539 match url.password() {
540 None => Ok(None),
541 Some(password) => Ok(Some(password)),
542 }
543}
544
545fn get_database_from_url(url: &Url) -> Result<Option<&str>> {
546 match url.path_segments() {
547 None => Ok(None),
548 Some(mut segments) => {
549 let head = segments.next();
550
551 if segments.next().is_some() {
552 return Err(Error::Url(UrlError::Invalid));
553 }
554
555 match head {
556 Some(database) if !database.is_empty() => Ok(Some(database)),
557 _ => Ok(None),
558 }
559 }
560 }
561}
562
563fn parse_duration(source: &str) -> std::result::Result<Duration, ()> {
564 let digits_count = source.chars().take_while(|c| c.is_digit(10)).count();
565
566 let left: String = source.chars().take(digits_count).collect();
567 let right: String = source.chars().skip(digits_count).collect();
568
569 let num = match u64::from_str(&left) {
570 Ok(value) => value,
571 Err(_) => return Err(()),
572 };
573
574 match right.as_str() {
575 "s" => Ok(Duration::from_secs(num)),
576 "ms" => Ok(Duration::from_millis(num)),
577 _ => Err(()),
578 }
579}
580
581fn parse_opt_duration(source: &str) -> std::result::Result<Option<Duration>, ()> {
582 if source == "none" {
583 return Ok(None);
584 }
585
586 let duration = parse_duration(source)?;
587 Ok(Some(duration))
588}
589
590fn parse_opt_u8(source: &str) -> std::result::Result<Option<u8>, ()> {
591 if source == "none" {
592 return Ok(None);
593 }
594
595 let duration: u8 = match source.parse() {
596 Ok(value) => value,
597 Err(_) => return Err(()),
598 };
599
600 Ok(Some(duration))
601}
602
603fn parse_compression(source: &str) -> std::result::Result<bool, ()> {
604 match source {
605 "none" => Ok(false),
606 "lz4" => Ok(true),
607 _ => Err(()),
608 }
609}
610
611fn parse_hosts(source: &str) -> std::result::Result<Vec<Url>, ()> {
612 let mut result = Vec::new();
613 for host in source.split(',') {
614 match Url::from_str(&format!("tcp://{}", host)) {
615 Ok(url) => result.push(url),
616 Err(_) => return Err(()),
617 }
618 }
619 Ok(result)
620}
621
622#[cfg(test)]
623mod test {
624 use super::*;
625
626 #[test]
627 fn test_parse_hosts() {
628 let source = "host2:9000,host3:9000";
629 let expected = vec![
630 Url::from_str("tcp://host2:9000").unwrap(),
631 Url::from_str("tcp://host3:9000").unwrap()];
632 let actual = parse_hosts(source).unwrap();
633 assert_eq!(actual, expected)
634 }
635
636 #[test]
637 fn test_parse_default() {
638 let url = "tcp://host1";
639 let options = from_url(url).unwrap();
640 assert_eq!(options.database, "default");
641 assert_eq!(options.username, "default");
642 assert_eq!(options.password, "");
643 }
644
645 #[test]
646 #[cfg(feature = "tls")]
647 fn test_parse_secure_options() {
648 let url = "tcp://username:password@host1:9001/database?ping_timeout=42ms&keepalive=99s&compression=lz4&connection_timeout=10s&secure=true";
649 assert_eq!(
650 Options {
651 username: "username".into(),
652 password: "password".into(),
653 addr: Url::parse("tcp://username:password@host1:9001").unwrap(),
654 database: "database".into(),
655 keepalive: Some(Duration::from_secs(99)),
656 ping_timeout: Duration::from_millis(42),
657 connection_timeout: Duration::from_secs(10),
658 compression: true,
659 secure: true,
660 ..Options::default()
661 },
662 from_url(url).unwrap(),
663 );
664 }
665
666 #[test]
667 fn test_parse_options() {
668 let url = "tcp://username:password@host1:9001/database?ping_timeout=42ms&keepalive=99s&compression=lz4&connection_timeout=10s";
669 assert_eq!(
670 Options {
671 username: "username".into(),
672 password: "password".into(),
673 addr: Url::parse("tcp://username:password@host1:9001").unwrap(),
674 database: "database".into(),
675 keepalive: Some(Duration::from_secs(99)),
676 ping_timeout: Duration::from_millis(42),
677 connection_timeout: Duration::from_secs(10),
678 compression: true,
679 ..Options::default()
680 },
681 from_url(url).unwrap(),
682 );
683 }
684
685 #[test]
686 #[should_panic]
687 fn test_parse_invalid_url() {
688 let url = "ʘ_ʘ";
689 from_url(url).unwrap();
690 }
691
692 #[test]
693 #[should_panic]
694 fn test_parse_with_unknown_url() {
695 let url = "tcp://localhost:9000/foo?bar=baz";
696 from_url(url).unwrap();
697 }
698
699 #[test]
700 #[should_panic]
701 fn test_parse_with_multi_databases() {
702 let url = "tcp://localhost:9000/foo/bar";
703 from_url(url).unwrap();
704 }
705
706 #[test]
707 fn test_parse_duration() {
708 assert_eq!(parse_duration("3s").unwrap(), Duration::from_secs(3));
709 assert_eq!(parse_duration("123ms").unwrap(), Duration::from_millis(123));
710
711 parse_duration("ms").unwrap_err();
712 parse_duration("1ss").unwrap_err();
713 }
714
715 #[test]
716 fn test_parse_opt_duration() {
717 assert_eq!(
718 parse_opt_duration("3s").unwrap(),
719 Some(Duration::from_secs(3))
720 );
721 assert_eq!(parse_opt_duration("none").unwrap(), None::<Duration>);
722 }
723
724 #[test]
725 fn test_parse_compression() {
726 assert_eq!(parse_compression("none").unwrap(), false);
727 assert_eq!(parse_compression("lz4").unwrap(), true);
728 parse_compression("?").unwrap_err();
729 }
730}