1use std::{env, process, thread::sleep, time::Duration};
2
3use tempfile::TempDir;
4
5use crate::{
6 server::{Module, RedisServer},
7 utils::{TlsFilePaths, build_keys_and_certs_for_tls_ext, get_random_available_port},
8};
9
10pub struct RedisClusterConfiguration {
11 pub num_nodes: u16,
12 pub num_replicas: u16,
13 pub modules: Vec<Module>,
14 pub tls_insecure: bool,
15 pub mtls_enabled: bool,
16 pub ports: Vec<u16>,
17 pub certs_with_ip_alts: bool,
18}
19
20impl RedisClusterConfiguration {
21 pub fn single_replica_config() -> Self {
22 Self {
23 num_nodes: 6,
24 num_replicas: 1,
25 ..Default::default()
26 }
27 }
28}
29
30impl Default for RedisClusterConfiguration {
31 fn default() -> Self {
32 Self {
33 num_nodes: 3,
34 num_replicas: 0,
35 modules: vec![],
36 tls_insecure: true,
37 mtls_enabled: false,
38 ports: vec![],
39 certs_with_ip_alts: true,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45#[non_exhaustive]
46pub enum ClusterType {
47 Tcp,
48 TcpTls,
49}
50
51impl ClusterType {
52 pub fn get_intended() -> ClusterType {
53 match env::var("REDISRS_SERVER_TYPE")
54 .ok()
55 .as_ref()
56 .map(|x| &x[..])
57 {
58 Some("tcp") => ClusterType::Tcp,
59 Some("tcp+tls") => ClusterType::TcpTls,
60 Some(val) => {
61 panic!("Unknown server type {val:?}");
62 }
63 None => ClusterType::Tcp,
64 }
65 }
66
67 fn build_addr(port: u16) -> redis::ConnectionAddr {
68 match ClusterType::get_intended() {
69 ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
70 ClusterType::TcpTls => redis::ConnectionAddr::TcpTls {
71 host: "127.0.0.1".into(),
72 port,
73 insecure: true,
74 tls_params: None,
75 },
76 }
77 }
78}
79
80fn port_in_use(addr: &str) -> bool {
81 let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address");
82 let socket = socket2::Socket::new(
83 socket2::Domain::for_address(socket_addr),
84 socket2::Type::STREAM,
85 None,
86 )
87 .expect("Failed to create socket");
88
89 socket.connect(&socket_addr.into()).is_ok()
90}
91
92pub struct RedisCluster {
93 pub servers: Vec<RedisServer>,
94 pub folders: Vec<TempDir>,
95 pub tls_paths: Option<TlsFilePaths>,
96}
97
98impl RedisCluster {
99 pub fn username() -> &'static str {
100 "hello"
101 }
102
103 pub fn password() -> &'static str {
104 "world"
105 }
106
107 pub fn new(configuration: RedisClusterConfiguration) -> RedisCluster {
108 let RedisClusterConfiguration {
109 num_nodes: nodes,
110 num_replicas: replicas,
111 modules,
112 tls_insecure,
113 mtls_enabled,
114 ports,
115 certs_with_ip_alts,
116 } = configuration;
117
118 let optional_ports = if ports.is_empty() {
119 vec![None; nodes as usize]
120 } else {
121 assert!(ports.len() == nodes as usize);
122 ports.into_iter().map(Some).collect()
123 };
124 let mut chosen_ports = std::collections::HashSet::new();
125
126 let mut folders = vec![];
127 let mut addrs = vec![];
128 let mut tls_paths = None;
129
130 let mut is_tls = false;
131
132 if let ClusterType::TcpTls = ClusterType::get_intended() {
133 let tempdir = tempfile::Builder::new()
135 .prefix("redis")
136 .tempdir()
137 .expect("failed to create tempdir");
138 let files = build_keys_and_certs_for_tls_ext(&tempdir, certs_with_ip_alts);
139 folders.push(tempdir);
140 tls_paths = Some(files);
141 is_tls = true;
142 }
143
144 let max_attempts = 5;
145
146 let mut make_server = |port| {
147 RedisServer::new_with_addr_tls_modules_and_spawner(
148 ClusterType::build_addr(port),
149 None,
150 tls_paths.clone(),
151 mtls_enabled,
152 &modules,
153 |cmd| {
154 let tempdir = tempfile::Builder::new()
155 .prefix("redis")
156 .tempdir()
157 .expect("failed to create tempdir");
158 let acl_path = tempdir.path().join("users.acl");
159 let acl_content = format!(
160 "user {} on allcommands allkeys >{}",
161 Self::username(),
162 Self::password()
163 );
164 std::fs::write(&acl_path, acl_content).expect("failed to write acl file");
165 cmd.arg("--cluster-enabled")
166 .arg("yes")
167 .arg("--cluster-config-file")
168 .arg(tempdir.path().join("nodes.conf"))
169 .arg("--cluster-node-timeout")
170 .arg("5000")
171 .arg("--aclfile")
172 .arg(&acl_path);
173 if is_tls {
174 cmd.arg("--tls-cluster").arg("yes");
175 if replicas > 0 {
176 cmd.arg("--tls-replication").arg("yes");
177 }
178 }
179 cmd.current_dir(tempdir.path());
180 folders.push(tempdir);
181 cmd.spawn().unwrap()
182 },
183 )
184 };
185
186 let verify_server = |server: &mut RedisServer| {
187 let process = &mut server.process;
188 match process.try_wait() {
189 Ok(Some(status)) => {
190 let log_file_contents = server.log_file_contents();
191 let err = format!(
192 "redis server creation failed with status {status:?}.\nlog file: {log_file_contents:?}"
193 );
194 Err(err)
195 }
196 Ok(None) => {
197 let max_attempts = 200;
199 let mut cur_attempts = 0;
200 loop {
201 if cur_attempts == max_attempts {
202 let log_file_contents = server.log_file_contents();
203 break Err(format!(
204 "redis server creation failed: Address {} closed. {log_file_contents:?}",
205 server.addr
206 ));
207 } else if port_in_use(&server.addr.to_string()) {
208 break Ok(());
209 }
210 eprintln!("Waiting for redis process to initialize");
211 sleep(Duration::from_millis(50));
212 cur_attempts += 1;
213 }
214 }
215 Err(e) => {
216 panic!("Unexpected error in redis server creation {e}");
217 }
218 }
219 };
220
221 let servers = optional_ports
222 .into_iter()
223 .map(|port_option| {
224 for _ in 0..5 {
225 let port = match port_option {
226 Some(port) => port,
227 None => loop {
228 let port = get_random_available_port();
229 if chosen_ports.contains(&port) {
230 continue;
231 }
232 chosen_ports.insert(port);
233 break port;
234 },
235 };
236 let mut server = make_server(port);
237 sleep(Duration::from_millis(50));
238
239 match verify_server(&mut server) {
240 Ok(_) => {
241 let addr = format!("127.0.0.1:{port}");
242 addrs.push(addr.clone());
243 return server;
244 }
245 Err(err) => eprintln!("{err}"),
246 }
247 }
248 panic!("Exhausted retries");
249 })
250 .collect();
251
252 let mut cmd = process::Command::new("redis-cli");
253 cmd.stdout(process::Stdio::piped())
254 .arg("--cluster")
255 .arg("create")
256 .args(&addrs);
257 if replicas > 0 {
258 cmd.arg("--cluster-replicas").arg(replicas.to_string());
259 }
260 cmd.arg("--cluster-yes");
261
262 if is_tls {
263 if mtls_enabled {
264 if let Some(TlsFilePaths {
265 redis_crt,
266 redis_key,
267 ca_crt,
268 }) = &tls_paths
269 {
270 cmd.arg("--cert");
271 cmd.arg(redis_crt);
272 cmd.arg("--key");
273 cmd.arg(redis_key);
274 cmd.arg("--cacert");
275 cmd.arg(ca_crt);
276 cmd.arg("--tls");
277 }
278 } else if !tls_insecure && tls_paths.is_some() {
279 let ca_crt = &tls_paths.as_ref().unwrap().ca_crt;
280 cmd.arg("--tls").arg("--cacert").arg(ca_crt);
281 } else {
282 cmd.arg("--tls").arg("--insecure");
283 }
284 }
285
286 let mut cur_attempts = 0;
287 loop {
288 let output = cmd.output().unwrap();
289 if output.status.success() {
290 break;
291 } else {
292 let err = format!("Cluster creation failed: {output:?}");
293 if cur_attempts == max_attempts {
294 panic!("{err}");
295 }
296 eprintln!("Retrying: {err}");
297 sleep(Duration::from_millis(50));
298 cur_attempts += 1;
299 }
300 }
301
302 let cluster = RedisCluster {
303 servers,
304 folders,
305 tls_paths,
306 };
307 if replicas > 0 {
308 cluster.wait_for_replicas(replicas);
309 }
310
311 wait_for_status_ok(&cluster);
312 cluster
313 }
314
315 fn wait_for_replicas(&self, replicas: u16) {
316 'server: for server in &self.servers {
317 let conn_info = server.connection_info();
318 eprintln!(
319 "waiting until {:?} knows required number of replicas",
320 conn_info.addr()
321 );
322
323 let client = redis::Client::open(server.connection_info()).unwrap();
324
325 let mut con = client.get_connection().unwrap();
326
327 for _ in 1..500 {
329 let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap();
330 let slots: Vec<Vec<redis::Value>> = redis::from_redis_value(value).unwrap();
331
332 if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) {
335 continue 'server;
336 }
337
338 sleep(Duration::from_millis(100));
339 }
340
341 panic!("failed to create enough replicas");
342 }
343 }
344
345 pub fn stop(&mut self) {
346 for server in &mut self.servers {
347 server.stop();
348 }
349 }
350
351 pub fn iter_servers(&self) -> impl Iterator<Item = &RedisServer> {
352 self.servers.iter()
353 }
354}
355
356fn wait_for_status_ok(cluster: &RedisCluster) {
357 'server: for server in &cluster.servers {
358 let log_file = RedisServer::log_file(&server.tempdir);
359
360 for _ in 1..500 {
361 let contents =
362 std::fs::read_to_string(&log_file).expect("Should have been able to read the file");
363
364 if contents.contains("Cluster state changed: ok") {
365 continue 'server;
366 }
367 sleep(Duration::from_millis(20));
368 }
369 panic!("failed to reach state change: OK");
370 }
371}
372
373impl Drop for RedisCluster {
374 fn drop(&mut self) {
375 self.stop()
376 }
377}