Skip to main content

clickhouse_driver/pool/
options.rs

1use crate::errors::UrlError;
2use std::convert::TryFrom;
3use std::fmt;
4use std::{borrow::Cow, str::FromStr, time::Duration};
5use url::Url;
6
7type Result<T> = std::result::Result<T, UrlError>;
8
9const DEFAULT_MIN_POOL_SIZE: u16 = 2;
10const DEFAULT_MAX_POOL_SIZE: u16 = 10;
11
12#[derive(Copy, Clone, PartialEq, Debug)]
13/// At the moment Clickhouse_driver supports only LZ4 compression method.
14/// It's used by default in Clickhouse 20.x
15pub enum CompressionMethod {
16    None,
17    LZ4,
18}
19
20impl CompressionMethod {
21    #[inline(always)]
22    pub fn is_none(&self) -> bool {
23        matches!(self, CompressionMethod::None)
24    }
25}
26
27/// Clickhouse connection options.
28#[derive(Clone)]
29pub struct Options {
30    /// Address of clickhouse server (defaults to `127.0.0.1:9000`).
31    pub(crate) addr: Vec<String>,
32    /// Database name. (defaults to `default`).
33    pub(crate) database: String,
34    /// User name (defaults to `default`).
35    pub(crate) username: String,
36    /// Access password (defaults to `""`).
37    pub(crate) password: String,
38    /// Enable compression (defaults to `false`).
39    pub(crate) compression: CompressionMethod,
40    /// Lower bound of opened connections for `Pool` (defaults to 10).
41    pub(crate) pool_min: u16,
42    /// Upper bound of opened connections for `Pool` (defaults to 20).
43    pub(crate) pool_max: u16,
44    /// Whether to enable `TCP_NODELAY` (defaults to `true`).
45    //pub(crate) nodelay: bool,
46    /// TCP keep alive timeout in milliseconds (defaults to `None`).
47    pub(crate) keepalive: Option<Duration>,
48    /// Ping server every time before execute any query. (defaults to `true`)
49    pub(crate) ping_before_query: bool,
50    /// Timeout for ping (defaults to `500 ms`)
51    pub(crate) ping_timeout: Duration,
52    /// Timeout for connection (defaults to `500 ms`)
53    pub(crate) connection_timeout: Duration,
54    /// Timeout for queries (defaults to `180 sec`)
55    pub(crate) query_timeout: Duration,
56    /// Timeout for each block in a query (defaults to `180 sec`)
57    pub(crate) query_block_timeout: Duration,
58    /// Timeout for inserts (defaults to `180 sec`)
59    pub(crate) insert_timeout: Duration,
60    /// Timeout for execute (defaults to `180 sec`)
61    pub(crate) execute_timeout: Duration,
62    /// Enable TLS encryption (defaults to `false`)
63    #[cfg(feature = "tls")]
64    pub(crate) secure: bool,
65    /// Skip certificate verification (default is `false`).
66    #[cfg(feature = "tls")]
67    pub(crate) skip_verify: bool,
68
69    /// An X509 certificate.
70    //#[cfg(feature = "tls")]
71    //pub(crate) certificate: Option<Certificate>,
72
73    /// Restricts permissions for read data, write data and change settings queries.
74    pub(crate) readonly: u8,
75
76    /// The number of retries to send request to server. (defaults to `3`)
77    pub(crate) send_retries: u8,
78
79    /// Amount of time to wait before next retry. (defaults to `1 sec`)
80    pub(crate) retry_timeout: Duration,
81}
82
83// FIXME: replace with macro
84fn parse_param<'a, F, T, E>(param: Cow<'a, str>, value: Cow<'a, str>, parse: F) -> Result<T>
85where
86    F: Fn(&str) -> std::result::Result<T, E>,
87{
88    match parse(value.as_ref()) {
89        Ok(value) => Ok(value),
90        Err(_) => Err(UrlError::InvalidParamValue {
91            param: param.into(),
92            value: value.into(),
93        }),
94    }
95}
96
97fn get_database_from_url(url: &Url) -> Result<Option<&str>> {
98    match url.path_segments() {
99        None => Ok(None),
100        Some(mut segments) => {
101            let head = segments.next();
102
103            if segments.next().is_some() {
104                return Err(UrlError::Invalid);
105            }
106
107            match head {
108                Some(database) if !database.is_empty() => Ok(Some(database)),
109                _ => Ok(None),
110            }
111        }
112    }
113}
114
115fn parse_duration(source: &str) -> Result<Duration> {
116    let (num, unit) = match source.find(|c: char| !c.is_digit(10)) {
117        Some(pos) if pos > 0 => (u64::from_str(&source[0..pos]), &source[pos..]),
118        None => (u64::from_str(source), "s"),
119        _ => {
120            return Err(UrlError::Invalid);
121        }
122    };
123
124    let num = match num {
125        Ok(value) => value,
126        Err(_) => return Err(UrlError::Invalid),
127    };
128
129    match unit {
130        "s" => Ok(Duration::from_secs(num)),
131        "ms" => Ok(Duration::from_millis(num)),
132        _ => Err(UrlError::Invalid),
133    }
134}
135
136fn parse_opt_duration(source: &str) -> Result<Option<Duration>> {
137    if source == "none" {
138        return Ok(None);
139    }
140
141    let duration = parse_duration(source)?;
142    Ok(Some(duration))
143}
144
145fn parse_u8(source: &str) -> Result<u8> {
146    let duration: u8 = match source.parse() {
147        Ok(value) => value,
148        Err(_) => return Err(UrlError::Invalid),
149    };
150
151    Ok(duration)
152}
153
154fn parse_compression(source: &str) -> Result<CompressionMethod> {
155    match source {
156        "none" => Ok(CompressionMethod::None),
157        "lz4" => Ok(CompressionMethod::LZ4),
158        _ => Err(UrlError::Invalid),
159    }
160}
161
162impl Options {
163    fn new(url: Url) -> Result<Options> {
164        let defport = match url.scheme() {
165            "tcp" => 9000,
166            "tls" => 9009,
167            _ => {
168                return Err(UrlError::UnsupportedScheme {
169                    scheme: url.scheme().to_string(),
170                })
171            }
172        };
173
174        let mut options = crate::DEF_OPTIONS.clone(); // Options ::default();
175
176        let user = url.username();
177        if !user.is_empty() {
178            options.username = user.into();
179        }
180
181        if let Some(password) = url.password() {
182            options.password = password.into();
183        }
184
185        let port = url.port().unwrap_or(defport);
186        if url.cannot_be_a_base() || !url.has_host() {
187            return Err(UrlError::Invalid);
188        }
189
190        options.addr.clear();
191        options.addr.push(format!(
192            "{}:{}",
193            url.host_str().unwrap_or("localhost"),
194            port
195        ));
196
197        if let Some(database) = get_database_from_url(&url)? {
198            options.database = database.into();
199        }
200
201        for (key, value) in url.query_pairs() {
202            match key.as_ref() {
203                "pool_min" => options.pool_min = parse_param(key, value, u16::from_str)?,
204                "pool_max" => options.pool_max = parse_param(key, value, u16::from_str)?,
205                "keepalive" => options.keepalive = parse_param(key, value, parse_opt_duration)?,
206                "ping_before_query" => {
207                    options.ping_before_query = parse_param(key, value, bool::from_str)?
208                }
209                "send_retries" => options.send_retries = parse_param(key, value, u8::from_str)?,
210                "retry_timeout" => options.retry_timeout = parse_param(key, value, parse_duration)?,
211                "ping_timeout" => options.ping_timeout = parse_param(key, value, parse_duration)?,
212                "connection_timeout" => {
213                    options.connection_timeout = parse_param(key, value, parse_duration)?
214                }
215                "query_timeout" => options.query_timeout = parse_param(key, value, parse_duration)?,
216                "query_block_timeout" => {
217                    options.query_block_timeout = parse_param(key, value, parse_duration)?
218                }
219                "insert_timeout" => {
220                    options.insert_timeout = parse_param(key, value, parse_duration)?
221                }
222                "execute_timeout" => {
223                    options.execute_timeout = parse_param(key, value, parse_duration)?
224                }
225                "compression" => options.compression = parse_param(key, value, parse_compression)?,
226                #[cfg(feature = "tls")]
227                "secure" => options.secure = parse_param(key, value, bool::from_str)?,
228                #[cfg(feature = "tls")]
229                "skip_verify" => options.skip_verify = parse_param(key, value, bool::from_str)?,
230                "readonly" => options.readonly = parse_param(key, value, parse_u8)?,
231                "host" => options.addr.push(value.into_owned()),
232                _ => return Err(UrlError::UnknownParameter { param: key.into() }),
233            };
234        }
235
236        Ok(options)
237    }
238
239    pub fn set_compression(mut self, compression: CompressionMethod) -> Self {
240        self.compression = compression;
241        self
242    }
243
244    pub fn set_timeout(mut self, timeout: Duration) -> Self {
245        self.ping_timeout = timeout;
246        self.execute_timeout = timeout;
247        self.query_timeout = timeout;
248        self.insert_timeout = timeout;
249        self
250    }
251}
252
253impl fmt::Debug for Options {
254    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
255        f.debug_struct("Options")
256            .field("addr", &self.addr)
257            .field("database", &self.database)
258            .field("compression", &self.compression)
259            .field("pool_min", &self.pool_min)
260            .field("pool_max", &self.pool_max)
261            .field("keepalive", &self.keepalive)
262            .field("ping_before_query", &self.ping_before_query)
263            .field("send_retries", &self.send_retries)
264            .field("retry_timeout", &self.retry_timeout)
265            .field("ping_timeout", &self.ping_timeout)
266            .field("connection_timeout", &self.connection_timeout)
267            .field("query_timeout", &self.query_timeout)
268            .field("query_block_timeout", &self.query_block_timeout)
269            .field("insert_timeout", &self.insert_timeout)
270            .field("execute_timeout", &self.execute_timeout)
271            .field("readonly", &self.readonly)
272            .finish()
273    }
274}
275
276impl Default for Options {
277    fn default() -> Self {
278        let default_duration = Duration::from_secs(180);
279        Self {
280            addr: vec!["localhost:9000".into()],
281            database: "default".into(),
282            username: "default".into(),
283            password: "".into(),
284            compression: CompressionMethod::LZ4,
285            pool_min: DEFAULT_MIN_POOL_SIZE,
286            pool_max: DEFAULT_MAX_POOL_SIZE,
287            keepalive: None,
288            ping_before_query: true,
289            send_retries: 3,
290            retry_timeout: Duration::from_secs(1),
291            ping_timeout: Duration::from_millis(700),
292            connection_timeout: Duration::from_millis(500),
293            query_timeout: default_duration,
294            query_block_timeout: default_duration,
295            insert_timeout: default_duration,
296            execute_timeout: default_duration,
297            #[cfg(feature = "tls")]
298            secure: false,
299            #[cfg(feature = "tls")]
300            skip_verify: false,
301            readonly: 0,
302        }
303    }
304}
305
306impl TryFrom<Url> for Options {
307    type Error = UrlError;
308    fn try_from(value: Url) -> Result<Self> {
309        Options::new(value)
310    }
311}
312
313/// Weird template TryFrom<T> implementation collision
314/// ( https://github.com/rust-lang/rust/issues/50133 )
315/// with TryFrom<&Url> make us to draw up two separate implementations
316/// for &str and String
317impl TryFrom<&str> for Options {
318    type Error = UrlError;
319
320    fn try_from(value: &str) -> Result<Self> {
321        let url = Url::parse(value)?;
322        Options::new(url)
323    }
324}
325
326impl TryFrom<String> for Options {
327    type Error = UrlError;
328
329    fn try_from(value: String) -> Result<Self> {
330        let url = Url::parse(value.as_ref())?;
331        Options::new(url)
332    }
333}
334
335impl Options {
336    pub(crate) fn take_addr(&mut self) -> Vec<String> {
337        std::mem::replace(&mut self.addr, Vec::new())
338    }
339}
340
341#[cfg(test)]
342mod test {
343    use super::*;
344    use crate::pool::Pool;
345
346    #[test]
347    fn test_default_config() -> Result<()> {
348        let pool = Pool::create("tcp://localhost?ping_timeout=1ms").unwrap();
349
350        assert_eq!(pool.options().database, "default");
351        assert_eq!(pool.options().compression, CompressionMethod::LZ4);
352        assert_eq!(pool.options().username, "default");
353        assert_eq!(pool.options().password, "");
354        assert_eq!(pool.inner.hosts[0], "localhost:9000");
355        Ok(())
356    }
357
358    #[test]
359    fn test_configuration() -> Result<()> {
360        let url =
361            Url::parse("tcp://localhost/db1?query_block_timeout=300&ping_timeout=110ms&query_timeout=25s&compression=lz4")?;
362        let config = Options::new(url)?;
363
364        assert_eq!(config.addr[0], String::from("localhost:9000"));
365        assert_eq!(
366            config.compression,
367            CompressionMethod::LZ4,
368            "compression url parameter"
369        );
370        assert_eq!(config.query_timeout, Duration::from_secs(25));
371        assert_eq!(config.ping_timeout, Duration::from_millis(110));
372        assert_eq!(config.query_block_timeout, Duration::from_secs(300));
373        assert_eq!(config.database, "db1");
374
375        let url = Url::parse(
376            "tcp://host1:9001/db2?ping_timeout=110ms&query_timeout=25s&compression=lz4",
377        )?;
378
379        assert_eq!(url.host_str(), Some("host1"));
380        assert_eq!(url.port(), Some(9001));
381
382        let config = Options::new(url)?;
383        assert_eq!(config.addr[0], String::from("host1:9001"));
384
385        let url = Url::parse(
386            "tcp://host1,host2:9001/db2?ping_timeout=110ms&query_timeout=25s&compression=lz4",
387        )?;
388
389        assert_eq!(url.host_str(), Some("host1,host2"));
390
391        let url =
392            Url::parse("tcp://host1:9001/db2?ping_timeout=ms&query_timeout=25s&compression=lz4")?;
393        assert!(Options::new(url).is_err());
394
395        let url =
396            Url::parse("tcp://host1:9001/db2?ping_timeout=1ms&query_timeout=25s&compression=zlib")?;
397        assert!(Options::new(url).is_err());
398
399        let url = Url::parse(
400            "tcp://host1:9001/db2?ping_timeout=1ms&query_timeout=25s&pool_min=11&pool_max=10",
401        )?;
402        let url2 = url.clone();
403        let option = Options::new(url);
404        assert!(option.is_ok());
405        assert!(Pool::create(url2).is_err());
406
407        Ok(())
408    }
409}