1use crate::errors::UrlError;
2use std::convert::TryFrom;
3use std::fmt;
4use std::{borrow::Cow, str::FromStr, time::Duration};
5use url::Url;
6
7type Result<T> = std::result::Result<T, UrlError>;
8
9const DEFAULT_MIN_POOL_SIZE: u16 = 2;
10const DEFAULT_MAX_POOL_SIZE: u16 = 10;
11
12#[derive(Copy, Clone, PartialEq, Debug)]
13pub enum CompressionMethod {
16 None,
17 LZ4,
18}
19
20impl CompressionMethod {
21 #[inline(always)]
22 pub fn is_none(&self) -> bool {
23 matches!(self, CompressionMethod::None)
24 }
25}
26
27#[derive(Clone)]
29pub struct Options {
30 pub(crate) addr: Vec<String>,
32 pub(crate) database: String,
34 pub(crate) username: String,
36 pub(crate) password: String,
38 pub(crate) compression: CompressionMethod,
40 pub(crate) pool_min: u16,
42 pub(crate) pool_max: u16,
44 pub(crate) keepalive: Option<Duration>,
48 pub(crate) ping_before_query: bool,
50 pub(crate) ping_timeout: Duration,
52 pub(crate) connection_timeout: Duration,
54 pub(crate) query_timeout: Duration,
56 pub(crate) query_block_timeout: Duration,
58 pub(crate) insert_timeout: Duration,
60 pub(crate) execute_timeout: Duration,
62 #[cfg(feature = "tls")]
64 pub(crate) secure: bool,
65 #[cfg(feature = "tls")]
67 pub(crate) skip_verify: bool,
68
69 pub(crate) readonly: u8,
75
76 pub(crate) send_retries: u8,
78
79 pub(crate) retry_timeout: Duration,
81}
82
83fn parse_param<'a, F, T, E>(param: Cow<'a, str>, value: Cow<'a, str>, parse: F) -> Result<T>
85where
86 F: Fn(&str) -> std::result::Result<T, E>,
87{
88 match parse(value.as_ref()) {
89 Ok(value) => Ok(value),
90 Err(_) => Err(UrlError::InvalidParamValue {
91 param: param.into(),
92 value: value.into(),
93 }),
94 }
95}
96
97fn get_database_from_url(url: &Url) -> Result<Option<&str>> {
98 match url.path_segments() {
99 None => Ok(None),
100 Some(mut segments) => {
101 let head = segments.next();
102
103 if segments.next().is_some() {
104 return Err(UrlError::Invalid);
105 }
106
107 match head {
108 Some(database) if !database.is_empty() => Ok(Some(database)),
109 _ => Ok(None),
110 }
111 }
112 }
113}
114
115fn parse_duration(source: &str) -> Result<Duration> {
116 let (num, unit) = match source.find(|c: char| !c.is_digit(10)) {
117 Some(pos) if pos > 0 => (u64::from_str(&source[0..pos]), &source[pos..]),
118 None => (u64::from_str(source), "s"),
119 _ => {
120 return Err(UrlError::Invalid);
121 }
122 };
123
124 let num = match num {
125 Ok(value) => value,
126 Err(_) => return Err(UrlError::Invalid),
127 };
128
129 match unit {
130 "s" => Ok(Duration::from_secs(num)),
131 "ms" => Ok(Duration::from_millis(num)),
132 _ => Err(UrlError::Invalid),
133 }
134}
135
136fn parse_opt_duration(source: &str) -> Result<Option<Duration>> {
137 if source == "none" {
138 return Ok(None);
139 }
140
141 let duration = parse_duration(source)?;
142 Ok(Some(duration))
143}
144
145fn parse_u8(source: &str) -> Result<u8> {
146 let duration: u8 = match source.parse() {
147 Ok(value) => value,
148 Err(_) => return Err(UrlError::Invalid),
149 };
150
151 Ok(duration)
152}
153
154fn parse_compression(source: &str) -> Result<CompressionMethod> {
155 match source {
156 "none" => Ok(CompressionMethod::None),
157 "lz4" => Ok(CompressionMethod::LZ4),
158 _ => Err(UrlError::Invalid),
159 }
160}
161
162impl Options {
163 fn new(url: Url) -> Result<Options> {
164 let defport = match url.scheme() {
165 "tcp" => 9000,
166 "tls" => 9009,
167 _ => {
168 return Err(UrlError::UnsupportedScheme {
169 scheme: url.scheme().to_string(),
170 })
171 }
172 };
173
174 let mut options = crate::DEF_OPTIONS.clone(); let user = url.username();
177 if !user.is_empty() {
178 options.username = user.into();
179 }
180
181 if let Some(password) = url.password() {
182 options.password = password.into();
183 }
184
185 let port = url.port().unwrap_or(defport);
186 if url.cannot_be_a_base() || !url.has_host() {
187 return Err(UrlError::Invalid);
188 }
189
190 options.addr.clear();
191 options.addr.push(format!(
192 "{}:{}",
193 url.host_str().unwrap_or("localhost"),
194 port
195 ));
196
197 if let Some(database) = get_database_from_url(&url)? {
198 options.database = database.into();
199 }
200
201 for (key, value) in url.query_pairs() {
202 match key.as_ref() {
203 "pool_min" => options.pool_min = parse_param(key, value, u16::from_str)?,
204 "pool_max" => options.pool_max = parse_param(key, value, u16::from_str)?,
205 "keepalive" => options.keepalive = parse_param(key, value, parse_opt_duration)?,
206 "ping_before_query" => {
207 options.ping_before_query = parse_param(key, value, bool::from_str)?
208 }
209 "send_retries" => options.send_retries = parse_param(key, value, u8::from_str)?,
210 "retry_timeout" => options.retry_timeout = parse_param(key, value, parse_duration)?,
211 "ping_timeout" => options.ping_timeout = parse_param(key, value, parse_duration)?,
212 "connection_timeout" => {
213 options.connection_timeout = parse_param(key, value, parse_duration)?
214 }
215 "query_timeout" => options.query_timeout = parse_param(key, value, parse_duration)?,
216 "query_block_timeout" => {
217 options.query_block_timeout = parse_param(key, value, parse_duration)?
218 }
219 "insert_timeout" => {
220 options.insert_timeout = parse_param(key, value, parse_duration)?
221 }
222 "execute_timeout" => {
223 options.execute_timeout = parse_param(key, value, parse_duration)?
224 }
225 "compression" => options.compression = parse_param(key, value, parse_compression)?,
226 #[cfg(feature = "tls")]
227 "secure" => options.secure = parse_param(key, value, bool::from_str)?,
228 #[cfg(feature = "tls")]
229 "skip_verify" => options.skip_verify = parse_param(key, value, bool::from_str)?,
230 "readonly" => options.readonly = parse_param(key, value, parse_u8)?,
231 "host" => options.addr.push(value.into_owned()),
232 _ => return Err(UrlError::UnknownParameter { param: key.into() }),
233 };
234 }
235
236 Ok(options)
237 }
238
239 pub fn set_compression(mut self, compression: CompressionMethod) -> Self {
240 self.compression = compression;
241 self
242 }
243
244 pub fn set_timeout(mut self, timeout: Duration) -> Self {
245 self.ping_timeout = timeout;
246 self.execute_timeout = timeout;
247 self.query_timeout = timeout;
248 self.insert_timeout = timeout;
249 self
250 }
251}
252
253impl fmt::Debug for Options {
254 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
255 f.debug_struct("Options")
256 .field("addr", &self.addr)
257 .field("database", &self.database)
258 .field("compression", &self.compression)
259 .field("pool_min", &self.pool_min)
260 .field("pool_max", &self.pool_max)
261 .field("keepalive", &self.keepalive)
262 .field("ping_before_query", &self.ping_before_query)
263 .field("send_retries", &self.send_retries)
264 .field("retry_timeout", &self.retry_timeout)
265 .field("ping_timeout", &self.ping_timeout)
266 .field("connection_timeout", &self.connection_timeout)
267 .field("query_timeout", &self.query_timeout)
268 .field("query_block_timeout", &self.query_block_timeout)
269 .field("insert_timeout", &self.insert_timeout)
270 .field("execute_timeout", &self.execute_timeout)
271 .field("readonly", &self.readonly)
272 .finish()
273 }
274}
275
276impl Default for Options {
277 fn default() -> Self {
278 let default_duration = Duration::from_secs(180);
279 Self {
280 addr: vec!["localhost:9000".into()],
281 database: "default".into(),
282 username: "default".into(),
283 password: "".into(),
284 compression: CompressionMethod::LZ4,
285 pool_min: DEFAULT_MIN_POOL_SIZE,
286 pool_max: DEFAULT_MAX_POOL_SIZE,
287 keepalive: None,
288 ping_before_query: true,
289 send_retries: 3,
290 retry_timeout: Duration::from_secs(1),
291 ping_timeout: Duration::from_millis(700),
292 connection_timeout: Duration::from_millis(500),
293 query_timeout: default_duration,
294 query_block_timeout: default_duration,
295 insert_timeout: default_duration,
296 execute_timeout: default_duration,
297 #[cfg(feature = "tls")]
298 secure: false,
299 #[cfg(feature = "tls")]
300 skip_verify: false,
301 readonly: 0,
302 }
303 }
304}
305
306impl TryFrom<Url> for Options {
307 type Error = UrlError;
308 fn try_from(value: Url) -> Result<Self> {
309 Options::new(value)
310 }
311}
312
313impl TryFrom<&str> for Options {
318 type Error = UrlError;
319
320 fn try_from(value: &str) -> Result<Self> {
321 let url = Url::parse(value)?;
322 Options::new(url)
323 }
324}
325
326impl TryFrom<String> for Options {
327 type Error = UrlError;
328
329 fn try_from(value: String) -> Result<Self> {
330 let url = Url::parse(value.as_ref())?;
331 Options::new(url)
332 }
333}
334
335impl Options {
336 pub(crate) fn take_addr(&mut self) -> Vec<String> {
337 std::mem::replace(&mut self.addr, Vec::new())
338 }
339}
340
341#[cfg(test)]
342mod test {
343 use super::*;
344 use crate::pool::Pool;
345
346 #[test]
347 fn test_default_config() -> Result<()> {
348 let pool = Pool::create("tcp://localhost?ping_timeout=1ms").unwrap();
349
350 assert_eq!(pool.options().database, "default");
351 assert_eq!(pool.options().compression, CompressionMethod::LZ4);
352 assert_eq!(pool.options().username, "default");
353 assert_eq!(pool.options().password, "");
354 assert_eq!(pool.inner.hosts[0], "localhost:9000");
355 Ok(())
356 }
357
358 #[test]
359 fn test_configuration() -> Result<()> {
360 let url =
361 Url::parse("tcp://localhost/db1?query_block_timeout=300&ping_timeout=110ms&query_timeout=25s&compression=lz4")?;
362 let config = Options::new(url)?;
363
364 assert_eq!(config.addr[0], String::from("localhost:9000"));
365 assert_eq!(
366 config.compression,
367 CompressionMethod::LZ4,
368 "compression url parameter"
369 );
370 assert_eq!(config.query_timeout, Duration::from_secs(25));
371 assert_eq!(config.ping_timeout, Duration::from_millis(110));
372 assert_eq!(config.query_block_timeout, Duration::from_secs(300));
373 assert_eq!(config.database, "db1");
374
375 let url = Url::parse(
376 "tcp://host1:9001/db2?ping_timeout=110ms&query_timeout=25s&compression=lz4",
377 )?;
378
379 assert_eq!(url.host_str(), Some("host1"));
380 assert_eq!(url.port(), Some(9001));
381
382 let config = Options::new(url)?;
383 assert_eq!(config.addr[0], String::from("host1:9001"));
384
385 let url = Url::parse(
386 "tcp://host1,host2:9001/db2?ping_timeout=110ms&query_timeout=25s&compression=lz4",
387 )?;
388
389 assert_eq!(url.host_str(), Some("host1,host2"));
390
391 let url =
392 Url::parse("tcp://host1:9001/db2?ping_timeout=ms&query_timeout=25s&compression=lz4")?;
393 assert!(Options::new(url).is_err());
394
395 let url =
396 Url::parse("tcp://host1:9001/db2?ping_timeout=1ms&query_timeout=25s&compression=zlib")?;
397 assert!(Options::new(url).is_err());
398
399 let url = Url::parse(
400 "tcp://host1:9001/db2?ping_timeout=1ms&query_timeout=25s&pool_min=11&pool_max=10",
401 )?;
402 let url2 = url.clone();
403 let option = Options::new(url);
404 assert!(option.is_ok());
405 assert!(Pool::create(url2).is_err());
406
407 Ok(())
408 }
409}