1#![forbid(unsafe_code)]
2
3use std::{ops::{Deref, DerefMut}};
4use std::{sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}};
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6
7#[derive(Clone)]
8struct InnerMarchingBuffer<T> {
9 data: Arc<RwLock<Vec<T>>>,
10 finished_len: Arc<AtomicUsize>,
12 readers: Arc<AtomicUsize>,
14 has_writer: Arc<AtomicBool>,
16 write_offset: Arc<AtomicUsize>,
19}
20
21impl<T> InnerMarchingBuffer<T> {
22 fn check_reset(&self) {
23 if let Ok(mut data) = self.data.try_write() {
24 if self.readers.load(Ordering::SeqCst) == 0 && !self.has_writer.load(Ordering::SeqCst) {
25 self.write_offset.store(0, Ordering::SeqCst);
26 self.finished_len.store(0, Ordering::SeqCst);
27 data.clear();
28 }
29 }
30 }
31}
32
33impl<T> std::fmt::Debug for InnerMarchingBuffer<T> {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self.data.try_read() {
36 Ok(data) => {
37 f.debug_struct("InnerMarchingBuffer")
38 .field("data_len", &data.len())
39 .field("data_capacity", &data.capacity())
40 .field("finished_len", &self.finished_len.load(Ordering::SeqCst))
41 .field("readers", &self.readers.load(Ordering::SeqCst))
42 .field("has_writer", &self.has_writer.load(Ordering::SeqCst))
43 .field("write_offset", &self.write_offset.load(Ordering::SeqCst))
44 .finish()
45 }
46 Err(_) => {
47 f.debug_struct("InnerMarchingBuffer")
48 .field("data_len", &"(locked)")
49 .field("data_capacity", &"(locked)")
50 .field("finished_len", &self.finished_len.load(Ordering::SeqCst))
51 .field("readers", &self.readers.load(Ordering::SeqCst))
52 .field("has_writer", &self.has_writer.load(Ordering::SeqCst))
53 .field("write_offset", &self.write_offset.load(Ordering::SeqCst))
54 .finish()
55 }
56 }
57 }
58}
59
60#[derive(Clone)]
61pub struct MarchingBuffer<T> {
62 inner: Arc<InnerMarchingBuffer<T>>
63}
64
65impl<T> MarchingBuffer<T> {
66 pub fn new() -> Self {
67 Self {
68 inner: Arc::new(InnerMarchingBuffer {
69 data: Arc::new(RwLock::new(Vec::new())),
70 finished_len: Arc::new(AtomicUsize::new(0)),
71 readers: Arc::new(AtomicUsize::new(0)),
72 has_writer: Arc::new(AtomicBool::new(false)),
73 write_offset: Arc::new(AtomicUsize::new(0))
74 })
75 }
76 }
77
78 pub fn finished_len(&self) -> usize {
79 self.inner.finished_len.load(Ordering::SeqCst)
80 }
81
82 pub fn get_writer(&self) -> Writer<T> {
83 self.try_get_writer().expect("Cannot get Writer because one already exists")
84 }
85
86 pub fn try_get_writer(&self) -> Option<Writer<T>> {
87 match self.inner.has_writer.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) {
88 Ok(_) => Some(Writer {
89 inner: self.inner.clone(),
90 write_offset: self.inner.write_offset.load(Ordering::SeqCst),
91 amount_written: 0,
92 }),
93 Err(_) => None
94 }
95 }
96}
97
98impl<T> std::fmt::Debug for MarchingBuffer<T> {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 self.inner.fmt(f)
101 }
102}
103
104pub struct Reader<T> {
105 inner: Arc<InnerMarchingBuffer<T>>,
106 read_offset: usize,
107 read_len: usize
108}
109
110impl<T> Reader<T> {
111 pub fn access(&self) -> ReaderAccess<T> {
112 self.try_access().expect("Cannot access Reader because concurrent Writer is already accessed")
113 }
114
115 pub fn try_access(&self) -> Option<ReaderAccess<T>> {
116 match self.inner.data.try_read() {
117 Ok(data) => {
118 Some(ReaderAccess {
119 reader: self,
120 data,
121 read_offset: self.read_offset,
122 read_len: self.read_len,
123 })
124 },
125 Err(_) => {
126 None
127 }
128 }
129 }
130}
131
132impl<T> Drop for Reader<T> {
133 fn drop(&mut self) {
134 self.inner.readers.fetch_sub(1, Ordering::SeqCst);
135 self.inner.check_reset();
136 }
137}
138
139impl<T> Clone for Reader<T> {
140 fn clone(&self) -> Self {
141 self.inner.readers.fetch_add(1, Ordering::SeqCst);
142 Self {
143 inner: self.inner.clone(),
144 read_offset: self.read_offset,
145 read_len: self.read_len
146 }
147 }
148}
149
150impl<T> std::fmt::Debug for Reader<T> {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("Reader")
153 .field("read_offset", &self.read_offset)
154 .field("read_len", &self.read_len)
155 .finish()
156 }
157}
158
159pub struct ReaderAccess<'reader, 'data, T> {
160 reader: &'reader Reader<T>,
161 data: RwLockReadGuard<'data, Vec<T>>,
162 read_offset: usize,
164 read_len: usize
165}
166
167impl<'reader, 'data, T> ReaderAccess<'reader, 'data, T> {
168 pub fn as_slice(&self) -> &[T] {
169 &self.data[self.read_offset .. (self.read_offset + self.read_len)]
170 }
171
172 pub fn is_empty(&self) -> bool {
173 self.read_len == 0
174 }
175}
176
177impl<'reader, 'data, T> std::fmt::Debug for ReaderAccess<'reader, 'data, T> {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("ReaderAccess")
180 .field("reader_offset", &self.reader.read_offset)
181 .field("reader_len", &self.reader.read_len)
182 .field("access_offset", &self.read_offset)
183 .field("access_len", &self.read_len)
184 .finish()
185 }
186}
187
188impl<'reader, 'data> std::io::Read for ReaderAccess<'reader, 'data, u8> {
189 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
190 let amount_read = std::cmp::min(self.read_len, buf.len());
191 buf.copy_from_slice(&self.data.as_slice()[self.reader.read_offset .. (self.reader.read_offset + amount_read)]);
192 self.read_offset += amount_read;
193 self.read_len -= amount_read;
194 Ok(amount_read)
195 }
196}
197
198impl<'reader, 'data, T> Deref for ReaderAccess<'reader, 'data, T> {
199 type Target = [T];
200
201 fn deref(&self) -> &Self::Target {
202 self.as_slice()
203 }
204}
205
206pub struct Writer<T> {
207 inner: Arc<InnerMarchingBuffer<T>>,
208 write_offset: usize,
210 amount_written: usize,
212}
213
214impl<T> Writer<T> {
215 pub fn finish(&mut self) -> Reader<T> {
216 let reader = Reader {
217 inner: self.inner.clone(),
218 read_offset: self.write_offset,
219 read_len: self.amount_written,
220 };
221 self.inner.readers.fetch_add(1, Ordering::SeqCst);
222 self.inner.write_offset.fetch_add(self.amount_written, Ordering::SeqCst);
223 self.inner.finished_len.fetch_add(self.amount_written, Ordering::SeqCst);
224 self.write_offset += self.amount_written;
225 self.amount_written = 0;
226 reader
227 }
228
229 pub fn access(&mut self) -> WriterAccess<T> {
230 self.try_access().expect("Cannot access Writer because at least one concurrent Reader is already accessed")
231 }
232
233 pub fn try_access(&mut self) -> Option<WriterAccess<T>> {
234 Some(WriterAccess {
235 data: self.inner.data.try_write().ok()?,
236 write_offset: &mut self.write_offset,
237 amount_written: &mut self.amount_written,
238 })
239 }
240}
241
242impl<T: Default + Copy> Writer<T> {
243 pub fn copy_from<const COPY_BUFFER_SIZE: usize>(&mut self, reader: &Reader<T>) {
246 let mut copy_buffer = [T::default(); COPY_BUFFER_SIZE];
251 let mut bytes_copied = 0;
252 let mut bytes_remaining = reader.access().len();
253 while bytes_remaining > 0 {
254 let copied_this_round = std::cmp::min(bytes_remaining, 4096);
255 &mut copy_buffer[..copied_this_round].copy_from_slice(&reader.access().as_slice()[bytes_copied .. (bytes_copied + copied_this_round)]);
256 self.access().extend_from_slice(©_buffer[..copied_this_round]);
257 bytes_copied += copied_this_round;
258 bytes_remaining -= copied_this_round;
259 }
260 }
261}
262
263impl<T> Drop for Writer<T> {
264 fn drop(&mut self) {
265 self.inner.has_writer.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
266 .expect("has_writer was false somehow when Writer was dropped");
267 self.inner.check_reset();
268 }
269}
270
271impl<T> std::fmt::Debug for Writer<T> {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 f.debug_struct("Writer")
274 .field("write_offset", &self.write_offset)
275 .field("amount_written", &self.amount_written)
276 .finish()
277 }
278}
279
280pub struct WriterAccess<'writer, 'data, T> {
281 data: RwLockWriteGuard<'data, Vec<T>>,
282 write_offset: &'writer mut usize,
283 amount_written: &'writer mut usize,
284}
285
286impl<'writer, 'data, T> WriterAccess<'writer, 'data, T> {
287 pub fn as_slice(&self) -> &[T] {
288 &self.data[*self.write_offset .. (*self.write_offset + *self.amount_written)]
289 }
290
291 pub fn as_mut_slice(&mut self) -> &mut [T] {
292 &mut self.data[*self.write_offset .. (*self.write_offset + *self.amount_written)]
293 }
294
295 pub fn push(&mut self, value: T) {
296 self.data.push(value);
297 *self.amount_written += 1;
298 }
299
300 pub fn pop(&mut self) -> Option<T> {
301 if *self.amount_written > 0 {
302 *self.amount_written -= 1;
303 self.data.pop()
304 } else {
305 None
306 }
307 }
308}
309
310impl<'writer, 'data, T: Clone> WriterAccess<'writer, 'data, T> {
311 pub fn extend_from_slice(&mut self, slice: &[T]) {
312 self.data.extend_from_slice(slice);
313 *self.amount_written += slice.len();
314 }
315}
316
317impl<'writer, 'data> std::fmt::Write for WriterAccess<'writer, 'data, u8> {
318 fn write_str(&mut self, s: &str) -> std::fmt::Result {
319 self.data.extend_from_slice(s.as_bytes());
320 *self.amount_written += s.len();
321 Ok(())
322 }
323}
324
325impl<'writer, 'data> std::io::Write for WriterAccess<'writer, 'data, u8> {
326 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
327 self.data.extend_from_slice(buf);
328 *self.amount_written += buf.len();
329 Ok(buf.len())
330 }
331
332 fn flush(&mut self) -> std::io::Result<()> {
333 Ok(())
334 }
335}
336
337impl<'writer, 'data, T> Deref for WriterAccess<'writer, 'data, T> {
338 type Target = [T];
339
340 fn deref(&self) -> &Self::Target {
341 self.as_slice()
342 }
343}
344
345impl<'writer, 'data, T> DerefMut for WriterAccess<'writer, 'data, T> {
346 fn deref_mut(&mut self) -> &mut Self::Target {
347 self.as_mut_slice()
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use std::fmt::Write;
355
356 #[test]
357 fn basic_nesting_test() {
358 let alloc = MarchingBuffer::new();
359 {
360 let mut writer = alloc.get_writer();
361
362 write!(writer.access(), "Hello world").unwrap();
363 let hello_world = writer.finish();
364 assert_eq!(b"Hello world", hello_world.access().as_slice());
365 assert_eq!("Hello world".len(), alloc.finished_len());
366
367 write!(writer.access(), "Foo").unwrap();
368 assert_eq!("Hello world".len(), alloc.finished_len());
370
371 write!(writer.access(), "Bar").unwrap();
372 let foo_bar = writer.finish();
373 assert_eq!(b"FooBar", foo_bar.access().as_slice());
374 assert_eq!("Hello world".len() + "FooBar".len(), alloc.finished_len());
375
376 write!(writer.access(), "End of line").unwrap();
377 writer.finish();
378 assert_eq!("Hello world".len() + "FooBar".len() + "End of line".len(), alloc.finished_len());
379 }
380 assert_eq!(0, alloc.finished_len());
381 }
382
383 #[test]
384 fn unfinished_writes_are_ignored() {
385 let alloc = MarchingBuffer::new();
386 {
387 let mut writer = alloc.get_writer();
388 write!(writer.access(), "Hello world").unwrap();
389 }
390 {
391 let mut writer = alloc.get_writer();
392 write!(writer.access(), "foo bar").unwrap();
393 assert_eq!(b"foo bar", writer.finish().access().as_slice());
394 }
395 assert_eq!(0, alloc.finished_len());
396 }
397}