opensrv_clickhouse/types/
options.rs

1// Copyright 2021 Datafuse Labs.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16#[cfg(feature = "tls")]
17use std::convert;
18use std::fmt;
19#[cfg(feature = "tls")]
20use std::fmt::Formatter;
21use std::str::FromStr;
22use std::sync::Arc;
23use std::sync::Mutex;
24use std::time::Duration;
25
26#[cfg(feature = "tls")]
27use tokio_native_tls::native_tls;
28use url::Url;
29
30use crate::errors::Error;
31use crate::errors::Result;
32use crate::errors::UrlError;
33
34const DEFAULT_MIN_CONNS: usize = 10;
35
36const DEFAULT_MAX_CONNS: usize = 20;
37
38#[derive(Debug)]
39#[allow(clippy::large_enum_variant)]
40enum State {
41    Raw(Options),
42    Url(String),
43}
44
45#[derive(Clone)]
46pub struct OptionsSource {
47    state: Arc<Mutex<State>>,
48}
49
50impl fmt::Debug for OptionsSource {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        let guard = self.state.lock().unwrap();
53        match &*guard {
54            State::Url(ref url) => write!(f, "Url({})", url),
55            State::Raw(ref options) => write!(f, "{:?}", options),
56        }
57    }
58}
59
60#[allow(dead_code)]
61impl OptionsSource {
62    pub(crate) fn get(&self) -> Result<Cow<Options>> {
63        let mut state = self.state.lock().unwrap();
64        loop {
65            let new_state = match &*state {
66                State::Raw(ref options) => {
67                    let ptr = options as *const Options;
68                    return unsafe { Ok(Cow::Borrowed(ptr.as_ref().unwrap())) };
69                }
70                State::Url(url) => {
71                    let options = from_url(url)?;
72                    State::Raw(options)
73                }
74            };
75            *state = new_state;
76        }
77    }
78}
79
80impl Default for OptionsSource {
81    fn default() -> Self {
82        Self {
83            state: Arc::new(Mutex::new(State::Raw(Options::default()))),
84        }
85    }
86}
87
88pub trait IntoOptions {
89    fn into_options_src(self) -> OptionsSource;
90}
91
92impl IntoOptions for Options {
93    fn into_options_src(self) -> OptionsSource {
94        OptionsSource {
95            state: Arc::new(Mutex::new(State::Raw(self))),
96        }
97    }
98}
99
100impl IntoOptions for &str {
101    fn into_options_src(self) -> OptionsSource {
102        OptionsSource {
103            state: Arc::new(Mutex::new(State::Url(self.into()))),
104        }
105    }
106}
107
108impl IntoOptions for String {
109    fn into_options_src(self) -> OptionsSource {
110        OptionsSource {
111            state: Arc::new(Mutex::new(State::Url(self))),
112        }
113    }
114}
115
116/// An X509 certificate.
117#[cfg(feature = "tls")]
118#[derive(Clone)]
119pub struct Certificate(Arc<native_tls::Certificate>);
120
121#[cfg(feature = "tls")]
122impl Certificate {
123    /// Parses a DER-formatted X509 certificate.
124    pub fn from_der(der: &[u8]) -> Result<Certificate> {
125        let inner = match native_tls::Certificate::from_der(der) {
126            Ok(certificate) => certificate,
127            Err(err) => return Err(Error::Other(err.to_string().into())),
128        };
129        Ok(Certificate(Arc::new(inner)))
130    }
131
132    /// Parses a PEM-formatted X509 certificate.
133    pub fn from_pem(der: &[u8]) -> Result<Certificate> {
134        let inner = match native_tls::Certificate::from_pem(der) {
135            Ok(certificate) => certificate,
136            Err(err) => return Err(Error::Other(err.to_string().into())),
137        };
138        Ok(Certificate(Arc::new(inner)))
139    }
140}
141
142#[cfg(feature = "tls")]
143impl fmt::Debug for Certificate {
144    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
145        write!(f, "[Certificate]")
146    }
147}
148
149#[cfg(feature = "tls")]
150impl PartialEq for Certificate {
151    fn eq(&self, _other: &Self) -> bool {
152        true
153    }
154}
155
156#[cfg(feature = "tls")]
157impl convert::From<Certificate> for native_tls::Certificate {
158    fn from(value: Certificate) -> Self {
159        value.0.as_ref().clone()
160    }
161}
162
163/// Clickhouse connection options.
164// the trait `std::cmp::Eq` is not implemented for `types::options::Certificate`
165// so we can't use `derive(Eq)` in tls feature.
166#[allow(clippy::derive_partial_eq_without_eq)]
167#[derive(Clone, PartialEq)]
168pub struct Options {
169    /// Address of clickhouse server (defaults to `127.0.0.1:9000`).
170    pub addr: Url,
171
172    /// Database name. (defaults to `default`).
173    pub database: String,
174    /// User name (defaults to `default`).
175    pub username: String,
176    /// Access password (defaults to `""`).
177    pub password: String,
178
179    /// Enable compression (defaults to `false`).
180    pub compression: bool,
181
182    /// Lower bound of opened connections for `Pool` (defaults to 10).
183    pub pool_min: usize,
184    /// Upper bound of opened connections for `Pool` (defaults to 20).
185    pub pool_max: usize,
186
187    /// Whether to enable `TCP_NODELAY` (defaults to `true`).
188    pub nodelay: bool,
189    /// TCP keep alive timeout in milliseconds (defaults to `None`).
190    pub keepalive: Option<Duration>,
191
192    /// Ping server every time before execute any query. (defaults to `true`)
193    pub ping_before_query: bool,
194    /// Count of retry to send request to server. (defaults to `3`)
195    pub send_retries: usize,
196    /// Amount of time to wait before next retry. (defaults to `5 sec`)
197    pub retry_timeout: Duration,
198    /// Timeout for ping (defaults to `500 ms`)
199    pub ping_timeout: Duration,
200
201    /// Timeout for connection (defaults to `500 ms`)
202    pub connection_timeout: Duration,
203
204    /// Timeout for queries (defaults to `180 sec`)
205    pub query_timeout: Duration,
206
207    /// Timeout for inserts (defaults to `180 sec`)
208    pub insert_timeout: Option<Duration>,
209
210    /// Timeout for execute (defaults to `180 sec`)
211    pub execute_timeout: Option<Duration>,
212
213    /// Enable TLS encryption (defaults to `false`)
214    #[cfg(feature = "tls")]
215    pub secure: bool,
216
217    /// Skip certificate verification (default is `false`).
218    #[cfg(feature = "tls")]
219    pub skip_verify: bool,
220
221    /// An X509 certificate.
222    #[cfg(feature = "tls")]
223    pub certificate: Option<Certificate>,
224
225    /// Restricts permissions for read data, write data and change settings queries.
226    pub readonly: Option<u8>,
227
228    /// Comma separated list of single address host for load-balancing.
229    pub alt_hosts: Vec<Url>,
230}
231
232impl fmt::Debug for Options {
233    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
234        f.debug_struct("Options")
235            .field("addr", &self.addr)
236            .field("database", &self.database)
237            .field("compression", &self.compression)
238            .field("pool_min", &self.pool_min)
239            .field("pool_max", &self.pool_max)
240            .field("nodelay", &self.nodelay)
241            .field("keepalive", &self.keepalive)
242            .field("ping_before_query", &self.ping_before_query)
243            .field("send_retries", &self.send_retries)
244            .field("retry_timeout", &self.retry_timeout)
245            .field("ping_timeout", &self.ping_timeout)
246            .field("connection_timeout", &self.connection_timeout)
247            .field("readonly", &self.readonly)
248            .field("alt_hosts", &self.alt_hosts)
249            .finish()
250    }
251}
252
253impl Default for Options {
254    fn default() -> Self {
255        Self {
256            addr: Url::parse("tcp://default@127.0.0.1:9000").unwrap(),
257            database: "default".into(),
258            username: "default".into(),
259            password: "".into(),
260            compression: false,
261            pool_min: DEFAULT_MIN_CONNS,
262            pool_max: DEFAULT_MAX_CONNS,
263            nodelay: true,
264            keepalive: None,
265            ping_before_query: true,
266            send_retries: 3,
267            retry_timeout: Duration::from_secs(5),
268            ping_timeout: Duration::from_millis(500),
269            connection_timeout: Duration::from_millis(500),
270            query_timeout: Duration::from_secs(180),
271            insert_timeout: Some(Duration::from_secs(180)),
272            execute_timeout: Some(Duration::from_secs(180)),
273            #[cfg(feature = "tls")]
274            secure: false,
275            #[cfg(feature = "tls")]
276            skip_verify: false,
277            #[cfg(feature = "tls")]
278            certificate: None,
279            readonly: None,
280            alt_hosts: Vec::new(),
281        }
282    }
283}
284
285macro_rules! property {
286    ( $k:ident: $t:ty ) => {
287        pub fn $k(self, $k: $t) -> Self {
288            Self {
289                $k: $k.into(),
290                ..self
291            }
292        }
293    };
294    ( $(#[$attr:meta])* => $k:ident: $t:ty ) => {
295        $(#[$attr])*
296        #[must_use]
297        pub fn $k(self, $k: $t) -> Self {
298            Self {
299                $k: $k.into(),
300                ..self
301            }
302        }
303    }
304}
305
306impl Options {
307    /// Constructs a new Options.
308    pub fn new<A>(addr: A) -> Self
309    where
310        A: Into<Url>,
311    {
312        Self {
313            addr: addr.into(),
314            ..Self::default()
315        }
316    }
317
318    property! {
319        /// Database name. (defaults to `default`).
320        => database: &str
321    }
322
323    property! {
324        /// User name (defaults to `default`).
325        => username: &str
326    }
327
328    property! {
329        /// Access password (defaults to `""`).
330        => password: &str
331    }
332
333    /// Enable compression (defaults to `false`).
334    #[must_use]
335    pub fn with_compression(self) -> Self {
336        Self {
337            compression: true,
338            ..self
339        }
340    }
341
342    property! {
343        /// Lower bound of opened connections for `Pool` (defaults to `10`).
344        => pool_min: usize
345    }
346
347    property! {
348        /// Upper bound of opened connections for `Pool` (defaults to `20`).
349        => pool_max: usize
350    }
351
352    property! {
353        /// Whether to enable `TCP_NODELAY` (defaults to `true`).
354        => nodelay: bool
355    }
356
357    property! {
358        /// TCP keep alive timeout in milliseconds (defaults to `None`).
359        => keepalive: Option<Duration>
360    }
361
362    property! {
363        /// Ping server every time before execute any query. (defaults to `true`).
364        => ping_before_query: bool
365    }
366
367    property! {
368        /// Count of retry to send request to server. (defaults to `3`).
369        => send_retries: usize
370    }
371
372    property! {
373        /// Amount of time to wait before next retry. (defaults to `5 sec`).
374        => retry_timeout: Duration
375    }
376
377    property! {
378        /// Timeout for ping (defaults to `500 ms`).
379        => ping_timeout: Duration
380    }
381
382    property! {
383        /// Timeout for connection (defaults to `500 ms`).
384        => connection_timeout: Duration
385    }
386
387    property! {
388        /// Timeout for query (defaults to `180,000 ms`).
389        => query_timeout: Duration
390    }
391
392    property! {
393        /// Timeout for insert (defaults to `180,000 ms`).
394        => insert_timeout: Option<Duration>
395    }
396
397    property! {
398        /// Timeout for execute (defaults to `180 sec`).
399        => execute_timeout: Option<Duration>
400    }
401
402    #[cfg(feature = "tls")]
403    property! {
404        /// Establish secure connection (default is `false`).
405        => secure: bool
406    }
407
408    #[cfg(feature = "tls")]
409    property! {
410        /// Skip certificate verification (default is `false`).
411        => skip_verify: bool
412    }
413
414    #[cfg(feature = "tls")]
415    property! {
416        /// An X509 certificate.
417        => certificate: Option<Certificate>
418    }
419
420    property! {
421        /// Restricts permissions for read data, write data and change settings queries.
422        => readonly: Option<u8>
423    }
424
425    property! {
426        /// Comma separated list of single address host for load-balancing.
427        => alt_hosts: Vec<Url>
428    }
429}
430
431impl FromStr for Options {
432    type Err = Error;
433
434    fn from_str(url: &str) -> Result<Self> {
435        from_url(url)
436    }
437}
438
439pub fn from_url(url_str: &str) -> Result<Options> {
440    let url = Url::parse(url_str)?;
441
442    if url.scheme() != "tcp" {
443        return Err(UrlError::UnsupportedScheme {
444            scheme: url.scheme().to_string(),
445        }
446        .into());
447    }
448
449    if url.cannot_be_a_base() || !url.has_host() {
450        return Err(UrlError::Invalid.into());
451    }
452
453    let mut options = Options::default();
454
455    if let Some(username) = get_username_from_url(&url) {
456        options.username = username.into();
457    }
458
459    if let Some(password) = get_password_from_url(&url) {
460        options.password = password.into()
461    }
462
463    let mut addr = url.clone();
464    addr.set_path("");
465    addr.set_query(None);
466
467    let port = url.port().or(Some(9000));
468    addr.set_port(port).map_err(|_| UrlError::Invalid)?;
469    options.addr = addr;
470
471    if let Some(database) = get_database_from_url(&url)? {
472        options.database = database.into();
473    }
474
475    set_params(&mut options, url.query_pairs())?;
476
477    Ok(options)
478}
479
480fn set_params<'a, I>(options: &mut Options, iter: I) -> std::result::Result<(), UrlError>
481where
482    I: Iterator<Item = (Cow<'a, str>, Cow<'a, str>)>,
483{
484    for (key, value) in iter {
485        match key.as_ref() {
486            "pool_min" => options.pool_min = parse_param(key, value, usize::from_str)?,
487            "pool_max" => options.pool_max = parse_param(key, value, usize::from_str)?,
488            "nodelay" => options.nodelay = parse_param(key, value, bool::from_str)?,
489            "keepalive" => options.keepalive = parse_param(key, value, parse_opt_duration)?,
490            "ping_before_query" => {
491                options.ping_before_query = parse_param(key, value, bool::from_str)?
492            }
493            "send_retries" => options.send_retries = parse_param(key, value, usize::from_str)?,
494            "retry_timeout" => options.retry_timeout = parse_param(key, value, parse_duration)?,
495            "ping_timeout" => options.ping_timeout = parse_param(key, value, parse_duration)?,
496            "connection_timeout" => {
497                options.connection_timeout = parse_param(key, value, parse_duration)?
498            }
499            "query_timeout" => options.query_timeout = parse_param(key, value, parse_duration)?,
500            "insert_timeout" => {
501                options.insert_timeout = parse_param(key, value, parse_opt_duration)?
502            }
503            "execute_timeout" => {
504                options.execute_timeout = parse_param(key, value, parse_opt_duration)?
505            }
506            "compression" => options.compression = parse_param(key, value, parse_compression)?,
507            #[cfg(feature = "tls")]
508            "secure" => options.secure = parse_param(key, value, bool::from_str)?,
509            #[cfg(feature = "tls")]
510            "skip_verify" => options.skip_verify = parse_param(key, value, bool::from_str)?,
511            "readonly" => options.readonly = parse_param(key, value, parse_opt_u8)?,
512            "alt_hosts" => options.alt_hosts = parse_param(key, value, parse_hosts)?,
513            _ => return Err(UrlError::UnknownParameter { param: key.into() }),
514        };
515    }
516
517    Ok(())
518}
519
520fn parse_param<'a, F, T, E>(
521    param: Cow<'a, str>,
522    value: Cow<'a, str>,
523    parse: F,
524) -> std::result::Result<T, UrlError>
525where
526    F: Fn(&str) -> std::result::Result<T, E>,
527{
528    match parse(value.as_ref()) {
529        Ok(value) => Ok(value),
530        Err(_) => Err(UrlError::InvalidParamValue {
531            param: param.into(),
532            value: value.into(),
533        }),
534    }
535}
536
537pub fn get_username_from_url(url: &Url) -> Option<&str> {
538    let user = url.username();
539    if user.is_empty() {
540        return None;
541    }
542    Some(user)
543}
544
545pub fn get_password_from_url(url: &Url) -> Option<&str> {
546    url.password()
547}
548
549pub fn get_database_from_url(url: &Url) -> Result<Option<&str>> {
550    match url.path_segments() {
551        None => Ok(None),
552        Some(mut segments) => {
553            let head = segments.next();
554
555            if segments.next().is_some() {
556                return Err(Error::Url(UrlError::Invalid));
557            }
558
559            match head {
560                Some(database) if !database.is_empty() => Ok(Some(database)),
561                _ => Ok(None),
562            }
563        }
564    }
565}
566
567#[allow(clippy::result_unit_err)]
568pub fn parse_duration(source: &str) -> std::result::Result<Duration, ()> {
569    let digits_count = source.chars().take_while(|c| c.is_ascii_digit()).count();
570
571    let left: String = source.chars().take(digits_count).collect();
572    let right: String = source.chars().skip(digits_count).collect();
573
574    let num = match u64::from_str(&left) {
575        Ok(value) => value,
576        Err(_) => return Err(()),
577    };
578
579    match right.as_str() {
580        "s" => Ok(Duration::from_secs(num)),
581        "ms" => Ok(Duration::from_millis(num)),
582        _ => Err(()),
583    }
584}
585
586#[allow(clippy::result_unit_err)]
587pub fn parse_opt_duration(source: &str) -> std::result::Result<Option<Duration>, ()> {
588    if source == "none" {
589        return Ok(None);
590    }
591
592    let duration = parse_duration(source)?;
593    Ok(Some(duration))
594}
595
596#[allow(clippy::result_unit_err)]
597pub fn parse_opt_u8(source: &str) -> std::result::Result<Option<u8>, ()> {
598    if source == "none" {
599        return Ok(None);
600    }
601
602    let duration: u8 = match source.parse() {
603        Ok(value) => value,
604        Err(_) => return Err(()),
605    };
606
607    Ok(Some(duration))
608}
609
610#[allow(clippy::result_unit_err)]
611pub fn parse_compression(source: &str) -> std::result::Result<bool, ()> {
612    match source {
613        "none" => Ok(false),
614        "lz4" => Ok(true),
615        _ => Err(()),
616    }
617}
618
619#[allow(clippy::result_unit_err)]
620pub fn parse_hosts(source: &str) -> std::result::Result<Vec<Url>, ()> {
621    let mut result = Vec::new();
622    for host in source.split(',') {
623        match Url::from_str(&format!("tcp://{}", host)) {
624            Ok(url) => result.push(url),
625            Err(_) => return Err(()),
626        }
627    }
628    Ok(result)
629}