Skip to main content

redis_test/
cluster.rs

1use std::{env, process, thread::sleep, time::Duration};
2
3use tempfile::TempDir;
4
5use crate::{
6    server::{Module, RedisServer},
7    utils::{TlsFilePaths, build_keys_and_certs_for_tls_ext, get_random_available_port},
8};
9
10pub struct RedisClusterConfiguration {
11    pub num_nodes: u16,
12    pub num_replicas: u16,
13    pub modules: Vec<Module>,
14    pub tls_insecure: bool,
15    pub mtls_enabled: bool,
16    pub ports: Vec<u16>,
17    pub certs_with_ip_alts: bool,
18}
19
20impl RedisClusterConfiguration {
21    pub fn single_replica_config() -> Self {
22        Self {
23            num_nodes: 6,
24            num_replicas: 1,
25            ..Default::default()
26        }
27    }
28}
29
30impl Default for RedisClusterConfiguration {
31    fn default() -> Self {
32        Self {
33            num_nodes: 3,
34            num_replicas: 0,
35            modules: vec![],
36            tls_insecure: true,
37            mtls_enabled: false,
38            ports: vec![],
39            certs_with_ip_alts: true,
40        }
41    }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45#[non_exhaustive]
46pub enum ClusterType {
47    Tcp,
48    TcpTls,
49}
50
51impl ClusterType {
52    pub fn get_intended() -> ClusterType {
53        match env::var("REDISRS_SERVER_TYPE")
54            .ok()
55            .as_ref()
56            .map(|x| &x[..])
57        {
58            Some("tcp") => ClusterType::Tcp,
59            Some("tcp+tls") => ClusterType::TcpTls,
60            Some(val) => {
61                panic!("Unknown server type {val:?}");
62            }
63            None => ClusterType::Tcp,
64        }
65    }
66
67    fn build_addr(port: u16) -> redis::ConnectionAddr {
68        match ClusterType::get_intended() {
69            ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
70            ClusterType::TcpTls => redis::ConnectionAddr::TcpTls {
71                host: "127.0.0.1".into(),
72                port,
73                insecure: true,
74                tls_params: None,
75            },
76        }
77    }
78}
79
80fn port_in_use(addr: &str) -> bool {
81    let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address");
82    let socket = socket2::Socket::new(
83        socket2::Domain::for_address(socket_addr),
84        socket2::Type::STREAM,
85        None,
86    )
87    .expect("Failed to create socket");
88
89    socket.connect(&socket_addr.into()).is_ok()
90}
91
92pub struct RedisCluster {
93    pub servers: Vec<RedisServer>,
94    pub folders: Vec<TempDir>,
95    pub tls_paths: Option<TlsFilePaths>,
96}
97
98impl RedisCluster {
99    pub fn username() -> &'static str {
100        "hello"
101    }
102
103    pub fn password() -> &'static str {
104        "world"
105    }
106
107    pub fn new(configuration: RedisClusterConfiguration) -> RedisCluster {
108        let RedisClusterConfiguration {
109            num_nodes: nodes,
110            num_replicas: replicas,
111            modules,
112            tls_insecure,
113            mtls_enabled,
114            ports,
115            certs_with_ip_alts,
116        } = configuration;
117
118        let optional_ports = if ports.is_empty() {
119            vec![None; nodes as usize]
120        } else {
121            assert!(ports.len() == nodes as usize);
122            ports.into_iter().map(Some).collect()
123        };
124        let mut chosen_ports = std::collections::HashSet::new();
125
126        let mut folders = vec![];
127        let mut addrs = vec![];
128        let mut tls_paths = None;
129
130        let mut is_tls = false;
131
132        if let ClusterType::TcpTls = ClusterType::get_intended() {
133            // Create a shared set of keys in cluster mode
134            let tempdir = tempfile::Builder::new()
135                .prefix("redis")
136                .tempdir()
137                .expect("failed to create tempdir");
138            let files = build_keys_and_certs_for_tls_ext(&tempdir, certs_with_ip_alts);
139            folders.push(tempdir);
140            tls_paths = Some(files);
141            is_tls = true;
142        }
143
144        let max_attempts = 5;
145
146        let mut make_server = |port| {
147            RedisServer::new_with_addr_tls_modules_and_spawner(
148                ClusterType::build_addr(port),
149                None,
150                tls_paths.clone(),
151                mtls_enabled,
152                &modules,
153                |cmd| {
154                    let tempdir = tempfile::Builder::new()
155                        .prefix("redis")
156                        .tempdir()
157                        .expect("failed to create tempdir");
158                    let acl_path = tempdir.path().join("users.acl");
159                    let acl_content = format!(
160                        "user {} on allcommands allkeys >{}",
161                        Self::username(),
162                        Self::password()
163                    );
164                    std::fs::write(&acl_path, acl_content).expect("failed to write acl file");
165                    cmd.arg("--cluster-enabled")
166                        .arg("yes")
167                        .arg("--cluster-config-file")
168                        .arg(tempdir.path().join("nodes.conf"))
169                        .arg("--cluster-node-timeout")
170                        .arg("5000")
171                        .arg("--aclfile")
172                        .arg(&acl_path);
173                    if is_tls {
174                        cmd.arg("--tls-cluster").arg("yes");
175                        if replicas > 0 {
176                            cmd.arg("--tls-replication").arg("yes");
177                        }
178                    }
179                    cmd.current_dir(tempdir.path());
180                    folders.push(tempdir);
181                    cmd.spawn().unwrap()
182                },
183            )
184        };
185
186        let verify_server = |server: &mut RedisServer| {
187            let process = &mut server.process;
188            match process.try_wait() {
189                Ok(Some(status)) => {
190                    let log_file_contents = server.log_file_contents();
191                    let err = format!(
192                        "redis server creation failed with status {status:?}.\nlog file: {log_file_contents:?}"
193                    );
194                    Err(err)
195                }
196                Ok(None) => {
197                    // wait for 10 seconds for the server to be available.
198                    let max_attempts = 200;
199                    let mut cur_attempts = 0;
200                    loop {
201                        if cur_attempts == max_attempts {
202                            let log_file_contents = server.log_file_contents();
203                            break Err(format!(
204                                "redis server creation failed: Address {} closed. {log_file_contents:?}",
205                                server.addr
206                            ));
207                        } else if port_in_use(&server.addr.to_string()) {
208                            break Ok(());
209                        }
210                        eprintln!("Waiting for redis process to initialize");
211                        sleep(Duration::from_millis(50));
212                        cur_attempts += 1;
213                    }
214                }
215                Err(e) => {
216                    panic!("Unexpected error in redis server creation {e}");
217                }
218            }
219        };
220
221        let servers = optional_ports
222            .into_iter()
223            .map(|port_option| {
224                for _ in 0..5 {
225                    let port = match port_option {
226                        Some(port) => port,
227                        None => loop {
228                            let port = get_random_available_port();
229                            if chosen_ports.contains(&port) {
230                                continue;
231                            }
232                            chosen_ports.insert(port);
233                            break port;
234                        },
235                    };
236                    let mut server = make_server(port);
237                    sleep(Duration::from_millis(50));
238
239                    match verify_server(&mut server) {
240                        Ok(_) => {
241                            let addr = format!("127.0.0.1:{port}");
242                            addrs.push(addr.clone());
243                            return server;
244                        }
245                        Err(err) => eprintln!("{err}"),
246                    }
247                }
248                panic!("Exhausted retries");
249            })
250            .collect();
251
252        let mut cmd = process::Command::new("redis-cli");
253        cmd.stdout(process::Stdio::piped())
254            .arg("--cluster")
255            .arg("create")
256            .args(&addrs);
257        if replicas > 0 {
258            cmd.arg("--cluster-replicas").arg(replicas.to_string());
259        }
260        cmd.arg("--cluster-yes");
261
262        if is_tls {
263            if mtls_enabled {
264                if let Some(TlsFilePaths {
265                    redis_crt,
266                    redis_key,
267                    ca_crt,
268                }) = &tls_paths
269                {
270                    cmd.arg("--cert");
271                    cmd.arg(redis_crt);
272                    cmd.arg("--key");
273                    cmd.arg(redis_key);
274                    cmd.arg("--cacert");
275                    cmd.arg(ca_crt);
276                    cmd.arg("--tls");
277                }
278            } else if !tls_insecure && tls_paths.is_some() {
279                let ca_crt = &tls_paths.as_ref().unwrap().ca_crt;
280                cmd.arg("--tls").arg("--cacert").arg(ca_crt);
281            } else {
282                cmd.arg("--tls").arg("--insecure");
283            }
284        }
285
286        let mut cur_attempts = 0;
287        loop {
288            let output = cmd.output().unwrap();
289            if output.status.success() {
290                break;
291            } else {
292                let err = format!("Cluster creation failed: {output:?}");
293                if cur_attempts == max_attempts {
294                    panic!("{err}");
295                }
296                eprintln!("Retrying: {err}");
297                sleep(Duration::from_millis(50));
298                cur_attempts += 1;
299            }
300        }
301
302        let cluster = RedisCluster {
303            servers,
304            folders,
305            tls_paths,
306        };
307        if replicas > 0 {
308            cluster.wait_for_replicas(replicas);
309        }
310
311        wait_for_status_ok(&cluster);
312        cluster
313    }
314
315    fn wait_for_replicas(&self, replicas: u16) {
316        'server: for server in &self.servers {
317            let conn_info = server.connection_info();
318            eprintln!(
319                "waiting until {:?} knows required number of replicas",
320                conn_info.addr()
321            );
322
323            let client = redis::Client::open(server.connection_info()).unwrap();
324
325            let mut con = client.get_connection().unwrap();
326
327            // retry 500 times
328            for _ in 1..500 {
329                let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap();
330                let slots: Vec<Vec<redis::Value>> = redis::from_redis_value(value).unwrap();
331
332                // all slots should have following items:
333                // [start slot range, end slot range, master's IP, replica1's IP, replica2's IP,... ]
334                if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) {
335                    continue 'server;
336                }
337
338                sleep(Duration::from_millis(100));
339            }
340
341            panic!("failed to create enough replicas");
342        }
343    }
344
345    pub fn stop(&mut self) {
346        for server in &mut self.servers {
347            server.stop();
348        }
349    }
350
351    pub fn iter_servers(&self) -> impl Iterator<Item = &RedisServer> {
352        self.servers.iter()
353    }
354}
355
356fn wait_for_status_ok(cluster: &RedisCluster) {
357    'server: for server in &cluster.servers {
358        let log_file = RedisServer::log_file(&server.tempdir);
359
360        for _ in 1..500 {
361            let contents =
362                std::fs::read_to_string(&log_file).expect("Should have been able to read the file");
363
364            if contents.contains("Cluster state changed: ok") {
365                continue 'server;
366            }
367            sleep(Duration::from_millis(20));
368        }
369        panic!("failed to reach state change: OK");
370    }
371}
372
373impl Drop for RedisCluster {
374    fn drop(&mut self) {
375        self.stop()
376    }
377}