raft_engine/
write_barrier.rs

1// Copyright (c) 2017-present, PingCAP, Inc. Licensed under Apache-2.0.
2
3//! Synchronizer of writes.
4//!
5//! This module relies heavily on unsafe codes. Extra call site constraints are
6//! required to maintain memory safety. Use it with great caution.
7
8use std::cell::Cell;
9use std::marker::PhantomData;
10use std::ptr::NonNull;
11use std::time::Instant;
12
13use fail::fail_point;
14use parking_lot::{Condvar, Mutex};
15
16use crate::PerfContext;
17
18type Ptr<T> = Option<NonNull<T>>;
19
20///
21pub struct Writer<P, O> {
22    next: Cell<Ptr<Writer<P, O>>>,
23    payload: *mut P,
24    output: Option<O>,
25
26    pub(crate) sync: bool,
27    pub(crate) entered_time: Option<Instant>,
28    pub(crate) perf_context_diff: PerfContext,
29}
30
31impl<P, O> Writer<P, O> {
32    /// Creates a new writer.
33    ///
34    /// # Safety
35    ///
36    /// Data pointed by `payload` is mutably referenced by this writer. Do not
37    /// access the payload by its original name during this writer's lifetime.
38    pub fn new(payload: &mut P, sync: bool) -> Self {
39        Writer {
40            next: Cell::new(None),
41            payload: payload as *mut _,
42            output: None,
43            sync,
44            entered_time: None,
45            perf_context_diff: PerfContext::default(),
46        }
47    }
48
49    /// Returns a mutable reference to the payload.
50    pub fn mut_payload(&mut self) -> &mut P {
51        unsafe { &mut *self.payload }
52    }
53
54    /// Sets the output. This method is re-entrant.
55    pub fn set_output(&mut self, output: O) {
56        self.output = Some(output);
57    }
58
59    /// Consumes itself and yields an output.
60    ///
61    /// # Panics
62    ///
63    /// Panics if called before being processed by a [`WriteBarrier`] or setting
64    /// the output itself.
65    pub fn finish(mut self) -> O {
66        self.output.take().unwrap()
67    }
68
69    fn get_next(&self) -> Ptr<Writer<P, O>> {
70        self.next.get()
71    }
72
73    fn set_next(&self, next: Ptr<Writer<P, O>>) {
74        self.next.set(next);
75    }
76}
77
78/// A collection of writers. User thread (leader) that receives a [`WriteGroup`]
79/// is responsible for processing its containing writers.
80pub struct WriteGroup<'a, 'b, P: 'a, O: 'a> {
81    start: Ptr<Writer<P, O>>,
82    back: Ptr<Writer<P, O>>,
83
84    ref_barrier: &'a WriteBarrier<P, O>,
85    marker: PhantomData<&'b Writer<P, O>>,
86}
87
88impl<'a, 'b, P, O> WriteGroup<'a, 'b, P, O> {
89    pub fn iter_mut(&mut self) -> WriterIter<'_, 'a, 'b, P, O> {
90        WriterIter {
91            start: self.start,
92            back: self.back,
93            marker: PhantomData,
94        }
95    }
96}
97
98impl<'a, 'b, P, O> Drop for WriteGroup<'a, 'b, P, O> {
99    fn drop(&mut self) {
100        self.ref_barrier.leader_exit();
101    }
102}
103
104/// An iterator over the [`Writer`]s in one [`WriteGroup`].
105pub struct WriterIter<'a, 'b, 'c, P: 'c, O: 'c> {
106    start: Ptr<Writer<P, O>>,
107    back: Ptr<Writer<P, O>>,
108    marker: PhantomData<&'a WriteGroup<'b, 'c, P, O>>,
109}
110
111impl<'a, 'b, 'c, P, O> Iterator for WriterIter<'a, 'b, 'c, P, O> {
112    type Item = &'a mut Writer<P, O>;
113
114    fn next(&mut self) -> Option<Self::Item> {
115        if self.start.is_none() {
116            None
117        } else {
118            let writer = unsafe { self.start.unwrap().as_mut() };
119            if self.start == self.back {
120                self.start = None;
121            } else {
122                self.start = writer.get_next();
123            }
124            Some(writer)
125        }
126    }
127}
128
129struct WriteBarrierInner<P, O> {
130    head: Cell<Ptr<Writer<P, O>>>,
131    tail: Cell<Ptr<Writer<P, O>>>,
132
133    pending_leader: Cell<Ptr<Writer<P, O>>>,
134    pending_index: Cell<usize>,
135}
136
137unsafe impl<P: Send, O: Send> Send for WriteBarrierInner<P, O> {}
138
139impl<P, O> Default for WriteBarrierInner<P, O> {
140    fn default() -> Self {
141        WriteBarrierInner {
142            head: Cell::new(None),
143            tail: Cell::new(None),
144            pending_leader: Cell::new(None),
145            pending_index: Cell::new(0),
146        }
147    }
148}
149
150/// A synchronizer of [`Writer`]s.
151pub struct WriteBarrier<P, O> {
152    inner: Mutex<WriteBarrierInner<P, O>>,
153    leader_cv: Condvar,
154    follower_cvs: [Condvar; 2],
155}
156
157impl<P, O> Default for WriteBarrier<P, O> {
158    fn default() -> Self {
159        WriteBarrier {
160            leader_cv: Condvar::new(),
161            follower_cvs: [Condvar::new(), Condvar::new()],
162            inner: Mutex::new(WriteBarrierInner::default()),
163        }
164    }
165}
166
167impl<P, O> WriteBarrier<P, O> {
168    /// Waits until the caller should perform some work. If `writer` has become
169    /// the leader of a set of writers, returns a [`WriteGroup`] that contains
170    /// them, `writer` included.
171    pub fn enter<'a>(&self, writer: &'a mut Writer<P, O>) -> Option<WriteGroup<'_, 'a, P, O>> {
172        let node = unsafe { Some(NonNull::new_unchecked(writer)) };
173        let mut inner = self.inner.lock();
174        if let Some(tail) = inner.tail.get() {
175            unsafe {
176                tail.as_ref().set_next(node);
177            }
178            inner.tail.set(node);
179
180            if inner.pending_leader.get().is_some() {
181                // follower of next write group.
182                self.follower_cvs[inner.pending_index.get() % 2].wait(&mut inner);
183                return None;
184            } else {
185                // leader of next write group.
186                inner.pending_leader.set(node);
187                inner
188                    .pending_index
189                    .set(inner.pending_index.get().wrapping_add(1));
190                //
191                self.leader_cv.wait(&mut inner);
192                inner.pending_leader.set(None);
193            }
194        } else {
195            // leader of a empty write group. proceed directly.
196            debug_assert!(inner.pending_leader.get().is_none());
197            inner.head.set(node);
198            inner.tail.set(node);
199        }
200
201        Some(WriteGroup {
202            start: node,
203            back: inner.tail.get(),
204            ref_barrier: self,
205            marker: PhantomData,
206        })
207    }
208
209    /// Must called when write group leader finishes processing its responsible
210    /// writers, and next write group should be formed.
211    fn leader_exit(&self) {
212        fail_point!("write_barrier::leader_exit", |_| {});
213        let inner = self.inner.lock();
214        if let Some(leader) = inner.pending_leader.get() {
215            // wake up leader of next write group.
216            self.leader_cv.notify_one();
217            // wake up follower of current write group.
218            self.follower_cvs[inner.pending_index.get().wrapping_sub(1) % 2].notify_all();
219            inner.head.set(Some(leader));
220        } else {
221            // wake up follower of current write group.
222            self.follower_cvs[inner.pending_index.get() % 2].notify_all();
223            inner.head.set(None);
224            inner.tail.set(None);
225        }
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use std::sync::mpsc;
233    use std::sync::{Arc, Barrier};
234    use std::thread::{self, Builder as ThreadBuilder};
235    use std::time::Duration;
236
237    #[test]
238    fn test_sequential_groups() {
239        let barrier: WriteBarrier<(), u32> = Default::default();
240        let mut leaders = 0;
241        let mut processed_writers = 0;
242
243        for _ in 0..4 {
244            let mut writer = Writer::new(&mut (), false);
245            {
246                let mut wg = barrier.enter(&mut writer).unwrap();
247                leaders += 1;
248                for writer in wg.iter_mut() {
249                    writer.set_output(7);
250                    processed_writers += 1;
251                }
252            }
253            assert_eq!(writer.finish(), 7);
254        }
255
256        assert_eq!(processed_writers, 4);
257        assert_eq!(leaders, 4);
258    }
259
260    struct ConcurrentWriteContext {
261        barrier: Arc<WriteBarrier<u32, u32>>,
262
263        seq: u32,
264        ths: Vec<thread::JoinHandle<()>>,
265        leader_exit_tx: mpsc::SyncSender<()>,
266        leader_exit_rx: mpsc::Receiver<()>,
267    }
268
269    impl ConcurrentWriteContext {
270        fn new() -> Self {
271            let (leader_exit_tx, leader_exit_rx) = mpsc::sync_channel(0);
272            Self {
273                barrier: Default::default(),
274                seq: 0,
275                ths: Vec::new(),
276                leader_exit_tx,
277                leader_exit_rx,
278            }
279        }
280
281        // 1) create `n` writers and form a new write group
282        // 2) current active write group finishes writing and exits
283        // 3) the new write group enters writing phrase
284        fn step(&mut self, n: usize) {
285            if self.ths.is_empty() {
286                // ensure there is at least one active writer.
287                self.seq += 1;
288                let (leader_enter_tx, leader_enter_rx) = mpsc::channel();
289
290                let barrier = self.barrier.clone();
291                let leader_exit_tx = self.leader_exit_tx.clone();
292                let mut seq = self.seq;
293                self.ths.push(
294                    ThreadBuilder::new()
295                        .spawn(move || {
296                            let mut writer = Writer::new(&mut seq, false);
297                            {
298                                let mut wg = barrier.enter(&mut writer).unwrap();
299                                leader_enter_tx.send(()).unwrap();
300                                let mut n = 0;
301                                for w in wg.iter_mut() {
302                                    let p = *w.mut_payload();
303                                    w.set_output(p);
304                                    n += 1;
305                                }
306                                assert_eq!(n, 1);
307                                leader_exit_tx.send(()).unwrap();
308                            }
309                            assert_eq!(writer.finish(), seq);
310                        })
311                        .unwrap(),
312                );
313
314                leader_enter_rx.recv().unwrap();
315            }
316
317            let prev_writers = self.ths.len();
318            let (leader_enter_tx, leader_enter_rx) = mpsc::channel();
319            let start_thread = Arc::new(Barrier::new(n + 1));
320            for _ in 0..n {
321                self.seq += 1;
322
323                let barrier = self.barrier.clone();
324                let start_thread = start_thread.clone();
325                let leader_enter_tx_clone = leader_enter_tx.clone();
326                let leader_exit_tx = self.leader_exit_tx.clone();
327                let mut seq = self.seq;
328                self.ths.push(
329                    ThreadBuilder::new()
330                        .spawn(move || {
331                            let mut writer = Writer::new(&mut seq, false);
332                            start_thread.wait();
333                            if let Some(mut wg) = barrier.enter(&mut writer) {
334                                leader_enter_tx_clone.send(()).unwrap();
335                                let mut idx = 0;
336                                for w in wg.iter_mut() {
337                                    let p = *w.mut_payload();
338                                    w.set_output(p);
339                                    idx += 1;
340                                }
341                                assert_eq!(idx, n as u32);
342                                leader_exit_tx.send(()).unwrap();
343                            }
344                            assert_eq!(writer.finish(), seq);
345                        })
346                        .unwrap(),
347                );
348            }
349            start_thread.wait();
350            std::thread::sleep(Duration::from_millis(100));
351            // unblock current leader
352            self.leader_exit_rx.recv().unwrap();
353            for th in self.ths.drain(0..prev_writers) {
354                th.join().unwrap();
355            }
356            // make sure new leader is ready
357            leader_enter_rx.recv().unwrap();
358        }
359
360        fn join(&mut self) {
361            self.leader_exit_rx.recv().unwrap();
362            for th in self.ths.drain(..) {
363                th.join().unwrap();
364            }
365        }
366    }
367
368    #[test]
369    fn test_parallel_groups() {
370        let mut ctx = ConcurrentWriteContext::new();
371        for i in 1..5 {
372            ctx.step(i);
373        }
374        ctx.join();
375    }
376}