redis_test/
server.rs

1use redis::{ConnectionAddr, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo};
2use std::path::Path;
3use std::{env, fs, path::PathBuf, process};
4
5use tempfile::TempDir;
6
7use crate::utils::{build_keys_and_certs_for_tls, get_random_available_port, TlsFilePaths};
8
9pub fn use_protocol() -> ProtocolVersion {
10    if env::var("PROTOCOL").unwrap_or_default() == "RESP3" {
11        ProtocolVersion::RESP3
12    } else {
13        ProtocolVersion::RESP2
14    }
15}
16
17pub fn redis_settings() -> RedisConnectionInfo {
18    RedisConnectionInfo::default().set_protocol(use_protocol())
19}
20
21#[derive(PartialEq)]
22enum ServerType {
23    Tcp { tls: bool },
24    Unix,
25}
26
27#[non_exhaustive]
28pub enum Module {
29    Json,
30}
31
32pub struct RedisServer {
33    pub process: process::Child,
34    pub tempdir: tempfile::TempDir,
35    pub log_file: PathBuf,
36    pub addr: redis::ConnectionAddr,
37    pub tls_paths: Option<TlsFilePaths>,
38}
39
40impl ServerType {
41    fn get_intended() -> ServerType {
42        match env::var("REDISRS_SERVER_TYPE")
43            .ok()
44            .as_ref()
45            .map(|x| &x[..])
46        {
47            Some("tcp") => ServerType::Tcp { tls: false },
48            Some("tcp+tls") => ServerType::Tcp { tls: true },
49            Some("unix") => ServerType::Unix,
50            Some(val) => {
51                panic!("Unknown server type {val:?}");
52            }
53            None => ServerType::Tcp { tls: false },
54        }
55    }
56}
57
58impl Drop for RedisServer {
59    fn drop(&mut self) {
60        self.stop()
61    }
62}
63
64impl Default for RedisServer {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl RedisServer {
71    pub fn new() -> RedisServer {
72        RedisServer::with_modules(&[], false)
73    }
74
75    pub fn new_with_mtls() -> RedisServer {
76        RedisServer::with_modules(&[], true)
77    }
78
79    pub fn log_file_contents(&self) -> Option<String> {
80        std::fs::read_to_string(self.log_file.clone()).ok()
81    }
82
83    pub fn get_addr(port: u16) -> ConnectionAddr {
84        let server_type = ServerType::get_intended();
85        match server_type {
86            ServerType::Tcp { tls } => {
87                if tls {
88                    redis::ConnectionAddr::TcpTls {
89                        host: "127.0.0.1".to_string(),
90                        port,
91                        insecure: true,
92                        tls_params: None,
93                    }
94                } else {
95                    redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port)
96                }
97            }
98            ServerType::Unix => {
99                let (a, b) = rand::random::<(u64, u64)>();
100                let path = format!("/tmp/redis-rs-test-{a}-{b}.sock");
101                redis::ConnectionAddr::Unix(PathBuf::from(&path))
102            }
103        }
104    }
105
106    pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer {
107        // this is technically a race but we can't do better with
108        // the tools that redis gives us :(
109        let redis_port = get_random_available_port();
110        let addr = RedisServer::get_addr(redis_port);
111
112        RedisServer::new_with_addr_tls_modules_and_spawner(
113            addr,
114            None,
115            None,
116            mtls_enabled,
117            modules,
118            |cmd| {
119                cmd.spawn()
120                    .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
121            },
122        )
123    }
124
125    pub fn new_with_addr_and_modules(
126        addr: redis::ConnectionAddr,
127        modules: &[Module],
128        mtls_enabled: bool,
129    ) -> RedisServer {
130        RedisServer::new_with_addr_tls_modules_and_spawner(
131            addr,
132            None,
133            None,
134            mtls_enabled,
135            modules,
136            |cmd| {
137                cmd.spawn()
138                    .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
139            },
140        )
141    }
142
143    pub fn new_with_addr_tls_modules_and_spawner<
144        F: FnOnce(&mut process::Command) -> process::Child,
145    >(
146        addr: redis::ConnectionAddr,
147        config_file: Option<&Path>,
148        tls_paths: Option<TlsFilePaths>,
149        mtls_enabled: bool,
150        modules: &[Module],
151        spawner: F,
152    ) -> RedisServer {
153        let mut redis_cmd = process::Command::new("redis-server");
154
155        if let Some(config_path) = config_file {
156            redis_cmd.arg(config_path);
157        }
158
159        // Load Redis Modules
160        for module in modules {
161            match module {
162                Module::Json => {
163                    redis_cmd
164                        .arg("--loadmodule")
165                        .arg(env::var("REDIS_RS_REDIS_JSON_PATH").expect(
166                        "Unable to find path to RedisJSON at REDIS_RS_REDIS_JSON_PATH, is it set?",
167                    ));
168                }
169            };
170        }
171
172        redis_cmd
173            .stdout(process::Stdio::piped())
174            .stderr(process::Stdio::piped());
175        let tempdir = tempfile::Builder::new()
176            .prefix("redis")
177            .tempdir()
178            .expect("failed to create tempdir");
179        let log_file = Self::log_file(&tempdir);
180        redis_cmd.arg("--logfile").arg(log_file.clone());
181        if get_major_version() > 6 {
182            redis_cmd.arg("--enable-debug-command").arg("yes");
183        }
184        match addr {
185            redis::ConnectionAddr::Tcp(ref bind, server_port) => {
186                redis_cmd
187                    .arg("--port")
188                    .arg(server_port.to_string())
189                    .arg("--bind")
190                    .arg(bind);
191
192                RedisServer {
193                    process: spawner(&mut redis_cmd),
194                    log_file,
195                    tempdir,
196                    addr,
197                    tls_paths: None,
198                }
199            }
200            redis::ConnectionAddr::TcpTls { ref host, port, .. } => {
201                let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir));
202
203                let auth_client = if mtls_enabled { "yes" } else { "no" };
204
205                // prepare redis with TLS
206                redis_cmd
207                    .arg("--tls-port")
208                    .arg(port.to_string())
209                    .arg("--port")
210                    .arg("0")
211                    .arg("--tls-cert-file")
212                    .arg(&tls_paths.redis_crt)
213                    .arg("--tls-key-file")
214                    .arg(&tls_paths.redis_key)
215                    .arg("--tls-ca-cert-file")
216                    .arg(&tls_paths.ca_crt)
217                    .arg("--tls-auth-clients")
218                    .arg(auth_client)
219                    .arg("--bind")
220                    .arg(host);
221
222                // Insecure only disabled if `mtls` is enabled
223                let insecure = !mtls_enabled;
224
225                let addr = redis::ConnectionAddr::TcpTls {
226                    host: host.clone(),
227                    port,
228                    insecure,
229                    tls_params: None,
230                };
231
232                RedisServer {
233                    process: spawner(&mut redis_cmd),
234                    log_file,
235                    tempdir,
236                    addr,
237                    tls_paths: Some(tls_paths),
238                }
239            }
240            redis::ConnectionAddr::Unix(ref path) => {
241                redis_cmd
242                    .arg("--port")
243                    .arg("0")
244                    .arg("--unixsocket")
245                    .arg(path);
246                RedisServer {
247                    process: spawner(&mut redis_cmd),
248                    log_file,
249                    tempdir,
250                    addr,
251                    tls_paths: None,
252                }
253            }
254            _ => panic!("Unknown address format: {addr:?}"),
255        }
256    }
257
258    pub fn client_addr(&self) -> &redis::ConnectionAddr {
259        &self.addr
260    }
261
262    pub fn connection_info(&self) -> redis::ConnectionInfo {
263        self.client_addr()
264            .clone()
265            .into_connection_info()
266            .unwrap()
267            .set_redis_settings(redis_settings())
268    }
269
270    pub fn stop(&mut self) {
271        let _ = self.process.kill();
272        let _ = self.process.wait();
273        if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() {
274            fs::remove_file(path).ok();
275        }
276    }
277
278    pub fn log_file(tempdir: &TempDir) -> PathBuf {
279        tempdir.path().join("redis.log")
280    }
281}
282
283fn get_major_version() -> u8 {
284    let full_string = String::from_utf8(
285        process::Command::new("redis-server")
286            .arg("-v")
287            .output()
288            .unwrap()
289            .stdout,
290    )
291    .unwrap();
292    let (_, res) = full_string.split_once(" v=").unwrap();
293    let (res, _) = res.split_once(".").unwrap();
294    res.parse().unwrap()
295}