1const WAITING_SCRIPT: &str = r#"
13redis.call("set", ARGV[1] .. ":waiting:" .. ARGV[2], 1, "EX", ARGV[4])
14
15local leader_id = redis.call("get", ARGV[1] .. ":leader")
16if leader_id then
17 if leader_id == ARGV[2] then
18 return 2
19 end
20 return 1
21end
22
23local count = 0
24local cursor = "0"
25
26repeat
27 local res = redis.call("scan", cursor, "MATCH", ARGV[1] .. ":waiting:*", "COUNT", ARGV[3] + 1)
28 if next(res[2]) ~= nil then
29 count = count + #res[2]
30 end
31 cursor = res[1]
32until cursor == "0"
33
34if count < tonumber(ARGV[3]) then
35 return 0
36end
37
38if not leader_id then
39 if redis.call("set", ARGV[1] .. ":leader" , ARGV[2], "EX", ARGV[4], "NX") then
40 return 2
41 end
42end
43
44return 1
45"#;
46
47const RESET_SCRIPT: &str = r#"
56redis.call("del", ARGV[1] .. ":waiting:" .. ARGV[2])
57
58local count = 0
59local cursor = "0"
60
61repeat
62 local res = redis.call("scan", cursor, "MATCH", ARGV[1] .. ":waiting:*", "COUNT", ARGV[3] + 1)
63 if next(res[2]) ~= nil then
64 count = count + #res[2]
65 end
66 cursor = res[1]
67until cursor == "0"
68
69-- if it is the last barrier, delete the leader and uuids key
70if count == 0 then
71 redis.call("del", ARGV[1] .. ":leader")
72 redis.call("del", ARGV[1] .. ":uuids")
73end
74"#;
75
76const UUID_SCRIPT: &str = r#"
83redis.call("incr", ARGV[1] .. ":uuids")
84local val = redis.call("get", ARGV[1] .. ":uuids")
85return val
86"#;
87
88pub struct Barrier {
89 uuid: usize,
90 num: usize,
91 key: String,
92 _client: redis::Client,
93 conn: Option<redis::Connection>,
94}
95
96#[derive(PartialEq)]
97enum RedisBarrierStatus {
98 Waiting,
99 Leader,
100 Done,
101}
102
103impl From<u8> for RedisBarrierStatus {
104 fn from(val: u8) -> Self {
105 match val {
106 0 => RedisBarrierStatus::Waiting,
107 1 => RedisBarrierStatus::Done,
108 2 => RedisBarrierStatus::Leader,
109 _ => panic!("Invalid RedisBarrierStatus"),
110 }
111 }
112}
113
114pub struct BarrierWaitResult(bool);
127
128impl BarrierWaitResult {
129 pub fn is_leader(&self) -> bool {
146 self.0
147 }
148}
149
150enum BarrierError {
151 RedisError(redis::RedisError),
152}
153
154impl Barrier {
155 pub fn new(num: usize, key: &str, client: redis::Client) -> Self {
156 let mut conn = client.get_connection().unwrap();
157
158 let uuid = redis::Script::new(UUID_SCRIPT)
159 .arg(&key)
160 .arg(&num)
161 .invoke::<usize>(&mut conn)
162 .expect("Failed to create barrier");
163
164 Barrier {
165 uuid: uuid,
166 num,
167 key: key.to_string(),
168 _client: client,
169 conn: Some(conn),
170 }
171 }
172
173 pub fn wait(&mut self) -> BarrierWaitResult {
211 let mut conn = self.conn.take().unwrap();
212 let timeout = 2;
213
214 let mut status = RedisBarrierStatus::Waiting;
215 while status == RedisBarrierStatus::Waiting {
216 status = redis::Script::new(WAITING_SCRIPT)
217 .arg(&self.key)
218 .arg(self.uuid)
219 .arg(self.num)
220 .arg(timeout)
221 .invoke::<u8>(&mut conn)
222 .expect("Failed to wait for barrier")
223 .into();
224 }
225 self.conn = Some(conn);
226
227 if status == RedisBarrierStatus::Leader {
228 BarrierWaitResult(true)
229 } else {
230 BarrierWaitResult(false)
231 }
232 }
233}
234
235impl Drop for Barrier {
236 fn drop(&mut self) {
237 let mut conn = self.conn.take().unwrap();
238 redis::Script::new(RESET_SCRIPT)
239 .arg(&self.key)
240 .arg(self.uuid)
241 .arg(self.num)
242 .invoke::<()>(&mut conn)
243 .expect("Failed to reset barrier");
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use std::thread;
251 use std::thread::sleep;
252
253 #[test]
254 fn test_barrier_leader() {
255 let client = redis::Client::open("redis://localhost:6379").unwrap();
256 let mut barrier = Barrier::new(1, "barrier_test_leader", client);
257 let barrier_wait_result = barrier.wait();
258 assert!(barrier_wait_result.is_leader());
259 }
260
261 #[test]
262 fn test_barrier_not_leader() {
263 let client = redis::Client::open("redis://localhost:6379").unwrap();
264
265 let mut barrier = Barrier::new(2, "barrier_test_notleader", client.clone());
266
267 let h1 = thread::spawn(move || {
268 let mut barrier = Barrier::new(2, "barrier_test_notleader", client);
269 let barrier_wait_result = barrier.wait();
270 assert!(!barrier_wait_result.is_leader());
271 });
272
273 let h2 = thread::spawn(move || {
274 sleep(std::time::Duration::from_millis(1000));
275 let barrier_wait_result = barrier.wait();
276 assert!(barrier_wait_result.is_leader());
277 });
278
279 h1.join().unwrap();
280 h2.join().unwrap();
281 }
282
283 #[test]
284 fn test_barrier_slow_check() {
285 let n = 10;
286 let mut handles = Vec::with_capacity(n);
287 let client = redis::Client::open("redis://localhost:6379").unwrap();
288 for _ in 0..n {
289 let mut barrier = Barrier::new(n, "barrier_doc_test2", client.clone());
292 handles.push(thread::spawn(move || barrier.wait().is_leader()));
293 }
294 assert_eq!(
296 handles
297 .into_iter()
298 .map(|h| h.join().unwrap())
299 .map(|x| if x { 1 } else { 0 })
300 .sum::<i32>(),
301 1
302 );
303 }
304
305 #[test]
306 fn test_barrier_reuse() {
307 let client = redis::Client::open("redis://localhost:6379").unwrap();
308
309 let mut barrier = Barrier::new(1, "barrier_test_reuse", client.clone());
310 barrier.wait();
311 let mut barrier = Barrier::new(1, "barrier_test_reuse", client.clone());
312 barrier.wait();
313 }
314}