dtypes/redis/
barrier.rs

1/// The waiting script.
2/// Is is used to indicate, if there is a thread waiting for the barrier.
3/// Returns 1 if #num thread waiting >= #num threads that should wait. Otherwise 0.
4/// If the thread is the leader, it returns 2.
5/// Needs to be used n a loop to update expiration time to signal your wait.
6///
7/// Takes 4 arguments:
8/// 1. The key of the barrier.
9/// 2. The id of the barrier itself.
10/// 3. The number of threads that should wait for the barrier.
11/// 4. The timeout in seconds.
12const WAITING_SCRIPT: &str = r#"
13redis.call("set", ARGV[1] .. ":waiting:" .. ARGV[2], 1, "EX", ARGV[4])
14
15local leader_id = redis.call("get", ARGV[1] .. ":leader")
16if leader_id then
17    if leader_id == ARGV[2] then
18        return 2
19    end
20    return 1
21end
22
23local count = 0
24local cursor = "0"
25
26repeat
27    local res = redis.call("scan", cursor, "MATCH", ARGV[1] .. ":waiting:*", "COUNT", ARGV[3] + 1)
28    if next(res[2]) ~= nil then
29        count = count + #res[2]
30    end
31    cursor = res[1]
32until cursor == "0"
33
34if count < tonumber(ARGV[3]) then
35    return 0
36end
37
38if not leader_id then
39    if redis.call("set", ARGV[1] .. ":leader" , ARGV[2], "EX", ARGV[4], "NX") then
40        return 2
41    end
42end
43
44return 1
45"#;
46
47/// The reset script.
48/// It is used to reset the barrier, so you can reuse it.
49/// Essentially it deletes all keys that are used by the barrier.
50///
51/// Takes 1 Argument:
52/// 1. The key of the value to lock.
53/// 2. The uuid of the barrier.
54/// 3. The number of threads that should wait for the barrier.
55const RESET_SCRIPT: &str = r#"
56redis.call("del", ARGV[1] .. ":waiting:" .. ARGV[2])
57
58local count = 0
59local cursor = "0"
60
61repeat
62    local res = redis.call("scan", cursor, "MATCH", ARGV[1] .. ":waiting:*", "COUNT", ARGV[3] + 1)
63    if next(res[2]) ~= nil then
64        count = count + #res[2]
65    end
66    cursor = res[1]
67until cursor == "0"
68
69-- if it is the last barrier, delete the leader and uuids key
70if count == 0 then
71    redis.call("del", ARGV[1] .. ":leader")
72    redis.call("del", ARGV[1] .. ":uuids")
73end
74"#;
75
76/// The uuid script.
77/// It is used to generate a uuid for the barrier.
78/// It is a very simple counter that is stored in Redis and returns all numbers only once.
79///
80/// Takes 1 Argument:
81/// 1. The key of the value to lock.
82const UUID_SCRIPT: &str = r#"
83redis.call("incr", ARGV[1] .. ":uuids")
84local val = redis.call("get", ARGV[1] .. ":uuids")
85return val
86"#;
87
88pub struct Barrier {
89    uuid: usize,
90    num: usize,
91    key: String,
92    _client: redis::Client,
93    conn: Option<redis::Connection>,
94}
95
96#[derive(PartialEq)]
97enum RedisBarrierStatus {
98    Waiting,
99    Leader,
100    Done,
101}
102
103impl From<u8> for RedisBarrierStatus {
104    fn from(val: u8) -> Self {
105        match val {
106            0 => RedisBarrierStatus::Waiting,
107            1 => RedisBarrierStatus::Done,
108            2 => RedisBarrierStatus::Leader,
109            _ => panic!("Invalid RedisBarrierStatus"),
110        }
111    }
112}
113
114/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all systems
115/// in the [`Barrier`] have rendezvoused.
116///
117/// # Examples
118///
119/// ```
120/// use dtypes::redis::sync::Barrier;
121///
122/// let client = redis::Client::open("redis://localhost:6379").unwrap();
123/// let mut  barrier = Barrier::new(1, "barrier_doc_test", client);
124/// let barrier_wait_result = barrier.wait();
125/// ```
126pub struct BarrierWaitResult(bool);
127
128impl BarrierWaitResult {
129    /// Returns `true` if this thread is the "leader thread" for the call to
130    /// [`Barrier::wait()`].
131    ///
132    /// Only one thread will have `true` returned from their result, all other
133    /// threads will have `false` returned.
134    ///
135    /// # Examples
136    ///
137    /// ```
138    /// use dtypes::redis::sync::Barrier;
139    ///
140    /// let client = redis::Client::open("redis://localhost:6379").unwrap();
141    /// let mut  barrier = Barrier::new(1, "barrier_doc_test", client);
142    /// let barrier_wait_result = barrier.wait();
143    /// println!("{:?}", barrier_wait_result.is_leader());
144    /// ```
145    pub fn is_leader(&self) -> bool {
146        self.0
147    }
148}
149
150enum BarrierError {
151    RedisError(redis::RedisError),
152}
153
154impl Barrier {
155    pub fn new(num: usize, key: &str, client: redis::Client) -> Self {
156        let mut conn = client.get_connection().unwrap();
157
158        let uuid = redis::Script::new(UUID_SCRIPT)
159            .arg(&key)
160            .arg(&num)
161            .invoke::<usize>(&mut conn)
162            .expect("Failed to create barrier");
163
164        Barrier {
165            uuid: uuid,
166            num,
167            key: key.to_string(),
168            _client: client,
169            conn: Some(conn),
170        }
171    }
172
173    /// Blocks the current thread until all threads have rendezvoused here.
174    ///
175    /// Barriers are re-usable after all threads have rendezvoused once, and can
176    /// be used continuously.
177    ///
178    /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
179    /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
180    /// from this function, and all other threads will receive a result that
181    /// will return `false` from [`BarrierWaitResult::is_leader()`].
182    ///
183    /// The barrier needs to be mutable, because it guarantees that the barrier is only used once in thread.
184    /// If you want to synchronize threads, you need to create a new barrier for each thread, so it has its own uuid.
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// use dtypes::redis::sync::Barrier;
190    /// use std::thread;
191    ///
192    /// let n = 10;
193    /// let mut handles = Vec::with_capacity(n);
194    /// let client = redis::Client::open("redis://localhost:6379").unwrap();
195    /// for _ in 0..n {
196    ///     // The same messages will be printed together.
197    ///     // You will NOT see any interleaving.
198    ///     let mut barrier = Barrier::new(n, "barrier_doc_test2", client.clone());
199    ///     handles.push(thread::spawn(move|| {
200    ///         println!("before wait");
201    ///         barrier.wait();
202    ///         println!("after wait");
203    ///     }));
204    /// }
205    /// // Wait for other threads to finish.
206    /// for handle in handles {
207    ///     handle.join().unwrap();
208    /// }
209    /// ```
210    pub fn wait(&mut self) -> BarrierWaitResult {
211        let mut conn = self.conn.take().unwrap();
212        let timeout = 2;
213
214        let mut status = RedisBarrierStatus::Waiting;
215        while status == RedisBarrierStatus::Waiting {
216            status = redis::Script::new(WAITING_SCRIPT)
217                .arg(&self.key)
218                .arg(self.uuid)
219                .arg(self.num)
220                .arg(timeout)
221                .invoke::<u8>(&mut conn)
222                .expect("Failed to wait for barrier")
223                .into();
224        }
225        self.conn = Some(conn);
226
227        if status == RedisBarrierStatus::Leader {
228            BarrierWaitResult(true)
229        } else {
230            BarrierWaitResult(false)
231        }
232    }
233}
234
235impl Drop for Barrier {
236    fn drop(&mut self) {
237        let mut conn = self.conn.take().unwrap();
238        redis::Script::new(RESET_SCRIPT)
239            .arg(&self.key)
240            .arg(self.uuid)
241            .arg(self.num)
242            .invoke::<()>(&mut conn)
243            .expect("Failed to reset barrier");
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use std::thread;
251    use std::thread::sleep;
252
253    #[test]
254    fn test_barrier_leader() {
255        let client = redis::Client::open("redis://localhost:6379").unwrap();
256        let mut barrier = Barrier::new(1, "barrier_test_leader", client);
257        let barrier_wait_result = barrier.wait();
258        assert!(barrier_wait_result.is_leader());
259    }
260
261    #[test]
262    fn test_barrier_not_leader() {
263        let client = redis::Client::open("redis://localhost:6379").unwrap();
264
265        let mut barrier = Barrier::new(2, "barrier_test_notleader", client.clone());
266
267        let h1 = thread::spawn(move || {
268            let mut barrier = Barrier::new(2, "barrier_test_notleader", client);
269            let barrier_wait_result = barrier.wait();
270            assert!(!barrier_wait_result.is_leader());
271        });
272
273        let h2 = thread::spawn(move || {
274            sleep(std::time::Duration::from_millis(1000));
275            let barrier_wait_result = barrier.wait();
276            assert!(barrier_wait_result.is_leader());
277        });
278
279        h1.join().unwrap();
280        h2.join().unwrap();
281    }
282
283    #[test]
284    fn test_barrier_slow_check() {
285        let n = 10;
286        let mut handles = Vec::with_capacity(n);
287        let client = redis::Client::open("redis://localhost:6379").unwrap();
288        for _ in 0..n {
289            // The same messages will be printed together.
290            // You will NOT see any interleaving.
291            let mut barrier = Barrier::new(n, "barrier_doc_test2", client.clone());
292            handles.push(thread::spawn(move || barrier.wait().is_leader()));
293        }
294        // Wait for other threads to finish.
295        assert_eq!(
296            handles
297                .into_iter()
298                .map(|h| h.join().unwrap())
299                .map(|x| if x { 1 } else { 0 })
300                .sum::<i32>(),
301            1
302        );
303    }
304
305    #[test]
306    fn test_barrier_reuse() {
307        let client = redis::Client::open("redis://localhost:6379").unwrap();
308
309        let mut barrier = Barrier::new(1, "barrier_test_reuse", client.clone());
310        barrier.wait();
311        let mut barrier = Barrier::new(1, "barrier_test_reuse", client.clone());
312        barrier.wait();
313    }
314}