1use std::cell::Cell;
9use std::marker::PhantomData;
10use std::ptr::NonNull;
11use std::time::Instant;
12
13use fail::fail_point;
14use parking_lot::{Condvar, Mutex};
15
16use crate::PerfContext;
17
18type Ptr<T> = Option<NonNull<T>>;
19
20pub struct Writer<P, O> {
22 next: Cell<Ptr<Writer<P, O>>>,
23 payload: *mut P,
24 output: Option<O>,
25
26 pub(crate) sync: bool,
27 pub(crate) entered_time: Option<Instant>,
28 pub(crate) perf_context_diff: PerfContext,
29}
30
31impl<P, O> Writer<P, O> {
32 pub fn new(payload: &mut P, sync: bool) -> Self {
39 Writer {
40 next: Cell::new(None),
41 payload: payload as *mut _,
42 output: None,
43 sync,
44 entered_time: None,
45 perf_context_diff: PerfContext::default(),
46 }
47 }
48
49 pub fn mut_payload(&mut self) -> &mut P {
51 unsafe { &mut *self.payload }
52 }
53
54 pub fn set_output(&mut self, output: O) {
56 self.output = Some(output);
57 }
58
59 pub fn finish(mut self) -> O {
66 self.output.take().unwrap()
67 }
68
69 fn get_next(&self) -> Ptr<Writer<P, O>> {
70 self.next.get()
71 }
72
73 fn set_next(&self, next: Ptr<Writer<P, O>>) {
74 self.next.set(next);
75 }
76}
77
78pub struct WriteGroup<'a, 'b, P: 'a, O: 'a> {
81 start: Ptr<Writer<P, O>>,
82 back: Ptr<Writer<P, O>>,
83
84 ref_barrier: &'a WriteBarrier<P, O>,
85 marker: PhantomData<&'b Writer<P, O>>,
86}
87
88impl<'a, 'b, P, O> WriteGroup<'a, 'b, P, O> {
89 pub fn iter_mut(&mut self) -> WriterIter<'_, 'a, 'b, P, O> {
90 WriterIter {
91 start: self.start,
92 back: self.back,
93 marker: PhantomData,
94 }
95 }
96}
97
98impl<'a, 'b, P, O> Drop for WriteGroup<'a, 'b, P, O> {
99 fn drop(&mut self) {
100 self.ref_barrier.leader_exit();
101 }
102}
103
104pub struct WriterIter<'a, 'b, 'c, P: 'c, O: 'c> {
106 start: Ptr<Writer<P, O>>,
107 back: Ptr<Writer<P, O>>,
108 marker: PhantomData<&'a WriteGroup<'b, 'c, P, O>>,
109}
110
111impl<'a, 'b, 'c, P, O> Iterator for WriterIter<'a, 'b, 'c, P, O> {
112 type Item = &'a mut Writer<P, O>;
113
114 fn next(&mut self) -> Option<Self::Item> {
115 if self.start.is_none() {
116 None
117 } else {
118 let writer = unsafe { self.start.unwrap().as_mut() };
119 if self.start == self.back {
120 self.start = None;
121 } else {
122 self.start = writer.get_next();
123 }
124 Some(writer)
125 }
126 }
127}
128
129struct WriteBarrierInner<P, O> {
130 head: Cell<Ptr<Writer<P, O>>>,
131 tail: Cell<Ptr<Writer<P, O>>>,
132
133 pending_leader: Cell<Ptr<Writer<P, O>>>,
134 pending_index: Cell<usize>,
135}
136
137unsafe impl<P: Send, O: Send> Send for WriteBarrierInner<P, O> {}
138
139impl<P, O> Default for WriteBarrierInner<P, O> {
140 fn default() -> Self {
141 WriteBarrierInner {
142 head: Cell::new(None),
143 tail: Cell::new(None),
144 pending_leader: Cell::new(None),
145 pending_index: Cell::new(0),
146 }
147 }
148}
149
150pub struct WriteBarrier<P, O> {
152 inner: Mutex<WriteBarrierInner<P, O>>,
153 leader_cv: Condvar,
154 follower_cvs: [Condvar; 2],
155}
156
157impl<P, O> Default for WriteBarrier<P, O> {
158 fn default() -> Self {
159 WriteBarrier {
160 leader_cv: Condvar::new(),
161 follower_cvs: [Condvar::new(), Condvar::new()],
162 inner: Mutex::new(WriteBarrierInner::default()),
163 }
164 }
165}
166
167impl<P, O> WriteBarrier<P, O> {
168 pub fn enter<'a>(&self, writer: &'a mut Writer<P, O>) -> Option<WriteGroup<'_, 'a, P, O>> {
172 let node = unsafe { Some(NonNull::new_unchecked(writer)) };
173 let mut inner = self.inner.lock();
174 if let Some(tail) = inner.tail.get() {
175 unsafe {
176 tail.as_ref().set_next(node);
177 }
178 inner.tail.set(node);
179
180 if inner.pending_leader.get().is_some() {
181 self.follower_cvs[inner.pending_index.get() % 2].wait(&mut inner);
183 return None;
184 } else {
185 inner.pending_leader.set(node);
187 inner
188 .pending_index
189 .set(inner.pending_index.get().wrapping_add(1));
190 self.leader_cv.wait(&mut inner);
192 inner.pending_leader.set(None);
193 }
194 } else {
195 debug_assert!(inner.pending_leader.get().is_none());
197 inner.head.set(node);
198 inner.tail.set(node);
199 }
200
201 Some(WriteGroup {
202 start: node,
203 back: inner.tail.get(),
204 ref_barrier: self,
205 marker: PhantomData,
206 })
207 }
208
209 fn leader_exit(&self) {
212 fail_point!("write_barrier::leader_exit", |_| {});
213 let inner = self.inner.lock();
214 if let Some(leader) = inner.pending_leader.get() {
215 self.leader_cv.notify_one();
217 self.follower_cvs[inner.pending_index.get().wrapping_sub(1) % 2].notify_all();
219 inner.head.set(Some(leader));
220 } else {
221 self.follower_cvs[inner.pending_index.get() % 2].notify_all();
223 inner.head.set(None);
224 inner.tail.set(None);
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use std::sync::mpsc;
233 use std::sync::{Arc, Barrier};
234 use std::thread::{self, Builder as ThreadBuilder};
235 use std::time::Duration;
236
237 #[test]
238 fn test_sequential_groups() {
239 let barrier: WriteBarrier<(), u32> = Default::default();
240 let mut leaders = 0;
241 let mut processed_writers = 0;
242
243 for _ in 0..4 {
244 let mut writer = Writer::new(&mut (), false);
245 {
246 let mut wg = barrier.enter(&mut writer).unwrap();
247 leaders += 1;
248 for writer in wg.iter_mut() {
249 writer.set_output(7);
250 processed_writers += 1;
251 }
252 }
253 assert_eq!(writer.finish(), 7);
254 }
255
256 assert_eq!(processed_writers, 4);
257 assert_eq!(leaders, 4);
258 }
259
260 struct ConcurrentWriteContext {
261 barrier: Arc<WriteBarrier<u32, u32>>,
262
263 seq: u32,
264 ths: Vec<thread::JoinHandle<()>>,
265 leader_exit_tx: mpsc::SyncSender<()>,
266 leader_exit_rx: mpsc::Receiver<()>,
267 }
268
269 impl ConcurrentWriteContext {
270 fn new() -> Self {
271 let (leader_exit_tx, leader_exit_rx) = mpsc::sync_channel(0);
272 Self {
273 barrier: Default::default(),
274 seq: 0,
275 ths: Vec::new(),
276 leader_exit_tx,
277 leader_exit_rx,
278 }
279 }
280
281 fn step(&mut self, n: usize) {
285 if self.ths.is_empty() {
286 self.seq += 1;
288 let (leader_enter_tx, leader_enter_rx) = mpsc::channel();
289
290 let barrier = self.barrier.clone();
291 let leader_exit_tx = self.leader_exit_tx.clone();
292 let mut seq = self.seq;
293 self.ths.push(
294 ThreadBuilder::new()
295 .spawn(move || {
296 let mut writer = Writer::new(&mut seq, false);
297 {
298 let mut wg = barrier.enter(&mut writer).unwrap();
299 leader_enter_tx.send(()).unwrap();
300 let mut n = 0;
301 for w in wg.iter_mut() {
302 let p = *w.mut_payload();
303 w.set_output(p);
304 n += 1;
305 }
306 assert_eq!(n, 1);
307 leader_exit_tx.send(()).unwrap();
308 }
309 assert_eq!(writer.finish(), seq);
310 })
311 .unwrap(),
312 );
313
314 leader_enter_rx.recv().unwrap();
315 }
316
317 let prev_writers = self.ths.len();
318 let (leader_enter_tx, leader_enter_rx) = mpsc::channel();
319 let start_thread = Arc::new(Barrier::new(n + 1));
320 for _ in 0..n {
321 self.seq += 1;
322
323 let barrier = self.barrier.clone();
324 let start_thread = start_thread.clone();
325 let leader_enter_tx_clone = leader_enter_tx.clone();
326 let leader_exit_tx = self.leader_exit_tx.clone();
327 let mut seq = self.seq;
328 self.ths.push(
329 ThreadBuilder::new()
330 .spawn(move || {
331 let mut writer = Writer::new(&mut seq, false);
332 start_thread.wait();
333 if let Some(mut wg) = barrier.enter(&mut writer) {
334 leader_enter_tx_clone.send(()).unwrap();
335 let mut idx = 0;
336 for w in wg.iter_mut() {
337 let p = *w.mut_payload();
338 w.set_output(p);
339 idx += 1;
340 }
341 assert_eq!(idx, n as u32);
342 leader_exit_tx.send(()).unwrap();
343 }
344 assert_eq!(writer.finish(), seq);
345 })
346 .unwrap(),
347 );
348 }
349 start_thread.wait();
350 std::thread::sleep(Duration::from_millis(100));
351 self.leader_exit_rx.recv().unwrap();
353 for th in self.ths.drain(0..prev_writers) {
354 th.join().unwrap();
355 }
356 leader_enter_rx.recv().unwrap();
358 }
359
360 fn join(&mut self) {
361 self.leader_exit_rx.recv().unwrap();
362 for th in self.ths.drain(..) {
363 th.join().unwrap();
364 }
365 }
366 }
367
368 #[test]
369 fn test_parallel_groups() {
370 let mut ctx = ConcurrentWriteContext::new();
371 for i in 1..5 {
372 ctx.step(i);
373 }
374 ctx.join();
375 }
376}