1use std::{env, process, thread::sleep, time::Duration};
2
3use tempfile::TempDir;
4
5use crate::{
6 server::{Module, RedisServer},
7 utils::{build_keys_and_certs_for_tls_ext, get_random_available_port, TlsFilePaths},
8};
9
10pub struct RedisClusterConfiguration {
11 pub num_nodes: u16,
12 pub num_replicas: u16,
13 pub modules: Vec<Module>,
14 pub tls_insecure: bool,
15 pub mtls_enabled: bool,
16 pub ports: Vec<u16>,
17 pub certs_with_ip_alts: bool,
18}
19
20impl RedisClusterConfiguration {
21 pub fn single_replica_config() -> Self {
22 Self {
23 num_nodes: 6,
24 num_replicas: 1,
25 ..Default::default()
26 }
27 }
28}
29
30impl Default for RedisClusterConfiguration {
31 fn default() -> Self {
32 Self {
33 num_nodes: 3,
34 num_replicas: 0,
35 modules: vec![],
36 tls_insecure: true,
37 mtls_enabled: false,
38 ports: vec![],
39 certs_with_ip_alts: true,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub enum ClusterType {
46 Tcp,
47 TcpTls,
48}
49
50impl ClusterType {
51 pub fn get_intended() -> ClusterType {
52 match env::var("REDISRS_SERVER_TYPE")
53 .ok()
54 .as_ref()
55 .map(|x| &x[..])
56 {
57 Some("tcp") => ClusterType::Tcp,
58 Some("tcp+tls") => ClusterType::TcpTls,
59 Some(val) => {
60 panic!("Unknown server type {val:?}");
61 }
62 None => ClusterType::Tcp,
63 }
64 }
65
66 fn build_addr(port: u16) -> redis::ConnectionAddr {
67 match ClusterType::get_intended() {
68 ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
69 ClusterType::TcpTls => redis::ConnectionAddr::TcpTls {
70 host: "127.0.0.1".into(),
71 port,
72 insecure: true,
73 tls_params: None,
74 },
75 }
76 }
77}
78
79fn port_in_use(addr: &str) -> bool {
80 let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address");
81 let socket = socket2::Socket::new(
82 socket2::Domain::for_address(socket_addr),
83 socket2::Type::STREAM,
84 None,
85 )
86 .expect("Failed to create socket");
87
88 socket.connect(&socket_addr.into()).is_ok()
89}
90
91pub struct RedisCluster {
92 pub servers: Vec<RedisServer>,
93 pub folders: Vec<TempDir>,
94 pub tls_paths: Option<TlsFilePaths>,
95}
96
97impl RedisCluster {
98 pub fn username() -> &'static str {
99 "hello"
100 }
101
102 pub fn password() -> &'static str {
103 "world"
104 }
105
106 pub fn new(configuration: RedisClusterConfiguration) -> RedisCluster {
107 let RedisClusterConfiguration {
108 num_nodes: nodes,
109 num_replicas: replicas,
110 modules,
111 tls_insecure,
112 mtls_enabled,
113 ports,
114 certs_with_ip_alts,
115 } = configuration;
116
117 let optional_ports = if ports.is_empty() {
118 vec![None; nodes as usize]
119 } else {
120 assert!(ports.len() == nodes as usize);
121 ports.into_iter().map(Some).collect()
122 };
123 let mut chosen_ports = std::collections::HashSet::new();
124
125 let mut folders = vec![];
126 let mut addrs = vec![];
127 let mut tls_paths = None;
128
129 let mut is_tls = false;
130
131 if let ClusterType::TcpTls = ClusterType::get_intended() {
132 let tempdir = tempfile::Builder::new()
134 .prefix("redis")
135 .tempdir()
136 .expect("failed to create tempdir");
137 let files = build_keys_and_certs_for_tls_ext(&tempdir, certs_with_ip_alts);
138 folders.push(tempdir);
139 tls_paths = Some(files);
140 is_tls = true;
141 }
142
143 let max_attempts = 5;
144
145 let mut make_server = |port| {
146 RedisServer::new_with_addr_tls_modules_and_spawner(
147 ClusterType::build_addr(port),
148 None,
149 tls_paths.clone(),
150 mtls_enabled,
151 &modules,
152 |cmd| {
153 let tempdir = tempfile::Builder::new()
154 .prefix("redis")
155 .tempdir()
156 .expect("failed to create tempdir");
157 let acl_path = tempdir.path().join("users.acl");
158 let acl_content = format!(
159 "user {} on allcommands allkeys >{}",
160 Self::username(),
161 Self::password()
162 );
163 std::fs::write(&acl_path, acl_content).expect("failed to write acl file");
164 cmd.arg("--cluster-enabled")
165 .arg("yes")
166 .arg("--cluster-config-file")
167 .arg(tempdir.path().join("nodes.conf"))
168 .arg("--cluster-node-timeout")
169 .arg("5000")
170 .arg("--appendonly")
171 .arg("yes")
172 .arg("--aclfile")
173 .arg(&acl_path);
174 if is_tls {
175 cmd.arg("--tls-cluster").arg("yes");
176 if replicas > 0 {
177 cmd.arg("--tls-replication").arg("yes");
178 }
179 }
180 cmd.current_dir(tempdir.path());
181 folders.push(tempdir);
182 cmd.spawn().unwrap()
183 },
184 )
185 };
186
187 let verify_server = |server: &mut RedisServer| {
188 let process = &mut server.process;
189 match process.try_wait() {
190 Ok(Some(status)) => {
191 let log_file_contents = server.log_file_contents();
192 let err =
193 format!("redis server creation failed with status {status:?}.\nlog file: {log_file_contents:?}");
194 Err(err)
195 }
196 Ok(None) => {
197 let max_attempts = 200;
199 let mut cur_attempts = 0;
200 loop {
201 if cur_attempts == max_attempts {
202 let log_file_contents = server.log_file_contents();
203 break Err(format!("redis server creation failed: Address {} closed. {log_file_contents:?}", server.addr));
204 } else if port_in_use(&server.addr.to_string()) {
205 break Ok(());
206 }
207 eprintln!("Waiting for redis process to initialize");
208 sleep(Duration::from_millis(50));
209 cur_attempts += 1;
210 }
211 }
212 Err(e) => {
213 panic!("Unexpected error in redis server creation {e}");
214 }
215 }
216 };
217
218 let servers = optional_ports
219 .into_iter()
220 .map(|port_option| {
221 for _ in 0..5 {
222 let port = match port_option {
223 Some(port) => port,
224 None => loop {
225 let port = get_random_available_port();
226 if chosen_ports.contains(&port) {
227 continue;
228 }
229 chosen_ports.insert(port);
230 break port;
231 },
232 };
233 let mut server = make_server(port);
234 sleep(Duration::from_millis(50));
235
236 match verify_server(&mut server) {
237 Ok(_) => {
238 let addr = format!("127.0.0.1:{port}");
239 addrs.push(addr.clone());
240 return server;
241 }
242 Err(err) => eprintln!("{err}"),
243 }
244 }
245 panic!("Exhausted retries");
246 })
247 .collect();
248
249 let mut cmd = process::Command::new("redis-cli");
250 cmd.stdout(process::Stdio::piped())
251 .arg("--cluster")
252 .arg("create")
253 .args(&addrs);
254 if replicas > 0 {
255 cmd.arg("--cluster-replicas").arg(replicas.to_string());
256 }
257 cmd.arg("--cluster-yes");
258
259 if is_tls {
260 if mtls_enabled {
261 if let Some(TlsFilePaths {
262 redis_crt,
263 redis_key,
264 ca_crt,
265 }) = &tls_paths
266 {
267 cmd.arg("--cert");
268 cmd.arg(redis_crt);
269 cmd.arg("--key");
270 cmd.arg(redis_key);
271 cmd.arg("--cacert");
272 cmd.arg(ca_crt);
273 cmd.arg("--tls");
274 }
275 } else if !tls_insecure && tls_paths.is_some() {
276 let ca_crt = &tls_paths.as_ref().unwrap().ca_crt;
277 cmd.arg("--tls").arg("--cacert").arg(ca_crt);
278 } else {
279 cmd.arg("--tls").arg("--insecure");
280 }
281 }
282
283 let mut cur_attempts = 0;
284 loop {
285 let output = cmd.output().unwrap();
286 if output.status.success() {
287 break;
288 } else {
289 let err = format!("Cluster creation failed: {output:?}");
290 if cur_attempts == max_attempts {
291 panic!("{err}");
292 }
293 eprintln!("Retrying: {err}");
294 sleep(Duration::from_millis(50));
295 cur_attempts += 1;
296 }
297 }
298
299 let cluster = RedisCluster {
300 servers,
301 folders,
302 tls_paths,
303 };
304 if replicas > 0 {
305 cluster.wait_for_replicas(replicas);
306 }
307
308 wait_for_status_ok(&cluster);
309 cluster
310 }
311
312 fn wait_for_replicas(&self, replicas: u16) {
313 'server: for server in &self.servers {
314 let conn_info = server.connection_info();
315 eprintln!(
316 "waiting until {:?} knows required number of replicas",
317 conn_info.addr
318 );
319
320 let client = redis::Client::open(server.connection_info()).unwrap();
321
322 let mut con = client.get_connection().unwrap();
323
324 for _ in 1..500 {
326 let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap();
327 let slots: Vec<Vec<redis::Value>> = redis::from_owned_redis_value(value).unwrap();
328
329 if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) {
332 continue 'server;
333 }
334
335 sleep(Duration::from_millis(100));
336 }
337
338 panic!("failed to create enough replicas");
339 }
340 }
341
342 pub fn stop(&mut self) {
343 for server in &mut self.servers {
344 server.stop();
345 }
346 }
347
348 pub fn iter_servers(&self) -> impl Iterator<Item = &RedisServer> {
349 self.servers.iter()
350 }
351}
352
353fn wait_for_status_ok(cluster: &RedisCluster) {
354 'server: for server in &cluster.servers {
355 let log_file = RedisServer::log_file(&server.tempdir);
356
357 for _ in 1..500 {
358 let contents =
359 std::fs::read_to_string(&log_file).expect("Should have been able to read the file");
360
361 if contents.contains("Cluster state changed: ok") {
362 continue 'server;
363 }
364 sleep(Duration::from_millis(20));
365 }
366 panic!("failed to reach state change: OK");
367 }
368}
369
370impl Drop for RedisCluster {
371 fn drop(&mut self) {
372 self.stop()
373 }
374}