1use std::{env, process, thread::sleep, time::Duration};
2
3use tempfile::TempDir;
4
5use crate::{
6 server::{Module, RedisServer},
7 utils::{TlsFilePaths, build_keys_and_certs_for_tls_ext, get_random_available_port},
8};
9
10pub struct RedisClusterConfiguration {
11 pub num_nodes: u16,
12 pub num_replicas: u16,
13 pub modules: Vec<Module>,
14 pub tls_insecure: bool,
15 pub mtls_enabled: bool,
16 pub ports: Vec<u16>,
17 pub certs_with_ip_alts: bool,
18}
19
20impl RedisClusterConfiguration {
21 pub fn single_replica_config() -> Self {
22 Self {
23 num_nodes: 6,
24 num_replicas: 1,
25 ..Default::default()
26 }
27 }
28}
29
30impl Default for RedisClusterConfiguration {
31 fn default() -> Self {
32 Self {
33 num_nodes: 3,
34 num_replicas: 0,
35 modules: vec![],
36 tls_insecure: true,
37 mtls_enabled: false,
38 ports: vec![],
39 certs_with_ip_alts: true,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45#[non_exhaustive]
46pub enum ClusterType {
47 Tcp,
48 TcpTls,
49}
50
51impl ClusterType {
52 pub fn get_intended() -> ClusterType {
53 match env::var("REDISRS_SERVER_TYPE")
54 .ok()
55 .as_ref()
56 .map(|x| &x[..])
57 {
58 Some("tcp") => ClusterType::Tcp,
59 Some("tcp+tls") => ClusterType::TcpTls,
60 Some(val) => {
61 panic!("Unknown server type {val:?}");
62 }
63 None => ClusterType::Tcp,
64 }
65 }
66
67 fn build_addr(port: u16) -> redis::ConnectionAddr {
68 match ClusterType::get_intended() {
69 ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
70 ClusterType::TcpTls => redis::ConnectionAddr::TcpTls {
71 host: "127.0.0.1".into(),
72 port,
73 insecure: true,
74 tls_params: None,
75 },
76 }
77 }
78}
79
80fn port_in_use(addr: &str) -> bool {
81 let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address");
82 let socket = socket2::Socket::new(
83 socket2::Domain::for_address(socket_addr),
84 socket2::Type::STREAM,
85 None,
86 )
87 .expect("Failed to create socket");
88
89 socket.connect(&socket_addr.into()).is_ok()
90}
91
92pub struct RedisCluster {
93 pub servers: Vec<RedisServer>,
94 pub folders: Vec<TempDir>,
95 pub tls_paths: Option<TlsFilePaths>,
96}
97
98impl RedisCluster {
99 pub fn username() -> &'static str {
100 "hello"
101 }
102
103 pub fn password() -> &'static str {
104 "world"
105 }
106
107 pub fn new(configuration: RedisClusterConfiguration) -> RedisCluster {
108 let RedisClusterConfiguration {
109 num_nodes: nodes,
110 num_replicas: replicas,
111 modules,
112 tls_insecure,
113 mtls_enabled,
114 ports,
115 certs_with_ip_alts,
116 } = configuration;
117
118 let optional_ports = if ports.is_empty() {
119 vec![None; nodes as usize]
120 } else {
121 assert!(ports.len() == nodes as usize);
122 ports.into_iter().map(Some).collect()
123 };
124 let mut chosen_ports = std::collections::HashSet::new();
125
126 let mut folders = vec![];
127 let mut addrs = vec![];
128 let mut tls_paths = None;
129
130 let mut is_tls = false;
131
132 if let ClusterType::TcpTls = ClusterType::get_intended() {
133 let tempdir = tempfile::Builder::new()
135 .prefix("redis")
136 .tempdir()
137 .expect("failed to create tempdir");
138 let files = build_keys_and_certs_for_tls_ext(&tempdir, certs_with_ip_alts);
139 folders.push(tempdir);
140 tls_paths = Some(files);
141 is_tls = true;
142 }
143
144 let max_attempts = 5;
145
146 let mut make_server = |port| {
147 RedisServer::new_with_addr_tls_modules_and_spawner(
148 ClusterType::build_addr(port),
149 None,
150 tls_paths.clone(),
151 mtls_enabled,
152 &modules,
153 |cmd| {
154 let tempdir = tempfile::Builder::new()
155 .prefix("redis")
156 .tempdir()
157 .expect("failed to create tempdir");
158 let acl_path = tempdir.path().join("users.acl");
159 let acl_content = format!(
160 "user {} on allcommands allkeys >{}",
161 Self::username(),
162 Self::password()
163 );
164 std::fs::write(&acl_path, acl_content).expect("failed to write acl file");
165 cmd.arg("--cluster-enabled")
166 .arg("yes")
167 .arg("--cluster-config-file")
168 .arg(tempdir.path().join("nodes.conf"))
169 .arg("--cluster-node-timeout")
170 .arg("5000")
171 .arg("--appendonly")
172 .arg("yes")
173 .arg("--aclfile")
174 .arg(&acl_path);
175 if is_tls {
176 cmd.arg("--tls-cluster").arg("yes");
177 if replicas > 0 {
178 cmd.arg("--tls-replication").arg("yes");
179 }
180 }
181 cmd.current_dir(tempdir.path());
182 folders.push(tempdir);
183 cmd.spawn().unwrap()
184 },
185 )
186 };
187
188 let verify_server = |server: &mut RedisServer| {
189 let process = &mut server.process;
190 match process.try_wait() {
191 Ok(Some(status)) => {
192 let log_file_contents = server.log_file_contents();
193 let err = format!(
194 "redis server creation failed with status {status:?}.\nlog file: {log_file_contents:?}"
195 );
196 Err(err)
197 }
198 Ok(None) => {
199 let max_attempts = 200;
201 let mut cur_attempts = 0;
202 loop {
203 if cur_attempts == max_attempts {
204 let log_file_contents = server.log_file_contents();
205 break Err(format!(
206 "redis server creation failed: Address {} closed. {log_file_contents:?}",
207 server.addr
208 ));
209 } else if port_in_use(&server.addr.to_string()) {
210 break Ok(());
211 }
212 eprintln!("Waiting for redis process to initialize");
213 sleep(Duration::from_millis(50));
214 cur_attempts += 1;
215 }
216 }
217 Err(e) => {
218 panic!("Unexpected error in redis server creation {e}");
219 }
220 }
221 };
222
223 let servers = optional_ports
224 .into_iter()
225 .map(|port_option| {
226 for _ in 0..5 {
227 let port = match port_option {
228 Some(port) => port,
229 None => loop {
230 let port = get_random_available_port();
231 if chosen_ports.contains(&port) {
232 continue;
233 }
234 chosen_ports.insert(port);
235 break port;
236 },
237 };
238 let mut server = make_server(port);
239 sleep(Duration::from_millis(50));
240
241 match verify_server(&mut server) {
242 Ok(_) => {
243 let addr = format!("127.0.0.1:{port}");
244 addrs.push(addr.clone());
245 return server;
246 }
247 Err(err) => eprintln!("{err}"),
248 }
249 }
250 panic!("Exhausted retries");
251 })
252 .collect();
253
254 let mut cmd = process::Command::new("redis-cli");
255 cmd.stdout(process::Stdio::piped())
256 .arg("--cluster")
257 .arg("create")
258 .args(&addrs);
259 if replicas > 0 {
260 cmd.arg("--cluster-replicas").arg(replicas.to_string());
261 }
262 cmd.arg("--cluster-yes");
263
264 if is_tls {
265 if mtls_enabled {
266 if let Some(TlsFilePaths {
267 redis_crt,
268 redis_key,
269 ca_crt,
270 }) = &tls_paths
271 {
272 cmd.arg("--cert");
273 cmd.arg(redis_crt);
274 cmd.arg("--key");
275 cmd.arg(redis_key);
276 cmd.arg("--cacert");
277 cmd.arg(ca_crt);
278 cmd.arg("--tls");
279 }
280 } else if !tls_insecure && tls_paths.is_some() {
281 let ca_crt = &tls_paths.as_ref().unwrap().ca_crt;
282 cmd.arg("--tls").arg("--cacert").arg(ca_crt);
283 } else {
284 cmd.arg("--tls").arg("--insecure");
285 }
286 }
287
288 let mut cur_attempts = 0;
289 loop {
290 let output = cmd.output().unwrap();
291 if output.status.success() {
292 break;
293 } else {
294 let err = format!("Cluster creation failed: {output:?}");
295 if cur_attempts == max_attempts {
296 panic!("{err}");
297 }
298 eprintln!("Retrying: {err}");
299 sleep(Duration::from_millis(50));
300 cur_attempts += 1;
301 }
302 }
303
304 let cluster = RedisCluster {
305 servers,
306 folders,
307 tls_paths,
308 };
309 if replicas > 0 {
310 cluster.wait_for_replicas(replicas);
311 }
312
313 wait_for_status_ok(&cluster);
314 cluster
315 }
316
317 fn wait_for_replicas(&self, replicas: u16) {
318 'server: for server in &self.servers {
319 let conn_info = server.connection_info();
320 eprintln!(
321 "waiting until {:?} knows required number of replicas",
322 conn_info.addr()
323 );
324
325 let client = redis::Client::open(server.connection_info()).unwrap();
326
327 let mut con = client.get_connection().unwrap();
328
329 for _ in 1..500 {
331 let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap();
332 let slots: Vec<Vec<redis::Value>> = redis::from_redis_value(value).unwrap();
333
334 if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) {
337 continue 'server;
338 }
339
340 sleep(Duration::from_millis(100));
341 }
342
343 panic!("failed to create enough replicas");
344 }
345 }
346
347 pub fn stop(&mut self) {
348 for server in &mut self.servers {
349 server.stop();
350 }
351 }
352
353 pub fn iter_servers(&self) -> impl Iterator<Item = &RedisServer> {
354 self.servers.iter()
355 }
356}
357
358fn wait_for_status_ok(cluster: &RedisCluster) {
359 'server: for server in &cluster.servers {
360 let log_file = RedisServer::log_file(&server.tempdir);
361
362 for _ in 1..500 {
363 let contents =
364 std::fs::read_to_string(&log_file).expect("Should have been able to read the file");
365
366 if contents.contains("Cluster state changed: ok") {
367 continue 'server;
368 }
369 sleep(Duration::from_millis(20));
370 }
371 panic!("failed to reach state change: OK");
372 }
373}
374
375impl Drop for RedisCluster {
376 fn drop(&mut self) {
377 self.stop()
378 }
379}