Skip to main content

redis_universal_client/
lib.rs

1#[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
2use redis::TlsMode;
3use redis::{
4    Client, ErrorKind, RedisConnectionInfo, RedisError, RedisResult, cluster::ClusterClient,
5};
6
7/// A universal Redis client that works with both standalone Redis and Redis Cluster.
8///
9/// Wraps either a [`redis::Client`] or a [`redis::cluster::ClusterClient`], similar to
10/// go-redis's `UniversalClient`.
11///
12/// # Examples
13///
14/// ```no_run
15/// use redis::AsyncCommands;
16/// use redis_universal_client::UniversalClient;
17///
18/// # async fn example() -> redis::RedisResult<()> {
19/// // Standalone Redis
20/// let client = UniversalClient::open(vec!["redis://127.0.0.1:6379"])?;
21/// let mut conn = client.get_connection().await?;
22/// conn.set::<_, _, ()>("key", "value").await?;
23/// let val: String = conn.get("key").await?;
24///
25/// // Redis Cluster (multiple addresses)
26/// let client = UniversalClient::open(vec![
27///     "redis://127.0.0.1:7000",
28///     "redis://127.0.0.1:7001",
29/// ])?;
30/// let mut conn = client.get_connection().await?;
31/// # Ok(())
32/// # }
33/// ```
34#[derive(Clone)]
35pub enum UniversalClient {
36    Client(Client),
37    Cluster(ClusterClient),
38}
39
40impl UniversalClient {
41    pub async fn get_connection(&self) -> RedisResult<UniversalConnection> {
42        match self {
43            Self::Client(cli) => cli
44                .get_multiplexed_async_connection()
45                .await
46                .map(UniversalConnection::Client),
47            Self::Cluster(cli) => cli
48                .get_async_connection()
49                .await
50                .map(|c| UniversalConnection::Cluster(Box::new(c))),
51        }
52    }
53
54    /// Creates a [`UniversalClient`] from a list of addresses.
55    ///
56    /// - 1 address: creates a standalone [`redis::Client`]
57    /// - Multiple addresses: creates a [`redis::cluster::ClusterClient`]
58    ///
59    /// To force cluster mode with a single address, use [`UniversalBuilder`] instead.
60    pub fn open<T: redis::IntoConnectionInfo + Clone>(
61        addrs: Vec<T>,
62    ) -> RedisResult<UniversalClient> {
63        let mut addrs = addrs;
64
65        if addrs.is_empty() {
66            return Err(RedisError::from((
67                ErrorKind::InvalidClientConfig,
68                "No address specified",
69            )));
70        }
71
72        if addrs.len() == 1 {
73            Client::open(addrs.remove(0)).map(Self::Client)
74        } else {
75            ClusterClient::new(addrs).map(Self::Cluster)
76        }
77    }
78}
79
80/// Builder for [`UniversalClient`] with explicit control over cluster mode and credentials.
81///
82/// Unlike [`UniversalClient::open`], the builder lets you force cluster mode
83/// regardless of the number of addresses, and set ACL username/password
84/// programmatically rather than embedding them in the URL.
85///
86/// # Examples
87///
88/// ```no_run
89/// use redis_universal_client::UniversalBuilder;
90///
91/// # fn example() -> redis::RedisResult<()> {
92/// // Force cluster mode with a single address
93/// let client = UniversalBuilder::new(vec!["redis://127.0.0.1:7000".to_string()])
94///     .cluster(true)
95///     .build()?;
96///
97/// // Standalone Redis with ACL credentials
98/// let client = UniversalBuilder::new(vec!["redis://127.0.0.1:6379".to_string()])
99///     .username("alice")
100///     .password("secret")
101///     .build()?;
102/// # Ok(())
103/// # }
104/// ```
105pub struct UniversalBuilder<T> {
106    addrs: Vec<T>,
107    cluster: bool,
108    username: Option<String>,
109    password: Option<String>,
110    #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
111    tls: Option<TlsMode>,
112}
113
114impl<T> UniversalBuilder<T> {
115    pub fn new(addrs: Vec<T>) -> UniversalBuilder<T> {
116        UniversalBuilder {
117            addrs,
118            cluster: false,
119            username: None,
120            password: None,
121            #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
122            tls: None,
123        }
124    }
125
126    pub fn cluster(mut self, flag: bool) -> UniversalBuilder<T> {
127        self.cluster = flag;
128        self
129    }
130
131    /// Set the ACL username for authentication (Redis 6.0+).
132    pub fn username(mut self, username: impl Into<String>) -> UniversalBuilder<T> {
133        self.username = Some(username.into());
134        self
135    }
136
137    /// Set the password for authentication.
138    pub fn password(mut self, password: impl Into<String>) -> UniversalBuilder<T> {
139        self.password = Some(password.into());
140        self
141    }
142
143    /// Enable TLS. Use [`TlsMode::Secure`] to verify certificates (recommended)
144    /// or [`TlsMode::Insecure`] to skip verification.
145    ///
146    /// Requires the `tls-native-tls` or `tls-rustls` feature.
147    #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
148    pub fn tls(mut self, mode: TlsMode) -> UniversalBuilder<T> {
149        self.tls = Some(mode);
150        self
151    }
152
153    pub fn build(self) -> RedisResult<UniversalClient>
154    where
155        T: redis::IntoConnectionInfo + Clone,
156    {
157        let UniversalBuilder {
158            mut addrs,
159            cluster,
160            username,
161            password,
162            #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
163            tls,
164        } = self;
165
166        if addrs.is_empty() {
167            return Err(RedisError::from((
168                ErrorKind::InvalidClientConfig,
169                "No address specified",
170            )));
171        }
172
173        if cluster {
174            let mut builder = ClusterClient::builder(addrs);
175            if let Some(u) = username {
176                builder = builder.username(u);
177            }
178            if let Some(p) = password {
179                builder = builder.password(p);
180            }
181            #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
182            if let Some(mode) = tls {
183                builder = builder.tls(mode);
184            }
185            builder.build().map(UniversalClient::Cluster)
186        } else if username.is_some() || password.is_some() || {
187            #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
188            {
189                tls.is_some()
190            }
191            #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
192            {
193                false
194            }
195        } {
196            let conn_info = addrs.remove(0).into_connection_info()?;
197            let orig = conn_info.redis_settings();
198            let mut redis_info = RedisConnectionInfo::default()
199                .set_db(orig.db())
200                .set_protocol(orig.protocol());
201            if let Some(u) = username {
202                redis_info = redis_info.set_username(u);
203            }
204            if let Some(p) = password {
205                redis_info = redis_info.set_password(p);
206            }
207            let conn_info = conn_info.set_redis_settings(redis_info);
208            #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
209            let conn_info = if let Some(mode) = tls {
210                apply_tls_to_conn_info(conn_info, mode)?
211            } else {
212                conn_info
213            };
214            Client::open(conn_info).map(UniversalClient::Client)
215        } else {
216            Client::open(addrs.remove(0)).map(UniversalClient::Client)
217        }
218    }
219}
220
221/// Converts a `ConnectionInfo` with a plain TCP address to TLS by replacing
222/// `ConnectionAddr::Tcp` with `ConnectionAddr::TcpTls`.
223///
224/// If the address is already TLS or is a Unix socket, it is left unchanged.
225#[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
226fn apply_tls_to_conn_info(
227    conn_info: redis::ConnectionInfo,
228    mode: TlsMode,
229) -> RedisResult<redis::ConnectionInfo> {
230    let insecure = mode == TlsMode::Insecure;
231    let new_addr = match conn_info.addr() {
232        redis::ConnectionAddr::Tcp(host, port) => redis::ConnectionAddr::TcpTls {
233            host: host.clone(),
234            port: *port,
235            insecure,
236            tls_params: None,
237        },
238        // Already TLS or Unix socket — leave as-is
239        other => other.clone(),
240    };
241    Ok(conn_info.set_addr(new_addr))
242}
243
244/// Async multiplexed connection for both standalone and cluster Redis.
245///
246/// Wraps either a [`redis::aio::MultiplexedConnection`] or a
247/// [`redis::cluster_async::ClusterConnection`]. Implements [`redis::aio::ConnectionLike`],
248/// so all [`redis::AsyncCommands`] work transparently.
249///
250/// Both variants are `Clone + Send + Sync`.
251#[derive(Clone)]
252pub enum UniversalConnection {
253    Client(redis::aio::MultiplexedConnection),
254    Cluster(Box<redis::cluster_async::ClusterConnection>),
255}
256
257#[cfg(test)]
258impl UniversalClient {
259    fn is_client(&self) -> bool {
260        matches!(self, Self::Client(_))
261    }
262
263    fn is_cluster(&self) -> bool {
264        matches!(self, Self::Cluster(_))
265    }
266}
267
268impl redis::aio::ConnectionLike for UniversalConnection {
269    fn req_packed_command<'a>(
270        &'a mut self,
271        cmd: &'a redis::Cmd,
272    ) -> redis::RedisFuture<'a, redis::Value> {
273        match self {
274            Self::Client(conn) => conn.req_packed_command(cmd),
275            Self::Cluster(conn) => conn.req_packed_command(cmd),
276        }
277    }
278
279    fn req_packed_commands<'a>(
280        &'a mut self,
281        cmd: &'a redis::Pipeline,
282        offset: usize,
283        count: usize,
284    ) -> redis::RedisFuture<'a, Vec<redis::Value>> {
285        match self {
286            Self::Client(conn) => conn.req_packed_commands(cmd, offset, count),
287            Self::Cluster(conn) => conn.req_packed_commands(cmd, offset, count),
288        }
289    }
290
291    fn get_db(&self) -> i64 {
292        match self {
293            Self::Client(conn) => conn.get_db(),
294            Self::Cluster(conn) => conn.get_db(),
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn open_empty_addrs_error() {
305        let result = UniversalClient::open(Vec::<String>::new());
306        assert!(result.is_err());
307    }
308
309    #[test]
310    fn open_single_addr_is_client() {
311        let result = UniversalClient::open(vec!["redis://127.0.0.1:6379"]);
312        assert!(result.unwrap().is_client());
313    }
314
315    #[test]
316    fn open_multiple_addrs_is_cluster() {
317        let result =
318            UniversalClient::open(vec!["redis://127.0.0.1:7000", "redis://127.0.0.1:7001"]);
319        assert!(result.unwrap().is_cluster());
320    }
321
322    #[test]
323    fn builder_empty_addrs_error() {
324        let result = UniversalBuilder::new(Vec::<String>::new()).build();
325        assert!(result.is_err());
326    }
327
328    #[test]
329    fn builder_cluster_true_forces_cluster() {
330        let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6379".to_string()])
331            .cluster(true)
332            .build();
333        assert!(result.unwrap().is_cluster());
334    }
335
336    #[test]
337    fn builder_cluster_false_uses_first_addr() {
338        let result = UniversalBuilder::new(vec![
339            "redis://127.0.0.1:7000".to_string(),
340            "redis://127.0.0.1:7001".to_string(),
341        ])
342        .cluster(false)
343        .build();
344        assert!(result.unwrap().is_client());
345    }
346
347    #[test]
348    fn builder_with_password_is_client() {
349        let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6379".to_string()])
350            .password("secret")
351            .build();
352        assert!(result.unwrap().is_client());
353    }
354
355    #[test]
356    fn builder_with_username_and_password_is_client() {
357        let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6379".to_string()])
358            .username("alice")
359            .password("secret")
360            .build();
361        assert!(result.unwrap().is_client());
362    }
363
364    #[test]
365    fn builder_with_password_cluster_is_cluster() {
366        let result = UniversalBuilder::new(vec![
367            "redis://127.0.0.1:7000".to_string(),
368            "redis://127.0.0.1:7001".to_string(),
369        ])
370        .password("secret")
371        .cluster(true)
372        .build();
373        assert!(result.unwrap().is_cluster());
374    }
375
376    #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
377    #[test]
378    fn builder_tls_secure_is_client() {
379        let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6380".to_string()])
380            .tls(redis::TlsMode::Secure)
381            .build();
382        assert!(result.unwrap().is_client());
383    }
384
385    #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
386    #[test]
387    fn builder_tls_insecure_is_client() {
388        let result = UniversalBuilder::new(vec!["redis://127.0.0.1:6380".to_string()])
389            .tls(redis::TlsMode::Insecure)
390            .build();
391        assert!(result.unwrap().is_client());
392    }
393
394    #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
395    #[test]
396    fn builder_tls_cluster_is_cluster() {
397        let result = UniversalBuilder::new(vec![
398            "redis://127.0.0.1:7000".to_string(),
399            "redis://127.0.0.1:7001".to_string(),
400        ])
401        .tls(redis::TlsMode::Secure)
402        .cluster(true)
403        .build();
404        assert!(result.unwrap().is_cluster());
405    }
406}