1use redis::{ConnectionAddr, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo};
2use std::path::Path;
3use std::{env, fs, path::PathBuf, process};
4
5use tempfile::TempDir;
6
7use crate::utils::{TlsFilePaths, build_keys_and_certs_for_tls, get_random_available_port};
8
9pub fn use_protocol() -> ProtocolVersion {
10 if env::var("PROTOCOL").unwrap_or_default() == "RESP3" {
11 ProtocolVersion::RESP3
12 } else {
13 ProtocolVersion::RESP2
14 }
15}
16
17pub fn redis_settings() -> RedisConnectionInfo {
18 RedisConnectionInfo::default().set_protocol(use_protocol())
19}
20
21#[derive(PartialEq)]
22enum ServerType {
23 Tcp { tls: bool },
24 Unix,
25}
26
27#[non_exhaustive]
28pub enum Module {
29 Json,
30}
31
32pub struct RedisServer {
33 pub process: process::Child,
34 pub tempdir: tempfile::TempDir,
35 pub log_file: PathBuf,
36 pub addr: redis::ConnectionAddr,
37 pub tls_paths: Option<TlsFilePaths>,
38}
39
40impl ServerType {
41 fn get_intended() -> ServerType {
42 match env::var("REDISRS_SERVER_TYPE")
43 .ok()
44 .as_ref()
45 .map(|x| &x[..])
46 {
47 Some("tcp") => ServerType::Tcp { tls: false },
48 Some("tcp+tls") => ServerType::Tcp { tls: true },
49 Some("unix") => ServerType::Unix,
50 Some(val) => {
51 panic!("Unknown server type {val:?}");
52 }
53 None => ServerType::Tcp { tls: false },
54 }
55 }
56}
57
58impl Drop for RedisServer {
59 fn drop(&mut self) {
60 self.stop()
61 }
62}
63
64impl Default for RedisServer {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl RedisServer {
71 pub fn new() -> RedisServer {
72 RedisServer::with_modules(&[], false)
73 }
74
75 pub fn new_with_mtls() -> RedisServer {
76 RedisServer::with_modules(&[], true)
77 }
78
79 pub fn log_file_contents(&self) -> Option<String> {
80 std::fs::read_to_string(self.log_file.clone()).ok()
81 }
82
83 pub fn get_addr(port: u16) -> ConnectionAddr {
84 let server_type = ServerType::get_intended();
85 match server_type {
86 ServerType::Tcp { tls } => {
87 if tls {
88 redis::ConnectionAddr::TcpTls {
89 host: "127.0.0.1".to_string(),
90 port,
91 insecure: true,
92 tls_params: None,
93 }
94 } else {
95 redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port)
96 }
97 }
98 ServerType::Unix => {
99 let (a, b) = rand::random::<(u64, u64)>();
100 let path = format!("/tmp/redis-rs-test-{a}-{b}.sock");
101 redis::ConnectionAddr::Unix(PathBuf::from(&path))
102 }
103 }
104 }
105
106 pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer {
107 let redis_port = get_random_available_port();
110 let addr = RedisServer::get_addr(redis_port);
111
112 RedisServer::new_with_addr_tls_modules_and_spawner(
113 addr,
114 None,
115 None,
116 mtls_enabled,
117 modules,
118 |cmd| {
119 cmd.spawn()
120 .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
121 },
122 )
123 }
124
125 pub fn new_with_addr_and_modules(
126 addr: redis::ConnectionAddr,
127 modules: &[Module],
128 mtls_enabled: bool,
129 ) -> RedisServer {
130 RedisServer::new_with_addr_tls_modules_and_spawner(
131 addr,
132 None,
133 None,
134 mtls_enabled,
135 modules,
136 |cmd| {
137 cmd.spawn()
138 .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
139 },
140 )
141 }
142
143 pub fn new_with_addr_tls_modules_and_spawner<
144 F: FnOnce(&mut process::Command) -> process::Child,
145 >(
146 addr: redis::ConnectionAddr,
147 config_file: Option<&Path>,
148 tls_paths: Option<TlsFilePaths>,
149 mtls_enabled: bool,
150 modules: &[Module],
151 spawner: F,
152 ) -> RedisServer {
153 let bin = env::var("REDISRS_SERVER_BIN").unwrap_or_else(|_| "redis-server".to_string());
154 let mut redis_cmd = process::Command::new(bin);
155
156 if let Some(config_path) = config_file {
157 redis_cmd.arg(config_path);
158 }
159
160 redis_cmd.arg("--save").arg("");
163
164 for module in modules {
166 match module {
167 Module::Json => {
168 let path = match env::var("REDISRS_REDIS_JSON_PATH") {
170 Ok(path) => path,
171 Err(_) => match env::var("REDIS_RS_REDIS_JSON_PATH") {
173 Ok(path) => {
174 eprintln!(
175 "Warning: Use of REDIS_RS_REDIS_JSON_PATH is deprecated. Use REDISRS_REDIS_JSON_PATH (no '_' before 'RS') instead"
176 );
177 path
178 }
179 Err(_) => {
180 panic!(
181 "Unable to find path to RedisJSON at REDISRS_REDIS_JSON_PATH, is it set?"
182 );
183 }
184 },
185 };
186
187 redis_cmd.arg("--loadmodule").arg(path);
188 }
189 };
190 }
191
192 redis_cmd
193 .stdout(process::Stdio::piped())
194 .stderr(process::Stdio::piped());
195 let tempdir = tempfile::Builder::new()
196 .prefix("redis")
197 .tempdir()
198 .expect("failed to create tempdir");
199 let log_file = Self::log_file(&tempdir);
200 redis_cmd.arg("--logfile").arg(log_file.clone());
201 if get_major_version() > 6 {
202 redis_cmd.arg("--enable-debug-command").arg("yes");
203 }
204 match addr {
205 redis::ConnectionAddr::Tcp(ref bind, server_port) => {
206 redis_cmd
207 .arg("--port")
208 .arg(server_port.to_string())
209 .arg("--bind")
210 .arg(bind);
211
212 RedisServer {
213 process: spawner(&mut redis_cmd),
214 log_file,
215 tempdir,
216 addr,
217 tls_paths: None,
218 }
219 }
220 redis::ConnectionAddr::TcpTls { ref host, port, .. } => {
221 let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir));
222
223 let auth_client = if mtls_enabled { "yes" } else { "no" };
224
225 redis_cmd
227 .arg("--tls-port")
228 .arg(port.to_string())
229 .arg("--port")
230 .arg("0")
231 .arg("--tls-cert-file")
232 .arg(&tls_paths.redis_crt)
233 .arg("--tls-key-file")
234 .arg(&tls_paths.redis_key)
235 .arg("--tls-ca-cert-file")
236 .arg(&tls_paths.ca_crt)
237 .arg("--tls-auth-clients")
238 .arg(auth_client)
239 .arg("--bind")
240 .arg(host);
241
242 let insecure = !mtls_enabled;
244
245 let addr = redis::ConnectionAddr::TcpTls {
246 host: host.clone(),
247 port,
248 insecure,
249 tls_params: None,
250 };
251
252 RedisServer {
253 process: spawner(&mut redis_cmd),
254 log_file,
255 tempdir,
256 addr,
257 tls_paths: Some(tls_paths),
258 }
259 }
260 redis::ConnectionAddr::Unix(ref path) => {
261 redis_cmd
262 .arg("--port")
263 .arg("0")
264 .arg("--unixsocket")
265 .arg(path);
266 RedisServer {
267 process: spawner(&mut redis_cmd),
268 log_file,
269 tempdir,
270 addr,
271 tls_paths: None,
272 }
273 }
274 _ => panic!("Unknown address format: {addr:?}"),
275 }
276 }
277
278 pub fn client_addr(&self) -> &redis::ConnectionAddr {
279 &self.addr
280 }
281
282 pub fn host_and_port(&self) -> Option<(&str, u16)> {
283 match &self.addr {
284 ConnectionAddr::Tcp(host, port) => Some((host, *port)),
285 ConnectionAddr::TcpTls { host, port, .. } => Some((host, *port)),
286 _ => None,
287 }
288 }
289
290 pub fn connection_info(&self) -> redis::ConnectionInfo {
291 self.client_addr()
292 .clone()
293 .into_connection_info()
294 .unwrap()
295 .set_redis_settings(redis_settings())
296 }
297
298 pub fn stop(&mut self) {
299 let _ = self.process.kill();
300 let _ = self.process.wait();
301 if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() {
302 fs::remove_file(path).ok();
303 }
304 }
305
306 pub fn log_file(tempdir: &TempDir) -> PathBuf {
307 tempdir.path().join("redis.log")
308 }
309}
310
311fn get_major_version() -> u8 {
312 let full_string = String::from_utf8(
313 process::Command::new("redis-server")
314 .arg("-v")
315 .output()
316 .unwrap()
317 .stdout,
318 )
319 .unwrap();
320 let (_, res) = full_string.split_once(" v=").unwrap();
321 let (res, _) = res.split_once(".").unwrap();
322 res.parse().unwrap()
323}