Skip to main content

redis_test/
server.rs

1use redis::{ConnectionAddr, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo};
2use std::path::Path;
3use std::{env, fs, path::PathBuf, process};
4
5use tempfile::TempDir;
6
7use crate::utils::{TlsFilePaths, build_keys_and_certs_for_tls, get_random_available_port};
8
9pub fn use_protocol() -> ProtocolVersion {
10    if env::var("PROTOCOL").unwrap_or_default() == "RESP3" {
11        ProtocolVersion::RESP3
12    } else {
13        ProtocolVersion::RESP2
14    }
15}
16
17pub fn redis_settings() -> RedisConnectionInfo {
18    RedisConnectionInfo::default().set_protocol(use_protocol())
19}
20
21/// Get the default host to use for TCP connections.
22pub fn get_default_host() -> String {
23    "127.0.0.1".to_string()
24}
25
26#[derive(PartialEq)]
27enum ServerType {
28    Tcp { tls: bool },
29    Unix,
30}
31
32/// Represents a module that can be loaded into the Redis server.
33#[non_exhaustive]
34pub enum Module {
35    Json,
36}
37
38/// A standalone Redis server instance for testing.
39///
40/// `RedisServer` manages the lifecycle of a Redis process, including startup,
41/// configuration, and shutdown.
42///
43/// # Example
44/// ```rust,no_run
45/// use redis_test::server::RedisServer;
46///
47/// let server = RedisServer::new();
48/// let info = server.connection_info();
49/// // Connect to the server using `info`...
50/// ```
51pub struct RedisServer {
52    pub process: process::Child,
53    pub tempdir: tempfile::TempDir,
54    pub log_file: PathBuf,
55    pub addr: redis::ConnectionAddr,
56    pub tls_paths: Option<TlsFilePaths>,
57}
58
59impl ServerType {
60    fn get_intended() -> ServerType {
61        match env::var("REDISRS_SERVER_TYPE")
62            .ok()
63            .as_ref()
64            .map(|x| &x[..])
65        {
66            Some("tcp") => ServerType::Tcp { tls: false },
67            Some("tcp+tls") => ServerType::Tcp { tls: true },
68            Some("unix") => ServerType::Unix,
69            Some(val) => {
70                panic!("Unknown server type {val:?}");
71            }
72            None => ServerType::Tcp { tls: false },
73        }
74    }
75}
76
77impl Drop for RedisServer {
78    fn drop(&mut self) {
79        self.stop()
80    }
81}
82
83impl Default for RedisServer {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl RedisServer {
90    pub fn new() -> RedisServer {
91        RedisServer::with_modules(&[], false)
92    }
93
94    pub fn new_with_mtls() -> RedisServer {
95        RedisServer::with_modules(&[], true)
96    }
97
98    pub fn log_file_contents(&self) -> Option<String> {
99        std::fs::read_to_string(self.log_file.clone()).ok()
100    }
101
102    pub fn get_addr(port: u16) -> ConnectionAddr {
103        let server_type = ServerType::get_intended();
104        match server_type {
105            ServerType::Tcp { tls } => {
106                if tls {
107                    redis::ConnectionAddr::TcpTls {
108                        host: get_default_host(),
109                        port,
110                        insecure: true,
111                        tls_params: None,
112                    }
113                } else {
114                    redis::ConnectionAddr::Tcp(get_default_host(), port)
115                }
116            }
117            ServerType::Unix => {
118                let (a, b) = rand::random::<(u64, u64)>();
119                let path = format!("/tmp/redis-rs-test-{a}-{b}.sock");
120                redis::ConnectionAddr::Unix(PathBuf::from(&path))
121            }
122        }
123    }
124
125    pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer {
126        // this is technically a race but we can't do better with
127        // the tools that redis gives us :(
128        let redis_port = get_random_available_port();
129        let addr = RedisServer::get_addr(redis_port);
130
131        RedisServer::new_with_addr_tls_modules_and_spawner(
132            addr,
133            None,
134            None,
135            mtls_enabled,
136            None,
137            modules,
138            |cmd| {
139                cmd.spawn()
140                    .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
141            },
142        )
143    }
144
145    pub fn new_with_addr_and_modules(
146        addr: redis::ConnectionAddr,
147        modules: &[Module],
148        mtls_enabled: bool,
149    ) -> RedisServer {
150        RedisServer::new_with_addr_tls_modules_and_spawner(
151            addr,
152            None,
153            None,
154            mtls_enabled,
155            None,
156            modules,
157            |cmd| {
158                cmd.spawn()
159                    .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
160            },
161        )
162    }
163
164    pub fn new_with_addr_tls_modules_and_spawner<
165        F: FnOnce(&mut process::Command) -> process::Child,
166    >(
167        addr: redis::ConnectionAddr,
168        config_file: Option<&Path>,
169        tls_paths: Option<TlsFilePaths>,
170        mtls_enabled: bool,
171        cert_auth_field: Option<&str>,
172        modules: &[Module],
173        spawner: F,
174    ) -> RedisServer {
175        let bin = env::var("REDISRS_SERVER_BIN").unwrap_or_else(|_| "redis-server".to_string());
176        let mut redis_cmd = process::Command::new(bin);
177
178        if let Some(config_path) = config_file {
179            redis_cmd.arg(config_path);
180        }
181
182        // Disable snapshotting
183        // This stops littering `dump.rdb` files during testing/development.
184        redis_cmd.arg("--save").arg("");
185
186        // Load Redis Modules
187        for module in modules {
188            match module {
189                Module::Json => {
190                    // Try to pick up json module path from REDISRS_REDIS_JSON_PATH environment variable
191                    let path = match env::var("REDISRS_REDIS_JSON_PATH") {
192                        Ok(path) => path,
193                        // Falling back to legacy REDIS_RS_REDIS_JSON_PATH environment variable
194                        Err(_) => match env::var("REDIS_RS_REDIS_JSON_PATH") {
195                            Ok(path) => {
196                                eprintln!(
197                                    "Warning: Use of REDIS_RS_REDIS_JSON_PATH is deprecated. Use REDISRS_REDIS_JSON_PATH (no '_' before 'RS') instead"
198                                );
199                                path
200                            }
201                            Err(_) => {
202                                panic!(
203                                    "Unable to find path to RedisJSON at REDISRS_REDIS_JSON_PATH, is it set?"
204                                );
205                            }
206                        },
207                    };
208
209                    redis_cmd.arg("--loadmodule").arg(path);
210                }
211            };
212        }
213
214        redis_cmd
215            .stdout(process::Stdio::piped())
216            .stderr(process::Stdio::piped());
217        let tempdir = tempfile::Builder::new()
218            .prefix("redis")
219            .tempdir()
220            .expect("failed to create tempdir");
221        let log_file = Self::log_file(&tempdir);
222        redis_cmd.arg("--logfile").arg(log_file.clone());
223        if get_major_version() > 6 {
224            redis_cmd.arg("--enable-debug-command").arg("yes");
225        }
226        match addr {
227            redis::ConnectionAddr::Tcp(ref bind, server_port) => {
228                redis_cmd
229                    .arg("--port")
230                    .arg(server_port.to_string())
231                    .arg("--bind")
232                    .arg(bind);
233
234                RedisServer {
235                    process: spawner(&mut redis_cmd),
236                    log_file,
237                    tempdir,
238                    addr,
239                    tls_paths: None,
240                }
241            }
242            redis::ConnectionAddr::TcpTls { ref host, port, .. } => {
243                let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir));
244
245                let auth_client = if mtls_enabled { "yes" } else { "no" };
246
247                // prepare redis with TLS
248                redis_cmd
249                    .arg("--tls-port")
250                    .arg(port.to_string())
251                    .arg("--port")
252                    .arg("0")
253                    .arg("--tls-cert-file")
254                    .arg(&tls_paths.redis_crt)
255                    .arg("--tls-key-file")
256                    .arg(&tls_paths.redis_key)
257                    .arg("--tls-ca-cert-file")
258                    .arg(&tls_paths.ca_crt)
259                    .arg("--tls-auth-clients")
260                    .arg(auth_client)
261                    .arg("--bind")
262                    .arg(host);
263
264                // Enable certificate-based authentication (Redis 8.6+)
265                // The cert_auth_field specifies which certificate field to use for username mapping
266                // (e.g., "CN" for Common Name)
267                if let Some(field) = cert_auth_field {
268                    redis_cmd.arg("--tls-auth-clients-user").arg(field);
269                }
270
271                // Insecure only disabled if `mtls` is enabled
272                let insecure = !mtls_enabled;
273
274                let addr = redis::ConnectionAddr::TcpTls {
275                    host: host.clone(),
276                    port,
277                    insecure,
278                    tls_params: None,
279                };
280
281                RedisServer {
282                    process: spawner(&mut redis_cmd),
283                    log_file,
284                    tempdir,
285                    addr,
286                    tls_paths: Some(tls_paths),
287                }
288            }
289            redis::ConnectionAddr::Unix(ref path) => {
290                redis_cmd
291                    .arg("--port")
292                    .arg("0")
293                    .arg("--unixsocket")
294                    .arg(path);
295                RedisServer {
296                    process: spawner(&mut redis_cmd),
297                    log_file,
298                    tempdir,
299                    addr,
300                    tls_paths: None,
301                }
302            }
303            _ => panic!("Unknown address format: {addr:?}"),
304        }
305    }
306
307    pub fn client_addr(&self) -> &redis::ConnectionAddr {
308        &self.addr
309    }
310
311    pub fn host_and_port(&self) -> Option<(&str, u16)> {
312        match &self.addr {
313            ConnectionAddr::Tcp(host, port) => Some((host, *port)),
314            ConnectionAddr::TcpTls { host, port, .. } => Some((host, *port)),
315            _ => None,
316        }
317    }
318
319    pub fn connection_info(&self) -> redis::ConnectionInfo {
320        self.client_addr()
321            .clone()
322            .into_connection_info()
323            .unwrap()
324            .set_redis_settings(redis_settings())
325    }
326
327    pub fn stop(&mut self) {
328        let _ = self.process.kill();
329        let _ = self.process.wait();
330        if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() {
331            fs::remove_file(path).ok();
332        }
333    }
334
335    pub fn log_file(tempdir: &TempDir) -> PathBuf {
336        tempdir.path().join("redis.log")
337    }
338}
339
340fn get_major_version() -> u8 {
341    let full_string = String::from_utf8(
342        process::Command::new("redis-server")
343            .arg("-v")
344            .output()
345            .unwrap()
346            .stdout,
347    )
348    .unwrap();
349    let (_, res) = full_string.split_once(" v=").unwrap();
350    let (res, _) = res.split_once(".").unwrap();
351    res.parse().unwrap()
352}