redis_test/
cluster.rs

1use std::{env, process, thread::sleep, time::Duration};
2
3use tempfile::TempDir;
4
5use crate::{
6    server::{Module, RedisServer},
7    utils::{TlsFilePaths, build_keys_and_certs_for_tls_ext, get_random_available_port},
8};
9
10pub struct RedisClusterConfiguration {
11    pub num_nodes: u16,
12    pub num_replicas: u16,
13    pub modules: Vec<Module>,
14    pub tls_insecure: bool,
15    pub mtls_enabled: bool,
16    pub ports: Vec<u16>,
17    pub certs_with_ip_alts: bool,
18}
19
20impl RedisClusterConfiguration {
21    pub fn single_replica_config() -> Self {
22        Self {
23            num_nodes: 6,
24            num_replicas: 1,
25            ..Default::default()
26        }
27    }
28}
29
30impl Default for RedisClusterConfiguration {
31    fn default() -> Self {
32        Self {
33            num_nodes: 3,
34            num_replicas: 0,
35            modules: vec![],
36            tls_insecure: true,
37            mtls_enabled: false,
38            ports: vec![],
39            certs_with_ip_alts: true,
40        }
41    }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45#[non_exhaustive]
46pub enum ClusterType {
47    Tcp,
48    TcpTls,
49}
50
51impl ClusterType {
52    pub fn get_intended() -> ClusterType {
53        match env::var("REDISRS_SERVER_TYPE")
54            .ok()
55            .as_ref()
56            .map(|x| &x[..])
57        {
58            Some("tcp") => ClusterType::Tcp,
59            Some("tcp+tls") => ClusterType::TcpTls,
60            Some(val) => {
61                panic!("Unknown server type {val:?}");
62            }
63            None => ClusterType::Tcp,
64        }
65    }
66
67    fn build_addr(port: u16) -> redis::ConnectionAddr {
68        match ClusterType::get_intended() {
69            ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
70            ClusterType::TcpTls => redis::ConnectionAddr::TcpTls {
71                host: "127.0.0.1".into(),
72                port,
73                insecure: true,
74                tls_params: None,
75            },
76        }
77    }
78}
79
80fn port_in_use(addr: &str) -> bool {
81    let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address");
82    let socket = socket2::Socket::new(
83        socket2::Domain::for_address(socket_addr),
84        socket2::Type::STREAM,
85        None,
86    )
87    .expect("Failed to create socket");
88
89    socket.connect(&socket_addr.into()).is_ok()
90}
91
92pub struct RedisCluster {
93    pub servers: Vec<RedisServer>,
94    pub folders: Vec<TempDir>,
95    pub tls_paths: Option<TlsFilePaths>,
96}
97
98impl RedisCluster {
99    pub fn username() -> &'static str {
100        "hello"
101    }
102
103    pub fn password() -> &'static str {
104        "world"
105    }
106
107    pub fn new(configuration: RedisClusterConfiguration) -> RedisCluster {
108        let RedisClusterConfiguration {
109            num_nodes: nodes,
110            num_replicas: replicas,
111            modules,
112            tls_insecure,
113            mtls_enabled,
114            ports,
115            certs_with_ip_alts,
116        } = configuration;
117
118        let optional_ports = if ports.is_empty() {
119            vec![None; nodes as usize]
120        } else {
121            assert!(ports.len() == nodes as usize);
122            ports.into_iter().map(Some).collect()
123        };
124        let mut chosen_ports = std::collections::HashSet::new();
125
126        let mut folders = vec![];
127        let mut addrs = vec![];
128        let mut tls_paths = None;
129
130        let mut is_tls = false;
131
132        if let ClusterType::TcpTls = ClusterType::get_intended() {
133            // Create a shared set of keys in cluster mode
134            let tempdir = tempfile::Builder::new()
135                .prefix("redis")
136                .tempdir()
137                .expect("failed to create tempdir");
138            let files = build_keys_and_certs_for_tls_ext(&tempdir, certs_with_ip_alts);
139            folders.push(tempdir);
140            tls_paths = Some(files);
141            is_tls = true;
142        }
143
144        let max_attempts = 5;
145
146        let mut make_server = |port| {
147            RedisServer::new_with_addr_tls_modules_and_spawner(
148                ClusterType::build_addr(port),
149                None,
150                tls_paths.clone(),
151                mtls_enabled,
152                &modules,
153                |cmd| {
154                    let tempdir = tempfile::Builder::new()
155                        .prefix("redis")
156                        .tempdir()
157                        .expect("failed to create tempdir");
158                    let acl_path = tempdir.path().join("users.acl");
159                    let acl_content = format!(
160                        "user {} on allcommands allkeys >{}",
161                        Self::username(),
162                        Self::password()
163                    );
164                    std::fs::write(&acl_path, acl_content).expect("failed to write acl file");
165                    cmd.arg("--cluster-enabled")
166                        .arg("yes")
167                        .arg("--cluster-config-file")
168                        .arg(tempdir.path().join("nodes.conf"))
169                        .arg("--cluster-node-timeout")
170                        .arg("5000")
171                        .arg("--appendonly")
172                        .arg("yes")
173                        .arg("--aclfile")
174                        .arg(&acl_path);
175                    if is_tls {
176                        cmd.arg("--tls-cluster").arg("yes");
177                        if replicas > 0 {
178                            cmd.arg("--tls-replication").arg("yes");
179                        }
180                    }
181                    cmd.current_dir(tempdir.path());
182                    folders.push(tempdir);
183                    cmd.spawn().unwrap()
184                },
185            )
186        };
187
188        let verify_server = |server: &mut RedisServer| {
189            let process = &mut server.process;
190            match process.try_wait() {
191                Ok(Some(status)) => {
192                    let log_file_contents = server.log_file_contents();
193                    let err = format!(
194                        "redis server creation failed with status {status:?}.\nlog file: {log_file_contents:?}"
195                    );
196                    Err(err)
197                }
198                Ok(None) => {
199                    // wait for 10 seconds for the server to be available.
200                    let max_attempts = 200;
201                    let mut cur_attempts = 0;
202                    loop {
203                        if cur_attempts == max_attempts {
204                            let log_file_contents = server.log_file_contents();
205                            break Err(format!(
206                                "redis server creation failed: Address {} closed. {log_file_contents:?}",
207                                server.addr
208                            ));
209                        } else if port_in_use(&server.addr.to_string()) {
210                            break Ok(());
211                        }
212                        eprintln!("Waiting for redis process to initialize");
213                        sleep(Duration::from_millis(50));
214                        cur_attempts += 1;
215                    }
216                }
217                Err(e) => {
218                    panic!("Unexpected error in redis server creation {e}");
219                }
220            }
221        };
222
223        let servers = optional_ports
224            .into_iter()
225            .map(|port_option| {
226                for _ in 0..5 {
227                    let port = match port_option {
228                        Some(port) => port,
229                        None => loop {
230                            let port = get_random_available_port();
231                            if chosen_ports.contains(&port) {
232                                continue;
233                            }
234                            chosen_ports.insert(port);
235                            break port;
236                        },
237                    };
238                    let mut server = make_server(port);
239                    sleep(Duration::from_millis(50));
240
241                    match verify_server(&mut server) {
242                        Ok(_) => {
243                            let addr = format!("127.0.0.1:{port}");
244                            addrs.push(addr.clone());
245                            return server;
246                        }
247                        Err(err) => eprintln!("{err}"),
248                    }
249                }
250                panic!("Exhausted retries");
251            })
252            .collect();
253
254        let mut cmd = process::Command::new("redis-cli");
255        cmd.stdout(process::Stdio::piped())
256            .arg("--cluster")
257            .arg("create")
258            .args(&addrs);
259        if replicas > 0 {
260            cmd.arg("--cluster-replicas").arg(replicas.to_string());
261        }
262        cmd.arg("--cluster-yes");
263
264        if is_tls {
265            if mtls_enabled {
266                if let Some(TlsFilePaths {
267                    redis_crt,
268                    redis_key,
269                    ca_crt,
270                }) = &tls_paths
271                {
272                    cmd.arg("--cert");
273                    cmd.arg(redis_crt);
274                    cmd.arg("--key");
275                    cmd.arg(redis_key);
276                    cmd.arg("--cacert");
277                    cmd.arg(ca_crt);
278                    cmd.arg("--tls");
279                }
280            } else if !tls_insecure && tls_paths.is_some() {
281                let ca_crt = &tls_paths.as_ref().unwrap().ca_crt;
282                cmd.arg("--tls").arg("--cacert").arg(ca_crt);
283            } else {
284                cmd.arg("--tls").arg("--insecure");
285            }
286        }
287
288        let mut cur_attempts = 0;
289        loop {
290            let output = cmd.output().unwrap();
291            if output.status.success() {
292                break;
293            } else {
294                let err = format!("Cluster creation failed: {output:?}");
295                if cur_attempts == max_attempts {
296                    panic!("{err}");
297                }
298                eprintln!("Retrying: {err}");
299                sleep(Duration::from_millis(50));
300                cur_attempts += 1;
301            }
302        }
303
304        let cluster = RedisCluster {
305            servers,
306            folders,
307            tls_paths,
308        };
309        if replicas > 0 {
310            cluster.wait_for_replicas(replicas);
311        }
312
313        wait_for_status_ok(&cluster);
314        cluster
315    }
316
317    fn wait_for_replicas(&self, replicas: u16) {
318        'server: for server in &self.servers {
319            let conn_info = server.connection_info();
320            eprintln!(
321                "waiting until {:?} knows required number of replicas",
322                conn_info.addr()
323            );
324
325            let client = redis::Client::open(server.connection_info()).unwrap();
326
327            let mut con = client.get_connection().unwrap();
328
329            // retry 500 times
330            for _ in 1..500 {
331                let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap();
332                let slots: Vec<Vec<redis::Value>> = redis::from_redis_value(value).unwrap();
333
334                // all slots should have following items:
335                // [start slot range, end slot range, master's IP, replica1's IP, replica2's IP,... ]
336                if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) {
337                    continue 'server;
338                }
339
340                sleep(Duration::from_millis(100));
341            }
342
343            panic!("failed to create enough replicas");
344        }
345    }
346
347    pub fn stop(&mut self) {
348        for server in &mut self.servers {
349            server.stop();
350        }
351    }
352
353    pub fn iter_servers(&self) -> impl Iterator<Item = &RedisServer> {
354        self.servers.iter()
355    }
356}
357
358fn wait_for_status_ok(cluster: &RedisCluster) {
359    'server: for server in &cluster.servers {
360        let log_file = RedisServer::log_file(&server.tempdir);
361
362        for _ in 1..500 {
363            let contents =
364                std::fs::read_to_string(&log_file).expect("Should have been able to read the file");
365
366            if contents.contains("Cluster state changed: ok") {
367                continue 'server;
368            }
369            sleep(Duration::from_millis(20));
370        }
371        panic!("failed to reach state change: OK");
372    }
373}
374
375impl Drop for RedisCluster {
376    fn drop(&mut self) {
377        self.stop()
378    }
379}