redis_driver/clients/
config.rs

1use crate::{Error, Result};
2#[cfg(feature = "tls")]
3use native_tls::{Certificate, Identity, Protocol, TlsConnector, TlsConnectorBuilder};
4use std::{str::FromStr, time::Duration};
5use url::Url;
6
7const DEFAULT_PORT: u16 = 6379;
8const DEFAULT_DATABASE: usize = 0;
9
10type Uri<'a> = (
11    &'a str,
12    Option<&'a str>,
13    Option<&'a str>,
14    Vec<(&'a str, u16)>,
15    Vec<&'a str>,
16);
17
18#[derive(Clone, Default)]
19pub struct Config {
20    pub server: ServerConfig,
21    pub username: Option<String>,
22    pub password: Option<String>,
23    pub database: usize,
24    #[cfg(feature = "tls")]
25    pub tls_config: Option<TlsConfig>,
26}
27
28impl FromStr for Config {
29    type Err = Error;
30
31    /// Build a config from an URI or a standard address format `host`:`port`
32    fn from_str(str: &str) -> Result<Config> {
33        if let Some(config) = Self::parse_uri(str) {
34            Ok(config)
35        } else if let Some(addr) = Self::parse_addr(str) {
36            addr.into_config()
37        } else {
38            Err(Error::Config(format!("Cannot parse config from {str}")))
39        }
40    }
41}
42
43impl Config {
44    /// Build a config from an URI in the format `redis[s]://[[username]:password@]host[:port]/[database]`
45    pub fn from_uri(uri: Url) -> Result<Config> {
46        Self::from_str(uri.as_str())
47    }
48
49    /// Parse address in the standard formart `host`:`port`
50    fn parse_addr(str: &str) -> Option<(&str, u16)> {
51        let mut iter = str.split(':');
52
53        match (iter.next(), iter.next(), iter.next()) {
54            (Some(host), Some(port), None) => {
55                if let Ok(port) = port.parse::<u16>() {
56                    Some((host, port))
57                } else {
58                    None
59                }
60            }
61            (Some(host), None, None) => Some((host, DEFAULT_PORT)),
62            _ => None,
63        }
64    }
65
66    fn parse_uri(uri: &str) -> Option<Config> {
67        let (scheme, username, password, hosts, path_segments) = Self::break_down_uri(uri)?;
68        let mut hosts = hosts;
69        let mut path_segments = path_segments.into_iter();
70
71        enum ServerType {
72            Standalone,
73            Sentinel,
74            Cluster,
75        }
76
77        #[cfg(feature = "tls")]
78        let (tls_config, server_type) = match scheme {
79            "redis" => (None, ServerType::Standalone),
80            "rediss" => (Some(TlsConfig::default()), ServerType::Standalone),
81            "redis+sentinel" => (None, ServerType::Sentinel),
82            "rediss+sentinel" => (Some(TlsConfig::default()), ServerType::Sentinel),
83            "redis+cluster" => (None, ServerType::Cluster),
84            "rediss+cluster" => (Some(TlsConfig::default()), ServerType::Cluster),
85            _ => {
86                return None;
87            }
88        };
89
90        #[cfg(not(feature = "tls"))]
91        let server_type = match scheme {
92            "redis" => ServerType::Standalone,
93            "redis+sentinel" => ServerType::Sentinel,
94            "redis+cluster" => ServerType::Cluster,
95            _ => {
96                return None;
97            }
98        };
99
100        let server = match server_type {
101            ServerType::Standalone => {
102                if hosts.len() > 1 {
103                    return None;
104                } else {
105                    let (host, port) = hosts.pop()?;
106                    ServerConfig::Standalone {
107                        host: host.to_owned(),
108                        port,
109                    }
110                }
111            }
112            ServerType::Sentinel => {
113                let instances = hosts
114                    .iter()
115                    .map(|(host, port)| ((*host).to_owned(), *port))
116                    .collect::<Vec<_>>();
117
118                let service_name = match path_segments.next() {
119                    Some(service_name) => service_name.to_owned(),
120                    None => {
121                        return None;
122                    }
123                };
124
125                ServerConfig::Sentinel(SentinelConfig {
126                    instances,
127                    service_name,
128                    ..Default::default()
129                })
130            }
131            ServerType::Cluster => {
132                let nodes = hosts
133                    .iter()
134                    .map(|(host, port)| ((*host).to_owned(), *port))
135                    .collect::<Vec<_>>();
136
137                ServerConfig::Cluster(ClusterConfig { nodes })
138            }
139        };
140
141        let database = match path_segments.next() {
142            Some(database) => match database.parse::<usize>() {
143                Ok(database) => database,
144                Err(_) => {
145                    return None;
146                }
147            },
148            None => DEFAULT_DATABASE,
149        };
150
151        Some(Config {
152            server,
153            username: username.map(|u| u.to_owned()),
154            password: password.map(|p| p.to_owned()),
155            database,
156            #[cfg(feature = "tls")]
157            tls_config,
158        })
159    }
160
161    /// break down an uri in a tuple (scheme, username, password, hosts, path_segments)
162    fn break_down_uri(uri: &str) -> Option<Uri> {
163        let end_of_scheme = match uri.find("://") {
164            Some(index) => index,
165            None => {
166                return None;
167            }
168        };
169
170        let scheme = &uri[..end_of_scheme];
171
172        let after_scheme = &uri[end_of_scheme + 3..];
173
174        let (before_query, _query) = match after_scheme.find('?') {
175            Some(index) => match Self::exclusive_split_at(after_scheme, index) {
176                (Some(before_query), after_query) => (before_query, after_query),
177                _ => {
178                    return None;
179                }
180            },
181            None => (after_scheme, None),
182        };
183
184        let (authority, path) = match after_scheme.find('/') {
185            Some(index) => match Self::exclusive_split_at(before_query, index) {
186                (Some(authority), path) => (authority, path),
187                _ => {
188                    return None;
189                }
190            },
191            None => (after_scheme, None),
192        };
193
194        let (user_info, hosts) = match authority.rfind('@') {
195            Some(index) => {
196                // if '@' is in the host section, it MUST be interpreted as a request for
197                // authentication, even if the credentials are empty.
198                let (user_info, hosts) = Self::exclusive_split_at(authority, index);
199                match hosts {
200                    Some(hosts) => (user_info, hosts),
201                    None => {
202                        // missing hosts
203                        return None;
204                    }
205                }
206            }
207            None => (None, authority),
208        };
209
210        let (username, password) = match user_info {
211            Some(user_info) => match user_info.find(':') {
212                Some(index) => match Self::exclusive_split_at(user_info, index) {
213                    (username, None) => (username, Some("")),
214                    (username, password) => (username, password),
215                },
216                None => {
217                    // username without password is not accepted
218                    return None;
219                }
220            },
221            None => (None, None),
222        };
223
224        let hosts = hosts
225            .split(',')
226            .map(Self::parse_addr)
227            .collect::<Option<Vec<_>>>();
228        let hosts = hosts?;
229
230        let path_segments = match path {
231            Some(path) => path.split('/').collect::<Vec<_>>(),
232            None => Vec::new(),
233        };
234
235        Some((scheme, username, password, hosts, path_segments))
236    }
237
238    /// Splits a string into a section before a given index and a section exclusively after the index.
239    /// Empty portions are returned as `None`.
240    fn exclusive_split_at(s: &str, i: usize) -> (Option<&str>, Option<&str>) {
241        let (l, r) = s.split_at(i);
242
243        let lout = if !l.is_empty() { Some(l) } else { None };
244        let rout = if r.len() > 1 { Some(&r[1..]) } else { None };
245
246        (lout, rout)
247    }
248}
249
250impl ToString for Config {
251    fn to_string(&self) -> String {
252        #[cfg(feature = "tls")]
253        let mut s = if self.tls_config.is_some() {
254            match &self.server {
255                ServerConfig::Standalone { host: _, port: _ } => "rediss://",
256                ServerConfig::Sentinel(_) => "rediss+sentinel://",
257                ServerConfig::Cluster(_) => "rediss+cluster://",
258            }
259        } else {
260            match &self.server {
261                ServerConfig::Standalone { host: _, port: _ } => "redis://",
262                ServerConfig::Sentinel(_) => "redis+sentinel://",
263                ServerConfig::Cluster(_) => "redis+cluster://",
264            }
265        }
266        .to_owned();
267
268        #[cfg(not(feature = "tls"))]
269        let mut s = match &self.server {
270            ServerConfig::Standalone { host: _, port: _ } => "redis://",
271            ServerConfig::Sentinel(_) => "redis+sentinel://",
272            ServerConfig::Cluster(_) => "redis+cluster://",
273        }
274        .to_owned();
275
276        if let Some(username) = &self.username {
277            s.push_str(username);
278        }
279
280        if let Some(password) = &self.password {
281            s.push(':');
282            s.push_str(password);
283            s.push('@');
284        }
285
286        match &self.server {
287            ServerConfig::Standalone { host, port } => {
288                s.push_str(host);
289                s.push(':');
290                s.push_str(&port.to_string());
291            }
292            ServerConfig::Sentinel(SentinelConfig {
293                instances,
294                service_name,
295                wait_beetween_failures: _,
296            }) => {
297                s.push_str(
298                    &instances
299                        .iter()
300                        .map(|(host, port)| format!("{host}:{port}"))
301                        .collect::<Vec<String>>()
302                        .join(","),
303                );
304                s.push('/');
305                s.push_str(service_name);
306            }
307            ServerConfig::Cluster(ClusterConfig { nodes }) => {
308                s.push_str(
309                    &nodes
310                        .iter()
311                        .map(|(host, port)| format!("{host}:{port}"))
312                        .collect::<Vec<String>>()
313                        .join(","),
314                );
315            }
316        }
317
318        if self.database > 0 {
319            s.push('/');
320            s.push_str(&self.database.to_string());
321        }
322
323        s
324    }
325}
326
327/// Configuration for connecting to a Redis server
328#[derive(Clone)]
329pub enum ServerConfig {
330    /// Connection to a simple server (no master-replica, no cluster)
331    Standalone {
332        host: String,
333        port: u16,
334    },
335    Sentinel(SentinelConfig),
336    Cluster(ClusterConfig),
337}
338
339impl Default for ServerConfig {
340    fn default() -> Self {
341        ServerConfig::Standalone {
342            host: "127.0.0.1".to_owned(),
343            port: 6379,
344        }
345    }
346}
347
348/// Configuration for connecting to a Redis via Sentinel
349#[derive(Clone)]
350pub struct SentinelConfig {
351    /// An array of `(host, port)` tuples for each known sentinel instance.
352    pub instances: Vec<(String, u16)>,
353
354    /// The service name
355    pub service_name: String,
356
357    /// Waiting time after failing before connecting to the next Sentinel instance (default 250ms).
358    pub wait_beetween_failures: Duration,
359}
360
361impl Default for SentinelConfig {
362    fn default() -> Self {
363        Self {
364            instances: Default::default(),
365            service_name: Default::default(),
366            wait_beetween_failures: Duration::from_millis(250),
367        }
368    }
369}
370
371/// Configuration for connecting to a Redis Cluster
372#[derive(Clone, Default)]
373pub struct ClusterConfig {
374    /// An array of `(host, port)` tuples for each known cluster node.
375    pub nodes: Vec<(String, u16)>,
376}
377
378/// Config for TLS.
379///
380/// See [TlsConnectorBuilder](https://docs.rs/tokio-native-tls/0.3.0/tokio_native_tls/native_tls/struct.TlsConnectorBuilder.html) documentation
381#[cfg(feature = "tls")]
382#[derive(Clone)]
383pub struct TlsConfig {
384    identity: Option<Identity>,
385    root_certificates: Option<Vec<Certificate>>,
386    min_protocol_version: Option<Protocol>,
387    max_protocol_version: Option<Protocol>,
388    disable_built_in_roots: bool,
389    danger_accept_invalid_certs: bool,
390    danger_accept_invalid_hostnames: bool,
391    use_sni: bool,
392}
393
394#[cfg(feature = "tls")]
395impl Default for TlsConfig {
396    fn default() -> Self {
397        Self {
398            identity: None,
399            root_certificates: None,
400            min_protocol_version: Some(Protocol::Tlsv10),
401            max_protocol_version: None,
402            disable_built_in_roots: false,
403            danger_accept_invalid_certs: false,
404            danger_accept_invalid_hostnames: false,
405            use_sni: true,
406        }
407    }
408}
409
410#[cfg(feature = "tls")]
411impl TlsConfig {
412    pub fn identity(&mut self, identity: Identity) -> &mut Self {
413        self.identity = Some(identity);
414        self
415    }
416
417    pub fn root_certificates(&mut self, root_certificates: Vec<Certificate>) -> &mut Self {
418        self.root_certificates = Some(root_certificates);
419        self
420    }
421
422    pub fn min_protocol_version(&mut self, min_protocol_version: Protocol) -> &mut Self {
423        self.min_protocol_version = Some(min_protocol_version);
424        self
425    }
426
427    pub fn max_protocol_version(&mut self, max_protocol_version: Protocol) -> &mut Self {
428        self.max_protocol_version = Some(max_protocol_version);
429        self
430    }
431
432    pub fn disable_built_in_roots(&mut self, disable_built_in_roots: bool) -> &mut Self {
433        self.disable_built_in_roots = disable_built_in_roots;
434        self
435    }
436
437    pub fn danger_accept_invalid_certs(&mut self, danger_accept_invalid_certs: bool) -> &mut Self {
438        self.danger_accept_invalid_certs = danger_accept_invalid_certs;
439        self
440    }
441
442    pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
443        self.use_sni = use_sni;
444        self
445    }
446
447    pub fn danger_accept_invalid_hostnames(
448        &mut self,
449        danger_accept_invalid_hostnames: bool,
450    ) -> &mut Self {
451        self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
452        self
453    }
454
455    pub fn into_tls_connector_builder(&self) -> TlsConnectorBuilder {
456        let mut builder = TlsConnector::builder();
457
458        if let Some(root_certificates) = &self.root_certificates {
459            for root_certificate in root_certificates {
460                builder.add_root_certificate(root_certificate.clone());
461            }
462        }
463
464        builder.min_protocol_version(self.min_protocol_version);
465        builder.max_protocol_version(self.max_protocol_version);
466        builder.disable_built_in_roots(self.disable_built_in_roots);
467        builder.danger_accept_invalid_certs(self.danger_accept_invalid_certs);
468        builder.danger_accept_invalid_hostnames(self.danger_accept_invalid_hostnames);
469        builder.use_sni(self.use_sni);
470
471        builder
472    }
473}
474
475pub trait IntoConfig {
476    fn into_config(self) -> Result<Config>;
477}
478
479impl IntoConfig for Config {
480    fn into_config(self) -> Result<Config> {
481        Ok(self)
482    }
483}
484
485impl<T: Into<String>> IntoConfig for (T, u16) {
486    fn into_config(self) -> Result<Config> {
487        Ok(Config {
488            server: ServerConfig::Standalone {
489                host: self.0.into(),
490                port: self.1,
491            },
492            username: None,
493            password: None,
494            database: 0,
495            #[cfg(feature = "tls")]
496            tls_config: None,
497        })
498    }
499}
500
501impl IntoConfig for &str {
502    fn into_config(self) -> Result<Config> {
503        Config::from_str(self)
504    }
505}
506
507impl IntoConfig for String {
508    fn into_config(self) -> Result<Config> {
509        Config::from_str(&self)
510    }
511}
512
513impl IntoConfig for Url {
514    fn into_config(self) -> Result<Config> {
515        Config::from_uri(self)
516    }
517}