use std::
{
collections::VecDeque,
ops::{Deref, DerefMut},
ptr::{self, NonNull},
sync::atomic::{AtomicU64, Ordering},
fmt,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RwBufferError
{
TooManyRead,
TooManyBase,
ReadTryAgianLater,
WriteTryAgianLater,
OutOfBuffers,
DowngradeFailed,
InvalidArguments
}
impl fmt::Display for RwBufferError
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
match self
{
RwBufferError::TooManyRead =>
write!(f, "TooManyRead: read soft limit reached"),
RwBufferError::TooManyBase =>
write!(f, "TooManyBase: base soft limit reached"),
RwBufferError::ReadTryAgianLater =>
write!(f, "ReadTryAgianLater: shared access not available, try again later"),
RwBufferError::WriteTryAgianLater =>
write!(f, "WriteTryAgianLater: exclusive access not available, try again later"),
RwBufferError::OutOfBuffers =>
write!(f, "OutOfBuffers: no more free bufers are left"),
RwBufferError::DowngradeFailed =>
write!(f, "DowngradeFailed: can not downgrade exclusive to shared, race condition"),
RwBufferError::InvalidArguments =>
write!(f, "InvalidArguments: arguments are not valid"),
}
}
}
pub type RwBufferRes<T> = Result<T, RwBufferError>;
#[derive(Debug, PartialEq, Eq)]
pub struct RBuffer(NonNull<RwBufferInner>);
unsafe impl Send for RBuffer {}
unsafe impl Sync for RBuffer {}
impl RBuffer
{
#[inline]
fn new(inner: NonNull<RwBufferInner>) -> Self
{
return Self(inner);
}
#[cfg(test)]
fn get_flags(&self) -> RwBufferFlags
{
let inner = unsafe{ self.0.as_ref() };
let flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
return flags;
}
pub
fn as_slice(&self) -> &[u8]
{
let inner = unsafe { self.0.as_ref() };
return inner.buf.as_ref().unwrap().as_slice();
}
pub
fn try_inner(mut self) -> Result<Vec<u8>, Self>
{
let inner = unsafe { self.0.as_ref() };
let flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
if flags.read == 1 && flags.write == false && flags.base == 0
{
let inner = unsafe { self.0.as_mut() };
let buf = inner.buf.take().unwrap();
drop(self);
return Ok(buf);
}
else
{
return Err(self);
}
}
fn inner(&self) -> &RwBufferInner
{
return unsafe { self.0.as_ref() };
}
}
impl Deref for RBuffer
{
type Target = Vec<u8>;
fn deref(&self) -> &Vec<u8>
{
let inner = self.inner();
return inner.buf.as_ref().unwrap();
}
}
impl Clone for RBuffer
{
fn clone(&self) -> Self
{
let inner = self.inner();
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
if flags.read().unwrap() == false
{
panic!("too many read references for RBuffer");
}
inner.flags.store(flags.into(), Ordering::Release);
return Self(self.0);
}
}
impl Drop for RBuffer
{
fn drop(&mut self)
{
let inner = self.inner();
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
flags.unread();
if flags.read == 0 && flags.base == 0
{
unsafe { ptr::drop_in_place(self.0.as_ptr()) };
return;
}
inner.flags.store(flags.into(), Ordering::Release);
return;
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct WBuffer
{
buf: NonNull<RwBufferInner>,
downgraded: bool
}
unsafe impl Send for WBuffer{}
impl WBuffer
{
#[inline]
fn new(inner: NonNull<RwBufferInner>) -> Self
{
return Self{ buf: inner, downgraded: false };
}
pub
fn downgrade(mut self) -> RwBufferRes<RBuffer>
{
let inner = unsafe { self.buf.as_ref() };
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
let res = flags.downgrade();
inner.flags.store(flags.into(), Ordering::Release);
if res == true
{
self.downgraded = true;
return Ok(RBuffer::new(self.buf.clone()));
}
else
{
return Err(RwBufferError::DowngradeFailed);
}
}
pub
fn as_slice(&self) -> &[u8]
{
let inner = unsafe { self.buf.as_ref() };
return inner.buf.as_ref().unwrap()
}
}
impl Deref for WBuffer
{
type Target = Vec<u8>;
fn deref(&self) -> &Vec<u8>
{
let inner = unsafe { self.buf.as_ref() };
return inner.buf.as_ref().unwrap();
}
}
impl DerefMut for WBuffer
{
fn deref_mut(&mut self) -> &mut Vec<u8>
{
let inner = unsafe { self.buf.as_mut() };
return inner.buf.as_mut().unwrap();
}
}
impl Drop for WBuffer
{
fn drop(&mut self)
{
if self.downgraded == true
{
return;
}
let inner = unsafe { self.buf.as_ref() };
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
flags.unwrite();
if flags.read == 0 && flags.base == 0
{
unsafe { ptr::drop_in_place(self.buf.as_ptr()) };
return;
}
inner.flags.store(flags.into(), Ordering::Release);
return;
}
}
#[repr(align(8))]
#[derive(Debug, PartialEq, Eq)]
struct RwBufferFlags
{
read: u32,
write: bool,
base: u16,
unused0: u8 }
impl From<u64> for RwBufferFlags
{
fn from(value: u64) -> Self
{
return unsafe { std::mem::transmute(value) };
}
}
impl From<RwBufferFlags> for u64
{
fn from(value: RwBufferFlags) -> Self
{
return unsafe { std::mem::transmute(value) };
}
}
impl Default for RwBufferFlags
{
fn default() -> Self
{
return
Self
{
read: 0,
write: false,
base: 1,
unused0: 0,
};
}
}
impl RwBufferFlags
{
pub const MAX_READ_REFS: u32 = u32::MAX - 2;
pub const MAX_BASE_REFS: u16 = u16::MAX - 2;
#[inline]
fn base(&mut self) -> bool
{
self.base += 1;
return self.base <= Self::MAX_BASE_REFS;
}
#[inline]
fn unbase(&mut self) -> bool
{
self.base -= 1;
return self.base != 0;
}
#[inline]
fn unread(&mut self)
{
self.read -= 1;
}
#[inline]
fn downgrade(&mut self) -> bool
{
if self.write == true
{
self.write = false;
self.read += 1;
return true;
}
else
{
return false;
}
}
#[inline]
fn read(&mut self) -> RwBufferRes<bool>
{
if self.write == false
{
self.read += 1;
return Ok(self.read <= Self::MAX_READ_REFS);
}
return Err(RwBufferError::ReadTryAgianLater);
}
#[inline]
fn write(&mut self) -> RwBufferRes<()>
{
if self.read == 0
{
self.write = true;
return Ok(());
}
else
{
return Err(RwBufferError::WriteTryAgianLater);
}
}
#[inline]
fn unwrite(&mut self)
{
self.write = false;
}
}
#[derive(Debug)]
pub struct RwBufferInner
{
flags: AtomicU64,
buf: Option<Vec<u8>>,
}
impl RwBufferInner
{
fn new(buf_size: usize) -> Self
{
return
Self
{
flags: AtomicU64::new(RwBufferFlags::default().into()),
buf: Some(vec![0_u8; buf_size])
};
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct RwBuffer(NonNull<RwBufferInner>);
unsafe impl Send for RwBuffer {}
unsafe impl Sync for RwBuffer {}
impl RwBuffer
{
#[inline]
fn new(buf_size: usize) -> Self
{
let status = Box::new(RwBufferInner::new(buf_size));
return Self(Box::leak(status).into());
}
#[inline]
fn inner(&self) -> &RwBufferInner
{
return unsafe { self.0.as_ref() };
}
#[inline]
pub
fn is_free(&self) -> bool
{
let inner = self.inner();
let flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
return flags.write == false && flags.read == 0 && flags.base == 1;
}
#[inline]
pub
fn accure_if_free(&self) -> bool
{
let inner = self.inner();
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
let res =
if flags.write == false && flags.read == 0 && flags.base == 1
{
let _ = flags.base();
true
}
else
{
false
};
inner.flags.store(flags.into(), Ordering::Release);
return res;
}
pub
fn write(&self) -> RwBufferRes<WBuffer>
{
let inner = self.inner();
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
let res = flags.write();
inner.flags.store(flags.into(), Ordering::Release);
res?;
return Ok(WBuffer::new(self.0.clone()));
}
pub
fn read(&self) -> RwBufferRes<RBuffer>
{
let inner = self.inner();
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
let res = flags.read();
inner.flags.store(flags.into(), Ordering::Release);
if res? == false
{
return Err(RwBufferError::TooManyRead);
}
return Ok(RBuffer::new(self.0.clone()));
}
#[cfg(test)]
fn get_flags(&self) -> RwBufferFlags
{
let inner = self.inner();
let flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
return flags;
}
}
impl Clone for RwBuffer
{
fn clone(&self) -> Self
{
let inner = self.inner();
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
if flags.base() == false
{
panic!("too many base references for RBuffer");
}
inner.flags.store(flags.into(), Ordering::Release);
return Self(self.0.clone());
}
}
impl Drop for RwBuffer
{
fn drop(&mut self)
{
let inner = self.inner();
let mut flags: RwBufferFlags = inner.flags.load(Ordering::Acquire).into();
let unbased = flags.unbase();
if flags.read == 0 && flags.write == false && unbased == false
{
unsafe { ptr::drop_in_place(self.0.as_ptr()) };
return;
}
inner.flags.store(flags.into(), Ordering::Release);
}
}
#[derive(Debug)]
pub struct RwBuffers
{
buf_len: usize,
bufs_cnt_lim: usize,
buffs: VecDeque<RwBuffer>
}
impl RwBuffers
{
pub
fn new(buf_len: usize, pre_init_cnt: usize, bufs_cnt_lim: usize) -> RwBufferRes<Self>
{
if pre_init_cnt > bufs_cnt_lim
{
return Err(RwBufferError::InvalidArguments);
}
else if buf_len == 0
{
return Err(RwBufferError::InvalidArguments);
}
let buffs: VecDeque<RwBuffer> =
if pre_init_cnt > 0
{
let mut buffs = VecDeque::with_capacity(bufs_cnt_lim);
for _ in 0..pre_init_cnt
{
buffs.push_back(RwBuffer::new(buf_len));
}
buffs
}
else
{
VecDeque::with_capacity(bufs_cnt_lim)
};
return Ok(
Self
{
buf_len: buf_len,
bufs_cnt_lim: bufs_cnt_lim,
buffs: buffs,
}
)
}
pub
fn new_unbounded(buf_len: usize, pre_init_cnt: usize) -> Self
{
let mut buffs = VecDeque::with_capacity(pre_init_cnt);
for _ in 0..pre_init_cnt
{
buffs.push_back(RwBuffer::new(buf_len));
}
return
Self
{
buf_len: buf_len,
bufs_cnt_lim: 0,
buffs: buffs,
};
}
pub
fn allocate(&mut self) -> RwBufferRes<RwBuffer>
{
for buf in self.buffs.iter()
{
if buf.is_free() == true
{
return Ok(buf.clone());
}
}
if self.bufs_cnt_lim == 0 || self.buffs.len() < self.bufs_cnt_lim
{
let buf = RwBuffer::new(self.buf_len);
let c_buf = buf.clone();
self.buffs.push_back(buf);
return Ok(c_buf);
}
return Err(RwBufferError::OutOfBuffers);
}
pub
fn allocate_in_place(&mut self) -> RwBuffer
{
for i in 0..self.buffs.len()
{
if self.buffs[i].accure_if_free() == true
{
return self.buffs.remove(i).unwrap();
}
}
let buf = RwBuffer::new(self.buf_len);
return buf;
}
pub
fn compact(&mut self, mut cnt: usize) -> usize
{
let p_cnt = cnt;
self
.buffs
.retain(
|buf|
{
if buf.is_free() == true
{
cnt -= 1;
return false;
}
return true;
}
);
return p_cnt - cnt;
}
#[cfg(test)]
fn get_flags_by_index(&self, index: usize) -> Option<RwBufferFlags>
{
return Some(self.buffs.get(index)?.get_flags());
}
}
#[cfg(test)]
mod tests
{
use std::time::{Duration, Instant};
use super::*;
#[test]
fn simple_test()
{
let mut bufs = RwBuffers::new(4096, 1, 2).unwrap();
let buf0_res = bufs.allocate();
assert_eq!(buf0_res.is_ok(), true, "{:?}", buf0_res.err().unwrap());
let buf0 = buf0_res.unwrap();
let buf0_w = buf0.write();
assert_eq!(buf0_w.is_ok(), true, "{:?}", buf0_w.err().unwrap());
assert_eq!(buf0.read(), Err(RwBufferError::ReadTryAgianLater));
drop(buf0_w);
let buf0_r = buf0.read();
assert_eq!(buf0_r.is_ok(), true, "{:?}", buf0_r.err().unwrap());
assert_eq!(buf0.write(), Err(RwBufferError::WriteTryAgianLater));
let buf0_1 = buf0.clone();
assert_eq!(buf0_1.write(), Err(RwBufferError::WriteTryAgianLater));
let flags0 = buf0.get_flags();
let flags0_1 = buf0_1.get_flags();
assert_eq!(flags0, flags0_1);
assert_eq!(flags0.base, 3);
assert_eq!(flags0.read, 1);
assert_eq!(flags0.write, false);
}
#[test]
fn simple_test_dopped_in_place()
{
let mut bufs = RwBuffers::new(4096, 1, 2).unwrap();
let buf0_res = bufs.allocate();
assert_eq!(buf0_res.is_ok(), true, "{:?}", buf0_res.err().unwrap());
let buf0 = buf0_res.unwrap();
println!("{:?}", buf0.get_flags());
let buf0_w = buf0.write();
assert_eq!(buf0_w.is_ok(), true, "{:?}", buf0_w.err().unwrap());
assert_eq!(buf0.read(), Err(RwBufferError::ReadTryAgianLater));
drop(buf0);
let buf0_flags = bufs.get_flags_by_index(0);
assert_eq!(buf0_flags.is_some(), true, "no flags");
let buf0_flags = buf0_flags.unwrap();
println!("{:?}", buf0_flags);
assert_eq!(buf0_flags.base, 1);
assert_eq!(buf0_flags.read, 0);
assert_eq!(buf0_flags.write, true);
drop(buf0_w.unwrap());
let buf0_flags = bufs.get_flags_by_index(0);
assert_eq!(buf0_flags.is_some(), true, "no flags");
let buf0_flags = buf0_flags.unwrap();
println!("{:?}", buf0_flags);
assert_eq!(buf0_flags.base, 1);
assert_eq!(buf0_flags.read, 0);
assert_eq!(buf0_flags.write, false);
}
#[test]
fn simple_test_dropped_in_place_downgrade()
{
let mut bufs = RwBuffers::new(4096, 1, 2).unwrap();
let buf0_res = bufs.allocate();
assert_eq!(buf0_res.is_ok(), true, "{:?}", buf0_res.err().unwrap());
let buf0 = buf0_res.unwrap();
println!("{:?}", buf0.get_flags());
let buf0_w = buf0.write();
assert_eq!(buf0_w.is_ok(), true, "{:?}", buf0_w.err().unwrap());
assert_eq!(buf0.read(), Err(RwBufferError::ReadTryAgianLater));
drop(buf0);
let buf0_rd = buf0_w.unwrap().downgrade();
assert_eq!(buf0_rd.is_ok(), true, "{:?}", buf0_rd.err().unwrap());
let buf0_flags = bufs.get_flags_by_index(0);
assert_eq!(buf0_flags.is_some(), true, "no flags");
let buf0_flags = buf0_flags.unwrap();
println!("{:?}", buf0_flags);
assert_eq!(buf0_flags.base, 1);
assert_eq!(buf0_flags.read, 1);
assert_eq!(buf0_flags.write, false);
}
#[test]
fn simple_test_drop_in_place_downgrade()
{
let mut bufs = RwBuffers::new(4096, 1, 2).unwrap();
let buf0 = bufs.allocate_in_place();
println!("{:?}", buf0.get_flags());
let buf0_w = buf0.write();
assert_eq!(buf0_w.is_ok(), true, "{:?}", buf0_w.err().unwrap());
assert_eq!(buf0.read(), Err(RwBufferError::ReadTryAgianLater));
drop(buf0);
let buf0_rd = buf0_w.unwrap().downgrade();
assert_eq!(buf0_rd.is_ok(), true, "{:?}", buf0_rd.err().unwrap());
let buf0_flags = bufs.get_flags_by_index(0);
assert_eq!(buf0_flags.is_some(), false, "flags");
let buf0_rd = buf0_rd.unwrap();
let buf0_flags = buf0_rd.get_flags();
println!("{:?}", buf0_flags);
assert_eq!(buf0_flags.base, 0);
assert_eq!(buf0_flags.read, 1);
assert_eq!(buf0_flags.write, false);
}
#[test]
fn timing_test()
{
let mut bufs = RwBuffers::new(4096, 1, 2).unwrap();
for _ in 0..10
{
let inst = Instant::now();
let buf0_res = bufs.allocate_in_place();
let end = inst.elapsed();
println!("alloc: {:?}", end);
drop(buf0_res);
}
let buf0_res = bufs.allocate();
assert_eq!(buf0_res.is_ok(), true, "{:?}", buf0_res.err().unwrap());
let buf0 = buf0_res.unwrap();
for _ in 0..10
{
let inst = Instant::now();
let buf0_w = buf0.write();
let end = inst.elapsed();
println!("write: {:?}", end);
assert_eq!(buf0_w.is_ok(), true, "{:?}", buf0_w.err().unwrap());
assert_eq!(buf0.read(), Err(RwBufferError::ReadTryAgianLater));
drop(buf0_w);
}
for _ in 0..10
{
let inst = Instant::now();
let buf0_r = buf0.read();
let end = inst.elapsed();
println!("read: {:?}", end);
assert_eq!(buf0_r.is_ok(), true, "{:?}", buf0_r.err().unwrap());
assert_eq!(buf0.write(), Err(RwBufferError::WriteTryAgianLater));
drop(buf0_r);
}
}
#[test]
fn simple_test_mth()
{
let mut bufs = RwBuffers::new(4096, 1, 3).unwrap();
let buf0 = bufs.allocate().unwrap();
let buf0_rd = buf0.write().unwrap().downgrade().unwrap();
let join1=
std::thread::spawn(move ||
{
println!("{:?}", buf0_rd);
std::thread::sleep(Duration::from_secs(2));
return;
}
);
let buf1_rd = buf0.read().unwrap();
let join2=
std::thread::spawn(move ||
{
println!("{:?}", buf1_rd);
std::thread::sleep(Duration::from_secs(2));
return;
}
);
let flags = buf0.get_flags();
assert_eq!(flags.base, 2);
assert_eq!(flags.read, 2);
assert_eq!(flags.write, false);
let _ = join1.join();
let _ = join2.join();
let flags = buf0.get_flags();
assert_eq!(flags.base, 2);
assert_eq!(flags.read, 0);
assert_eq!(flags.write, false);
}
#[test]
fn test_try_into_read()
{
let mut bufs = RwBuffers::new(4096, 1, 2).unwrap();
let buf0 = bufs.allocate_in_place();
println!("{:?}", buf0.get_flags());
let buf0_w = buf0.write();
assert_eq!(buf0_w.is_ok(), true, "{:?}", buf0_w.err().unwrap());
assert_eq!(buf0.read(), Err(RwBufferError::ReadTryAgianLater));
drop(buf0);
let buf0_rd = buf0_w.unwrap().downgrade();
assert_eq!(buf0_rd.is_ok(), true, "{:?}", buf0_rd.err().unwrap());
let buf0_flags = bufs.get_flags_by_index(0);
assert_eq!(buf0_flags.is_some(), false, "flags");
let buf0_rd = buf0_rd.unwrap();
let buf0_flags = buf0_rd.get_flags();
println!("{:?}", buf0_flags);
assert_eq!(buf0_flags.base, 0);
assert_eq!(buf0_flags.read, 1);
assert_eq!(buf0_flags.write, false);
let inst = Instant::now();
let ve = buf0_rd.try_inner();
let end = inst.elapsed();
println!("try inner: {:?}", end);
assert_eq!(ve.is_ok(), true);
}
#[tokio::test]
async fn test_multithreading()
{
let mut bufs = RwBuffers::new(4096, 1, 3).unwrap();
let buf0 = bufs.allocate().unwrap();
let mut buf0_write = buf0.write().unwrap();
buf0_write.as_mut_slice()[0] = 5;
buf0_write.as_mut_slice()[1] = 4;
println!("{}", buf0_write[0]);
let buf0_r = buf0_write.downgrade().unwrap();
let join1=
tokio::task::spawn(async move
{
println!("thread[1]:{}", buf0_r[0]);
tokio::time::sleep(Duration::from_millis(200)).await;
return;
}
);
let buf0_r = buf0.read().unwrap();
drop(buf0);
let join2=
tokio::task::spawn(async move
{
println!("thread[2]: {}", buf0_r[0]);
println!("thread[2]: {}", buf0_r[1]);
tokio::time::sleep(Duration::from_millis(200)).await;
return;
}
);
let _ = join1.await;
let _ = join2.await;
return;
}
}