redis_test/
cluster.rs

1use std::{env, process, thread::sleep, time::Duration};
2
3use tempfile::TempDir;
4
5use crate::{
6    server::{Module, RedisServer},
7    utils::{build_keys_and_certs_for_tls_ext, get_random_available_port, TlsFilePaths},
8};
9
10pub struct RedisClusterConfiguration {
11    pub num_nodes: u16,
12    pub num_replicas: u16,
13    pub modules: Vec<Module>,
14    pub tls_insecure: bool,
15    pub mtls_enabled: bool,
16    pub ports: Vec<u16>,
17    pub certs_with_ip_alts: bool,
18}
19
20impl RedisClusterConfiguration {
21    pub fn single_replica_config() -> Self {
22        Self {
23            num_nodes: 6,
24            num_replicas: 1,
25            ..Default::default()
26        }
27    }
28}
29
30impl Default for RedisClusterConfiguration {
31    fn default() -> Self {
32        Self {
33            num_nodes: 3,
34            num_replicas: 0,
35            modules: vec![],
36            tls_insecure: true,
37            mtls_enabled: false,
38            ports: vec![],
39            certs_with_ip_alts: true,
40        }
41    }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub enum ClusterType {
46    Tcp,
47    TcpTls,
48}
49
50impl ClusterType {
51    pub fn get_intended() -> ClusterType {
52        match env::var("REDISRS_SERVER_TYPE")
53            .ok()
54            .as_ref()
55            .map(|x| &x[..])
56        {
57            Some("tcp") => ClusterType::Tcp,
58            Some("tcp+tls") => ClusterType::TcpTls,
59            Some(val) => {
60                panic!("Unknown server type {val:?}");
61            }
62            None => ClusterType::Tcp,
63        }
64    }
65
66    fn build_addr(port: u16) -> redis::ConnectionAddr {
67        match ClusterType::get_intended() {
68            ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
69            ClusterType::TcpTls => redis::ConnectionAddr::TcpTls {
70                host: "127.0.0.1".into(),
71                port,
72                insecure: true,
73                tls_params: None,
74            },
75        }
76    }
77}
78
79fn port_in_use(addr: &str) -> bool {
80    let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address");
81    let socket = socket2::Socket::new(
82        socket2::Domain::for_address(socket_addr),
83        socket2::Type::STREAM,
84        None,
85    )
86    .expect("Failed to create socket");
87
88    socket.connect(&socket_addr.into()).is_ok()
89}
90
91pub struct RedisCluster {
92    pub servers: Vec<RedisServer>,
93    pub folders: Vec<TempDir>,
94    pub tls_paths: Option<TlsFilePaths>,
95}
96
97impl RedisCluster {
98    pub fn username() -> &'static str {
99        "hello"
100    }
101
102    pub fn password() -> &'static str {
103        "world"
104    }
105
106    pub fn new(configuration: RedisClusterConfiguration) -> RedisCluster {
107        let RedisClusterConfiguration {
108            num_nodes: nodes,
109            num_replicas: replicas,
110            modules,
111            tls_insecure,
112            mtls_enabled,
113            ports,
114            certs_with_ip_alts,
115        } = configuration;
116
117        let optional_ports = if ports.is_empty() {
118            vec![None; nodes as usize]
119        } else {
120            assert!(ports.len() == nodes as usize);
121            ports.into_iter().map(Some).collect()
122        };
123        let mut chosen_ports = std::collections::HashSet::new();
124
125        let mut folders = vec![];
126        let mut addrs = vec![];
127        let mut tls_paths = None;
128
129        let mut is_tls = false;
130
131        if let ClusterType::TcpTls = ClusterType::get_intended() {
132            // Create a shared set of keys in cluster mode
133            let tempdir = tempfile::Builder::new()
134                .prefix("redis")
135                .tempdir()
136                .expect("failed to create tempdir");
137            let files = build_keys_and_certs_for_tls_ext(&tempdir, certs_with_ip_alts);
138            folders.push(tempdir);
139            tls_paths = Some(files);
140            is_tls = true;
141        }
142
143        let max_attempts = 5;
144
145        let mut make_server = |port| {
146            RedisServer::new_with_addr_tls_modules_and_spawner(
147                ClusterType::build_addr(port),
148                None,
149                tls_paths.clone(),
150                mtls_enabled,
151                &modules,
152                |cmd| {
153                    let tempdir = tempfile::Builder::new()
154                        .prefix("redis")
155                        .tempdir()
156                        .expect("failed to create tempdir");
157                    let acl_path = tempdir.path().join("users.acl");
158                    let acl_content = format!(
159                        "user {} on allcommands allkeys >{}",
160                        Self::username(),
161                        Self::password()
162                    );
163                    std::fs::write(&acl_path, acl_content).expect("failed to write acl file");
164                    cmd.arg("--cluster-enabled")
165                        .arg("yes")
166                        .arg("--cluster-config-file")
167                        .arg(tempdir.path().join("nodes.conf"))
168                        .arg("--cluster-node-timeout")
169                        .arg("5000")
170                        .arg("--appendonly")
171                        .arg("yes")
172                        .arg("--aclfile")
173                        .arg(&acl_path);
174                    if is_tls {
175                        cmd.arg("--tls-cluster").arg("yes");
176                        if replicas > 0 {
177                            cmd.arg("--tls-replication").arg("yes");
178                        }
179                    }
180                    cmd.current_dir(tempdir.path());
181                    folders.push(tempdir);
182                    cmd.spawn().unwrap()
183                },
184            )
185        };
186
187        let verify_server = |server: &mut RedisServer| {
188            let process = &mut server.process;
189            match process.try_wait() {
190                Ok(Some(status)) => {
191                    let log_file_contents = server.log_file_contents();
192                    let err =
193                                    format!("redis server creation failed with status {status:?}.\nlog file: {log_file_contents:?}");
194                    Err(err)
195                }
196                Ok(None) => {
197                    // wait for 10 seconds for the server to be available.
198                    let max_attempts = 200;
199                    let mut cur_attempts = 0;
200                    loop {
201                        if cur_attempts == max_attempts {
202                            let log_file_contents = server.log_file_contents();
203                            break Err(format!("redis server creation failed: Address {} closed. {log_file_contents:?}", server.addr));
204                        } else if port_in_use(&server.addr.to_string()) {
205                            break Ok(());
206                        }
207                        eprintln!("Waiting for redis process to initialize");
208                        sleep(Duration::from_millis(50));
209                        cur_attempts += 1;
210                    }
211                }
212                Err(e) => {
213                    panic!("Unexpected error in redis server creation {e}");
214                }
215            }
216        };
217
218        let servers = optional_ports
219            .into_iter()
220            .map(|port_option| {
221                for _ in 0..5 {
222                    let port = match port_option {
223                        Some(port) => port,
224                        None => loop {
225                            let port = get_random_available_port();
226                            if chosen_ports.contains(&port) {
227                                continue;
228                            }
229                            chosen_ports.insert(port);
230                            break port;
231                        },
232                    };
233                    let mut server = make_server(port);
234                    sleep(Duration::from_millis(50));
235
236                    match verify_server(&mut server) {
237                        Ok(_) => {
238                            let addr = format!("127.0.0.1:{port}");
239                            addrs.push(addr.clone());
240                            return server;
241                        }
242                        Err(err) => eprintln!("{err}"),
243                    }
244                }
245                panic!("Exhausted retries");
246            })
247            .collect();
248
249        let mut cmd = process::Command::new("redis-cli");
250        cmd.stdout(process::Stdio::piped())
251            .arg("--cluster")
252            .arg("create")
253            .args(&addrs);
254        if replicas > 0 {
255            cmd.arg("--cluster-replicas").arg(replicas.to_string());
256        }
257        cmd.arg("--cluster-yes");
258
259        if is_tls {
260            if mtls_enabled {
261                if let Some(TlsFilePaths {
262                    redis_crt,
263                    redis_key,
264                    ca_crt,
265                }) = &tls_paths
266                {
267                    cmd.arg("--cert");
268                    cmd.arg(redis_crt);
269                    cmd.arg("--key");
270                    cmd.arg(redis_key);
271                    cmd.arg("--cacert");
272                    cmd.arg(ca_crt);
273                    cmd.arg("--tls");
274                }
275            } else if !tls_insecure && tls_paths.is_some() {
276                let ca_crt = &tls_paths.as_ref().unwrap().ca_crt;
277                cmd.arg("--tls").arg("--cacert").arg(ca_crt);
278            } else {
279                cmd.arg("--tls").arg("--insecure");
280            }
281        }
282
283        let mut cur_attempts = 0;
284        loop {
285            let output = cmd.output().unwrap();
286            if output.status.success() {
287                break;
288            } else {
289                let err = format!("Cluster creation failed: {output:?}");
290                if cur_attempts == max_attempts {
291                    panic!("{err}");
292                }
293                eprintln!("Retrying: {err}");
294                sleep(Duration::from_millis(50));
295                cur_attempts += 1;
296            }
297        }
298
299        let cluster = RedisCluster {
300            servers,
301            folders,
302            tls_paths,
303        };
304        if replicas > 0 {
305            cluster.wait_for_replicas(replicas);
306        }
307
308        wait_for_status_ok(&cluster);
309        cluster
310    }
311
312    fn wait_for_replicas(&self, replicas: u16) {
313        'server: for server in &self.servers {
314            let conn_info = server.connection_info();
315            eprintln!(
316                "waiting until {:?} knows required number of replicas",
317                conn_info.addr
318            );
319
320            let client = redis::Client::open(server.connection_info()).unwrap();
321
322            let mut con = client.get_connection().unwrap();
323
324            // retry 500 times
325            for _ in 1..500 {
326                let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap();
327                let slots: Vec<Vec<redis::Value>> = redis::from_owned_redis_value(value).unwrap();
328
329                // all slots should have following items:
330                // [start slot range, end slot range, master's IP, replica1's IP, replica2's IP,... ]
331                if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) {
332                    continue 'server;
333                }
334
335                sleep(Duration::from_millis(100));
336            }
337
338            panic!("failed to create enough replicas");
339        }
340    }
341
342    pub fn stop(&mut self) {
343        for server in &mut self.servers {
344            server.stop();
345        }
346    }
347
348    pub fn iter_servers(&self) -> impl Iterator<Item = &RedisServer> {
349        self.servers.iter()
350    }
351}
352
353fn wait_for_status_ok(cluster: &RedisCluster) {
354    'server: for server in &cluster.servers {
355        let log_file = RedisServer::log_file(&server.tempdir);
356
357        for _ in 1..500 {
358            let contents =
359                std::fs::read_to_string(&log_file).expect("Should have been able to read the file");
360
361            if contents.contains("Cluster state changed: ok") {
362                continue 'server;
363            }
364            sleep(Duration::from_millis(20));
365        }
366        panic!("failed to reach state change: OK");
367    }
368}
369
370impl Drop for RedisCluster {
371    fn drop(&mut self) {
372        self.stop()
373    }
374}