Skip to main content

redis_test/
server.rs

1use redis::{ConnectionAddr, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo};
2use std::path::Path;
3use std::{env, fs, path::PathBuf, process};
4
5use tempfile::TempDir;
6
7use crate::utils::{TlsFilePaths, build_keys_and_certs_for_tls, get_random_available_port};
8
9pub fn use_protocol() -> ProtocolVersion {
10    if env::var("PROTOCOL").unwrap_or_default() == "RESP3" {
11        ProtocolVersion::RESP3
12    } else {
13        ProtocolVersion::RESP2
14    }
15}
16
17pub fn redis_settings() -> RedisConnectionInfo {
18    RedisConnectionInfo::default().set_protocol(use_protocol())
19}
20
21#[derive(PartialEq)]
22enum ServerType {
23    Tcp { tls: bool },
24    Unix,
25}
26
27#[non_exhaustive]
28pub enum Module {
29    Json,
30}
31
32pub struct RedisServer {
33    pub process: process::Child,
34    pub tempdir: tempfile::TempDir,
35    pub log_file: PathBuf,
36    pub addr: redis::ConnectionAddr,
37    pub tls_paths: Option<TlsFilePaths>,
38}
39
40impl ServerType {
41    fn get_intended() -> ServerType {
42        match env::var("REDISRS_SERVER_TYPE")
43            .ok()
44            .as_ref()
45            .map(|x| &x[..])
46        {
47            Some("tcp") => ServerType::Tcp { tls: false },
48            Some("tcp+tls") => ServerType::Tcp { tls: true },
49            Some("unix") => ServerType::Unix,
50            Some(val) => {
51                panic!("Unknown server type {val:?}");
52            }
53            None => ServerType::Tcp { tls: false },
54        }
55    }
56}
57
58impl Drop for RedisServer {
59    fn drop(&mut self) {
60        self.stop()
61    }
62}
63
64impl Default for RedisServer {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl RedisServer {
71    pub fn new() -> RedisServer {
72        RedisServer::with_modules(&[], false)
73    }
74
75    pub fn new_with_mtls() -> RedisServer {
76        RedisServer::with_modules(&[], true)
77    }
78
79    pub fn log_file_contents(&self) -> Option<String> {
80        std::fs::read_to_string(self.log_file.clone()).ok()
81    }
82
83    pub fn get_addr(port: u16) -> ConnectionAddr {
84        let server_type = ServerType::get_intended();
85        match server_type {
86            ServerType::Tcp { tls } => {
87                if tls {
88                    redis::ConnectionAddr::TcpTls {
89                        host: "127.0.0.1".to_string(),
90                        port,
91                        insecure: true,
92                        tls_params: None,
93                    }
94                } else {
95                    redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port)
96                }
97            }
98            ServerType::Unix => {
99                let (a, b) = rand::random::<(u64, u64)>();
100                let path = format!("/tmp/redis-rs-test-{a}-{b}.sock");
101                redis::ConnectionAddr::Unix(PathBuf::from(&path))
102            }
103        }
104    }
105
106    pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer {
107        // this is technically a race but we can't do better with
108        // the tools that redis gives us :(
109        let redis_port = get_random_available_port();
110        let addr = RedisServer::get_addr(redis_port);
111
112        RedisServer::new_with_addr_tls_modules_and_spawner(
113            addr,
114            None,
115            None,
116            mtls_enabled,
117            modules,
118            |cmd| {
119                cmd.spawn()
120                    .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
121            },
122        )
123    }
124
125    pub fn new_with_addr_and_modules(
126        addr: redis::ConnectionAddr,
127        modules: &[Module],
128        mtls_enabled: bool,
129    ) -> RedisServer {
130        RedisServer::new_with_addr_tls_modules_and_spawner(
131            addr,
132            None,
133            None,
134            mtls_enabled,
135            modules,
136            |cmd| {
137                cmd.spawn()
138                    .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
139            },
140        )
141    }
142
143    pub fn new_with_addr_tls_modules_and_spawner<
144        F: FnOnce(&mut process::Command) -> process::Child,
145    >(
146        addr: redis::ConnectionAddr,
147        config_file: Option<&Path>,
148        tls_paths: Option<TlsFilePaths>,
149        mtls_enabled: bool,
150        modules: &[Module],
151        spawner: F,
152    ) -> RedisServer {
153        let bin = env::var("REDISRS_SERVER_BIN").unwrap_or_else(|_| "redis-server".to_string());
154        let mut redis_cmd = process::Command::new(bin);
155
156        if let Some(config_path) = config_file {
157            redis_cmd.arg(config_path);
158        }
159
160        // Disable snapshotting
161        // This stops littering `dump.rdb` files during testing/development.
162        redis_cmd.arg("--save").arg("");
163
164        // Load Redis Modules
165        for module in modules {
166            match module {
167                Module::Json => {
168                    // Try to pick up json module path from REDISRS_REDIS_JSON_PATH environment variable
169                    let path = match env::var("REDISRS_REDIS_JSON_PATH") {
170                        Ok(path) => path,
171                        // Falling back to legacy REDIS_RS_REDIS_JSON_PATH environment variable
172                        Err(_) => match env::var("REDIS_RS_REDIS_JSON_PATH") {
173                            Ok(path) => {
174                                eprintln!(
175                                    "Warning: Use of REDIS_RS_REDIS_JSON_PATH is deprecated. Use REDISRS_REDIS_JSON_PATH (no '_' before 'RS') instead"
176                                );
177                                path
178                            }
179                            Err(_) => {
180                                panic!(
181                                    "Unable to find path to RedisJSON at REDISRS_REDIS_JSON_PATH, is it set?"
182                                );
183                            }
184                        },
185                    };
186
187                    redis_cmd.arg("--loadmodule").arg(path);
188                }
189            };
190        }
191
192        redis_cmd
193            .stdout(process::Stdio::piped())
194            .stderr(process::Stdio::piped());
195        let tempdir = tempfile::Builder::new()
196            .prefix("redis")
197            .tempdir()
198            .expect("failed to create tempdir");
199        let log_file = Self::log_file(&tempdir);
200        redis_cmd.arg("--logfile").arg(log_file.clone());
201        if get_major_version() > 6 {
202            redis_cmd.arg("--enable-debug-command").arg("yes");
203        }
204        match addr {
205            redis::ConnectionAddr::Tcp(ref bind, server_port) => {
206                redis_cmd
207                    .arg("--port")
208                    .arg(server_port.to_string())
209                    .arg("--bind")
210                    .arg(bind);
211
212                RedisServer {
213                    process: spawner(&mut redis_cmd),
214                    log_file,
215                    tempdir,
216                    addr,
217                    tls_paths: None,
218                }
219            }
220            redis::ConnectionAddr::TcpTls { ref host, port, .. } => {
221                let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir));
222
223                let auth_client = if mtls_enabled { "yes" } else { "no" };
224
225                // prepare redis with TLS
226                redis_cmd
227                    .arg("--tls-port")
228                    .arg(port.to_string())
229                    .arg("--port")
230                    .arg("0")
231                    .arg("--tls-cert-file")
232                    .arg(&tls_paths.redis_crt)
233                    .arg("--tls-key-file")
234                    .arg(&tls_paths.redis_key)
235                    .arg("--tls-ca-cert-file")
236                    .arg(&tls_paths.ca_crt)
237                    .arg("--tls-auth-clients")
238                    .arg(auth_client)
239                    .arg("--bind")
240                    .arg(host);
241
242                // Insecure only disabled if `mtls` is enabled
243                let insecure = !mtls_enabled;
244
245                let addr = redis::ConnectionAddr::TcpTls {
246                    host: host.clone(),
247                    port,
248                    insecure,
249                    tls_params: None,
250                };
251
252                RedisServer {
253                    process: spawner(&mut redis_cmd),
254                    log_file,
255                    tempdir,
256                    addr,
257                    tls_paths: Some(tls_paths),
258                }
259            }
260            redis::ConnectionAddr::Unix(ref path) => {
261                redis_cmd
262                    .arg("--port")
263                    .arg("0")
264                    .arg("--unixsocket")
265                    .arg(path);
266                RedisServer {
267                    process: spawner(&mut redis_cmd),
268                    log_file,
269                    tempdir,
270                    addr,
271                    tls_paths: None,
272                }
273            }
274            _ => panic!("Unknown address format: {addr:?}"),
275        }
276    }
277
278    pub fn client_addr(&self) -> &redis::ConnectionAddr {
279        &self.addr
280    }
281
282    pub fn host_and_port(&self) -> Option<(&str, u16)> {
283        match &self.addr {
284            ConnectionAddr::Tcp(host, port) => Some((host, *port)),
285            ConnectionAddr::TcpTls { host, port, .. } => Some((host, *port)),
286            _ => None,
287        }
288    }
289
290    pub fn connection_info(&self) -> redis::ConnectionInfo {
291        self.client_addr()
292            .clone()
293            .into_connection_info()
294            .unwrap()
295            .set_redis_settings(redis_settings())
296    }
297
298    pub fn stop(&mut self) {
299        let _ = self.process.kill();
300        let _ = self.process.wait();
301        if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() {
302            fs::remove_file(path).ok();
303        }
304    }
305
306    pub fn log_file(tempdir: &TempDir) -> PathBuf {
307        tempdir.path().join("redis.log")
308    }
309}
310
311fn get_major_version() -> u8 {
312    let full_string = String::from_utf8(
313        process::Command::new("redis-server")
314            .arg("-v")
315            .output()
316            .unwrap()
317            .stdout,
318    )
319    .unwrap();
320    let (_, res) = full_string.split_once(" v=").unwrap();
321    let (res, _) = res.split_once(".").unwrap();
322    res.parse().unwrap()
323}