egglog_concurrency/parallel_writer.rs
1//! A Utility Struct for Writing to a Vector in parallel without blocking reads.
2
3use std::{
4    mem,
5    ops::{Deref, Range},
6    sync::atomic::{AtomicUsize, Ordering},
7};
8
9use crate::{MutexReader, ReadOptimizedLock};
10
11/// A struct that wraps a vector and allows for parallel writes to it.
12///
13/// While the writes happen, reads to the vector can proceed without being
14/// blocked by writes (except during a vector resize). The final vector can be
15/// extracted using the `finish` method. Elements written to the vector behind a
16/// ParallelVecWriter will not be dropped unless `finish` is called.
17pub struct ParallelVecWriter<T> {
18    data: ReadOptimizedLock<Vec<T>>,
19    end_len: AtomicUsize,
20}
21
22/// A handle that can be used to read arbitrary locations in a vector wrapped by a
23/// [`ParallelVecWriter`], even if they weren't
24/// initialized when the [`ParallelVecWriter`] was created.
25pub struct UnsafeReadAccess<'a, T> {
26    reader: MutexReader<'a, Vec<T>>,
27}
28
29impl<T> UnsafeReadAccess<'_, T> {
30    /// Get a reference to the given index in the vector.
31    ///
32    /// # Safety
33    /// `idx` must be either less than the length of the vector when the underlying
34    /// [`ParallelVecWriter`] was created, or it must be within bounds of a completed write to
35    /// [`ParallelVecWriter::write_contents`].
36    pub unsafe fn get_unchecked(&self, idx: usize) -> &T {
37        unsafe { &*self.reader.as_ptr().add(idx) }
38    }
39
40    /// Get a subslice of given index in the vector.
41    ///
42    /// # Safety
43    /// `slice`'s contents must be either within the vector when the underlying
44    /// [`ParallelVecWriter`] was created, or they must be within bounds of a completed write to
45    /// [`ParallelVecWriter::write_contents`].
46    pub unsafe fn get_unchecked_slice(&self, slice: Range<usize>) -> &[T] {
47        unsafe {
48            let start: *const T = self.reader.as_ptr().add(slice.start);
49            std::slice::from_raw_parts(start, slice.end - slice.start)
50        }
51    }
52}
53
54impl<T> ParallelVecWriter<T> {
55    pub fn new(data: Vec<T>) -> Self {
56        let start_len = data.len();
57        let end_len = AtomicUsize::new(start_len);
58        Self {
59            data: ReadOptimizedLock::new(data),
60            end_len,
61        }
62    }
63
64    /// Get read access to the portion of the vector that was present before the
65    /// ParallelVecWriter was created. Unlike the `with_` methods, callers
66    /// should be careful about keeping the object returned from this method
67    /// around for too long.
68    pub fn read_access(&self) -> impl Deref<Target = [T]> + '_ {
69        struct PrefixReader<'a, T> {
70            reader: MutexReader<'a, Vec<T>>,
71        }
72        impl<T> Deref for PrefixReader<'_, T> {
73            type Target = [T];
74
75            fn deref(&self) -> &[T] {
76                self.reader.as_slice()
77            }
78        }
79        PrefixReader {
80            reader: self.data.read(),
81        }
82    }
83
84    /// Get unsafe read access to the vector.
85    ///
86    /// This handle allows for reads past the end of the wrapped vector. Callers must guarantee
87    /// that any cells read are covered by a corresponding call to
88    /// [`ParallelVecWriter::write_contents`].
89    pub fn unsafe_read_access(&self) -> UnsafeReadAccess<'_, T> {
90        UnsafeReadAccess {
91            reader: self.data.read(),
92        }
93    }
94
95    /// Runs `f` with access to the element at `idx`.
96    ///
97    /// # Panics
98    /// This method panics if `idx` is greater than or equal to the length of
99    /// the vector when the ParallelVecWriter was created.
100    pub fn with_index<R>(&self, idx: usize, f: impl FnOnce(&T) -> R) -> R {
101        f(&self.read_access()[idx])
102    }
103
104    /// Runs `f` with access to the slice of elements in the range `slice`.
105    ///
106    /// # Panics
107    /// This method panics if `slice.end` is greater than or equal to the length
108    /// of the vector when the ParallelVecWriter was created.
109    pub fn with_slice<R>(&self, slice: Range<usize>, f: impl FnOnce(&[T]) -> R) -> R {
110        f(&self.read_access()[slice])
111    }
112
113    /// Write the contents of `items` to a contiguous chunk of the vector,
114    /// returning the index of the first element in `items`.
115    ///
116    /// *Panics* It is very important that `items` does not lie about its
117    /// length. This method panics if the actual length does not match the
118    /// length method.
119    pub fn write_contents(&self, items: impl ExactSizeIterator<Item = T>) -> usize {
120        let start = self.end_len.fetch_add(items.len(), Ordering::AcqRel);
121        let end = start + items.len();
122        let reader = self.data.read();
123        let current_len = reader.len();
124        let current_cap = reader.capacity();
125        mem::drop(reader);
126        if current_cap < end {
127            let mut writer = self.data.lock();
128            if writer.capacity() < end {
129                let new_cap = std::cmp::max(end, current_cap * 2);
130                writer.reserve(new_cap - current_len);
131            }
132        }
133        // SAFETY: the unsafe operations that `write_contents_at` performs are:
134        // * Writing to a shared buffer: this is safe because the `fetch_add` we
135        // perform gives us unique access to the subslice.
136        // * Writing past the length of the vector: this is safe because the
137        // above code pre-reseves sufficient capacity for `items` to write.
138        unsafe { self.write_contents_at(items, start) };
139        start
140    }
141
142    pub fn finish(self) -> Vec<T> {
143        let mut res = self.data.into_inner();
144        // SAFETY: this value is incremented past the original length of the
145        // vector once for each item written to it.
146        unsafe {
147            res.set_len(self.end_len.load(Ordering::Acquire));
148        }
149        res
150    }
151
152    pub fn take(&mut self) -> Vec<T> {
153        let mut res = mem::take(self.data.as_mut_ref());
154        // SAFETY: this value is incremented past the original length of the
155        // vector once for each item written to it.
156        unsafe {
157            res.set_len(self.end_len.load(Ordering::Acquire));
158        }
159        self.end_len.store(0, Ordering::Release);
160        res
161    }
162
163    unsafe fn write_contents_at(&self, items: impl ExactSizeIterator<Item = T>, start: usize) {
164        let mut written = 0;
165        let expected = items.len();
166        let reader = self.data.read();
167        debug_assert!(reader.capacity() >= start + items.len());
168        unsafe {
169            let mut mut_ptr = (reader.as_ptr() as *mut T).add(start);
170            for item in items {
171                written += 1;
172                std::ptr::write(mut_ptr, item);
173                mut_ptr = mut_ptr.offset(1);
174            }
175        }
176        assert_eq!(
177            written, expected,
178            "passed ExactSizeIterator with incorrect number of items"
179        );
180    }
181}