use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Mutex, MutexGuard};
use memmap2::MmapMut;
pub(crate) struct Buffer {
mmap: UnsafeCell<MmapMut>,
filled: AtomicUsize,
lock: Mutex<()>,
}
unsafe impl Send for Buffer {}
unsafe impl Sync for Buffer {}
pub(crate) struct BufferWrite<'buffer> {
buffer: &'buffer Buffer,
_guard: MutexGuard<'buffer, ()>,
}
impl Buffer {
pub(crate) fn new(capacity: usize) -> Buffer {
Buffer {
mmap: UnsafeCell::new(MmapMut::map_anon(capacity).unwrap()),
filled: AtomicUsize::new(0usize),
lock: Mutex::new(()),
}
}
pub(crate) fn write(&self) -> BufferWrite<'_> {
BufferWrite {
buffer: self,
_guard: self.lock.lock().unwrap(),
}
}
pub(crate) fn read(&self) -> &[u8] {
let end = self.filled.load(Ordering::SeqCst);
let data = unsafe { &*self.mmap.get() };
&data[..end]
}
pub(crate) fn available(&self) -> usize {
self.filled.load(Ordering::SeqCst)
}
}
impl<'buffer> BufferWrite<'buffer> {
pub(crate) fn written(self, len: usize) {
self.buffer.filled.fetch_add(len, Ordering::SeqCst);
}
}
impl<'buffer> Deref for BufferWrite<'buffer> {
type Target = [u8];
fn deref(&self) -> &[u8] {
let start = self.buffer.filled.load(Ordering::SeqCst);
let data = unsafe { &*self.buffer.mmap.get() };
&data[start..]
}
}
impl<'buffer> DerefMut for BufferWrite<'buffer> {
fn deref_mut(&mut self) -> &mut [u8] {
let start = self.buffer.filled.load(Ordering::SeqCst);
let data = unsafe { &mut *self.buffer.mmap.get() };
&mut data[start..]
}
}
#[cfg(test)]
mod test {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_buffer_write() {
let b = Arc::new(Buffer::new(20));
let b2 = b.clone();
let b3 = b.clone();
let mut w = b.write();
w[0] = 42;
w.written(1);
let t1 = thread::spawn(move || {
let mut w = b2.write();
w[0] = 64;
w.written(1);
});
let t2 = thread::spawn(move || {
let mut w = b3.write();
w[0] = 81;
w.written(1);
});
t1.join().unwrap();
t2.join().unwrap();
let mut w = b.write();
w[0] = 101;
w[1] = 99;
w.written(2);
assert_eq!(b.read().len(), 5);
assert_eq!(b.read()[0], 42);
assert_eq!(b.read()[1] + b.read()[2], 64 + 81);
assert_eq!(b.read()[3], 101);
assert_eq!(b.read()[4], 99);
}
}