redis_test/
cluster.rs

1use std::{env, process, thread::sleep, time::Duration};
2
3use tempfile::TempDir;
4
5use crate::{
6    server::{Module, RedisServer},
7    utils::{build_keys_and_certs_for_tls_ext, get_random_available_port, TlsFilePaths},
8};
9
10pub struct RedisClusterConfiguration {
11    pub num_nodes: u16,
12    pub num_replicas: u16,
13    pub modules: Vec<Module>,
14    pub tls_insecure: bool,
15    pub mtls_enabled: bool,
16    pub ports: Vec<u16>,
17    pub certs_with_ip_alts: bool,
18}
19
20impl RedisClusterConfiguration {
21    pub fn single_replica_config() -> Self {
22        Self {
23            num_nodes: 6,
24            num_replicas: 1,
25            ..Default::default()
26        }
27    }
28}
29
30impl Default for RedisClusterConfiguration {
31    fn default() -> Self {
32        Self {
33            num_nodes: 3,
34            num_replicas: 0,
35            modules: vec![],
36            tls_insecure: true,
37            mtls_enabled: false,
38            ports: vec![],
39            certs_with_ip_alts: true,
40        }
41    }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45#[non_exhaustive]
46pub enum ClusterType {
47    Tcp,
48    TcpTls,
49}
50
51impl ClusterType {
52    pub fn get_intended() -> ClusterType {
53        match env::var("REDISRS_SERVER_TYPE")
54            .ok()
55            .as_ref()
56            .map(|x| &x[..])
57        {
58            Some("tcp") => ClusterType::Tcp,
59            Some("tcp+tls") => ClusterType::TcpTls,
60            Some(val) => {
61                panic!("Unknown server type {val:?}");
62            }
63            None => ClusterType::Tcp,
64        }
65    }
66
67    fn build_addr(port: u16) -> redis::ConnectionAddr {
68        match ClusterType::get_intended() {
69            ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
70            ClusterType::TcpTls => redis::ConnectionAddr::TcpTls {
71                host: "127.0.0.1".into(),
72                port,
73                insecure: true,
74                tls_params: None,
75            },
76        }
77    }
78}
79
80fn port_in_use(addr: &str) -> bool {
81    let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address");
82    let socket = socket2::Socket::new(
83        socket2::Domain::for_address(socket_addr),
84        socket2::Type::STREAM,
85        None,
86    )
87    .expect("Failed to create socket");
88
89    socket.connect(&socket_addr.into()).is_ok()
90}
91
92pub struct RedisCluster {
93    pub servers: Vec<RedisServer>,
94    pub folders: Vec<TempDir>,
95    pub tls_paths: Option<TlsFilePaths>,
96}
97
98impl RedisCluster {
99    pub fn username() -> &'static str {
100        "hello"
101    }
102
103    pub fn password() -> &'static str {
104        "world"
105    }
106
107    pub fn new(configuration: RedisClusterConfiguration) -> RedisCluster {
108        let RedisClusterConfiguration {
109            num_nodes: nodes,
110            num_replicas: replicas,
111            modules,
112            tls_insecure,
113            mtls_enabled,
114            ports,
115            certs_with_ip_alts,
116        } = configuration;
117
118        let optional_ports = if ports.is_empty() {
119            vec![None; nodes as usize]
120        } else {
121            assert!(ports.len() == nodes as usize);
122            ports.into_iter().map(Some).collect()
123        };
124        let mut chosen_ports = std::collections::HashSet::new();
125
126        let mut folders = vec![];
127        let mut addrs = vec![];
128        let mut tls_paths = None;
129
130        let mut is_tls = false;
131
132        if let ClusterType::TcpTls = ClusterType::get_intended() {
133            // Create a shared set of keys in cluster mode
134            let tempdir = tempfile::Builder::new()
135                .prefix("redis")
136                .tempdir()
137                .expect("failed to create tempdir");
138            let files = build_keys_and_certs_for_tls_ext(&tempdir, certs_with_ip_alts);
139            folders.push(tempdir);
140            tls_paths = Some(files);
141            is_tls = true;
142        }
143
144        let max_attempts = 5;
145
146        let mut make_server = |port| {
147            RedisServer::new_with_addr_tls_modules_and_spawner(
148                ClusterType::build_addr(port),
149                None,
150                tls_paths.clone(),
151                mtls_enabled,
152                &modules,
153                |cmd| {
154                    let tempdir = tempfile::Builder::new()
155                        .prefix("redis")
156                        .tempdir()
157                        .expect("failed to create tempdir");
158                    let acl_path = tempdir.path().join("users.acl");
159                    let acl_content = format!(
160                        "user {} on allcommands allkeys >{}",
161                        Self::username(),
162                        Self::password()
163                    );
164                    std::fs::write(&acl_path, acl_content).expect("failed to write acl file");
165                    cmd.arg("--cluster-enabled")
166                        .arg("yes")
167                        .arg("--cluster-config-file")
168                        .arg(tempdir.path().join("nodes.conf"))
169                        .arg("--cluster-node-timeout")
170                        .arg("5000")
171                        .arg("--appendonly")
172                        .arg("yes")
173                        .arg("--aclfile")
174                        .arg(&acl_path);
175                    if is_tls {
176                        cmd.arg("--tls-cluster").arg("yes");
177                        if replicas > 0 {
178                            cmd.arg("--tls-replication").arg("yes");
179                        }
180                    }
181                    cmd.current_dir(tempdir.path());
182                    folders.push(tempdir);
183                    cmd.spawn().unwrap()
184                },
185            )
186        };
187
188        let verify_server = |server: &mut RedisServer| {
189            let process = &mut server.process;
190            match process.try_wait() {
191                Ok(Some(status)) => {
192                    let log_file_contents = server.log_file_contents();
193                    let err =
194                                    format!("redis server creation failed with status {status:?}.\nlog file: {log_file_contents:?}");
195                    Err(err)
196                }
197                Ok(None) => {
198                    // wait for 10 seconds for the server to be available.
199                    let max_attempts = 200;
200                    let mut cur_attempts = 0;
201                    loop {
202                        if cur_attempts == max_attempts {
203                            let log_file_contents = server.log_file_contents();
204                            break Err(format!("redis server creation failed: Address {} closed. {log_file_contents:?}", server.addr));
205                        } else if port_in_use(&server.addr.to_string()) {
206                            break Ok(());
207                        }
208                        eprintln!("Waiting for redis process to initialize");
209                        sleep(Duration::from_millis(50));
210                        cur_attempts += 1;
211                    }
212                }
213                Err(e) => {
214                    panic!("Unexpected error in redis server creation {e}");
215                }
216            }
217        };
218
219        let servers = optional_ports
220            .into_iter()
221            .map(|port_option| {
222                for _ in 0..5 {
223                    let port = match port_option {
224                        Some(port) => port,
225                        None => loop {
226                            let port = get_random_available_port();
227                            if chosen_ports.contains(&port) {
228                                continue;
229                            }
230                            chosen_ports.insert(port);
231                            break port;
232                        },
233                    };
234                    let mut server = make_server(port);
235                    sleep(Duration::from_millis(50));
236
237                    match verify_server(&mut server) {
238                        Ok(_) => {
239                            let addr = format!("127.0.0.1:{port}");
240                            addrs.push(addr.clone());
241                            return server;
242                        }
243                        Err(err) => eprintln!("{err}"),
244                    }
245                }
246                panic!("Exhausted retries");
247            })
248            .collect();
249
250        let mut cmd = process::Command::new("redis-cli");
251        cmd.stdout(process::Stdio::piped())
252            .arg("--cluster")
253            .arg("create")
254            .args(&addrs);
255        if replicas > 0 {
256            cmd.arg("--cluster-replicas").arg(replicas.to_string());
257        }
258        cmd.arg("--cluster-yes");
259
260        if is_tls {
261            if mtls_enabled {
262                if let Some(TlsFilePaths {
263                    redis_crt,
264                    redis_key,
265                    ca_crt,
266                }) = &tls_paths
267                {
268                    cmd.arg("--cert");
269                    cmd.arg(redis_crt);
270                    cmd.arg("--key");
271                    cmd.arg(redis_key);
272                    cmd.arg("--cacert");
273                    cmd.arg(ca_crt);
274                    cmd.arg("--tls");
275                }
276            } else if !tls_insecure && tls_paths.is_some() {
277                let ca_crt = &tls_paths.as_ref().unwrap().ca_crt;
278                cmd.arg("--tls").arg("--cacert").arg(ca_crt);
279            } else {
280                cmd.arg("--tls").arg("--insecure");
281            }
282        }
283
284        let mut cur_attempts = 0;
285        loop {
286            let output = cmd.output().unwrap();
287            if output.status.success() {
288                break;
289            } else {
290                let err = format!("Cluster creation failed: {output:?}");
291                if cur_attempts == max_attempts {
292                    panic!("{err}");
293                }
294                eprintln!("Retrying: {err}");
295                sleep(Duration::from_millis(50));
296                cur_attempts += 1;
297            }
298        }
299
300        let cluster = RedisCluster {
301            servers,
302            folders,
303            tls_paths,
304        };
305        if replicas > 0 {
306            cluster.wait_for_replicas(replicas);
307        }
308
309        wait_for_status_ok(&cluster);
310        cluster
311    }
312
313    fn wait_for_replicas(&self, replicas: u16) {
314        'server: for server in &self.servers {
315            let conn_info = server.connection_info();
316            eprintln!(
317                "waiting until {:?} knows required number of replicas",
318                conn_info.addr()
319            );
320
321            let client = redis::Client::open(server.connection_info()).unwrap();
322
323            let mut con = client.get_connection().unwrap();
324
325            // retry 500 times
326            for _ in 1..500 {
327                let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap();
328                let slots: Vec<Vec<redis::Value>> = redis::from_redis_value(value).unwrap();
329
330                // all slots should have following items:
331                // [start slot range, end slot range, master's IP, replica1's IP, replica2's IP,... ]
332                if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) {
333                    continue 'server;
334                }
335
336                sleep(Duration::from_millis(100));
337            }
338
339            panic!("failed to create enough replicas");
340        }
341    }
342
343    pub fn stop(&mut self) {
344        for server in &mut self.servers {
345            server.stop();
346        }
347    }
348
349    pub fn iter_servers(&self) -> impl Iterator<Item = &RedisServer> {
350        self.servers.iter()
351    }
352}
353
354fn wait_for_status_ok(cluster: &RedisCluster) {
355    'server: for server in &cluster.servers {
356        let log_file = RedisServer::log_file(&server.tempdir);
357
358        for _ in 1..500 {
359            let contents =
360                std::fs::read_to_string(&log_file).expect("Should have been able to read the file");
361
362            if contents.contains("Cluster state changed: ok") {
363                continue 'server;
364            }
365            sleep(Duration::from_millis(20));
366        }
367        panic!("failed to reach state change: OK");
368    }
369}
370
371impl Drop for RedisCluster {
372    fn drop(&mut self) {
373        self.stop()
374    }
375}