redis_test/
server.rs

1use redis::{ConnectionAddr, ProtocolVersion, RedisConnectionInfo};
2use std::path::Path;
3use std::{env, fs, path::PathBuf, process};
4
5use tempfile::TempDir;
6
7use crate::utils::{build_keys_and_certs_for_tls, get_random_available_port, TlsFilePaths};
8
9pub fn use_protocol() -> ProtocolVersion {
10    if env::var("PROTOCOL").unwrap_or_default() == "RESP3" {
11        ProtocolVersion::RESP3
12    } else {
13        ProtocolVersion::RESP2
14    }
15}
16
17#[derive(PartialEq)]
18enum ServerType {
19    Tcp { tls: bool },
20    Unix,
21}
22
23pub enum Module {
24    Json,
25}
26
27pub struct RedisServer {
28    pub process: process::Child,
29    pub tempdir: tempfile::TempDir,
30    pub log_file: PathBuf,
31    pub addr: redis::ConnectionAddr,
32    pub tls_paths: Option<TlsFilePaths>,
33}
34
35impl ServerType {
36    fn get_intended() -> ServerType {
37        match env::var("REDISRS_SERVER_TYPE")
38            .ok()
39            .as_ref()
40            .map(|x| &x[..])
41        {
42            Some("tcp") => ServerType::Tcp { tls: false },
43            Some("tcp+tls") => ServerType::Tcp { tls: true },
44            Some("unix") => ServerType::Unix,
45            Some(val) => {
46                panic!("Unknown server type {val:?}");
47            }
48            None => ServerType::Tcp { tls: false },
49        }
50    }
51}
52
53impl Drop for RedisServer {
54    fn drop(&mut self) {
55        self.stop()
56    }
57}
58
59impl Default for RedisServer {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl RedisServer {
66    pub fn new() -> RedisServer {
67        RedisServer::with_modules(&[], false)
68    }
69
70    pub fn new_with_mtls() -> RedisServer {
71        RedisServer::with_modules(&[], true)
72    }
73
74    pub fn log_file_contents(&self) -> Option<String> {
75        std::fs::read_to_string(self.log_file.clone()).ok()
76    }
77
78    pub fn get_addr(port: u16) -> ConnectionAddr {
79        let server_type = ServerType::get_intended();
80        match server_type {
81            ServerType::Tcp { tls } => {
82                if tls {
83                    redis::ConnectionAddr::TcpTls {
84                        host: "127.0.0.1".to_string(),
85                        port,
86                        insecure: true,
87                        tls_params: None,
88                    }
89                } else {
90                    redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port)
91                }
92            }
93            ServerType::Unix => {
94                let (a, b) = rand::random::<(u64, u64)>();
95                let path = format!("/tmp/redis-rs-test-{a}-{b}.sock");
96                redis::ConnectionAddr::Unix(PathBuf::from(&path))
97            }
98        }
99    }
100
101    pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer {
102        // this is technically a race but we can't do better with
103        // the tools that redis gives us :(
104        let redis_port = get_random_available_port();
105        let addr = RedisServer::get_addr(redis_port);
106
107        RedisServer::new_with_addr_tls_modules_and_spawner(
108            addr,
109            None,
110            None,
111            mtls_enabled,
112            modules,
113            |cmd| {
114                cmd.spawn()
115                    .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
116            },
117        )
118    }
119
120    pub fn new_with_addr_and_modules(
121        addr: redis::ConnectionAddr,
122        modules: &[Module],
123        mtls_enabled: bool,
124    ) -> RedisServer {
125        RedisServer::new_with_addr_tls_modules_and_spawner(
126            addr,
127            None,
128            None,
129            mtls_enabled,
130            modules,
131            |cmd| {
132                cmd.spawn()
133                    .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
134            },
135        )
136    }
137
138    pub fn new_with_addr_tls_modules_and_spawner<
139        F: FnOnce(&mut process::Command) -> process::Child,
140    >(
141        addr: redis::ConnectionAddr,
142        config_file: Option<&Path>,
143        tls_paths: Option<TlsFilePaths>,
144        mtls_enabled: bool,
145        modules: &[Module],
146        spawner: F,
147    ) -> RedisServer {
148        let mut redis_cmd = process::Command::new("redis-server");
149
150        if let Some(config_path) = config_file {
151            redis_cmd.arg(config_path);
152        }
153
154        // Load Redis Modules
155        for module in modules {
156            match module {
157                Module::Json => {
158                    redis_cmd
159                        .arg("--loadmodule")
160                        .arg(env::var("REDIS_RS_REDIS_JSON_PATH").expect(
161                        "Unable to find path to RedisJSON at REDIS_RS_REDIS_JSON_PATH, is it set?",
162                    ));
163                }
164            };
165        }
166
167        redis_cmd
168            .stdout(process::Stdio::piped())
169            .stderr(process::Stdio::piped());
170        let tempdir = tempfile::Builder::new()
171            .prefix("redis")
172            .tempdir()
173            .expect("failed to create tempdir");
174        let log_file = Self::log_file(&tempdir);
175        redis_cmd.arg("--logfile").arg(log_file.clone());
176        if get_major_version() > 6 {
177            redis_cmd.arg("--enable-debug-command").arg("yes");
178        }
179        match addr {
180            redis::ConnectionAddr::Tcp(ref bind, server_port) => {
181                redis_cmd
182                    .arg("--port")
183                    .arg(server_port.to_string())
184                    .arg("--bind")
185                    .arg(bind);
186
187                RedisServer {
188                    process: spawner(&mut redis_cmd),
189                    log_file,
190                    tempdir,
191                    addr,
192                    tls_paths: None,
193                }
194            }
195            redis::ConnectionAddr::TcpTls { ref host, port, .. } => {
196                let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir));
197
198                let auth_client = if mtls_enabled { "yes" } else { "no" };
199
200                // prepare redis with TLS
201                redis_cmd
202                    .arg("--tls-port")
203                    .arg(port.to_string())
204                    .arg("--port")
205                    .arg("0")
206                    .arg("--tls-cert-file")
207                    .arg(&tls_paths.redis_crt)
208                    .arg("--tls-key-file")
209                    .arg(&tls_paths.redis_key)
210                    .arg("--tls-ca-cert-file")
211                    .arg(&tls_paths.ca_crt)
212                    .arg("--tls-auth-clients")
213                    .arg(auth_client)
214                    .arg("--bind")
215                    .arg(host);
216
217                // Insecure only disabled if `mtls` is enabled
218                let insecure = !mtls_enabled;
219
220                let addr = redis::ConnectionAddr::TcpTls {
221                    host: host.clone(),
222                    port,
223                    insecure,
224                    tls_params: None,
225                };
226
227                RedisServer {
228                    process: spawner(&mut redis_cmd),
229                    log_file,
230                    tempdir,
231                    addr,
232                    tls_paths: Some(tls_paths),
233                }
234            }
235            redis::ConnectionAddr::Unix(ref path) => {
236                redis_cmd
237                    .arg("--port")
238                    .arg("0")
239                    .arg("--unixsocket")
240                    .arg(path);
241                RedisServer {
242                    process: spawner(&mut redis_cmd),
243                    log_file,
244                    tempdir,
245                    addr,
246                    tls_paths: None,
247                }
248            }
249        }
250    }
251
252    pub fn client_addr(&self) -> &redis::ConnectionAddr {
253        &self.addr
254    }
255
256    pub fn connection_info(&self) -> redis::ConnectionInfo {
257        redis::ConnectionInfo {
258            addr: self.client_addr().clone(),
259            redis: RedisConnectionInfo {
260                protocol: use_protocol(),
261                ..Default::default()
262            },
263        }
264    }
265
266    pub fn stop(&mut self) {
267        let _ = self.process.kill();
268        let _ = self.process.wait();
269        if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() {
270            fs::remove_file(path).ok();
271        }
272    }
273
274    pub fn log_file(tempdir: &TempDir) -> PathBuf {
275        tempdir.path().join("redis.log")
276    }
277}
278
279fn get_major_version() -> u8 {
280    let full_string = String::from_utf8(
281        process::Command::new("redis-server")
282            .arg("-v")
283            .output()
284            .unwrap()
285            .stdout,
286    )
287    .unwrap();
288    let (_, res) = full_string.split_once(" v=").unwrap();
289    let (res, _) = res.split_once(".").unwrap();
290    res.parse().unwrap()
291}