1use redis::{ConnectionAddr, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo};
2use std::path::Path;
3use std::{env, fs, path::PathBuf, process};
4
5use tempfile::TempDir;
6
7use crate::utils::{build_keys_and_certs_for_tls, get_random_available_port, TlsFilePaths};
8
9pub fn use_protocol() -> ProtocolVersion {
10 if env::var("PROTOCOL").unwrap_or_default() == "RESP3" {
11 ProtocolVersion::RESP3
12 } else {
13 ProtocolVersion::RESP2
14 }
15}
16
17pub fn redis_settings() -> RedisConnectionInfo {
18 RedisConnectionInfo::default().set_protocol(use_protocol())
19}
20
21#[derive(PartialEq)]
22enum ServerType {
23 Tcp { tls: bool },
24 Unix,
25}
26
27#[non_exhaustive]
28pub enum Module {
29 Json,
30}
31
32pub struct RedisServer {
33 pub process: process::Child,
34 pub tempdir: tempfile::TempDir,
35 pub log_file: PathBuf,
36 pub addr: redis::ConnectionAddr,
37 pub tls_paths: Option<TlsFilePaths>,
38}
39
40impl ServerType {
41 fn get_intended() -> ServerType {
42 match env::var("REDISRS_SERVER_TYPE")
43 .ok()
44 .as_ref()
45 .map(|x| &x[..])
46 {
47 Some("tcp") => ServerType::Tcp { tls: false },
48 Some("tcp+tls") => ServerType::Tcp { tls: true },
49 Some("unix") => ServerType::Unix,
50 Some(val) => {
51 panic!("Unknown server type {val:?}");
52 }
53 None => ServerType::Tcp { tls: false },
54 }
55 }
56}
57
58impl Drop for RedisServer {
59 fn drop(&mut self) {
60 self.stop()
61 }
62}
63
64impl Default for RedisServer {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl RedisServer {
71 pub fn new() -> RedisServer {
72 RedisServer::with_modules(&[], false)
73 }
74
75 pub fn new_with_mtls() -> RedisServer {
76 RedisServer::with_modules(&[], true)
77 }
78
79 pub fn log_file_contents(&self) -> Option<String> {
80 std::fs::read_to_string(self.log_file.clone()).ok()
81 }
82
83 pub fn get_addr(port: u16) -> ConnectionAddr {
84 let server_type = ServerType::get_intended();
85 match server_type {
86 ServerType::Tcp { tls } => {
87 if tls {
88 redis::ConnectionAddr::TcpTls {
89 host: "127.0.0.1".to_string(),
90 port,
91 insecure: true,
92 tls_params: None,
93 }
94 } else {
95 redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port)
96 }
97 }
98 ServerType::Unix => {
99 let (a, b) = rand::random::<(u64, u64)>();
100 let path = format!("/tmp/redis-rs-test-{a}-{b}.sock");
101 redis::ConnectionAddr::Unix(PathBuf::from(&path))
102 }
103 }
104 }
105
106 pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer {
107 let redis_port = get_random_available_port();
110 let addr = RedisServer::get_addr(redis_port);
111
112 RedisServer::new_with_addr_tls_modules_and_spawner(
113 addr,
114 None,
115 None,
116 mtls_enabled,
117 modules,
118 |cmd| {
119 cmd.spawn()
120 .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
121 },
122 )
123 }
124
125 pub fn new_with_addr_and_modules(
126 addr: redis::ConnectionAddr,
127 modules: &[Module],
128 mtls_enabled: bool,
129 ) -> RedisServer {
130 RedisServer::new_with_addr_tls_modules_and_spawner(
131 addr,
132 None,
133 None,
134 mtls_enabled,
135 modules,
136 |cmd| {
137 cmd.spawn()
138 .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
139 },
140 )
141 }
142
143 pub fn new_with_addr_tls_modules_and_spawner<
144 F: FnOnce(&mut process::Command) -> process::Child,
145 >(
146 addr: redis::ConnectionAddr,
147 config_file: Option<&Path>,
148 tls_paths: Option<TlsFilePaths>,
149 mtls_enabled: bool,
150 modules: &[Module],
151 spawner: F,
152 ) -> RedisServer {
153 let mut redis_cmd = process::Command::new("redis-server");
154
155 if let Some(config_path) = config_file {
156 redis_cmd.arg(config_path);
157 }
158
159 for module in modules {
161 match module {
162 Module::Json => {
163 redis_cmd
164 .arg("--loadmodule")
165 .arg(env::var("REDIS_RS_REDIS_JSON_PATH").expect(
166 "Unable to find path to RedisJSON at REDIS_RS_REDIS_JSON_PATH, is it set?",
167 ));
168 }
169 };
170 }
171
172 redis_cmd
173 .stdout(process::Stdio::piped())
174 .stderr(process::Stdio::piped());
175 let tempdir = tempfile::Builder::new()
176 .prefix("redis")
177 .tempdir()
178 .expect("failed to create tempdir");
179 let log_file = Self::log_file(&tempdir);
180 redis_cmd.arg("--logfile").arg(log_file.clone());
181 if get_major_version() > 6 {
182 redis_cmd.arg("--enable-debug-command").arg("yes");
183 }
184 match addr {
185 redis::ConnectionAddr::Tcp(ref bind, server_port) => {
186 redis_cmd
187 .arg("--port")
188 .arg(server_port.to_string())
189 .arg("--bind")
190 .arg(bind);
191
192 RedisServer {
193 process: spawner(&mut redis_cmd),
194 log_file,
195 tempdir,
196 addr,
197 tls_paths: None,
198 }
199 }
200 redis::ConnectionAddr::TcpTls { ref host, port, .. } => {
201 let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir));
202
203 let auth_client = if mtls_enabled { "yes" } else { "no" };
204
205 redis_cmd
207 .arg("--tls-port")
208 .arg(port.to_string())
209 .arg("--port")
210 .arg("0")
211 .arg("--tls-cert-file")
212 .arg(&tls_paths.redis_crt)
213 .arg("--tls-key-file")
214 .arg(&tls_paths.redis_key)
215 .arg("--tls-ca-cert-file")
216 .arg(&tls_paths.ca_crt)
217 .arg("--tls-auth-clients")
218 .arg(auth_client)
219 .arg("--bind")
220 .arg(host);
221
222 let insecure = !mtls_enabled;
224
225 let addr = redis::ConnectionAddr::TcpTls {
226 host: host.clone(),
227 port,
228 insecure,
229 tls_params: None,
230 };
231
232 RedisServer {
233 process: spawner(&mut redis_cmd),
234 log_file,
235 tempdir,
236 addr,
237 tls_paths: Some(tls_paths),
238 }
239 }
240 redis::ConnectionAddr::Unix(ref path) => {
241 redis_cmd
242 .arg("--port")
243 .arg("0")
244 .arg("--unixsocket")
245 .arg(path);
246 RedisServer {
247 process: spawner(&mut redis_cmd),
248 log_file,
249 tempdir,
250 addr,
251 tls_paths: None,
252 }
253 }
254 _ => panic!("Unknown address format: {addr:?}"),
255 }
256 }
257
258 pub fn client_addr(&self) -> &redis::ConnectionAddr {
259 &self.addr
260 }
261
262 pub fn connection_info(&self) -> redis::ConnectionInfo {
263 self.client_addr()
264 .clone()
265 .into_connection_info()
266 .unwrap()
267 .set_redis_settings(redis_settings())
268 }
269
270 pub fn stop(&mut self) {
271 let _ = self.process.kill();
272 let _ = self.process.wait();
273 if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() {
274 fs::remove_file(path).ok();
275 }
276 }
277
278 pub fn log_file(tempdir: &TempDir) -> PathBuf {
279 tempdir.path().join("redis.log")
280 }
281}
282
283fn get_major_version() -> u8 {
284 let full_string = String::from_utf8(
285 process::Command::new("redis-server")
286 .arg("-v")
287 .output()
288 .unwrap()
289 .stdout,
290 )
291 .unwrap();
292 let (_, res) = full_string.split_once(" v=").unwrap();
293 let (res, _) = res.split_once(".").unwrap();
294 res.parse().unwrap()
295}