1use redis::{ConnectionAddr, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo};
2use std::path::Path;
3use std::{env, fs, path::PathBuf, process};
4
5use tempfile::TempDir;
6
7use crate::utils::{TlsFilePaths, build_keys_and_certs_for_tls, get_random_available_port};
8
9pub fn use_protocol() -> ProtocolVersion {
10 if env::var("PROTOCOL").unwrap_or_default() == "RESP3" {
11 ProtocolVersion::RESP3
12 } else {
13 ProtocolVersion::RESP2
14 }
15}
16
17pub fn redis_settings() -> RedisConnectionInfo {
18 RedisConnectionInfo::default().set_protocol(use_protocol())
19}
20
21pub fn get_default_host() -> String {
23 "127.0.0.1".to_string()
24}
25
26#[derive(PartialEq)]
27enum ServerType {
28 Tcp { tls: bool },
29 Unix,
30}
31
32#[non_exhaustive]
34pub enum Module {
35 Json,
36}
37
38pub struct RedisServer {
52 pub process: process::Child,
53 pub tempdir: tempfile::TempDir,
54 pub log_file: PathBuf,
55 pub addr: redis::ConnectionAddr,
56 pub tls_paths: Option<TlsFilePaths>,
57}
58
59impl ServerType {
60 fn get_intended() -> ServerType {
61 match env::var("REDISRS_SERVER_TYPE")
62 .ok()
63 .as_ref()
64 .map(|x| &x[..])
65 {
66 Some("tcp") => ServerType::Tcp { tls: false },
67 Some("tcp+tls") => ServerType::Tcp { tls: true },
68 Some("unix") => ServerType::Unix,
69 Some(val) => {
70 panic!("Unknown server type {val:?}");
71 }
72 None => ServerType::Tcp { tls: false },
73 }
74 }
75}
76
77impl Drop for RedisServer {
78 fn drop(&mut self) {
79 self.stop()
80 }
81}
82
83impl Default for RedisServer {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl RedisServer {
90 pub fn new() -> RedisServer {
91 RedisServer::with_modules(&[], false)
92 }
93
94 pub fn new_with_mtls() -> RedisServer {
95 RedisServer::with_modules(&[], true)
96 }
97
98 pub fn log_file_contents(&self) -> Option<String> {
99 std::fs::read_to_string(self.log_file.clone()).ok()
100 }
101
102 pub fn get_addr(port: u16) -> ConnectionAddr {
103 let server_type = ServerType::get_intended();
104 match server_type {
105 ServerType::Tcp { tls } => {
106 if tls {
107 redis::ConnectionAddr::TcpTls {
108 host: get_default_host(),
109 port,
110 insecure: true,
111 tls_params: None,
112 }
113 } else {
114 redis::ConnectionAddr::Tcp(get_default_host(), port)
115 }
116 }
117 ServerType::Unix => {
118 let (a, b) = rand::random::<(u64, u64)>();
119 let path = format!("/tmp/redis-rs-test-{a}-{b}.sock");
120 redis::ConnectionAddr::Unix(PathBuf::from(&path))
121 }
122 }
123 }
124
125 pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer {
126 let redis_port = get_random_available_port();
129 let addr = RedisServer::get_addr(redis_port);
130
131 RedisServer::new_with_addr_tls_modules_and_spawner(
132 addr,
133 None,
134 None,
135 mtls_enabled,
136 None,
137 modules,
138 |cmd| {
139 cmd.spawn()
140 .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
141 },
142 )
143 }
144
145 pub fn new_with_addr_and_modules(
146 addr: redis::ConnectionAddr,
147 modules: &[Module],
148 mtls_enabled: bool,
149 ) -> RedisServer {
150 RedisServer::new_with_addr_tls_modules_and_spawner(
151 addr,
152 None,
153 None,
154 mtls_enabled,
155 None,
156 modules,
157 |cmd| {
158 cmd.spawn()
159 .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
160 },
161 )
162 }
163
164 pub fn new_with_addr_tls_modules_and_spawner<
165 F: FnOnce(&mut process::Command) -> process::Child,
166 >(
167 addr: redis::ConnectionAddr,
168 config_file: Option<&Path>,
169 tls_paths: Option<TlsFilePaths>,
170 mtls_enabled: bool,
171 cert_auth_field: Option<&str>,
172 modules: &[Module],
173 spawner: F,
174 ) -> RedisServer {
175 let bin = env::var("REDISRS_SERVER_BIN").unwrap_or_else(|_| "redis-server".to_string());
176 let mut redis_cmd = process::Command::new(bin);
177
178 if let Some(config_path) = config_file {
179 redis_cmd.arg(config_path);
180 }
181
182 redis_cmd.arg("--save").arg("");
185
186 for module in modules {
188 match module {
189 Module::Json => {
190 let path = match env::var("REDISRS_REDIS_JSON_PATH") {
192 Ok(path) => path,
193 Err(_) => match env::var("REDIS_RS_REDIS_JSON_PATH") {
195 Ok(path) => {
196 eprintln!(
197 "Warning: Use of REDIS_RS_REDIS_JSON_PATH is deprecated. Use REDISRS_REDIS_JSON_PATH (no '_' before 'RS') instead"
198 );
199 path
200 }
201 Err(_) => {
202 panic!(
203 "Unable to find path to RedisJSON at REDISRS_REDIS_JSON_PATH, is it set?"
204 );
205 }
206 },
207 };
208
209 redis_cmd.arg("--loadmodule").arg(path);
210 }
211 };
212 }
213
214 redis_cmd
215 .stdout(process::Stdio::piped())
216 .stderr(process::Stdio::piped());
217 let tempdir = tempfile::Builder::new()
218 .prefix("redis")
219 .tempdir()
220 .expect("failed to create tempdir");
221 let log_file = Self::log_file(&tempdir);
222 redis_cmd.arg("--logfile").arg(log_file.clone());
223 if get_major_version() > 6 {
224 redis_cmd.arg("--enable-debug-command").arg("yes");
225 }
226 match addr {
227 redis::ConnectionAddr::Tcp(ref bind, server_port) => {
228 redis_cmd
229 .arg("--port")
230 .arg(server_port.to_string())
231 .arg("--bind")
232 .arg(bind);
233
234 RedisServer {
235 process: spawner(&mut redis_cmd),
236 log_file,
237 tempdir,
238 addr,
239 tls_paths: None,
240 }
241 }
242 redis::ConnectionAddr::TcpTls { ref host, port, .. } => {
243 let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir));
244
245 let auth_client = if mtls_enabled { "yes" } else { "no" };
246
247 redis_cmd
249 .arg("--tls-port")
250 .arg(port.to_string())
251 .arg("--port")
252 .arg("0")
253 .arg("--tls-cert-file")
254 .arg(&tls_paths.redis_crt)
255 .arg("--tls-key-file")
256 .arg(&tls_paths.redis_key)
257 .arg("--tls-ca-cert-file")
258 .arg(&tls_paths.ca_crt)
259 .arg("--tls-auth-clients")
260 .arg(auth_client)
261 .arg("--bind")
262 .arg(host);
263
264 if let Some(field) = cert_auth_field {
268 redis_cmd.arg("--tls-auth-clients-user").arg(field);
269 }
270
271 let insecure = !mtls_enabled;
273
274 let addr = redis::ConnectionAddr::TcpTls {
275 host: host.clone(),
276 port,
277 insecure,
278 tls_params: None,
279 };
280
281 RedisServer {
282 process: spawner(&mut redis_cmd),
283 log_file,
284 tempdir,
285 addr,
286 tls_paths: Some(tls_paths),
287 }
288 }
289 redis::ConnectionAddr::Unix(ref path) => {
290 redis_cmd
291 .arg("--port")
292 .arg("0")
293 .arg("--unixsocket")
294 .arg(path);
295 RedisServer {
296 process: spawner(&mut redis_cmd),
297 log_file,
298 tempdir,
299 addr,
300 tls_paths: None,
301 }
302 }
303 _ => panic!("Unknown address format: {addr:?}"),
304 }
305 }
306
307 pub fn client_addr(&self) -> &redis::ConnectionAddr {
308 &self.addr
309 }
310
311 pub fn host_and_port(&self) -> Option<(&str, u16)> {
312 match &self.addr {
313 ConnectionAddr::Tcp(host, port) => Some((host, *port)),
314 ConnectionAddr::TcpTls { host, port, .. } => Some((host, *port)),
315 _ => None,
316 }
317 }
318
319 pub fn connection_info(&self) -> redis::ConnectionInfo {
320 self.client_addr()
321 .clone()
322 .into_connection_info()
323 .unwrap()
324 .set_redis_settings(redis_settings())
325 }
326
327 pub fn stop(&mut self) {
328 let _ = self.process.kill();
329 let _ = self.process.wait();
330 if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() {
331 fs::remove_file(path).ok();
332 }
333 }
334
335 pub fn log_file(tempdir: &TempDir) -> PathBuf {
336 tempdir.path().join("redis.log")
337 }
338}
339
340fn get_major_version() -> u8 {
341 let full_string = String::from_utf8(
342 process::Command::new("redis-server")
343 .arg("-v")
344 .output()
345 .unwrap()
346 .stdout,
347 )
348 .unwrap();
349 let (_, res) = full_string.split_once(" v=").unwrap();
350 let (res, _) = res.split_once(".").unwrap();
351 res.parse().unwrap()
352}