Skip to main content

clickhouse_rs/types/
options.rs

1#[cfg(feature = "tls")]
2use std::convert;
3
4use std::{
5    borrow::Cow,
6    fmt,
7    str::FromStr,
8    sync::{Arc, Mutex},
9    time::Duration,
10    fmt::Debug,
11};
12
13use crate::errors::{Error, Result, UrlError};
14use url::Url;
15
16const DEFAULT_MIN_CONNS: usize = 10;
17
18const DEFAULT_MAX_CONNS: usize = 20;
19
20#[derive(Debug)]
21#[allow(clippy::large_enum_variant)]
22enum State {
23    Raw(Options),
24    Url(String),
25}
26
27#[derive(Clone)]
28pub struct OptionsSource {
29    state: Arc<Mutex<State>>,
30}
31
32impl fmt::Debug for OptionsSource {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        let guard = self.state.lock().unwrap();
35        match *guard {
36            State::Url(ref url) => write!(f, "Url({})", url),
37            State::Raw(ref options) => write!(f, "{:?}", options),
38        }
39    }
40}
41
42impl OptionsSource {
43    pub(crate) fn get(&self) -> Result<Cow<Options>> {
44        let mut state = self.state.lock().unwrap();
45        loop {
46            let new_state;
47            match &*state {
48                State::Raw(ref options) => {
49                    let ptr = options as *const Options;
50                    return unsafe { Ok(Cow::Borrowed(ptr.as_ref().unwrap())) };
51                }
52                State::Url(url) => {
53                    let options = from_url(&url)?;
54                    new_state = State::Raw(options);
55                }
56            }
57            *state = new_state;
58        }
59    }
60}
61
62impl Default for OptionsSource {
63    fn default() -> Self {
64        Self {
65            state: Arc::new(Mutex::new(State::Raw(Options::default()))),
66        }
67    }
68}
69
70pub trait IntoOptions {
71    fn into_options_src(self) -> OptionsSource;
72}
73
74impl IntoOptions for Options {
75    fn into_options_src(self) -> OptionsSource {
76        OptionsSource {
77            state: Arc::new(Mutex::new(State::Raw(self))),
78        }
79    }
80}
81
82impl IntoOptions for &str {
83    fn into_options_src(self) -> OptionsSource {
84        OptionsSource {
85            state: Arc::new(Mutex::new(State::Url(self.into()))),
86        }
87    }
88}
89
90impl IntoOptions for String {
91    fn into_options_src(self) -> OptionsSource {
92        OptionsSource {
93            state: Arc::new(Mutex::new(State::Url(self))),
94        }
95    }
96}
97
98/// An X509 certificate.
99#[cfg(feature = "tls")]
100#[derive(Clone)]
101pub struct Certificate(Arc<native_tls::Certificate>);
102
103#[cfg(feature = "tls")]
104impl Certificate {
105    /// Parses a DER-formatted X509 certificate.
106    pub fn from_der(der: &[u8]) -> Result<Certificate> {
107        let inner = match native_tls::Certificate::from_der(der) {
108            Ok(certificate) => certificate,
109            Err(err) => return Err(Error::Other(err.to_string().into())),
110        };
111        Ok(Certificate(Arc::new(inner)))
112    }
113
114    /// Parses a PEM-formatted X509 certificate.
115    pub fn from_pem(der: &[u8]) -> Result<Certificate> {
116        let inner = match native_tls::Certificate::from_pem(der) {
117            Ok(certificate) => certificate,
118            Err(err) => return Err(Error::Other(err.to_string().into())),
119        };
120        Ok(Certificate(Arc::new(inner)))
121    }
122}
123
124#[cfg(feature = "tls")]
125impl fmt::Debug for Certificate {
126    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
127        write!(f, "[Certificate]")
128    }
129}
130
131#[cfg(feature = "tls")]
132impl PartialEq for Certificate {
133    fn eq(&self, _other: &Self) -> bool {
134        true
135    }
136}
137
138#[cfg(feature = "tls")]
139impl convert::From<Certificate> for native_tls::Certificate {
140    fn from(value: Certificate) -> Self {
141        value.0.as_ref().clone()
142    }
143}
144
145/// Clickhouse connection options.
146#[derive(Clone, PartialEq)]
147pub struct Options {
148    /// Address of clickhouse server (defaults to `127.0.0.1:9000`).
149    pub(crate) addr: Url,
150
151    /// Database name. (defaults to `default`).
152    pub(crate) database: String,
153    /// User name (defaults to `default`).
154    pub(crate) username: String,
155    /// Access password (defaults to `""`).
156    pub(crate) password: String,
157
158    /// Enable compression (defaults to `false`).
159    pub(crate) compression: bool,
160
161    /// Lower bound of opened connections for `Pool` (defaults to 10).
162    pub(crate) pool_min: usize,
163    /// Upper bound of opened connections for `Pool` (defaults to 20).
164    pub(crate) pool_max: usize,
165
166    /// Whether to enable `TCP_NODELAY` (defaults to `true`).
167    pub(crate) nodelay: bool,
168    /// TCP keep alive timeout in milliseconds (defaults to `None`).
169    pub(crate) keepalive: Option<Duration>,
170
171    /// Ping server every time before execute any query. (defaults to `true`)
172    pub(crate) ping_before_query: bool,
173    /// Count of retry to send request to server. (defaults to `3`)
174    pub(crate) send_retries: usize,
175    /// Amount of time to wait before next retry. (defaults to `5 sec`)
176    pub(crate) retry_timeout: Duration,
177    /// Timeout for ping (defaults to `500 ms`)
178    pub(crate) ping_timeout: Duration,
179
180    /// Timeout for connection (defaults to `500 ms`)
181    pub(crate) connection_timeout: Duration,
182
183    /// Timeout for queries (defaults to `180 sec`)
184    pub(crate) query_timeout: Option<Duration>,
185
186    /// Timeout for each block in a query (defaults to `180 sec`)
187    pub(crate) query_block_timeout: Option<Duration>,
188
189    /// Timeout for inserts (defaults to `180 sec`)
190    pub(crate) insert_timeout: Option<Duration>,
191
192    /// Timeout for execute (defaults to `180 sec`)
193    pub(crate) execute_timeout: Option<Duration>,
194
195    /// Enable TLS encryption (defaults to `false`)
196    #[cfg(feature = "tls")]
197    pub(crate) secure: bool,
198
199    /// Skip certificate verification (default is `false`).
200    #[cfg(feature = "tls")]
201    pub(crate) skip_verify: bool,
202
203    /// An X509 certificate.
204    #[cfg(feature = "tls")]
205    pub(crate) certificate: Option<Certificate>,
206
207    /// Restricts permissions for read data, write data and change settings queries.
208    pub(crate) readonly: Option<u8>,
209
210    /// Comma separated list of single address host for load-balancing.
211    pub(crate) alt_hosts: Vec<Url>,
212}
213
214impl fmt::Debug for Options {
215    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216        f.debug_struct("Options")
217            .field("addr", &self.addr)
218            .field("database", &self.database)
219            .field("compression", &self.compression)
220            .field("pool_min", &self.pool_min)
221            .field("pool_max", &self.pool_max)
222            .field("nodelay", &self.nodelay)
223            .field("keepalive", &self.keepalive)
224            .field("ping_before_query", &self.ping_before_query)
225            .field("send_retries", &self.send_retries)
226            .field("retry_timeout", &self.retry_timeout)
227            .field("ping_timeout", &self.ping_timeout)
228            .field("connection_timeout", &self.connection_timeout)
229            .field("query_timeout", &self.query_timeout)
230            .field("query_block_timeout", &self.query_block_timeout)
231            .field("insert_timeout", &self.insert_timeout)
232            .field("execute_timeout", &self.execute_timeout)
233            .field("readonly", &self.readonly)
234            .field("alt_hosts", &self.alt_hosts)
235            .finish()
236    }
237}
238
239impl Default for Options {
240    fn default() -> Self {
241        Self {
242            addr: Url::parse("tcp://default@127.0.0.1:9000").unwrap(),
243            database: "default".into(),
244            username: "default".into(),
245            password: "".into(),
246            compression: false,
247            pool_min: DEFAULT_MIN_CONNS,
248            pool_max: DEFAULT_MAX_CONNS,
249            nodelay: true,
250            keepalive: None,
251            ping_before_query: true,
252            send_retries: 3,
253            retry_timeout: Duration::from_secs(5),
254            ping_timeout: Duration::from_millis(500),
255            connection_timeout: Duration::from_millis(500),
256            query_timeout: Some(Duration::from_secs(180)),
257            query_block_timeout: Some(Duration::from_secs(180)),
258            insert_timeout: Some(Duration::from_secs(180)),
259            execute_timeout: Some(Duration::from_secs(180)),
260            #[cfg(feature = "tls")]
261            secure: false,
262            #[cfg(feature = "tls")]
263            skip_verify: false,
264            #[cfg(feature = "tls")]
265            certificate: None,
266            readonly: None,
267            alt_hosts: Vec::new(),
268        }
269    }
270}
271
272macro_rules! property {
273    ( $k:ident: $t:ty ) => {
274        pub fn $k(self, $k: $t) -> Self {
275            Self {
276                $k: $k.into(),
277                ..self
278            }
279        }
280    };
281    ( $(#[$attr:meta])* => $k:ident: $t:ty ) => {
282        $(#[$attr])*
283        pub fn $k(self, $k: $t) -> Self {
284            Self {
285                $k: $k.into(),
286                ..self
287            }
288        }
289    }
290}
291
292impl Options {
293    /// Constructs a new Options.
294    pub fn new<A>(addr: A) -> Self
295    where
296        A: Into<Url>,
297    {
298        Self {
299            addr: addr.into(),
300            ..Self::default()
301        }
302    }
303
304    property! {
305        /// Database name. (defaults to `default`).
306        => database: &str
307    }
308
309    property! {
310        /// User name (defaults to `default`).
311        => username: &str
312    }
313
314    property! {
315        /// Access password (defaults to `""`).
316        => password: &str
317    }
318
319    /// Enable compression (defaults to `false`).
320    pub fn with_compression(self) -> Self {
321        Self {
322            compression: true,
323            ..self
324        }
325    }
326
327    property! {
328        /// Lower bound of opened connections for `Pool` (defaults to `10`).
329        => pool_min: usize
330    }
331
332    property! {
333        /// Upper bound of opened connections for `Pool` (defaults to `20`).
334        => pool_max: usize
335    }
336
337    property! {
338        /// Whether to enable `TCP_NODELAY` (defaults to `true`).
339        => nodelay: bool
340    }
341
342    property! {
343        /// TCP keep alive timeout in milliseconds (defaults to `None`).
344        => keepalive: Option<Duration>
345    }
346
347    property! {
348        /// Ping server every time before execute any query. (defaults to `true`).
349        => ping_before_query: bool
350    }
351
352    property! {
353        /// Count of retry to send request to server. (defaults to `3`).
354        => send_retries: usize
355    }
356
357    property! {
358        /// Amount of time to wait before next retry. (defaults to `5 sec`).
359        => retry_timeout: Duration
360    }
361
362    property! {
363        /// Timeout for ping (defaults to `500 ms`).
364        => ping_timeout: Duration
365    }
366
367    property! {
368        /// Timeout for connection (defaults to `500 ms`).
369        => connection_timeout: Duration
370    }
371
372    property! {
373        /// Timeout for query (defaults to `180,000 ms`).
374        => query_timeout: Duration
375    }
376
377    property! {
378        /// Timeout for each block in a query (defaults to `180,000 ms`).
379        => query_block_timeout: Duration
380    }
381
382    property! {
383        /// Timeout for insert (defaults to `180,000 ms`).
384        => insert_timeout: Option<Duration>
385    }
386
387    property! {
388        /// Timeout for execute (defaults to `180 sec`).
389        => execute_timeout: Option<Duration>
390    }
391
392    #[cfg(feature = "tls")]
393    property! {
394        /// Establish secure connection (default is `false`).
395        => secure: bool
396    }
397
398    #[cfg(feature = "tls")]
399    property! {
400        /// Skip certificate verification (default is `false`).
401        => skip_verify: bool
402    }
403
404    #[cfg(feature = "tls")]
405    property! {
406        /// An X509 certificate.
407        => certificate: Option<Certificate>
408    }
409
410    property! {
411        /// Restricts permissions for read data, write data and change settings queries.
412        => readonly: Option<u8>
413    }
414
415    property! {
416        /// Comma separated list of single address host for load-balancing.
417        => alt_hosts: Vec<Url>
418    }
419}
420
421impl FromStr for Options {
422    type Err = Error;
423
424    fn from_str(url: &str) -> Result<Self> {
425        from_url(url)
426    }
427}
428
429fn from_url(url_str: &str) -> Result<Options> {
430    let url = Url::parse(url_str)?;
431
432    if url.scheme() != "tcp" {
433        return Err(UrlError::UnsupportedScheme {
434            scheme: url.scheme().to_string(),
435        }
436        .into());
437    }
438
439    if url.cannot_be_a_base() || !url.has_host() {
440        return Err(UrlError::Invalid.into());
441    }
442
443    let mut options = Options::default();
444
445    if let Some(username) = get_username_from_url(&url)? {
446        options.username = username.into();
447    }
448
449    if let Some(password) = get_password_from_url(&url)? {
450        options.password = password.into()
451    }
452
453    let mut addr = url.clone();
454    addr.set_path("");
455    addr.set_query(None);
456
457    let port = url.port().or(Some(9000));
458    addr.set_port(port).map_err(|_| UrlError::Invalid)?;
459    options.addr = addr;
460
461    if let Some(database) = get_database_from_url(&url)? {
462        options.database = database.into();
463    }
464
465    set_params(&mut options, url.query_pairs())?;
466
467    Ok(options)
468}
469
470fn set_params<'a, I>(options: &mut Options, iter: I) -> std::result::Result<(), UrlError>
471where
472    I: Iterator<Item = (Cow<'a, str>, Cow<'a, str>)>,
473{
474    for (key, value) in iter {
475        match key.as_ref() {
476            "pool_min" => options.pool_min = parse_param(key, value, usize::from_str)?,
477            "pool_max" => options.pool_max = parse_param(key, value, usize::from_str)?,
478            "nodelay" => options.nodelay = parse_param(key, value, bool::from_str)?,
479            "keepalive" => options.keepalive = parse_param(key, value, parse_opt_duration)?,
480            "ping_before_query" => {
481                options.ping_before_query = parse_param(key, value, bool::from_str)?
482            }
483            "send_retries" => options.send_retries = parse_param(key, value, usize::from_str)?,
484            "retry_timeout" => options.retry_timeout = parse_param(key, value, parse_duration)?,
485            "ping_timeout" => options.ping_timeout = parse_param(key, value, parse_duration)?,
486            "connection_timeout" => {
487                options.connection_timeout = parse_param(key, value, parse_duration)?
488            }
489            "query_timeout" => options.query_timeout = parse_param(key, value, parse_opt_duration)?,
490            "query_block_timeout" => {
491                options.query_block_timeout = parse_param(key, value, parse_opt_duration)?
492            }
493            "insert_timeout" => {
494                options.insert_timeout = parse_param(key, value, parse_opt_duration)?
495            }
496            "execute_timeout" => {
497                options.execute_timeout = parse_param(key, value, parse_opt_duration)?
498            }
499            "compression" => options.compression = parse_param(key, value, parse_compression)?,
500            #[cfg(feature = "tls")]
501            "secure" => options.secure = parse_param(key, value, bool::from_str)?,
502            #[cfg(feature = "tls")]
503            "skip_verify" => options.skip_verify = parse_param(key, value, bool::from_str)?,
504            "readonly" => options.readonly = parse_param(key, value, parse_opt_u8)?,
505            "alt_hosts" => options.alt_hosts = parse_param(key, value, parse_hosts)?,
506            _ => return Err(UrlError::UnknownParameter { param: key.into() }),
507        };
508    }
509
510    Ok(())
511}
512
513fn parse_param<'a, F, T, E>(
514    param: Cow<'a, str>,
515    value: Cow<'a, str>,
516    parse: F,
517) -> std::result::Result<T, UrlError>
518where
519    F: Fn(&str) -> std::result::Result<T, E>,
520{
521    match parse(value.as_ref()) {
522        Ok(value) => Ok(value),
523        Err(_) => Err(UrlError::InvalidParamValue {
524            param: param.into(),
525            value: value.into(),
526        }),
527    }
528}
529
530fn get_username_from_url(url: &Url) -> Result<Option<&str>> {
531    let user = url.username();
532    if user.is_empty() {
533        return Ok(None);
534    }
535    Ok(Some(user))
536}
537
538fn get_password_from_url(url: &Url) -> Result<Option<&str>> {
539    match url.password() {
540        None => Ok(None),
541        Some(password) => Ok(Some(password)),
542    }
543}
544
545fn get_database_from_url(url: &Url) -> Result<Option<&str>> {
546    match url.path_segments() {
547        None => Ok(None),
548        Some(mut segments) => {
549            let head = segments.next();
550
551            if segments.next().is_some() {
552                return Err(Error::Url(UrlError::Invalid));
553            }
554
555            match head {
556                Some(database) if !database.is_empty() => Ok(Some(database)),
557                _ => Ok(None),
558            }
559        }
560    }
561}
562
563fn parse_duration(source: &str) -> std::result::Result<Duration, ()> {
564    let digits_count = source.chars().take_while(|c| c.is_digit(10)).count();
565
566    let left: String = source.chars().take(digits_count).collect();
567    let right: String = source.chars().skip(digits_count).collect();
568
569    let num = match u64::from_str(&left) {
570        Ok(value) => value,
571        Err(_) => return Err(()),
572    };
573
574    match right.as_str() {
575        "s" => Ok(Duration::from_secs(num)),
576        "ms" => Ok(Duration::from_millis(num)),
577        _ => Err(()),
578    }
579}
580
581fn parse_opt_duration(source: &str) -> std::result::Result<Option<Duration>, ()> {
582    if source == "none" {
583        return Ok(None);
584    }
585
586    let duration = parse_duration(source)?;
587    Ok(Some(duration))
588}
589
590fn parse_opt_u8(source: &str) -> std::result::Result<Option<u8>, ()> {
591    if source == "none" {
592        return Ok(None);
593    }
594
595    let duration: u8 = match source.parse() {
596        Ok(value) => value,
597        Err(_) => return Err(()),
598    };
599
600    Ok(Some(duration))
601}
602
603fn parse_compression(source: &str) -> std::result::Result<bool, ()> {
604    match source {
605        "none" => Ok(false),
606        "lz4" => Ok(true),
607        _ => Err(()),
608    }
609}
610
611fn parse_hosts(source: &str) -> std::result::Result<Vec<Url>, ()> {
612    let mut result = Vec::new();
613    for host in source.split(',') {
614        match Url::from_str(&format!("tcp://{}", host)) {
615            Ok(url) => result.push(url),
616            Err(_) => return Err(()),
617        }
618    }
619    Ok(result)
620}
621
622#[cfg(test)]
623mod test {
624    use super::*;
625
626    #[test]
627    fn test_parse_hosts() {
628        let source = "host2:9000,host3:9000";
629        let expected = vec![
630            Url::from_str("tcp://host2:9000").unwrap(),
631            Url::from_str("tcp://host3:9000").unwrap()];
632        let actual = parse_hosts(source).unwrap();
633        assert_eq!(actual, expected)
634    }
635
636    #[test]
637    fn test_parse_default() {
638        let url = "tcp://host1";
639        let options = from_url(url).unwrap();
640        assert_eq!(options.database, "default");
641        assert_eq!(options.username, "default");
642        assert_eq!(options.password, "");
643    }
644
645    #[test]
646    #[cfg(feature = "tls")]
647    fn test_parse_secure_options() {
648        let url = "tcp://username:password@host1:9001/database?ping_timeout=42ms&keepalive=99s&compression=lz4&connection_timeout=10s&secure=true";
649        assert_eq!(
650            Options {
651                username: "username".into(),
652                password: "password".into(),
653                addr: Url::parse("tcp://username:password@host1:9001").unwrap(),
654                database: "database".into(),
655                keepalive: Some(Duration::from_secs(99)),
656                ping_timeout: Duration::from_millis(42),
657                connection_timeout: Duration::from_secs(10),
658                compression: true,
659                secure: true,
660                ..Options::default()
661            },
662            from_url(url).unwrap(),
663        );
664    }
665
666    #[test]
667    fn test_parse_options() {
668        let url = "tcp://username:password@host1:9001/database?ping_timeout=42ms&keepalive=99s&compression=lz4&connection_timeout=10s";
669        assert_eq!(
670            Options {
671                username: "username".into(),
672                password: "password".into(),
673                addr: Url::parse("tcp://username:password@host1:9001").unwrap(),
674                database: "database".into(),
675                keepalive: Some(Duration::from_secs(99)),
676                ping_timeout: Duration::from_millis(42),
677                connection_timeout: Duration::from_secs(10),
678                compression: true,
679                ..Options::default()
680            },
681            from_url(url).unwrap(),
682        );
683    }
684
685    #[test]
686    #[should_panic]
687    fn test_parse_invalid_url() {
688        let url = "ʘ_ʘ";
689        from_url(url).unwrap();
690    }
691
692    #[test]
693    #[should_panic]
694    fn test_parse_with_unknown_url() {
695        let url = "tcp://localhost:9000/foo?bar=baz";
696        from_url(url).unwrap();
697    }
698
699    #[test]
700    #[should_panic]
701    fn test_parse_with_multi_databases() {
702        let url = "tcp://localhost:9000/foo/bar";
703        from_url(url).unwrap();
704    }
705
706    #[test]
707    fn test_parse_duration() {
708        assert_eq!(parse_duration("3s").unwrap(), Duration::from_secs(3));
709        assert_eq!(parse_duration("123ms").unwrap(), Duration::from_millis(123));
710
711        parse_duration("ms").unwrap_err();
712        parse_duration("1ss").unwrap_err();
713    }
714
715    #[test]
716    fn test_parse_opt_duration() {
717        assert_eq!(
718            parse_opt_duration("3s").unwrap(),
719            Some(Duration::from_secs(3))
720        );
721        assert_eq!(parse_opt_duration("none").unwrap(), None::<Duration>);
722    }
723
724    #[test]
725    fn test_parse_compression() {
726        assert_eq!(parse_compression("none").unwrap(), false);
727        assert_eq!(parse_compression("lz4").unwrap(), true);
728        parse_compression("?").unwrap_err();
729    }
730}