1use arc_swap::ArcSwap;
2use std::alloc::{self, Layout};
3use std::ops::Deref;
4use std::ptr::NonNull;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7
8struct RawBuf<T> {
9 ptr: NonNull<T>,
10 len: AtomicUsize,
11 cap: usize,
12}
13
14impl<T> RawBuf<T> {
15 #[inline]
16 const fn new(ptr: NonNull<T>, len: usize, cap: usize) -> Self {
17 Self {
18 ptr,
19 len: AtomicUsize::new(len),
20 cap,
21 }
22 }
23
24 #[inline]
25 const fn empty() -> Self {
26 Self::new(std::ptr::NonNull::dangling(), 0, 0)
27 }
28
29 #[inline]
31 fn allocate(init_len: usize, cap: usize) -> Self {
32 if cap == 0 {
33 return Self::empty();
34 }
35
36 let layout = Layout::array::<T>(cap).unwrap();
40
41 assert!(layout.size() <= isize::MAX as usize, "Allocation too large");
43
44 let ptr = unsafe { alloc::alloc(layout) };
45
46 let Some(new_ptr) = NonNull::new(ptr.cast::<T>()) else {
48 alloc::handle_alloc_error(layout)
49 };
50
51 RawBuf::new(new_ptr, init_len, cap)
52 }
53}
54
55impl<T> RawBuf<T>
56where
57 T: Copy,
58{
59 fn allocate_copy(&self, len: usize, new_cap: Option<usize>) -> Self {
61 let new_cap = new_cap.unwrap_or((self.cap * 2).max(1));
62 debug_assert!(new_cap >= self.cap);
63
64 let new_buf = Self::allocate(len, new_cap);
65 if self.cap != 0 {
66 let old_ptr = self.ptr.as_ptr().cast::<u8>();
67 let new_ptr = new_buf.ptr.as_ptr().cast::<u8>();
69 let old_layout_len = Layout::array::<T>(len).unwrap();
71 unsafe { std::ptr::copy_nonoverlapping(old_ptr, new_ptr, old_layout_len.size()) };
72 }
73 new_buf
74 }
75}
76
77impl<T> Deref for RawBuf<T> {
78 type Target = NonNull<T>;
79
80 #[inline]
81 fn deref(&self) -> &Self::Target {
82 &self.ptr
83 }
84}
85
86unsafe impl<T: Send> Send for RawBuf<T> {}
87unsafe impl<T: Sync> Sync for RawBuf<T> {}
88
89impl<T> Drop for RawBuf<T> {
90 fn drop(&mut self) {
91 let cap = self.cap;
92 if cap != 0 {
93 unsafe {
95 std::ptr::drop_in_place(std::ptr::slice_from_raw_parts_mut(
96 self.ptr.as_ptr(),
97 self.len.load(Ordering::Relaxed),
98 ));
99 }
100 unsafe {
101 alloc::dealloc(
102 self.ptr.as_ptr().cast::<u8>(),
103 Layout::array::<T>(cap).unwrap(),
104 );
105 }
106 }
107 }
108}
109
110pub struct CowVecWriter<T> {
112 buf: Arc<ArcSwap<RawBuf<T>>>,
113}
114
115impl<T> CowVecWriter<T>
116where
117 T: Copy,
118{
119 pub fn push(&mut self, elem: T) {
123 let buf = self.buf.load();
124 let len = buf.len.load(Ordering::Acquire);
125 let cap = buf.cap;
126
127 let push_inner = move |buf: &RawBuf<T>| {
128 unsafe { std::ptr::write(buf.ptr.as_ptr().add(len), elem) }
129 buf.len.store(len + 1, Ordering::Release);
130 };
131
132 if len == cap {
133 push_inner(&self.grow(&buf, len, None))
135 } else {
136 push_inner(&buf)
137 }
138 }
139
140 #[allow(dead_code)]
145 pub fn insert(&mut self, index: usize, elem: T) {
146 let buf = self.buf.load();
151 let len = buf.len.load(Ordering::Acquire);
152
153 assert!(index <= len, "index out of bounds");
154 let mut new_buf = if buf.cap == len {
155 buf.allocate_copy(index, None)
156 } else {
157 buf.allocate_copy(index, Some(buf.cap))
158 };
159
160 unsafe {
161 std::ptr::copy_nonoverlapping(
163 buf.as_ptr().add(index),
164 new_buf.as_ptr().add(index + 1),
165 len - index,
166 );
167 std::ptr::write(new_buf.as_ptr().add(index), elem);
168 }
169
170 *new_buf.len.get_mut() = len + 1;
171
172 self.buf.store(Arc::new(new_buf))
173 }
174
175 pub fn reserve(&mut self, additional: usize) {
181 let buf = self.buf.load();
182 let len = buf.len.load(Ordering::Acquire);
183 if len.saturating_add(additional) > buf.cap {
184 self.grow(&buf, len, Some(buf.cap + additional));
185 }
186 }
187
188 fn grow(&mut self, buf: &RawBuf<T>, len: usize, new_cap: Option<usize>) -> Arc<RawBuf<T>> {
190 let ret = Arc::new(buf.allocate_copy(len, new_cap));
191 self.buf.store(ret.clone());
192 ret
193 }
194}
195
196impl<T> Deref for CowVecWriter<T> {
197 type Target = [T];
198
199 #[inline]
200 fn deref(&self) -> &Self::Target {
201 let buf = self.buf.load();
205 let len = buf.len.load(Ordering::SeqCst);
206 unsafe { std::slice::from_raw_parts(buf.as_ptr(), len) }
207 }
208}
209
210#[derive(Clone)]
221pub struct CowVec<T> {
222 buf: Arc<ArcSwap<RawBuf<T>>>,
223}
224
225impl<T> CowVec<T> {
226 #[inline]
230 pub fn new() -> (Self, CowVecWriter<T>) {
231 assert!(std::mem::size_of::<T>() != 0);
232 let buf = Arc::new(ArcSwap::from_pointee(RawBuf::empty()));
233 (Self { buf: buf.clone() }, CowVecWriter { buf })
234 }
235
236 #[allow(dead_code)]
242 pub fn with_capacity(cap: usize) -> (Self, CowVecWriter<T>) {
243 assert!(std::mem::size_of::<T>() != 0);
244 let buf = Arc::new(ArcSwap::from_pointee(RawBuf::allocate(0, cap)));
245 (Self { buf: buf.clone() }, CowVecWriter { buf })
246 }
247
248 #[inline]
250 pub fn empty() -> Self {
251 Self::new().0
252 }
253
254 pub fn len(&self) -> usize {
256 self.read(|slice| slice.len())
257 }
258
259 pub fn is_empty(&self) -> bool {
261 self.len() == 0
262 }
263
264 #[inline(always)]
265 fn read<F, R>(&self, cb: F) -> R
266 where
267 F: FnOnce(&[T]) -> R,
268 {
269 let buf = self.buf.load();
270 let len = buf.len.load(Ordering::SeqCst);
271 cb(unsafe { std::slice::from_raw_parts(buf.as_ptr(), len) })
272 }
273
274 pub fn snapshot(&self) -> CowVecSnapshot<T> {
279 let buf = self.buf.load_full();
280 CowVecSnapshot {
281 len: buf.len.load(Ordering::SeqCst),
282 buf,
283 }
284 }
285}
286
287impl<T> CowVec<T>
288where
289 T: Copy,
290{
291 pub fn get(&self, index: usize) -> Option<T> {
293 self.read(|slice| slice.get(index).copied())
294 }
295
296 #[allow(dead_code)]
298 pub unsafe fn get_unchecked(&self, index: usize) -> T {
299 self.get(index).unwrap_unchecked()
300 }
301}
302
303#[macro_export]
304macro_rules! cowvec {
305 () => (
306 $crate::vec::CowVec::new()
307 );
308 ($($x:expr),+ $(,)?) => ({
309 let mut vec = $crate::cowvec::CowVec::new();
310 $(vec.push($x);)+
311 vec
312 });
313}
314
315impl<T: Copy> From<Vec<T>> for CowVec<T> {
316 fn from(vec: Vec<T>) -> Self {
317 let mut me = std::mem::ManuallyDrop::new(vec);
318 let (ptr, len, cap) = (me.as_mut_ptr(), me.len(), me.capacity());
319
320 Self {
321 buf: Arc::new(ArcSwap::from_pointee(RawBuf::new(
322 NonNull::new(ptr).unwrap(),
323 len,
324 cap,
325 ))),
326 }
327 }
328}
329
330pub struct CowVecSnapshot<T> {
331 buf: Arc<RawBuf<T>>,
332 len: usize,
333}
334
335impl<T> CowVecSnapshot<T>
336where
337 T: Copy,
338{
339 pub fn get(&self, index: usize) -> Option<T> {
341 self.deref().get(index).copied()
342 }
343
344 pub unsafe fn get_unchecked(&self, index: usize) -> T {
346 self.get(index).unwrap_unchecked()
347 }
348
349 pub fn as_slice(&self) -> &[T] {
353 self
354 }
355}
356
357impl<T> Deref for CowVecSnapshot<T> {
358 type Target = [T];
359
360 #[inline(always)]
361 fn deref(&self) -> &Self::Target {
362 let buf = &self.buf;
364 let len = self.len;
365 unsafe { std::slice::from_raw_parts(buf.as_ptr(), len) }
366 }
367}
368
369#[cfg(test)]
370mod test {
371 use super::CowVec;
372
373 #[test]
374 fn test_miri_push_and_access() {
375 let (arr, mut writer) = CowVec::new();
376 for i in 0..10000 {
377 writer.push(i);
378 }
379 for i in 0..10000 {
380 assert_eq!(Some(i), arr.get(i));
381 }
382 }
383
384 #[test]
385 fn test_miri_push_and_concurrent_access() {
386 let (arr, mut writer) = CowVec::new();
387 let handle = std::thread::spawn({
388 move || {
389 for _ in 0..10 {
390 for i in 0..1000 {
391 writer.push(i);
392 }
393 std::thread::sleep(std::time::Duration::from_millis(100));
394 }
395 }
396 });
397
398 while !handle.is_finished() {
399 for i in 0..arr.len() {
400 assert_eq!(Some(i % 1000), arr.get(i));
401 }
402 }
403
404 handle.join().unwrap();
405 }
406
407 #[test]
408 fn test_miri_push_and_concurrent_access_snapshot() {
409 let (arr, mut writer) = CowVec::new();
410 let handle = std::thread::spawn({
411 move || {
412 for _ in 0..10 {
413 for i in 0..1000 {
414 writer.push(i);
415 }
416 std::thread::sleep(std::time::Duration::from_millis(100));
417 }
418 }
419 });
420
421 while !handle.is_finished() {
422 let slice = arr.snapshot();
423 for i in slice.iter().copied() {
424 assert_eq!(i, slice[i]);
425 }
426 }
427
428 handle.join().unwrap();
429 }
430
431 #[test]
432 fn test_miri_clone() {
433 let (arr, mut writer) = CowVec::new();
434 for i in 0..10 {
435 writer.push(i);
436 }
437 let cloned_arr = arr.clone();
438 assert_eq!(arr.len(), cloned_arr.len());
439 for i in 0..10 {
440 assert_eq!(arr.get(i), cloned_arr.get(i));
441 }
442 writer.push(10);
443 assert_eq!(arr.get(10), cloned_arr.get(10));
444 assert_eq!(arr.len(), cloned_arr.len());
445 }
446
447 #[test]
448 fn test_miri_deref() {
449 let (arr, mut writer) = CowVec::new();
450 for i in 0..10 {
451 writer.push(i);
452 }
453 let snap = arr.snapshot();
454 let slice: &[i32] = &snap;
455 assert_eq!(slice.len(), arr.len());
456 for i in 0..10 {
457 assert_eq!(slice.get(i).copied(), arr.get(i));
458 assert_eq!(snap.get(i), arr.get(i));
459 }
460 }
461
462 #[test]
463 fn test_miri_with_capacity() {
464 let (arr, mut writer) = CowVec::with_capacity(100);
465 let init_ptr = arr.buf.load().as_ptr();
466 for i in 0..100 {
467 writer.push(i);
468 }
469 let mid_ptr = arr.buf.load().as_ptr();
470 assert_eq!(init_ptr, mid_ptr);
471 writer.push(100);
472 let final_ptr = arr.buf.load().as_ptr();
473 assert_ne!(mid_ptr, final_ptr);
474 }
475
476 #[test]
477 fn test_miri_reserve() {
478 let (arr, mut writer) = CowVec::new();
479 writer.reserve(100);
480 let init_ptr = arr.buf.load().as_ptr();
481 for i in 0..100 {
482 writer.push(i);
483 }
484 let mid_ptr = arr.buf.load().as_ptr();
485 assert_eq!(init_ptr, mid_ptr);
486 writer.push(100);
487 let final_ptr = arr.buf.load().as_ptr();
488 assert_ne!(mid_ptr, final_ptr);
489 }
490
491 #[test]
492 fn test_miri_insert() {
493 let (arr, mut writer) = CowVec::new();
494 for i in (0..100).step_by(10) {
495 writer.push(i);
496 }
497
498 let expected = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90];
499 for (i, expected) in expected.into_iter().enumerate() {
500 assert_eq!(Some(expected), arr.get(i));
501 }
502
503 writer.insert(1, 5);
504 let expected = [0, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90];
505 for (i, expected) in expected.into_iter().enumerate() {
506 assert_eq!(Some(expected), arr.get(i));
507 }
508
509 writer.insert(1, 5);
510 let expected = [0, 5, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90];
511 for (i, expected) in expected.into_iter().enumerate() {
512 assert_eq!(Some(expected), arr.get(i));
513 }
514
515 writer.insert(12, 100);
516 let expected = [0, 5, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100];
517 for (i, expected) in expected.into_iter().enumerate() {
518 assert_eq!(Some(expected), arr.get(i));
519 }
520
521 writer.insert(0, 1);
522 let expected = [1, 0, 5, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100];
523 for (i, expected) in expected.into_iter().enumerate() {
524 assert_eq!(Some(expected), arr.get(i));
525 }
526 }
527}