1#[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
2use redis::TlsMode;
3use redis::{
4 Client, ErrorKind, RedisConnectionInfo, RedisError, RedisResult, cluster::ClusterClient,
5};
6
7#[derive(Clone)]
35pub enum UniversalClient {
36 Client(Client),
37 Cluster(ClusterClient),
38}
39
40impl UniversalClient {
41 pub async fn get_connection(&self) -> RedisResult<UniversalConnection> {
42 match self {
43 Self::Client(cli) => cli
44 .get_multiplexed_async_connection()
45 .await
46 .map(UniversalConnection::Client),
47 Self::Cluster(cli) => cli
48 .get_async_connection()
49 .await
50 .map(|c| UniversalConnection::Cluster(Box::new(c))),
51 }
52 }
53
54 pub fn open<T: redis::IntoConnectionInfo + Clone>(
61 addrs: Vec<T>,
62 ) -> RedisResult<UniversalClient> {
63 let mut addrs = addrs;
64
65 if addrs.is_empty() {
66 return Err(RedisError::from((
67 ErrorKind::InvalidClientConfig,
68 "No address specified",
69 )));
70 }
71
72 if addrs.len() == 1 {
73 Client::open(addrs.remove(0)).map(Self::Client)
74 } else {
75 ClusterClient::new(addrs).map(Self::Cluster)
76 }
77 }
78}
79
80pub struct UniversalBuilder<T> {
106 addrs: Vec<T>,
107 cluster: bool,
108 username: Option<String>,
109 password: Option<String>,
110 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
111 tls: Option<TlsMode>,
112}
113
114impl<T> UniversalBuilder<T> {
115 pub fn new(addrs: Vec<T>) -> UniversalBuilder<T> {
116 UniversalBuilder {
117 addrs,
118 cluster: false,
119 username: None,
120 password: None,
121 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
122 tls: None,
123 }
124 }
125
126 pub fn cluster(mut self, flag: bool) -> UniversalBuilder<T> {
127 self.cluster = flag;
128 self
129 }
130
131 pub fn username(mut self, username: impl Into<String>) -> UniversalBuilder<T> {
133 self.username = Some(username.into());
134 self
135 }
136
137 pub fn password(mut self, password: impl Into<String>) -> UniversalBuilder<T> {
139 self.password = Some(password.into());
140 self
141 }
142
143 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
148 pub fn tls(mut self, mode: TlsMode) -> UniversalBuilder<T> {
149 self.tls = Some(mode);
150 self
151 }
152
153 pub fn build(self) -> RedisResult<UniversalClient>
154 where
155 T: redis::IntoConnectionInfo + Clone,
156 {
157 let UniversalBuilder {
158 mut addrs,
159 cluster,
160 username,
161 password,
162 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
163 tls,
164 } = self;
165
166 if addrs.is_empty() {
167 return Err(RedisError::from((
168 ErrorKind::InvalidClientConfig,
169 "No address specified",
170 )));
171 }
172
173 if cluster {
174 let mut builder = ClusterClient::builder(addrs);
175 if let Some(u) = username {
176 builder = builder.username(u);
177 }
178 if let Some(p) = password {
179 builder = builder.password(p);
180 }
181 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
182 if let Some(mode) = tls {
183 builder = builder.tls(mode);
184 }
185 builder.build().map(UniversalClient::Cluster)
186 } else if username.is_some() || password.is_some() || {
187 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
188 {
189 tls.is_some()
190 }
191 #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
192 {
193 false
194 }
195 } {
196 let conn_info = addrs.remove(0).into_connection_info()?;
197 let orig = conn_info.redis_settings();
198 let mut redis_info = RedisConnectionInfo::default()
199 .set_db(orig.db())
200 .set_protocol(orig.protocol());
201 if let Some(u) = username {
202 redis_info = redis_info.set_username(u);
203 }
204 if let Some(p) = password {
205 redis_info = redis_info.set_password(p);
206 }
207 let conn_info = conn_info.set_redis_settings(redis_info);
208 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
209 let conn_info = if let Some(mode) = tls {
210 apply_tls_to_conn_info(conn_info, mode)?
211 } else {
212 conn_info
213 };
214 Client::open(conn_info).map(UniversalClient::Client)
215 } else {
216 Client::open(addrs.remove(0)).map(UniversalClient::Client)
217 }
218 }
219}
220
221#[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
226fn apply_tls_to_conn_info(
227 conn_info: redis::ConnectionInfo,
228 mode: TlsMode,
229) -> RedisResult<redis::ConnectionInfo> {
230 let insecure = mode == TlsMode::Insecure;
231 let new_addr = match conn_info.addr() {
232 redis::ConnectionAddr::Tcp(host, port) => redis::ConnectionAddr::TcpTls {
233 host: host.clone(),
234 port: *port,
235 insecure,
236 tls_params: None,
237 },
238 other => other.clone(),
240 };
241 Ok(conn_info.set_addr(new_addr))
242}
243
244#[derive(Clone)]
252pub enum UniversalConnection {
253 Client(redis::aio::MultiplexedConnection),
254 Cluster(Box<redis::cluster_async::ClusterConnection>),
255}
256
257#[cfg(test)]
258impl UniversalClient {
259 fn is_client(&self) -> bool {
260 matches!(self, Self::Client(_))
261 }
262
263 fn is_cluster(&self) -> bool {
264 matches!(self, Self::Cluster(_))
265 }
266}
267
268impl redis::aio::ConnectionLike for UniversalConnection {
269 fn req_packed_command<'a>(
270 &'a mut self,
271 cmd: &'a redis::Cmd,
272 ) -> redis::RedisFuture<'a, redis::Value> {
273 match self {
274 Self::Client(conn) => conn.req_packed_command(cmd),
275 Self::Cluster(conn) => conn.req_packed_command(cmd),
276 }
277 }
278
279 fn req_packed_commands<'a>(
280 &'a mut self,
281 cmd: &'a redis::Pipeline,
282 offset: usize,
283 count: usize,
284 ) -> redis::RedisFuture<'a, Vec<redis::Value>> {
285 match self {
286 Self::Client(conn) => conn.req_packed_commands(cmd, offset, count),
287 Self::Cluster(conn) => conn.req_packed_commands(cmd, offset, count),
288 }
289 }
290
291 fn get_db(&self) -> i64 {
292 match self {
293 Self::Client(conn) => conn.get_db(),
294 Self::Cluster(conn) => conn.get_db(),
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn open_empty_addrs_error() {
305 let result = UniversalClient::open(Vec::<String>::new());
306 assert!(result.is_err());
307 }
308
309 #[test]
310 fn open_single_addr_is_client() {
311 let result = UniversalClient::open(vec!["redis://127.0.0.1:6379"]);
312 assert!(result.unwrap().is_client());
313 }
314
315 #[test]
316 fn open_multiple_addrs_is_cluster() {
317 let result =
318 UniversalClient::open(vec!["redis://127.0.0.1:7000", "redis://127.0.0.1:7001"]);
319 assert!(result.unwrap().is_cluster());
320 }
321
322 #[test]
323 fn builder_empty_addrs_error() {
324 let result = UniversalBuilder::new(Vec::<String>::new()).build();
325 assert!(result.is_err());
326 }
327
328 #[test]
329 fn builder_cluster_true_forces_cluster() {
330 let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6379".to_string()])
331 .cluster(true)
332 .build();
333 assert!(result.unwrap().is_cluster());
334 }
335
336 #[test]
337 fn builder_cluster_false_uses_first_addr() {
338 let result = UniversalBuilder::new(vec![
339 "redis://127.0.0.1:7000".to_string(),
340 "redis://127.0.0.1:7001".to_string(),
341 ])
342 .cluster(false)
343 .build();
344 assert!(result.unwrap().is_client());
345 }
346
347 #[test]
348 fn builder_with_password_is_client() {
349 let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6379".to_string()])
350 .password("secret")
351 .build();
352 assert!(result.unwrap().is_client());
353 }
354
355 #[test]
356 fn builder_with_username_and_password_is_client() {
357 let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6379".to_string()])
358 .username("alice")
359 .password("secret")
360 .build();
361 assert!(result.unwrap().is_client());
362 }
363
364 #[test]
365 fn builder_with_password_cluster_is_cluster() {
366 let result = UniversalBuilder::new(vec![
367 "redis://127.0.0.1:7000".to_string(),
368 "redis://127.0.0.1:7001".to_string(),
369 ])
370 .password("secret")
371 .cluster(true)
372 .build();
373 assert!(result.unwrap().is_cluster());
374 }
375
376 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
377 #[test]
378 fn builder_tls_secure_is_client() {
379 let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6380".to_string()])
380 .tls(redis::TlsMode::Secure)
381 .build();
382 assert!(result.unwrap().is_client());
383 }
384
385 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
386 #[test]
387 fn builder_tls_insecure_is_client() {
388 let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6380".to_string()])
389 .tls(redis::TlsMode::Insecure)
390 .build();
391 assert!(result.unwrap().is_client());
392 }
393
394 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
395 #[test]
396 fn builder_tls_cluster_is_cluster() {
397 let result = UniversalBuilder::new(vec![
398 "redis://127.0.0.1:7000".to_string(),
399 "redis://127.0.0.1:7001".to_string(),
400 ])
401 .tls(redis::TlsMode::Secure)
402 .cluster(true)
403 .build();
404 assert!(result.unwrap().is_cluster());
405 }
406}