1use std::borrow::Cow;
16#[cfg(feature = "tls")]
17use std::convert;
18use std::fmt;
19#[cfg(feature = "tls")]
20use std::fmt::Formatter;
21use std::str::FromStr;
22use std::sync::Arc;
23use std::sync::Mutex;
24use std::time::Duration;
25
26#[cfg(feature = "tls")]
27use tokio_native_tls::native_tls;
28use url::Url;
29
30use crate::errors::Error;
31use crate::errors::Result;
32use crate::errors::UrlError;
33
34const DEFAULT_MIN_CONNS: usize = 10;
35
36const DEFAULT_MAX_CONNS: usize = 20;
37
38#[derive(Debug)]
39#[allow(clippy::large_enum_variant)]
40enum State {
41 Raw(Options),
42 Url(String),
43}
44
45#[derive(Clone)]
46pub struct OptionsSource {
47 state: Arc<Mutex<State>>,
48}
49
50impl fmt::Debug for OptionsSource {
51 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52 let guard = self.state.lock().unwrap();
53 match &*guard {
54 State::Url(ref url) => write!(f, "Url({})", url),
55 State::Raw(ref options) => write!(f, "{:?}", options),
56 }
57 }
58}
59
60#[allow(dead_code)]
61impl OptionsSource {
62 pub(crate) fn get(&self) -> Result<Cow<Options>> {
63 let mut state = self.state.lock().unwrap();
64 loop {
65 let new_state = match &*state {
66 State::Raw(ref options) => {
67 let ptr = options as *const Options;
68 return unsafe { Ok(Cow::Borrowed(ptr.as_ref().unwrap())) };
69 }
70 State::Url(url) => {
71 let options = from_url(url)?;
72 State::Raw(options)
73 }
74 };
75 *state = new_state;
76 }
77 }
78}
79
80impl Default for OptionsSource {
81 fn default() -> Self {
82 Self {
83 state: Arc::new(Mutex::new(State::Raw(Options::default()))),
84 }
85 }
86}
87
88pub trait IntoOptions {
89 fn into_options_src(self) -> OptionsSource;
90}
91
92impl IntoOptions for Options {
93 fn into_options_src(self) -> OptionsSource {
94 OptionsSource {
95 state: Arc::new(Mutex::new(State::Raw(self))),
96 }
97 }
98}
99
100impl IntoOptions for &str {
101 fn into_options_src(self) -> OptionsSource {
102 OptionsSource {
103 state: Arc::new(Mutex::new(State::Url(self.into()))),
104 }
105 }
106}
107
108impl IntoOptions for String {
109 fn into_options_src(self) -> OptionsSource {
110 OptionsSource {
111 state: Arc::new(Mutex::new(State::Url(self))),
112 }
113 }
114}
115
116#[cfg(feature = "tls")]
118#[derive(Clone)]
119pub struct Certificate(Arc<native_tls::Certificate>);
120
121#[cfg(feature = "tls")]
122impl Certificate {
123 pub fn from_der(der: &[u8]) -> Result<Certificate> {
125 let inner = match native_tls::Certificate::from_der(der) {
126 Ok(certificate) => certificate,
127 Err(err) => return Err(Error::Other(err.to_string().into())),
128 };
129 Ok(Certificate(Arc::new(inner)))
130 }
131
132 pub fn from_pem(der: &[u8]) -> Result<Certificate> {
134 let inner = match native_tls::Certificate::from_pem(der) {
135 Ok(certificate) => certificate,
136 Err(err) => return Err(Error::Other(err.to_string().into())),
137 };
138 Ok(Certificate(Arc::new(inner)))
139 }
140}
141
142#[cfg(feature = "tls")]
143impl fmt::Debug for Certificate {
144 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
145 write!(f, "[Certificate]")
146 }
147}
148
149#[cfg(feature = "tls")]
150impl PartialEq for Certificate {
151 fn eq(&self, _other: &Self) -> bool {
152 true
153 }
154}
155
156#[cfg(feature = "tls")]
157impl convert::From<Certificate> for native_tls::Certificate {
158 fn from(value: Certificate) -> Self {
159 value.0.as_ref().clone()
160 }
161}
162
163#[allow(clippy::derive_partial_eq_without_eq)]
167#[derive(Clone, PartialEq)]
168pub struct Options {
169 pub addr: Url,
171
172 pub database: String,
174 pub username: String,
176 pub password: String,
178
179 pub compression: bool,
181
182 pub pool_min: usize,
184 pub pool_max: usize,
186
187 pub nodelay: bool,
189 pub keepalive: Option<Duration>,
191
192 pub ping_before_query: bool,
194 pub send_retries: usize,
196 pub retry_timeout: Duration,
198 pub ping_timeout: Duration,
200
201 pub connection_timeout: Duration,
203
204 pub query_timeout: Duration,
206
207 pub insert_timeout: Option<Duration>,
209
210 pub execute_timeout: Option<Duration>,
212
213 #[cfg(feature = "tls")]
215 pub secure: bool,
216
217 #[cfg(feature = "tls")]
219 pub skip_verify: bool,
220
221 #[cfg(feature = "tls")]
223 pub certificate: Option<Certificate>,
224
225 pub readonly: Option<u8>,
227
228 pub alt_hosts: Vec<Url>,
230}
231
232impl fmt::Debug for Options {
233 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
234 f.debug_struct("Options")
235 .field("addr", &self.addr)
236 .field("database", &self.database)
237 .field("compression", &self.compression)
238 .field("pool_min", &self.pool_min)
239 .field("pool_max", &self.pool_max)
240 .field("nodelay", &self.nodelay)
241 .field("keepalive", &self.keepalive)
242 .field("ping_before_query", &self.ping_before_query)
243 .field("send_retries", &self.send_retries)
244 .field("retry_timeout", &self.retry_timeout)
245 .field("ping_timeout", &self.ping_timeout)
246 .field("connection_timeout", &self.connection_timeout)
247 .field("readonly", &self.readonly)
248 .field("alt_hosts", &self.alt_hosts)
249 .finish()
250 }
251}
252
253impl Default for Options {
254 fn default() -> Self {
255 Self {
256 addr: Url::parse("tcp://default@127.0.0.1:9000").unwrap(),
257 database: "default".into(),
258 username: "default".into(),
259 password: "".into(),
260 compression: false,
261 pool_min: DEFAULT_MIN_CONNS,
262 pool_max: DEFAULT_MAX_CONNS,
263 nodelay: true,
264 keepalive: None,
265 ping_before_query: true,
266 send_retries: 3,
267 retry_timeout: Duration::from_secs(5),
268 ping_timeout: Duration::from_millis(500),
269 connection_timeout: Duration::from_millis(500),
270 query_timeout: Duration::from_secs(180),
271 insert_timeout: Some(Duration::from_secs(180)),
272 execute_timeout: Some(Duration::from_secs(180)),
273 #[cfg(feature = "tls")]
274 secure: false,
275 #[cfg(feature = "tls")]
276 skip_verify: false,
277 #[cfg(feature = "tls")]
278 certificate: None,
279 readonly: None,
280 alt_hosts: Vec::new(),
281 }
282 }
283}
284
285macro_rules! property {
286 ( $k:ident: $t:ty ) => {
287 pub fn $k(self, $k: $t) -> Self {
288 Self {
289 $k: $k.into(),
290 ..self
291 }
292 }
293 };
294 ( $(#[$attr:meta])* => $k:ident: $t:ty ) => {
295 $(#[$attr])*
296 #[must_use]
297 pub fn $k(self, $k: $t) -> Self {
298 Self {
299 $k: $k.into(),
300 ..self
301 }
302 }
303 }
304}
305
306impl Options {
307 pub fn new<A>(addr: A) -> Self
309 where
310 A: Into<Url>,
311 {
312 Self {
313 addr: addr.into(),
314 ..Self::default()
315 }
316 }
317
318 property! {
319 => database: &str
321 }
322
323 property! {
324 => username: &str
326 }
327
328 property! {
329 => password: &str
331 }
332
333 #[must_use]
335 pub fn with_compression(self) -> Self {
336 Self {
337 compression: true,
338 ..self
339 }
340 }
341
342 property! {
343 => pool_min: usize
345 }
346
347 property! {
348 => pool_max: usize
350 }
351
352 property! {
353 => nodelay: bool
355 }
356
357 property! {
358 => keepalive: Option<Duration>
360 }
361
362 property! {
363 => ping_before_query: bool
365 }
366
367 property! {
368 => send_retries: usize
370 }
371
372 property! {
373 => retry_timeout: Duration
375 }
376
377 property! {
378 => ping_timeout: Duration
380 }
381
382 property! {
383 => connection_timeout: Duration
385 }
386
387 property! {
388 => query_timeout: Duration
390 }
391
392 property! {
393 => insert_timeout: Option<Duration>
395 }
396
397 property! {
398 => execute_timeout: Option<Duration>
400 }
401
402 #[cfg(feature = "tls")]
403 property! {
404 => secure: bool
406 }
407
408 #[cfg(feature = "tls")]
409 property! {
410 => skip_verify: bool
412 }
413
414 #[cfg(feature = "tls")]
415 property! {
416 => certificate: Option<Certificate>
418 }
419
420 property! {
421 => readonly: Option<u8>
423 }
424
425 property! {
426 => alt_hosts: Vec<Url>
428 }
429}
430
431impl FromStr for Options {
432 type Err = Error;
433
434 fn from_str(url: &str) -> Result<Self> {
435 from_url(url)
436 }
437}
438
439pub fn from_url(url_str: &str) -> Result<Options> {
440 let url = Url::parse(url_str)?;
441
442 if url.scheme() != "tcp" {
443 return Err(UrlError::UnsupportedScheme {
444 scheme: url.scheme().to_string(),
445 }
446 .into());
447 }
448
449 if url.cannot_be_a_base() || !url.has_host() {
450 return Err(UrlError::Invalid.into());
451 }
452
453 let mut options = Options::default();
454
455 if let Some(username) = get_username_from_url(&url) {
456 options.username = username.into();
457 }
458
459 if let Some(password) = get_password_from_url(&url) {
460 options.password = password.into()
461 }
462
463 let mut addr = url.clone();
464 addr.set_path("");
465 addr.set_query(None);
466
467 let port = url.port().or(Some(9000));
468 addr.set_port(port).map_err(|_| UrlError::Invalid)?;
469 options.addr = addr;
470
471 if let Some(database) = get_database_from_url(&url)? {
472 options.database = database.into();
473 }
474
475 set_params(&mut options, url.query_pairs())?;
476
477 Ok(options)
478}
479
480fn set_params<'a, I>(options: &mut Options, iter: I) -> std::result::Result<(), UrlError>
481where
482 I: Iterator<Item = (Cow<'a, str>, Cow<'a, str>)>,
483{
484 for (key, value) in iter {
485 match key.as_ref() {
486 "pool_min" => options.pool_min = parse_param(key, value, usize::from_str)?,
487 "pool_max" => options.pool_max = parse_param(key, value, usize::from_str)?,
488 "nodelay" => options.nodelay = parse_param(key, value, bool::from_str)?,
489 "keepalive" => options.keepalive = parse_param(key, value, parse_opt_duration)?,
490 "ping_before_query" => {
491 options.ping_before_query = parse_param(key, value, bool::from_str)?
492 }
493 "send_retries" => options.send_retries = parse_param(key, value, usize::from_str)?,
494 "retry_timeout" => options.retry_timeout = parse_param(key, value, parse_duration)?,
495 "ping_timeout" => options.ping_timeout = parse_param(key, value, parse_duration)?,
496 "connection_timeout" => {
497 options.connection_timeout = parse_param(key, value, parse_duration)?
498 }
499 "query_timeout" => options.query_timeout = parse_param(key, value, parse_duration)?,
500 "insert_timeout" => {
501 options.insert_timeout = parse_param(key, value, parse_opt_duration)?
502 }
503 "execute_timeout" => {
504 options.execute_timeout = parse_param(key, value, parse_opt_duration)?
505 }
506 "compression" => options.compression = parse_param(key, value, parse_compression)?,
507 #[cfg(feature = "tls")]
508 "secure" => options.secure = parse_param(key, value, bool::from_str)?,
509 #[cfg(feature = "tls")]
510 "skip_verify" => options.skip_verify = parse_param(key, value, bool::from_str)?,
511 "readonly" => options.readonly = parse_param(key, value, parse_opt_u8)?,
512 "alt_hosts" => options.alt_hosts = parse_param(key, value, parse_hosts)?,
513 _ => return Err(UrlError::UnknownParameter { param: key.into() }),
514 };
515 }
516
517 Ok(())
518}
519
520fn parse_param<'a, F, T, E>(
521 param: Cow<'a, str>,
522 value: Cow<'a, str>,
523 parse: F,
524) -> std::result::Result<T, UrlError>
525where
526 F: Fn(&str) -> std::result::Result<T, E>,
527{
528 match parse(value.as_ref()) {
529 Ok(value) => Ok(value),
530 Err(_) => Err(UrlError::InvalidParamValue {
531 param: param.into(),
532 value: value.into(),
533 }),
534 }
535}
536
537pub fn get_username_from_url(url: &Url) -> Option<&str> {
538 let user = url.username();
539 if user.is_empty() {
540 return None;
541 }
542 Some(user)
543}
544
545pub fn get_password_from_url(url: &Url) -> Option<&str> {
546 url.password()
547}
548
549pub fn get_database_from_url(url: &Url) -> Result<Option<&str>> {
550 match url.path_segments() {
551 None => Ok(None),
552 Some(mut segments) => {
553 let head = segments.next();
554
555 if segments.next().is_some() {
556 return Err(Error::Url(UrlError::Invalid));
557 }
558
559 match head {
560 Some(database) if !database.is_empty() => Ok(Some(database)),
561 _ => Ok(None),
562 }
563 }
564 }
565}
566
567#[allow(clippy::result_unit_err)]
568pub fn parse_duration(source: &str) -> std::result::Result<Duration, ()> {
569 let digits_count = source.chars().take_while(|c| c.is_ascii_digit()).count();
570
571 let left: String = source.chars().take(digits_count).collect();
572 let right: String = source.chars().skip(digits_count).collect();
573
574 let num = match u64::from_str(&left) {
575 Ok(value) => value,
576 Err(_) => return Err(()),
577 };
578
579 match right.as_str() {
580 "s" => Ok(Duration::from_secs(num)),
581 "ms" => Ok(Duration::from_millis(num)),
582 _ => Err(()),
583 }
584}
585
586#[allow(clippy::result_unit_err)]
587pub fn parse_opt_duration(source: &str) -> std::result::Result<Option<Duration>, ()> {
588 if source == "none" {
589 return Ok(None);
590 }
591
592 let duration = parse_duration(source)?;
593 Ok(Some(duration))
594}
595
596#[allow(clippy::result_unit_err)]
597pub fn parse_opt_u8(source: &str) -> std::result::Result<Option<u8>, ()> {
598 if source == "none" {
599 return Ok(None);
600 }
601
602 let duration: u8 = match source.parse() {
603 Ok(value) => value,
604 Err(_) => return Err(()),
605 };
606
607 Ok(Some(duration))
608}
609
610#[allow(clippy::result_unit_err)]
611pub fn parse_compression(source: &str) -> std::result::Result<bool, ()> {
612 match source {
613 "none" => Ok(false),
614 "lz4" => Ok(true),
615 _ => Err(()),
616 }
617}
618
619#[allow(clippy::result_unit_err)]
620pub fn parse_hosts(source: &str) -> std::result::Result<Vec<Url>, ()> {
621 let mut result = Vec::new();
622 for host in source.split(',') {
623 match Url::from_str(&format!("tcp://{}", host)) {
624 Ok(url) => result.push(url),
625 Err(_) => return Err(()),
626 }
627 }
628 Ok(result)
629}