egglog_concurrency/parallel_writer.rs
1//! A Utility Struct for Writing to a Vector in parallel without blocking reads.
2
3use std::{
4 mem,
5 ops::{Deref, Range},
6 sync::atomic::{AtomicUsize, Ordering},
7};
8
9use crate::{MutexReader, ReadOptimizedLock};
10
11/// A struct that wraps a vector and allows for parallel writes to it.
12///
13/// While the writes happen, reads to the vector can proceed without being
14/// blocked by writes (except during a vector resize). The final vector can be
15/// extracted using the `finish` method. Elements written to the vector behind a
16/// ParallelVecWriter will not be dropped unless `finish` is called.
17pub struct ParallelVecWriter<T> {
18 data: ReadOptimizedLock<Vec<T>>,
19 end_len: AtomicUsize,
20}
21
22/// A handle that can be used to read arbitrary locations in a vector wrapped by a
23/// [`ParallelVecWriter`], even if they weren't
24/// initialized when the [`ParallelVecWriter`] was created.
25pub struct UnsafeReadAccess<'a, T> {
26 reader: MutexReader<'a, Vec<T>>,
27}
28
29impl<T> UnsafeReadAccess<'_, T> {
30 /// Get a reference to the given index in the vector.
31 ///
32 /// # Safety
33 /// `idx` must be either less than the length of the vector when the underlying
34 /// [`ParallelVecWriter`] was created, or it must be within bounds of a completed write to
35 /// [`ParallelVecWriter::write_contents`].
36 pub unsafe fn get_unchecked(&self, idx: usize) -> &T {
37 unsafe { &*self.reader.as_ptr().add(idx) }
38 }
39
40 /// Get a subslice of given index in the vector.
41 ///
42 /// # Safety
43 /// `slice`'s contents must be either within the vector when the underlying
44 /// [`ParallelVecWriter`] was created, or they must be within bounds of a completed write to
45 /// [`ParallelVecWriter::write_contents`].
46 pub unsafe fn get_unchecked_slice(&self, slice: Range<usize>) -> &[T] {
47 unsafe {
48 let start: *const T = self.reader.as_ptr().add(slice.start);
49 std::slice::from_raw_parts(start, slice.end - slice.start)
50 }
51 }
52}
53
54impl<T> ParallelVecWriter<T> {
55 pub fn new(data: Vec<T>) -> Self {
56 let start_len = data.len();
57 let end_len = AtomicUsize::new(start_len);
58 Self {
59 data: ReadOptimizedLock::new(data),
60 end_len,
61 }
62 }
63
64 /// Get read access to the portion of the vector that was present before the
65 /// ParallelVecWriter was created. Unlike the `with_` methods, callers
66 /// should be careful about keeping the object returned from this method
67 /// around for too long.
68 pub fn read_access(&self) -> impl Deref<Target = [T]> + '_ {
69 struct PrefixReader<'a, T> {
70 reader: MutexReader<'a, Vec<T>>,
71 }
72 impl<T> Deref for PrefixReader<'_, T> {
73 type Target = [T];
74
75 fn deref(&self) -> &[T] {
76 self.reader.as_slice()
77 }
78 }
79 PrefixReader {
80 reader: self.data.read(),
81 }
82 }
83
84 /// Get unsafe read access to the vector.
85 ///
86 /// This handle allows for reads past the end of the wrapped vector. Callers must guarantee
87 /// that any cells read are covered by a corresponding call to
88 /// [`ParallelVecWriter::write_contents`].
89 pub fn unsafe_read_access(&self) -> UnsafeReadAccess<'_, T> {
90 UnsafeReadAccess {
91 reader: self.data.read(),
92 }
93 }
94
95 /// Runs `f` with access to the element at `idx`.
96 ///
97 /// # Panics
98 /// This method panics if `idx` is greater than or equal to the length of
99 /// the vector when the ParallelVecWriter was created.
100 pub fn with_index<R>(&self, idx: usize, f: impl FnOnce(&T) -> R) -> R {
101 f(&self.read_access()[idx])
102 }
103
104 /// Runs `f` with access to the slice of elements in the range `slice`.
105 ///
106 /// # Panics
107 /// This method panics if `slice.end` is greater than or equal to the length
108 /// of the vector when the ParallelVecWriter was created.
109 pub fn with_slice<R>(&self, slice: Range<usize>, f: impl FnOnce(&[T]) -> R) -> R {
110 f(&self.read_access()[slice])
111 }
112
113 /// Write the contents of `items` to a contiguous chunk of the vector,
114 /// returning the index of the first element in `items`.
115 ///
116 /// *Panics* It is very important that `items` does not lie about its
117 /// length. This method panics if the actual length does not match the
118 /// length method.
119 pub fn write_contents(&self, items: impl ExactSizeIterator<Item = T>) -> usize {
120 let start = self.end_len.fetch_add(items.len(), Ordering::AcqRel);
121 let end = start + items.len();
122 let reader = self.data.read();
123 let current_len = reader.len();
124 let current_cap = reader.capacity();
125 mem::drop(reader);
126 if current_cap < end {
127 let mut writer = self.data.lock();
128 if writer.capacity() < end {
129 let new_cap = std::cmp::max(end, current_cap * 2);
130 writer.reserve(new_cap - current_len);
131 }
132 }
133 // SAFETY: the unsafe operations that `write_contents_at` performs are:
134 // * Writing to a shared buffer: this is safe because the `fetch_add` we
135 // perform gives us unique access to the subslice.
136 // * Writing past the length of the vector: this is safe because the
137 // above code pre-reseves sufficient capacity for `items` to write.
138 unsafe { self.write_contents_at(items, start) };
139 start
140 }
141
142 pub fn finish(self) -> Vec<T> {
143 let mut res = self.data.into_inner();
144 // SAFETY: this value is incremented past the original length of the
145 // vector once for each item written to it.
146 unsafe {
147 res.set_len(self.end_len.load(Ordering::Acquire));
148 }
149 res
150 }
151
152 pub fn take(&mut self) -> Vec<T> {
153 let mut res = mem::take(self.data.as_mut_ref());
154 // SAFETY: this value is incremented past the original length of the
155 // vector once for each item written to it.
156 unsafe {
157 res.set_len(self.end_len.load(Ordering::Acquire));
158 }
159 self.end_len.store(0, Ordering::Release);
160 res
161 }
162
163 unsafe fn write_contents_at(&self, items: impl ExactSizeIterator<Item = T>, start: usize) {
164 let mut written = 0;
165 let expected = items.len();
166 let reader = self.data.read();
167 debug_assert!(reader.capacity() >= start + items.len());
168 unsafe {
169 let mut mut_ptr = (reader.as_ptr() as *mut T).add(start);
170 for item in items {
171 written += 1;
172 std::ptr::write(mut_ptr, item);
173 mut_ptr = mut_ptr.offset(1);
174 }
175 }
176 assert_eq!(
177 written, expected,
178 "passed ExactSizeIterator with incorrect number of items"
179 );
180 }
181}