1use redis::{ConnectionAddr, ProtocolVersion, RedisConnectionInfo};
2use std::path::Path;
3use std::{env, fs, path::PathBuf, process};
4
5use tempfile::TempDir;
6
7use crate::utils::{build_keys_and_certs_for_tls, get_random_available_port, TlsFilePaths};
8
9pub fn use_protocol() -> ProtocolVersion {
10 if env::var("PROTOCOL").unwrap_or_default() == "RESP3" {
11 ProtocolVersion::RESP3
12 } else {
13 ProtocolVersion::RESP2
14 }
15}
16
17#[derive(PartialEq)]
18enum ServerType {
19 Tcp { tls: bool },
20 Unix,
21}
22
23pub enum Module {
24 Json,
25}
26
27pub struct RedisServer {
28 pub process: process::Child,
29 pub tempdir: tempfile::TempDir,
30 pub log_file: PathBuf,
31 pub addr: redis::ConnectionAddr,
32 pub tls_paths: Option<TlsFilePaths>,
33}
34
35impl ServerType {
36 fn get_intended() -> ServerType {
37 match env::var("REDISRS_SERVER_TYPE")
38 .ok()
39 .as_ref()
40 .map(|x| &x[..])
41 {
42 Some("tcp") => ServerType::Tcp { tls: false },
43 Some("tcp+tls") => ServerType::Tcp { tls: true },
44 Some("unix") => ServerType::Unix,
45 Some(val) => {
46 panic!("Unknown server type {val:?}");
47 }
48 None => ServerType::Tcp { tls: false },
49 }
50 }
51}
52
53impl Drop for RedisServer {
54 fn drop(&mut self) {
55 self.stop()
56 }
57}
58
59impl Default for RedisServer {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl RedisServer {
66 pub fn new() -> RedisServer {
67 RedisServer::with_modules(&[], false)
68 }
69
70 pub fn new_with_mtls() -> RedisServer {
71 RedisServer::with_modules(&[], true)
72 }
73
74 pub fn log_file_contents(&self) -> Option<String> {
75 std::fs::read_to_string(self.log_file.clone()).ok()
76 }
77
78 pub fn get_addr(port: u16) -> ConnectionAddr {
79 let server_type = ServerType::get_intended();
80 match server_type {
81 ServerType::Tcp { tls } => {
82 if tls {
83 redis::ConnectionAddr::TcpTls {
84 host: "127.0.0.1".to_string(),
85 port,
86 insecure: true,
87 tls_params: None,
88 }
89 } else {
90 redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port)
91 }
92 }
93 ServerType::Unix => {
94 let (a, b) = rand::random::<(u64, u64)>();
95 let path = format!("/tmp/redis-rs-test-{a}-{b}.sock");
96 redis::ConnectionAddr::Unix(PathBuf::from(&path))
97 }
98 }
99 }
100
101 pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer {
102 let redis_port = get_random_available_port();
105 let addr = RedisServer::get_addr(redis_port);
106
107 RedisServer::new_with_addr_tls_modules_and_spawner(
108 addr,
109 None,
110 None,
111 mtls_enabled,
112 modules,
113 |cmd| {
114 cmd.spawn()
115 .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
116 },
117 )
118 }
119
120 pub fn new_with_addr_and_modules(
121 addr: redis::ConnectionAddr,
122 modules: &[Module],
123 mtls_enabled: bool,
124 ) -> RedisServer {
125 RedisServer::new_with_addr_tls_modules_and_spawner(
126 addr,
127 None,
128 None,
129 mtls_enabled,
130 modules,
131 |cmd| {
132 cmd.spawn()
133 .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}"))
134 },
135 )
136 }
137
138 pub fn new_with_addr_tls_modules_and_spawner<
139 F: FnOnce(&mut process::Command) -> process::Child,
140 >(
141 addr: redis::ConnectionAddr,
142 config_file: Option<&Path>,
143 tls_paths: Option<TlsFilePaths>,
144 mtls_enabled: bool,
145 modules: &[Module],
146 spawner: F,
147 ) -> RedisServer {
148 let mut redis_cmd = process::Command::new("redis-server");
149
150 if let Some(config_path) = config_file {
151 redis_cmd.arg(config_path);
152 }
153
154 for module in modules {
156 match module {
157 Module::Json => {
158 redis_cmd
159 .arg("--loadmodule")
160 .arg(env::var("REDIS_RS_REDIS_JSON_PATH").expect(
161 "Unable to find path to RedisJSON at REDIS_RS_REDIS_JSON_PATH, is it set?",
162 ));
163 }
164 };
165 }
166
167 redis_cmd
168 .stdout(process::Stdio::piped())
169 .stderr(process::Stdio::piped());
170 let tempdir = tempfile::Builder::new()
171 .prefix("redis")
172 .tempdir()
173 .expect("failed to create tempdir");
174 let log_file = Self::log_file(&tempdir);
175 redis_cmd.arg("--logfile").arg(log_file.clone());
176 if get_major_version() > 6 {
177 redis_cmd.arg("--enable-debug-command").arg("yes");
178 }
179 match addr {
180 redis::ConnectionAddr::Tcp(ref bind, server_port) => {
181 redis_cmd
182 .arg("--port")
183 .arg(server_port.to_string())
184 .arg("--bind")
185 .arg(bind);
186
187 RedisServer {
188 process: spawner(&mut redis_cmd),
189 log_file,
190 tempdir,
191 addr,
192 tls_paths: None,
193 }
194 }
195 redis::ConnectionAddr::TcpTls { ref host, port, .. } => {
196 let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir));
197
198 let auth_client = if mtls_enabled { "yes" } else { "no" };
199
200 redis_cmd
202 .arg("--tls-port")
203 .arg(port.to_string())
204 .arg("--port")
205 .arg("0")
206 .arg("--tls-cert-file")
207 .arg(&tls_paths.redis_crt)
208 .arg("--tls-key-file")
209 .arg(&tls_paths.redis_key)
210 .arg("--tls-ca-cert-file")
211 .arg(&tls_paths.ca_crt)
212 .arg("--tls-auth-clients")
213 .arg(auth_client)
214 .arg("--bind")
215 .arg(host);
216
217 let insecure = !mtls_enabled;
219
220 let addr = redis::ConnectionAddr::TcpTls {
221 host: host.clone(),
222 port,
223 insecure,
224 tls_params: None,
225 };
226
227 RedisServer {
228 process: spawner(&mut redis_cmd),
229 log_file,
230 tempdir,
231 addr,
232 tls_paths: Some(tls_paths),
233 }
234 }
235 redis::ConnectionAddr::Unix(ref path) => {
236 redis_cmd
237 .arg("--port")
238 .arg("0")
239 .arg("--unixsocket")
240 .arg(path);
241 RedisServer {
242 process: spawner(&mut redis_cmd),
243 log_file,
244 tempdir,
245 addr,
246 tls_paths: None,
247 }
248 }
249 }
250 }
251
252 pub fn client_addr(&self) -> &redis::ConnectionAddr {
253 &self.addr
254 }
255
256 pub fn connection_info(&self) -> redis::ConnectionInfo {
257 redis::ConnectionInfo {
258 addr: self.client_addr().clone(),
259 redis: RedisConnectionInfo {
260 protocol: use_protocol(),
261 ..Default::default()
262 },
263 }
264 }
265
266 pub fn stop(&mut self) {
267 let _ = self.process.kill();
268 let _ = self.process.wait();
269 if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() {
270 fs::remove_file(path).ok();
271 }
272 }
273
274 pub fn log_file(tempdir: &TempDir) -> PathBuf {
275 tempdir.path().join("redis.log")
276 }
277}
278
279fn get_major_version() -> u8 {
280 let full_string = String::from_utf8(
281 process::Command::new("redis-server")
282 .arg("-v")
283 .output()
284 .unwrap()
285 .stdout,
286 )
287 .unwrap();
288 let (_, res) = full_string.split_once(" v=").unwrap();
289 let (res, _) = res.split_once(".").unwrap();
290 res.parse().unwrap()
291}