marching_buffer/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::{ops::{Deref, DerefMut}};
4use std::{sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}};
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6
7#[derive(Clone)]
8struct InnerMarchingBuffer<T> {
9    data: Arc<RwLock<Vec<T>>>,
10    /// Number of entries that have been finished in the buffer.
11    finished_len: Arc<AtomicUsize>,
12    /// How many Readers exist.
13    readers: Arc<AtomicUsize>,
14    /// True if a Writer exists, false otherwise.
15    has_writer: Arc<AtomicBool>,
16    /// Offset into the data vec of where the writable section starts. Equivalently, the total amount of data that has
17    /// been frozen for reading. This is reset to 0 once all Readers and Writers are dropped. This is updated whenever a WriterAccess is dropped.
18    write_offset: Arc<AtomicUsize>,
19}
20
21impl<T> InnerMarchingBuffer<T> {
22    fn check_reset(&self) {
23        if let Ok(mut data) = self.data.try_write() {
24            if self.readers.load(Ordering::SeqCst) == 0 && !self.has_writer.load(Ordering::SeqCst) {
25                self.write_offset.store(0, Ordering::SeqCst);
26                self.finished_len.store(0, Ordering::SeqCst);
27                data.clear();
28            }
29        }
30    }
31}
32
33impl<T> std::fmt::Debug for InnerMarchingBuffer<T> {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self.data.try_read() {
36            Ok(data) => {
37                f.debug_struct("InnerMarchingBuffer")
38                    .field("data_len", &data.len())
39                    .field("data_capacity", &data.capacity())
40                    .field("finished_len", &self.finished_len.load(Ordering::SeqCst))
41                    .field("readers", &self.readers.load(Ordering::SeqCst))
42                    .field("has_writer", &self.has_writer.load(Ordering::SeqCst))
43                    .field("write_offset", &self.write_offset.load(Ordering::SeqCst))
44                    .finish()
45            }
46            Err(_) => {
47                f.debug_struct("InnerMarchingBuffer")
48                    .field("data_len", &"(locked)")
49                    .field("data_capacity", &"(locked)")
50                    .field("finished_len", &self.finished_len.load(Ordering::SeqCst))
51                    .field("readers", &self.readers.load(Ordering::SeqCst))
52                    .field("has_writer", &self.has_writer.load(Ordering::SeqCst))
53                    .field("write_offset", &self.write_offset.load(Ordering::SeqCst))
54                    .finish()
55            }
56        }
57    }
58}
59
60#[derive(Clone)]
61pub struct MarchingBuffer<T> {
62    inner: Arc<InnerMarchingBuffer<T>>
63}
64
65impl<T> MarchingBuffer<T> {
66    pub fn new() -> Self {
67        Self {
68            inner: Arc::new(InnerMarchingBuffer {
69                data: Arc::new(RwLock::new(Vec::new())),
70                finished_len: Arc::new(AtomicUsize::new(0)),
71                readers: Arc::new(AtomicUsize::new(0)),
72                has_writer: Arc::new(AtomicBool::new(false)),
73                write_offset: Arc::new(AtomicUsize::new(0))
74            })
75        }
76    }
77
78    pub fn finished_len(&self) -> usize {
79        self.inner.finished_len.load(Ordering::SeqCst)
80    }
81
82    pub fn get_writer(&self) -> Writer<T> {
83        self.try_get_writer().expect("Cannot get Writer because one already exists")
84    }
85
86    pub fn try_get_writer(&self) -> Option<Writer<T>> {
87        match self.inner.has_writer.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) {
88            Ok(_) => Some(Writer {
89                inner: self.inner.clone(),
90                write_offset: self.inner.write_offset.load(Ordering::SeqCst),
91                amount_written: 0,
92            }),
93            Err(_) => None
94        }
95    }
96}
97
98impl<T> std::fmt::Debug for MarchingBuffer<T> {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        self.inner.fmt(f)
101    }
102}
103
104pub struct Reader<T> {
105    inner: Arc<InnerMarchingBuffer<T>>,
106    read_offset: usize,
107    read_len: usize
108}
109
110impl<T> Reader<T> {
111    pub fn access(&self) -> ReaderAccess<T> {
112        self.try_access().expect("Cannot access Reader because concurrent Writer is already accessed")
113    }
114
115    pub fn try_access(&self) -> Option<ReaderAccess<T>> {
116        match self.inner.data.try_read() {
117            Ok(data) => {
118                Some(ReaderAccess {
119                    reader: self,
120                    data,
121                    read_offset: self.read_offset,
122                    read_len: self.read_len,
123                })
124            },
125            Err(_) => {
126                None
127            }
128        }
129    }
130}
131
132impl<T> Drop for Reader<T> {
133    fn drop(&mut self) {
134        self.inner.readers.fetch_sub(1, Ordering::SeqCst);
135        self.inner.check_reset();
136    }
137}
138
139impl<T> Clone for Reader<T> {
140    fn clone(&self) -> Self {
141        self.inner.readers.fetch_add(1, Ordering::SeqCst);
142        Self {
143            inner: self.inner.clone(),
144            read_offset: self.read_offset,
145            read_len: self.read_len
146        }
147    }
148}
149
150impl<T> std::fmt::Debug for Reader<T> {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.debug_struct("Reader")
153            .field("read_offset", &self.read_offset)
154            .field("read_len", &self.read_len)
155            .finish()
156    }
157}
158
159pub struct ReaderAccess<'reader, 'data, T> {
160    reader: &'reader Reader<T>,
161    data: RwLockReadGuard<'data, Vec<T>>,
162    // Stored in addition to the identically named fields in Reader so that ReaderAccess can implement Read, which mutates the Read struct.
163    read_offset: usize,
164    read_len: usize
165}
166
167impl<'reader, 'data, T> ReaderAccess<'reader, 'data, T> {
168    pub fn as_slice(&self) -> &[T] {
169        &self.data[self.read_offset .. (self.read_offset + self.read_len)]
170    }
171
172    pub fn is_empty(&self) -> bool {
173        self.read_len == 0
174    }
175}
176
177impl<'reader, 'data, T> std::fmt::Debug for ReaderAccess<'reader, 'data, T> {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        f.debug_struct("ReaderAccess")
180            .field("reader_offset", &self.reader.read_offset)
181            .field("reader_len", &self.reader.read_len)
182            .field("access_offset", &self.read_offset)
183            .field("access_len", &self.read_len)
184            .finish()
185    }
186}
187
188impl<'reader, 'data> std::io::Read for ReaderAccess<'reader, 'data, u8> {
189    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
190        let amount_read = std::cmp::min(self.read_len, buf.len());
191        buf.copy_from_slice(&self.data.as_slice()[self.reader.read_offset .. (self.reader.read_offset + amount_read)]);
192        self.read_offset += amount_read;
193        self.read_len -= amount_read;
194        Ok(amount_read)
195    }
196}
197
198impl<'reader, 'data, T> Deref for ReaderAccess<'reader, 'data, T> {
199    type Target = [T];
200
201    fn deref(&self) -> &Self::Target {
202        self.as_slice()
203    }
204}
205
206pub struct Writer<T> {
207    inner: Arc<InnerMarchingBuffer<T>>,
208    // Where in the data Vec the writable data starts at.
209    write_offset: usize,
210    // How much has been written into this Write.
211    amount_written: usize,
212}
213
214impl<T> Writer<T> {
215    pub fn finish(&mut self) -> Reader<T> {
216        let reader = Reader {
217            inner: self.inner.clone(),
218            read_offset: self.write_offset,
219            read_len: self.amount_written,
220        };
221        self.inner.readers.fetch_add(1, Ordering::SeqCst);
222        self.inner.write_offset.fetch_add(self.amount_written, Ordering::SeqCst);
223        self.inner.finished_len.fetch_add(self.amount_written, Ordering::SeqCst);
224        self.write_offset += self.amount_written;
225        self.amount_written = 0;
226        reader
227    }
228
229    pub fn access(&mut self) -> WriterAccess<T> {
230        self.try_access().expect("Cannot access Writer because at least one concurrent Reader is already accessed")
231    }
232
233    pub fn try_access(&mut self) -> Option<WriterAccess<T>> {
234        Some(WriterAccess {
235            data: self.inner.data.try_write().ok()?,
236            write_offset: &mut self.write_offset,
237            amount_written: &mut self.amount_written,
238        })
239    }
240}
241
242impl<T: Default + Copy> Writer<T> {
243    // A necessary convenience method for copying the contents of a Reader<T> into a Writer<T>. This method internally uses a temporary buffer
244    // of size COPY_BUFFER_SIZE, which is needed because the underlying data buffer may be reallocated at any point during the copy.
245    pub fn copy_from<const COPY_BUFFER_SIZE: usize>(&mut self, reader: &Reader<T>) {
246        // Technically we would be able to void the double copying if there's sufficient capacity in the data buffer not to need a reallocation.
247        // Or, if we just ensure() that there's enough additional capacity ahead of time. We could then use std::ptr::copy_nonoverlapping, but
248        // that would require unsafe.
249        // Or, I wonder if we could use various slice split() methods to get the disjoint slices without requiring unsafe?
250        let mut copy_buffer = [T::default(); COPY_BUFFER_SIZE];
251        let mut bytes_copied = 0;
252        let mut bytes_remaining = reader.access().len();
253        while bytes_remaining > 0 {
254            let copied_this_round = std::cmp::min(bytes_remaining, 4096);
255            &mut copy_buffer[..copied_this_round].copy_from_slice(&reader.access().as_slice()[bytes_copied .. (bytes_copied + copied_this_round)]);
256            self.access().extend_from_slice(&copy_buffer[..copied_this_round]);
257            bytes_copied += copied_this_round;
258            bytes_remaining -= copied_this_round;
259        }
260    }
261}
262
263impl<T> Drop for Writer<T> {
264    fn drop(&mut self) {
265        self.inner.has_writer.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
266            .expect("has_writer was false somehow when Writer was dropped");
267        self.inner.check_reset();
268    }
269}
270
271impl<T> std::fmt::Debug for Writer<T> {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        f.debug_struct("Writer")
274            .field("write_offset", &self.write_offset)
275            .field("amount_written", &self.amount_written)
276            .finish()
277    }
278}
279
280pub struct WriterAccess<'writer, 'data, T> {
281    data: RwLockWriteGuard<'data, Vec<T>>,
282    write_offset: &'writer mut usize,
283    amount_written: &'writer mut usize,
284}
285
286impl<'writer, 'data, T> WriterAccess<'writer, 'data, T> {
287    pub fn as_slice(&self) -> &[T] {
288        &self.data[*self.write_offset .. (*self.write_offset + *self.amount_written)]
289    }
290
291    pub fn as_mut_slice(&mut self) -> &mut [T] {
292        &mut self.data[*self.write_offset .. (*self.write_offset + *self.amount_written)]
293    }
294
295    pub fn push(&mut self, value: T) {
296        self.data.push(value);
297        *self.amount_written += 1;
298    }
299
300    pub fn pop(&mut self) -> Option<T> {
301        if *self.amount_written > 0 {
302            *self.amount_written -= 1;
303            self.data.pop()
304        } else {
305            None
306        }
307    }
308}
309
310impl<'writer, 'data, T: Clone> WriterAccess<'writer, 'data, T> {
311    pub fn extend_from_slice(&mut self, slice: &[T]) {
312        self.data.extend_from_slice(slice);
313        *self.amount_written += slice.len();
314    }
315}
316
317impl<'writer, 'data> std::fmt::Write for WriterAccess<'writer, 'data, u8> {
318    fn write_str(&mut self, s: &str) -> std::fmt::Result {
319        self.data.extend_from_slice(s.as_bytes());
320        *self.amount_written += s.len();
321        Ok(())
322    }
323}
324
325impl<'writer, 'data> std::io::Write for WriterAccess<'writer, 'data, u8> {
326    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
327        self.data.extend_from_slice(buf);
328        *self.amount_written += buf.len();
329        Ok(buf.len())
330    }
331
332    fn flush(&mut self) -> std::io::Result<()> {
333        Ok(())
334    }
335}
336
337impl<'writer, 'data, T> Deref for WriterAccess<'writer, 'data, T> {
338    type Target = [T];
339
340    fn deref(&self) -> &Self::Target {
341        self.as_slice()
342    }
343}
344
345impl<'writer, 'data, T> DerefMut for WriterAccess<'writer, 'data, T> {
346    fn deref_mut(&mut self) -> &mut Self::Target {
347        self.as_mut_slice()
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use std::fmt::Write;
355
356    #[test]
357    fn basic_nesting_test() {
358        let alloc = MarchingBuffer::new();
359        {
360            let mut writer = alloc.get_writer();
361
362            write!(writer.access(), "Hello world").unwrap();
363            let hello_world = writer.finish();
364            assert_eq!(b"Hello world", hello_world.access().as_slice());
365            assert_eq!("Hello world".len(), alloc.finished_len());
366
367            write!(writer.access(), "Foo").unwrap();
368            // "Foo" not counted to length until the current write is finished.
369            assert_eq!("Hello world".len(), alloc.finished_len());
370
371            write!(writer.access(), "Bar").unwrap();
372            let foo_bar = writer.finish();
373            assert_eq!(b"FooBar", foo_bar.access().as_slice());
374            assert_eq!("Hello world".len() + "FooBar".len(), alloc.finished_len());
375
376            write!(writer.access(), "End of line").unwrap();
377            writer.finish();
378            assert_eq!("Hello world".len() + "FooBar".len() + "End of line".len(), alloc.finished_len());
379        }
380        assert_eq!(0, alloc.finished_len());
381    }
382
383    #[test]
384    fn unfinished_writes_are_ignored() {
385        let alloc = MarchingBuffer::new();
386        {
387            let mut writer = alloc.get_writer();
388            write!(writer.access(), "Hello world").unwrap();
389        }
390        {
391            let mut writer = alloc.get_writer();
392            write!(writer.access(), "foo bar").unwrap();
393            assert_eq!(b"foo bar", writer.finish().access().as_slice());
394        }
395        assert_eq!(0, alloc.finished_len());
396    }
397}