1use crate::{Error, Result};
2#[cfg(feature = "tls")]
3use native_tls::{Certificate, Identity, Protocol, TlsConnector, TlsConnectorBuilder};
4use std::{str::FromStr, time::Duration};
5use url::Url;
6
7const DEFAULT_PORT: u16 = 6379;
8const DEFAULT_DATABASE: usize = 0;
9
10type Uri<'a> = (
11 &'a str,
12 Option<&'a str>,
13 Option<&'a str>,
14 Vec<(&'a str, u16)>,
15 Vec<&'a str>,
16);
17
18#[derive(Clone, Default)]
19pub struct Config {
20 pub server: ServerConfig,
21 pub username: Option<String>,
22 pub password: Option<String>,
23 pub database: usize,
24 #[cfg(feature = "tls")]
25 pub tls_config: Option<TlsConfig>,
26}
27
28impl FromStr for Config {
29 type Err = Error;
30
31 fn from_str(str: &str) -> Result<Config> {
33 if let Some(config) = Self::parse_uri(str) {
34 Ok(config)
35 } else if let Some(addr) = Self::parse_addr(str) {
36 addr.into_config()
37 } else {
38 Err(Error::Config(format!("Cannot parse config from {str}")))
39 }
40 }
41}
42
43impl Config {
44 pub fn from_uri(uri: Url) -> Result<Config> {
46 Self::from_str(uri.as_str())
47 }
48
49 fn parse_addr(str: &str) -> Option<(&str, u16)> {
51 let mut iter = str.split(':');
52
53 match (iter.next(), iter.next(), iter.next()) {
54 (Some(host), Some(port), None) => {
55 if let Ok(port) = port.parse::<u16>() {
56 Some((host, port))
57 } else {
58 None
59 }
60 }
61 (Some(host), None, None) => Some((host, DEFAULT_PORT)),
62 _ => None,
63 }
64 }
65
66 fn parse_uri(uri: &str) -> Option<Config> {
67 let (scheme, username, password, hosts, path_segments) = Self::break_down_uri(uri)?;
68 let mut hosts = hosts;
69 let mut path_segments = path_segments.into_iter();
70
71 enum ServerType {
72 Standalone,
73 Sentinel,
74 Cluster,
75 }
76
77 #[cfg(feature = "tls")]
78 let (tls_config, server_type) = match scheme {
79 "redis" => (None, ServerType::Standalone),
80 "rediss" => (Some(TlsConfig::default()), ServerType::Standalone),
81 "redis+sentinel" => (None, ServerType::Sentinel),
82 "rediss+sentinel" => (Some(TlsConfig::default()), ServerType::Sentinel),
83 "redis+cluster" => (None, ServerType::Cluster),
84 "rediss+cluster" => (Some(TlsConfig::default()), ServerType::Cluster),
85 _ => {
86 return None;
87 }
88 };
89
90 #[cfg(not(feature = "tls"))]
91 let server_type = match scheme {
92 "redis" => ServerType::Standalone,
93 "redis+sentinel" => ServerType::Sentinel,
94 "redis+cluster" => ServerType::Cluster,
95 _ => {
96 return None;
97 }
98 };
99
100 let server = match server_type {
101 ServerType::Standalone => {
102 if hosts.len() > 1 {
103 return None;
104 } else {
105 let (host, port) = hosts.pop()?;
106 ServerConfig::Standalone {
107 host: host.to_owned(),
108 port,
109 }
110 }
111 }
112 ServerType::Sentinel => {
113 let instances = hosts
114 .iter()
115 .map(|(host, port)| ((*host).to_owned(), *port))
116 .collect::<Vec<_>>();
117
118 let service_name = match path_segments.next() {
119 Some(service_name) => service_name.to_owned(),
120 None => {
121 return None;
122 }
123 };
124
125 ServerConfig::Sentinel(SentinelConfig {
126 instances,
127 service_name,
128 ..Default::default()
129 })
130 }
131 ServerType::Cluster => {
132 let nodes = hosts
133 .iter()
134 .map(|(host, port)| ((*host).to_owned(), *port))
135 .collect::<Vec<_>>();
136
137 ServerConfig::Cluster(ClusterConfig { nodes })
138 }
139 };
140
141 let database = match path_segments.next() {
142 Some(database) => match database.parse::<usize>() {
143 Ok(database) => database,
144 Err(_) => {
145 return None;
146 }
147 },
148 None => DEFAULT_DATABASE,
149 };
150
151 Some(Config {
152 server,
153 username: username.map(|u| u.to_owned()),
154 password: password.map(|p| p.to_owned()),
155 database,
156 #[cfg(feature = "tls")]
157 tls_config,
158 })
159 }
160
161 fn break_down_uri(uri: &str) -> Option<Uri> {
163 let end_of_scheme = match uri.find("://") {
164 Some(index) => index,
165 None => {
166 return None;
167 }
168 };
169
170 let scheme = &uri[..end_of_scheme];
171
172 let after_scheme = &uri[end_of_scheme + 3..];
173
174 let (before_query, _query) = match after_scheme.find('?') {
175 Some(index) => match Self::exclusive_split_at(after_scheme, index) {
176 (Some(before_query), after_query) => (before_query, after_query),
177 _ => {
178 return None;
179 }
180 },
181 None => (after_scheme, None),
182 };
183
184 let (authority, path) = match after_scheme.find('/') {
185 Some(index) => match Self::exclusive_split_at(before_query, index) {
186 (Some(authority), path) => (authority, path),
187 _ => {
188 return None;
189 }
190 },
191 None => (after_scheme, None),
192 };
193
194 let (user_info, hosts) = match authority.rfind('@') {
195 Some(index) => {
196 let (user_info, hosts) = Self::exclusive_split_at(authority, index);
199 match hosts {
200 Some(hosts) => (user_info, hosts),
201 None => {
202 return None;
204 }
205 }
206 }
207 None => (None, authority),
208 };
209
210 let (username, password) = match user_info {
211 Some(user_info) => match user_info.find(':') {
212 Some(index) => match Self::exclusive_split_at(user_info, index) {
213 (username, None) => (username, Some("")),
214 (username, password) => (username, password),
215 },
216 None => {
217 return None;
219 }
220 },
221 None => (None, None),
222 };
223
224 let hosts = hosts
225 .split(',')
226 .map(Self::parse_addr)
227 .collect::<Option<Vec<_>>>();
228 let hosts = hosts?;
229
230 let path_segments = match path {
231 Some(path) => path.split('/').collect::<Vec<_>>(),
232 None => Vec::new(),
233 };
234
235 Some((scheme, username, password, hosts, path_segments))
236 }
237
238 fn exclusive_split_at(s: &str, i: usize) -> (Option<&str>, Option<&str>) {
241 let (l, r) = s.split_at(i);
242
243 let lout = if !l.is_empty() { Some(l) } else { None };
244 let rout = if r.len() > 1 { Some(&r[1..]) } else { None };
245
246 (lout, rout)
247 }
248}
249
250impl ToString for Config {
251 fn to_string(&self) -> String {
252 #[cfg(feature = "tls")]
253 let mut s = if self.tls_config.is_some() {
254 match &self.server {
255 ServerConfig::Standalone { host: _, port: _ } => "rediss://",
256 ServerConfig::Sentinel(_) => "rediss+sentinel://",
257 ServerConfig::Cluster(_) => "rediss+cluster://",
258 }
259 } else {
260 match &self.server {
261 ServerConfig::Standalone { host: _, port: _ } => "redis://",
262 ServerConfig::Sentinel(_) => "redis+sentinel://",
263 ServerConfig::Cluster(_) => "redis+cluster://",
264 }
265 }
266 .to_owned();
267
268 #[cfg(not(feature = "tls"))]
269 let mut s = match &self.server {
270 ServerConfig::Standalone { host: _, port: _ } => "redis://",
271 ServerConfig::Sentinel(_) => "redis+sentinel://",
272 ServerConfig::Cluster(_) => "redis+cluster://",
273 }
274 .to_owned();
275
276 if let Some(username) = &self.username {
277 s.push_str(username);
278 }
279
280 if let Some(password) = &self.password {
281 s.push(':');
282 s.push_str(password);
283 s.push('@');
284 }
285
286 match &self.server {
287 ServerConfig::Standalone { host, port } => {
288 s.push_str(host);
289 s.push(':');
290 s.push_str(&port.to_string());
291 }
292 ServerConfig::Sentinel(SentinelConfig {
293 instances,
294 service_name,
295 wait_beetween_failures: _,
296 }) => {
297 s.push_str(
298 &instances
299 .iter()
300 .map(|(host, port)| format!("{host}:{port}"))
301 .collect::<Vec<String>>()
302 .join(","),
303 );
304 s.push('/');
305 s.push_str(service_name);
306 }
307 ServerConfig::Cluster(ClusterConfig { nodes }) => {
308 s.push_str(
309 &nodes
310 .iter()
311 .map(|(host, port)| format!("{host}:{port}"))
312 .collect::<Vec<String>>()
313 .join(","),
314 );
315 }
316 }
317
318 if self.database > 0 {
319 s.push('/');
320 s.push_str(&self.database.to_string());
321 }
322
323 s
324 }
325}
326
327#[derive(Clone)]
329pub enum ServerConfig {
330 Standalone {
332 host: String,
333 port: u16,
334 },
335 Sentinel(SentinelConfig),
336 Cluster(ClusterConfig),
337}
338
339impl Default for ServerConfig {
340 fn default() -> Self {
341 ServerConfig::Standalone {
342 host: "127.0.0.1".to_owned(),
343 port: 6379,
344 }
345 }
346}
347
348#[derive(Clone)]
350pub struct SentinelConfig {
351 pub instances: Vec<(String, u16)>,
353
354 pub service_name: String,
356
357 pub wait_beetween_failures: Duration,
359}
360
361impl Default for SentinelConfig {
362 fn default() -> Self {
363 Self {
364 instances: Default::default(),
365 service_name: Default::default(),
366 wait_beetween_failures: Duration::from_millis(250),
367 }
368 }
369}
370
371#[derive(Clone, Default)]
373pub struct ClusterConfig {
374 pub nodes: Vec<(String, u16)>,
376}
377
378#[cfg(feature = "tls")]
382#[derive(Clone)]
383pub struct TlsConfig {
384 identity: Option<Identity>,
385 root_certificates: Option<Vec<Certificate>>,
386 min_protocol_version: Option<Protocol>,
387 max_protocol_version: Option<Protocol>,
388 disable_built_in_roots: bool,
389 danger_accept_invalid_certs: bool,
390 danger_accept_invalid_hostnames: bool,
391 use_sni: bool,
392}
393
394#[cfg(feature = "tls")]
395impl Default for TlsConfig {
396 fn default() -> Self {
397 Self {
398 identity: None,
399 root_certificates: None,
400 min_protocol_version: Some(Protocol::Tlsv10),
401 max_protocol_version: None,
402 disable_built_in_roots: false,
403 danger_accept_invalid_certs: false,
404 danger_accept_invalid_hostnames: false,
405 use_sni: true,
406 }
407 }
408}
409
410#[cfg(feature = "tls")]
411impl TlsConfig {
412 pub fn identity(&mut self, identity: Identity) -> &mut Self {
413 self.identity = Some(identity);
414 self
415 }
416
417 pub fn root_certificates(&mut self, root_certificates: Vec<Certificate>) -> &mut Self {
418 self.root_certificates = Some(root_certificates);
419 self
420 }
421
422 pub fn min_protocol_version(&mut self, min_protocol_version: Protocol) -> &mut Self {
423 self.min_protocol_version = Some(min_protocol_version);
424 self
425 }
426
427 pub fn max_protocol_version(&mut self, max_protocol_version: Protocol) -> &mut Self {
428 self.max_protocol_version = Some(max_protocol_version);
429 self
430 }
431
432 pub fn disable_built_in_roots(&mut self, disable_built_in_roots: bool) -> &mut Self {
433 self.disable_built_in_roots = disable_built_in_roots;
434 self
435 }
436
437 pub fn danger_accept_invalid_certs(&mut self, danger_accept_invalid_certs: bool) -> &mut Self {
438 self.danger_accept_invalid_certs = danger_accept_invalid_certs;
439 self
440 }
441
442 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
443 self.use_sni = use_sni;
444 self
445 }
446
447 pub fn danger_accept_invalid_hostnames(
448 &mut self,
449 danger_accept_invalid_hostnames: bool,
450 ) -> &mut Self {
451 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
452 self
453 }
454
455 pub fn into_tls_connector_builder(&self) -> TlsConnectorBuilder {
456 let mut builder = TlsConnector::builder();
457
458 if let Some(root_certificates) = &self.root_certificates {
459 for root_certificate in root_certificates {
460 builder.add_root_certificate(root_certificate.clone());
461 }
462 }
463
464 builder.min_protocol_version(self.min_protocol_version);
465 builder.max_protocol_version(self.max_protocol_version);
466 builder.disable_built_in_roots(self.disable_built_in_roots);
467 builder.danger_accept_invalid_certs(self.danger_accept_invalid_certs);
468 builder.danger_accept_invalid_hostnames(self.danger_accept_invalid_hostnames);
469 builder.use_sni(self.use_sni);
470
471 builder
472 }
473}
474
475pub trait IntoConfig {
476 fn into_config(self) -> Result<Config>;
477}
478
479impl IntoConfig for Config {
480 fn into_config(self) -> Result<Config> {
481 Ok(self)
482 }
483}
484
485impl<T: Into<String>> IntoConfig for (T, u16) {
486 fn into_config(self) -> Result<Config> {
487 Ok(Config {
488 server: ServerConfig::Standalone {
489 host: self.0.into(),
490 port: self.1,
491 },
492 username: None,
493 password: None,
494 database: 0,
495 #[cfg(feature = "tls")]
496 tls_config: None,
497 })
498 }
499}
500
501impl IntoConfig for &str {
502 fn into_config(self) -> Result<Config> {
503 Config::from_str(self)
504 }
505}
506
507impl IntoConfig for String {
508 fn into_config(self) -> Result<Config> {
509 Config::from_str(&self)
510 }
511}
512
513impl IntoConfig for Url {
514 fn into_config(self) -> Result<Config> {
515 Config::from_uri(self)
516 }
517}