#![deny(missing_docs)]
use std::sync::Arc;
use core::sync::atomic::{AtomicUsize, Ordering};
use cache_line_size::CacheAligned;
struct BipBuffer {
sequestered: Box<std::any::Any>,
buf: *mut u8,
len: usize,
read: CacheAligned<AtomicUsize>,
write: CacheAligned<AtomicUsize>,
last: CacheAligned<AtomicUsize>,
}
#[cfg(feature = "debug")]
impl BipBuffer {
fn dbg_info(&self) -> String {
format!(" read: {:?} -- write: {:?} -- last: {:?} [len: {:?}] ",
self.read,
self.write,
self.last,
self.len)
}
}
pub struct BipBufferWriter {
buffer: Arc<BipBuffer>,
write: usize,
last: usize,
}
unsafe impl Send for BipBufferWriter {}
pub struct BipBufferReader {
buffer: Arc<BipBuffer>,
read: usize,
priv_write: usize,
priv_last: usize,
}
unsafe impl Send for BipBufferReader {}
pub fn bip_buffer_from<B: std::ops::DerefMut<Target=[u8]>+'static>(from: B) -> (BipBufferWriter, BipBufferReader) {
let mut sequestered = Box::new(from);
let len = sequestered.len();
let buf = sequestered.as_mut_ptr();
let buffer = Arc::new(BipBuffer {
sequestered,
buf,
len,
read: CacheAligned(AtomicUsize::new(0)),
write: CacheAligned(AtomicUsize::new(0)),
last: CacheAligned(AtomicUsize::new(0)),
});
(
BipBufferWriter {
buffer: buffer.clone(),
write: 0,
last: len,
},
BipBufferReader {
buffer,
read: 0,
priv_write: 0,
priv_last: len,
},
)
}
pub fn bip_buffer_with_len(len: usize) -> (BipBufferWriter, BipBufferReader) {
bip_buffer_from(vec![0u8; len].into_boxed_slice())
}
impl BipBuffer {
fn into_inner<B: std::ops::DerefMut<Target=[u8]>+'static>(self) -> B {
let BipBuffer { sequestered, .. } = self;
*sequestered.downcast::<B>().expect("incorrect underlying type")
}
}
#[derive(Clone, Copy)]
struct PendingReservation {
start: usize,
len: usize,
wraparound: bool,
}
impl BipBufferWriter {
fn reserve_core(&mut self, len: usize) -> Option<PendingReservation> {
assert!(len > 0);
let read = self.buffer.read.0.load(Ordering::Acquire);
if self.write >= read {
if self.buffer.len.saturating_sub(self.write) >= len {
Some(PendingReservation {
start: self.write,
len,
wraparound: false,
})
} else {
if read.saturating_sub(1) >= len {
Some(PendingReservation {
start: 0,
len,
wraparound: true,
})
} else {
None
}
}
} else {
if (read - self.write).saturating_sub(1) >= len {
Some(PendingReservation {
start: self.write,
len,
wraparound: false,
})
} else {
None
}
}
}
pub fn reserve(&mut self, len: usize) -> Option<BipBufferWriterReservation<'_>> {
let reserved = self.reserve_core(len);
if let Some(PendingReservation { start, len, wraparound }) = reserved {
Some(BipBufferWriterReservation { writer: self, start, len, wraparound })
} else {
None
}
}
pub fn spin_reserve(&mut self, len: usize) -> BipBufferWriterReservation<'_> {
assert!(len <= self.buffer.len);
let PendingReservation { start, len, wraparound } = loop {
match self.reserve_core(len) {
None => continue,
Some(r) => break r,
}
};
BipBufferWriterReservation { writer: self, start, len, wraparound }
}
pub fn try_unwrap<B: std::ops::DerefMut<Target=[u8]>+'static>(self) -> Result<B, Self> {
let BipBufferWriter { buffer, write, last, } = self;
match Arc::try_unwrap(buffer) {
Ok(b) => Ok(b.into_inner()),
Err(buffer) => Err(BipBufferWriter { buffer, write, last, }),
}
}
}
pub struct BipBufferWriterReservation<'a> {
writer: &'a mut BipBufferWriter,
start: usize,
len: usize,
wraparound: bool,
}
impl<'a> core::ops::Deref for BipBufferWriterReservation<'a> {
type Target = [u8];
fn deref(&self) -> &[u8] {
unsafe {
core::slice::from_raw_parts(self.writer.buffer.buf.add(self.start), self.len)
}
}
}
impl<'a> core::ops::DerefMut for BipBufferWriterReservation<'a> {
fn deref_mut(&mut self) -> &mut [u8] {
unsafe {
core::slice::from_raw_parts_mut(self.writer.buffer.buf.add(self.start), self.len)
}
}
}
impl<'a> core::ops::Drop for BipBufferWriterReservation<'a> {
fn drop(&mut self) {
if self.wraparound {
self.writer.buffer.last.0.store(self.writer.write, Ordering::Relaxed);
self.writer.write = 0;
}
self.writer.write += self.len;
if self.writer.write > self.writer.last {
self.writer.last = self.writer.write;
self.writer.buffer.last.0.store(self.writer.last, Ordering::Relaxed);
}
self.writer.buffer.write.0.store(self.writer.write, Ordering::Release);
#[cfg(feature = "debug")]
eprintln!("+++{}", self.writer.buffer.dbg_info());
}
}
impl<'a> BipBufferWriterReservation<'a> {
pub fn send(self) {
}
}
impl BipBufferReader {
pub fn valid(&mut self) -> &mut [u8] {
#[cfg(feature = "debug")]
eprintln!("???{}", self.buffer.dbg_info());
self.priv_write = self.buffer.write.0.load(Ordering::Acquire);
if self.priv_write >= self.read {
unsafe {
core::slice::from_raw_parts_mut(self.buffer.buf.add(self.read), self.priv_write - self.read)
}
} else {
self.priv_last = self.buffer.last.0.load(Ordering::Relaxed);
if self.read == self.priv_last {
self.read = 0;
return self.valid();
}
unsafe {
core::slice::from_raw_parts_mut(self.buffer.buf.add(self.read), self.priv_last - self.read)
}
}
}
pub fn consume(&mut self, len: usize) -> bool {
if self.priv_write >= self.read {
if len <= self.priv_write - self.read {
self.read += len;
} else {
return false;
}
} else {
let remaining = self.priv_last - self.read;
if len == remaining {
self.read = 0;
} else if len <= remaining {
self.read += len;
} else {
return false;
}
}
self.buffer.read.0.store(self.read, Ordering::Release);
#[cfg(feature = "debug")]
eprintln!("---{}", self.buffer.dbg_info());
true
}
pub fn try_unwrap<B: std::ops::DerefMut<Target=[u8]>+'static>(self) -> Result<B, Self> {
let BipBufferReader { buffer, read, priv_write, priv_last, } = self;
match Arc::try_unwrap(buffer) {
Ok(b) => Ok(b.into_inner()),
Err(buffer) => Err(BipBufferReader { buffer, read, priv_write, priv_last, }),
}
}
}
#[cfg(test)]
mod tests {
use crate::bip_buffer_from;
#[test]
fn basic() {
for i in 0..128 {
let (mut writer, mut reader) = bip_buffer_from(vec![0u8; 16].into_boxed_slice());
let sender = std::thread::spawn(move || {
writer.reserve(8).as_mut().expect("reserve").copy_from_slice(&[10, 11, 12, 13, 14, 15, 16, i]);
});
let receiver = std::thread::spawn(move || {
while reader.valid().len() < 8 {}
assert_eq!(reader.valid(), &[10, 11, 12, 13, 14, 15, 16, i]);
reader.consume(8);
});
sender.join().unwrap();
receiver.join().unwrap();
}
}
#[test]
fn spsc() {
let (mut writer, mut reader) = bip_buffer_from(vec![0u8; 256].into_boxed_slice());
let sender = std::thread::spawn(move || {
for i in 0..128 {
writer.spin_reserve(8).copy_from_slice(&[10, 11, 12, 13, 14, 15, 16, i]);
}
});
let receiver = std::thread::spawn(move || {
for i in 0..128 {
while reader.valid().len() < 8 {}
assert_eq!(&reader.valid()[..8], &[10, 11, 12, 13, 14, 15, 16, i]);
reader.consume(8);
}
});
sender.join().unwrap();
receiver.join().unwrap();
}
#[test]
fn provided_storage() {
let storage = vec![0u8; 256].into_boxed_slice();
let (mut writer, mut reader) = bip_buffer_from(storage);
let sender = std::thread::spawn(move || {
writer.spin_reserve(8).copy_from_slice(&[10, 11, 12, 13, 14, 15, 16, 17]);
});
let receiver = std::thread::spawn(move || {
while reader.valid().len() < 8 {}
reader.consume(8);
reader
});
sender.join().unwrap();
let reader = receiver.join().unwrap();
let _: Box<[u8]> = reader.try_unwrap().map_err(|_| ()).expect("failed to recover storage");
}
#[test]
#[should_panic]
fn provided_storage_wrong_type() {
let storage = vec![0u8; 256].into_boxed_slice();
let (writer, reader) = bip_buffer_from(storage);
std::mem::drop(writer);
let _: Vec<u8> = reader.try_unwrap().map_err(|_| ()).expect("failed to recover storage");
}
#[test]
fn provided_storage_still_alive() {
let storage = vec![0u8; 256].into_boxed_slice();
let (writer, reader) = bip_buffer_from(storage);
let result: Result<Box<[u8]>, _> = reader.try_unwrap();
assert!(result.is_err());
std::mem::drop(writer);
}
#[test]
fn static_prime_length() {
const MSG_LENGTH: u8 = 17;
let (mut writer, mut reader) = bip_buffer_from(vec![128u8; 64].into_boxed_slice());
let sender = std::thread::spawn(move || {
let mut msg = [0u8; MSG_LENGTH as usize];
for _ in 0..1024 {
for i in 0..128u8 {
&mut msg[..].copy_from_slice(&[i; MSG_LENGTH as usize][..]);
msg[i as usize % (MSG_LENGTH as usize)] = 0;
writer.spin_reserve(MSG_LENGTH as usize).copy_from_slice(&msg[..]);
}
}
});
let receiver = std::thread::spawn(move || {
let mut msg = [0u8; MSG_LENGTH as usize];
for _ in 0..1024 {
for i in 0..128u8 {
&mut msg[..].copy_from_slice(&[i; MSG_LENGTH as usize][..]);
msg[i as usize % (MSG_LENGTH as usize)] = 0;
while reader.valid().len() < (MSG_LENGTH as usize) {}
assert_eq!(&reader.valid()[..MSG_LENGTH as usize], &msg[..]);
assert!(reader.consume(MSG_LENGTH as usize));
}
}
});
sender.join().unwrap();
receiver.join().unwrap();
}
#[test]
fn random_length() {
use rand::Rng;
const MAX_LENGTH: usize = 127;
let (mut writer, mut reader) = bip_buffer_from(vec![0u8; 1024]);
let sender = std::thread::spawn(move || {
let mut rng = rand::thread_rng();
let mut msg = [0u8; MAX_LENGTH];
for _ in 0..1024 {
for round in 0..128u8 {
let length: u8 = rng.gen_range(1, MAX_LENGTH as u8);
msg[0] = length;
for i in 1..length {
msg[i as usize] = round;
}
writer.spin_reserve(length as usize).copy_from_slice(&msg[..length as usize]);
}
}
});
let receiver = std::thread::spawn(move || {
let mut msg = [0u8; MAX_LENGTH];
for _ in 0..1024 {
for round in 0..128u8 {
let msg_len = loop {
let valid = reader.valid();
if valid.len() < 1 { continue; }
break valid[0] as usize;
};
let recv_msg = loop {
let valid = reader.valid();
if valid.len() < msg_len { continue; }
break valid;
};
msg[0] = msg_len as u8;
for i in 1..msg_len {
msg[i as usize] = round;
}
assert_eq!(&recv_msg[..msg_len], &msg[..msg_len]);
assert!(reader.consume(msg_len as usize));
}
}
});
sender.join().unwrap();
receiver.join().unwrap();
}
}