#![cfg_attr(not(test), no_std)]
#![forbid(unsafe_code)]
#![doc = include_str!("../README.md")]
pub struct BytearrayRingbuffer<const N: usize> {
buffer: [u8; N],
head: usize,
tail: usize,
count: usize,
}
#[derive(Copy, Clone, Debug)]
pub struct NotEnoughSpaceError;
pub struct MultipartPush<'a, const N: usize> {
buf: &'a mut BytearrayRingbuffer<N>,
start: usize,
len: usize,
force: bool,
cancelled: bool,
}
impl<'a, const N: usize> MultipartPush<'a, N> {
pub fn push(&mut self, data: &[u8]) -> Result<(), NotEnoughSpaceError> {
if data.is_empty() {
return Ok(());
}
if self.len + data.len() > N - 8 {
return Err(NotEnoughSpaceError);
}
let needed = data.len() + 4;
if self.force {
while self.buf.bytes_unused() < needed && !self.buf.empty() {
self.buf.pop_front();
}
if self.buf.bytes_unused() < needed {
return Err(NotEnoughSpaceError);
}
} else if self.buf.bytes_unused() < needed {
return Err(NotEnoughSpaceError);
}
write_wrapping(&mut self.buf.buffer, self.buf.head, data);
self.buf.head = add_wrapping::<N>(self.buf.head, data.len());
self.len += data.len();
Ok(())
}
pub fn cancel(mut self) {
self.cancelled = true;
self.buf.head = self.start;
}
}
impl<'a, const N: usize> Drop for MultipartPush<'a, N> {
fn drop(&mut self) {
if self.cancelled {
return;
}
let len_bytes: [u8; 4] = (self.len as u32).to_ne_bytes();
write_wrapping(&mut self.buf.buffer, self.start, &len_bytes);
write_wrapping(&mut self.buf.buffer, self.buf.head, &len_bytes);
self.buf.head = add_wrapping::<N>(self.buf.head, 4);
self.buf.count += 1;
}
}
impl<const N: usize> BytearrayRingbuffer<N> {
pub const fn new() -> Self {
assert!(N > 8);
assert!(N < (u32::MAX as usize));
Self {
buffer: [0; N],
head: 0,
tail: 0,
count: 0,
}
}
pub const fn free(&self) -> usize {
self.bytes_unused().saturating_sub(8)
}
pub fn push(&mut self, data: &[u8]) -> Result<(), NotEnoughSpaceError> {
self._push(data, false)
}
pub fn push_force(&mut self, data: &[u8]) -> Result<(), NotEnoughSpaceError> {
self._push(data, true)
}
pub fn push_multipart(&mut self) -> Result<MultipartPush<'_, N>, NotEnoughSpaceError> {
if self.bytes_unused() < 8 {
return Err(NotEnoughSpaceError);
}
let start = self.head;
self.head = add_wrapping::<N>(self.head, 4);
Ok(MultipartPush {
buf: self,
start,
len: 0,
force: false,
cancelled: false,
})
}
pub fn push_multipart_force(&mut self) -> MultipartPush<'_, N> {
while self.bytes_unused() < 8 && !self.empty() {
self.pop_front();
}
let start = self.head;
self.head = add_wrapping::<N>(self.head, 4);
MultipartPush {
buf: self,
start,
len: 0,
force: true,
cancelled: false,
}
}
#[inline(always)]
pub const fn empty(&self) -> bool {
self.count == 0
}
const fn bytes_unused(&self) -> usize {
if self.empty() {
N
} else if self.head > self.tail {
N + self.tail - self.head
} else {
self.tail - self.head
}
}
fn _push(&mut self, data: &[u8], force: bool) -> Result<(), NotEnoughSpaceError> {
assert!(data.len() <= u32::MAX as usize);
if data.len() > N - 8 {
return Err(NotEnoughSpaceError);
}
if (data.len() + 8) > self.bytes_unused() {
if !force {
return Err(NotEnoughSpaceError);
}
while (data.len() + 8) > self.bytes_unused() {
self.pop_front();
}
}
let addr_a = self.head;
let addr_b = add_wrapping::<N>(self.head, 4);
let addr_c = add_wrapping::<N>(self.head, 4 + data.len());
let len_buffer: [u8; 4] = (data.len() as u32).to_ne_bytes();
write_wrapping(&mut self.buffer, addr_a, &len_buffer);
write_wrapping(&mut self.buffer, addr_b, data);
write_wrapping(&mut self.buffer, addr_c, &len_buffer);
self.head = add_wrapping::<N>(self.head, 8 + data.len());
self.count += 1;
Ok(())
}
pub fn pop_front(&mut self) -> Option<(&[u8], &[u8])> {
if self.empty() {
return None;
}
let mut len_buffer = [0; 4];
read_wrapping(&self.buffer, self.tail, &mut len_buffer);
let len = u32::from_ne_bytes(len_buffer) as usize;
let index_data = add_wrapping::<N>(self.tail, 4);
let len_a = (N - index_data).min(len);
let a = &self.buffer[index_data..index_data + len_a];
let b = if len_a == len {
&[]
} else {
&self.buffer[..len - len_a]
};
self.tail = add_wrapping::<N>(self.tail, len + 8);
self.count -= 1;
Some((a, b))
}
pub fn iter_backwards<'a>(&'a self) -> IterBackwards<'a, N> {
IterBackwards {
buffer: &self.buffer,
head: self.head,
count: self.count,
}
}
pub fn iter<'a>(&'a self) -> Iter<'a, N> {
Iter {
buffer: &self.buffer,
head: self.head,
tail: self.tail,
count: self.count,
}
}
#[inline(always)]
pub const fn count(&self) -> usize {
self.count
}
pub fn nth(&self, n: usize) -> Option<(&[u8], &[u8])> {
self.iter().nth(n)
}
pub fn nth_reverse(&self, n: usize) -> Option<(&[u8], &[u8])> {
self.iter_backwards().nth(n)
}
pub fn nth_contiguous(&mut self, mut n: usize) -> Option<&[u8]> {
if self.empty() || n >= self.count {
return None;
}
let mut tail = self.tail;
let len_data = loop {
let mut buf = [0u8; 4];
read_wrapping(&self.buffer, tail, &mut buf);
let len_data = u32::from_ne_bytes(buf) as usize;
if n == 0 {
break len_data;
}
n -= 1;
tail = add_wrapping::<N>(tail, len_data + 8);
};
let index_data = add_wrapping::<N>(tail, 4);
if index_data + len_data <= N {
return Some(&self.buffer[index_data..index_data + len_data]);
}
self.buffer.rotate_left(index_data);
self.tail = sub_wrapping::<N>(self.tail, index_data);
self.head = sub_wrapping::<N>(self.head, index_data);
Some(&self.buffer[..len_data])
}
}
pub struct IterBackwards<'a, const N: usize> {
buffer: &'a [u8; N],
head: usize,
count: usize,
}
impl<'a, const N: usize> Iterator for IterBackwards<'a, N> {
type Item = (&'a [u8], &'a [u8]);
fn next(&mut self) -> Option<Self::Item> {
if self.count == 0 {
return None;
}
let index_len = sub_wrapping::<N>(self.head, 4);
let mut buf = [0u8; 4];
read_wrapping(self.buffer, index_len, &mut buf);
let len_data = u32::from_ne_bytes(buf) as usize;
debug_assert!((len_data + 8) <= N);
#[cfg(test)]
{
let index_len = sub_wrapping::<N>(self.head, 8 + len_data);
let mut buf = [0u8; 4];
read_wrapping(self.buffer, index_len, &mut buf);
let len_2 = u32::from_ne_bytes(buf) as usize;
assert_eq!(len_data, len_2);
}
let index_data = sub_wrapping::<N>(self.head, 4 + len_data);
let first = (N - index_data).min(len_data);
let slice_a = &self.buffer[index_data..index_data + first];
let slice_b = if first < len_data {
&self.buffer[..len_data - first]
} else {
&[]
};
self.head = sub_wrapping::<N>(self.head, 8 + len_data);
self.count -= 1;
Some((slice_a, slice_b))
}
}
impl<const N: usize> Default for BytearrayRingbuffer<N> {
fn default() -> Self {
Self::new()
}
}
pub struct Iter<'a, const N: usize> {
buffer: &'a [u8; N],
head: usize,
tail: usize,
count: usize,
}
impl<'a, const N: usize> Iterator for Iter<'a, N> {
type Item = (&'a [u8], &'a [u8]);
fn next(&mut self) -> Option<Self::Item> {
if self.count == 0 {
return None;
}
let bytes_unused = if self.head > self.tail {
N + self.tail - self.head
} else {
self.tail - self.head
};
let bytes_occupied = N - bytes_unused;
debug_assert!(bytes_occupied >= 8);
let mut buf = [0u8; 4];
read_wrapping(self.buffer, self.tail, &mut buf);
let len_data = u32::from_ne_bytes(buf) as usize;
debug_assert!((len_data + 8) <= N);
debug_assert!((len_data + 8) <= bytes_occupied);
let index_data = add_wrapping::<N>(self.tail, 4);
let first = (N - index_data).min(len_data);
let slice_a = &self.buffer[index_data..index_data + first];
let slice_b = if first < len_data {
&self.buffer[..len_data - first]
} else {
&[]
};
self.tail = add_wrapping::<N>(self.tail, 8 + len_data);
self.count -= 1;
Some((slice_a, slice_b))
}
}
fn add_wrapping<const N: usize>(addr: usize, offset: usize) -> usize {
debug_assert!(addr < N);
debug_assert!(offset <= N);
let s = addr + offset;
if s < N { s } else { s - N }
}
fn sub_wrapping<const N: usize>(addr: usize, offset: usize) -> usize {
debug_assert!(addr < N);
debug_assert!(offset <= N);
if addr >= offset {
addr - offset
} else {
N + addr - offset
}
}
fn write_wrapping(buffer: &mut [u8], index: usize, data: &[u8]) {
let first = (buffer.len() - index).min(data.len());
buffer[index..index + first].copy_from_slice(&data[..first]);
if first < data.len() {
buffer[..data.len() - first].copy_from_slice(&data[first..]);
}
}
fn read_wrapping(buffer: &[u8], index: usize, data: &mut [u8]) {
let first = (buffer.len() - index).min(data.len());
data[..first].copy_from_slice(&buffer[index..index + first]);
if first < data.len() {
let remaining = data.len() - first;
data[first..].copy_from_slice(&buffer[..remaining]);
}
}
#[cfg(test)]
mod tests {
use std::collections::VecDeque;
use super::BytearrayRingbuffer;
#[test]
fn push_some_packets() {
const N: usize = 64;
for start_offset in 0..N {
let mut buf = BytearrayRingbuffer::<N>::new();
buf.head = start_offset;
buf.tail = start_offset;
let free = 64 - 8;
assert_eq!(buf.free(), free);
buf.push(b"01234567").unwrap();
let free = free - 8 - 8;
assert_eq!(buf.free(), free);
buf.push(b"").unwrap();
let free = free - 8;
assert_eq!(buf.free(), free);
buf.push(b"0123").unwrap();
let free = free - 4 - 8;
assert_eq!(buf.free(), free);
buf.push(b"0123").unwrap();
let free = free - 4 - 8;
assert_eq!(buf.free(), free);
}
}
#[test]
fn push_force() {
let mut buf = BytearrayRingbuffer::<16>::new();
assert_eq!(buf.bytes_unused(), 16);
let a = b"012345";
let b = b"0123";
buf.push(a).unwrap();
assert_eq!(buf.bytes_unused(), 16 - a.len() - 8);
buf.push(b).unwrap_err();
assert_eq!(buf.bytes_unused(), 16 - a.len() - 8);
buf.push_force(b).unwrap();
assert_eq!(buf.bytes_unused(), 16 - b.len() - 8);
}
#[test]
fn push_all_data_lengths() {
for n in 0..(32 - 8) {
let mut buf = BytearrayRingbuffer::<32>::new();
let data = (0..n as u8).collect::<Vec<u8>>();
assert_eq!(buf.free(), 32 - 8);
buf.push(&data).unwrap();
assert_eq!(buf.free(), (32usize - 16).saturating_sub(n));
}
}
#[test]
fn push_sum_of_lengths_possible() {
let mut buf = BytearrayRingbuffer::<32>::new();
assert_eq!(buf.free(), 32 - 8);
buf.push(b"01234567").unwrap();
assert_eq!(buf.free(), 32 - 8 - 16);
buf.push(b"01234567").unwrap();
assert_eq!(buf.free(), 0);
}
#[test]
fn push_pop() {
const N: usize = 64;
for start_offset in 0..N {
eprintln!("--------------");
let mut buf = BytearrayRingbuffer::<N>::new();
buf.head = start_offset;
buf.tail = start_offset;
let data = b"01234567";
buf.push(data).unwrap();
let (a, b) = buf.pop_front().unwrap();
let mut out = Vec::new();
out.extend_from_slice(a);
out.extend_from_slice(b);
dbg!(out.as_slice());
assert!(data == out.as_slice());
assert_eq!(buf.head, buf.tail);
assert_eq!(buf.bytes_unused(), N);
}
}
#[test]
fn push_read_back() {
let data = [b"hello world" as &[u8], b"", b"test"];
const N: usize = 64;
for start_offset in 0..N {
let mut buf = BytearrayRingbuffer::<N>::new();
buf.head = start_offset;
buf.tail = start_offset;
for &d in &data {
buf.push(d).unwrap();
}
let mut it = buf.iter();
for &d in data.iter() {
let (a, b) = it.next().unwrap();
let mut ab = Vec::new();
ab.extend_from_slice(a);
ab.extend_from_slice(b);
let ab = ab.as_slice();
assert_eq!(d, ab);
}
assert_eq!(it.next(), None);
let mut it = buf.iter_backwards();
for &d in data.iter().rev() {
let (a, b) = it.next().unwrap();
let mut ab = Vec::new();
ab.extend_from_slice(a);
ab.extend_from_slice(b);
let ab = ab.as_slice();
assert_eq!(d, ab);
}
assert_eq!(it.next(), None);
}
}
#[test]
fn push_count() {
let mut buf = BytearrayRingbuffer::<64>::new();
buf.push(b"1234").unwrap();
assert_eq!(buf.count(), 1);
buf.push(b"1234").unwrap();
assert_eq!(buf.count(), 2);
buf.push(b"1234").unwrap();
assert_eq!(buf.count(), 3);
}
fn test_with_readback<const N: usize>(words: &[&'static str]) {
eprintln!("--------------------------");
let mut buf = BytearrayRingbuffer::<N>::new();
let mut current_words = VecDeque::new();
for &word in words {
eprintln!("adding {word:?}");
let word = word.to_owned();
let current_bytes: usize = current_words.iter().map(|w: &String| w.len() + 8).sum();
if current_bytes + 8 + word.len() > N {
current_words.pop_front();
}
buf.push_force(word.as_bytes()).unwrap();
current_words.push_back(word);
for (a, b) in buf.iter_backwards().zip(current_words.iter().rev()) {
eprintln!("read back {b:?}");
let mut st = String::new();
st.push_str(core::str::from_utf8(a.0).unwrap());
st.push_str(core::str::from_utf8(a.1).unwrap());
assert_eq!(st, *b);
}
}
}
#[test]
fn readback_various() {
test_with_readback::<32>(&["ab", "123", "hello", "world"]);
test_with_readback::<32>(&["", "", "a", "", "", ""]);
test_with_readback::<32>(&["", "", "ab", "", "", ""]);
test_with_readback::<32>(&["", "", "abc", "", "", ""]);
test_with_readback::<32>(&["", "", "abcd", "", "", ""]);
test_with_readback::<32>(&["", "", "abcde", "", "", ""]);
test_with_readback::<24>(&["0", "1", "a", "2", "3", "4"]);
test_with_readback::<24>(&["0", "1", "ab", "2", "3", "4"]);
test_with_readback::<24>(&["0", "1", "abc", "2", "3", "4"]);
test_with_readback::<24>(&["0", "1", "abcd", "2", "3", "4"]);
test_with_readback::<24>(&["0", "1", "abcde", "2", "3", "4"]);
test_with_readback::<24>(&["0", "1", "abcdef", "2", "3", "4"]);
test_with_readback::<24>(&["0", "1", "abcdefg", "2", "3", "4"]);
}
#[test]
fn nth_contiguous_out_of_range_returns_none() {
let mut buf = BytearrayRingbuffer::<64>::new();
buf.push(b"hello").unwrap();
assert_eq!(buf.count(), 1);
assert_eq!(buf.nth_contiguous(1), None);
}
#[test]
fn rotate_contiguous() {
const N: usize = 48;
let data: [&[u8]; _] = [b"012345", b"hello world", b"xyz"];
for offset in 0..N {
let mut buf = BytearrayRingbuffer::<N>::new();
buf.head = offset;
buf.tail = offset;
for &d in &data {
buf.push(d).unwrap();
}
let read = buf.nth_contiguous(1).unwrap();
assert_eq!(data[1], read);
for (&r, (a, b)) in data.iter().zip(buf.iter()) {
let mut out = Vec::new();
out.extend_from_slice(a);
out.extend_from_slice(b);
assert_eq!(out.as_slice(), r);
}
}
}
fn collect(a: &[u8], b: &[u8]) -> Vec<u8> {
let mut v = Vec::new();
v.extend_from_slice(a);
v.extend_from_slice(b);
v
}
#[test]
fn multipart_normal_fits() {
const N: usize = 64;
for offset in 0..N {
let mut buf = BytearrayRingbuffer::<N>::new();
buf.head = offset;
buf.tail = offset;
let mut mp = buf.push_multipart().unwrap();
mp.push(b"hello").unwrap();
mp.push(b" ").unwrap();
mp.push(b"world").unwrap();
drop(mp);
assert_eq!(buf.count(), 1);
let (a, b) = buf.pop_front().unwrap();
assert_eq!(collect(a, b), b"hello world");
assert_eq!(buf.count(), 0);
}
}
#[test]
fn multipart_empty_packet() {
let mut buf = BytearrayRingbuffer::<64>::new();
let mp = buf.push_multipart().unwrap();
drop(mp); assert_eq!(buf.count(), 1);
let (a, b) = buf.pop_front().unwrap();
assert_eq!(collect(a, b), b"");
}
#[test]
fn multipart_normal_overflow_returns_err() {
let mut buf = BytearrayRingbuffer::<24>::new();
buf.push(b"").unwrap();
let mut mp = buf.push_multipart().unwrap();
mp.push(b"abcd").unwrap(); let err = mp.push(b"12345"); assert!(err.is_err());
drop(mp);
assert_eq!(buf.count(), 2);
let (a, b) = buf.nth(1).unwrap();
assert_eq!(collect(a, b), b"abcd");
}
#[test]
fn multipart_cancel_normal_mode() {
let mut buf = BytearrayRingbuffer::<64>::new();
let original_unused = buf.bytes_unused();
let original_count = buf.count();
let mut mp = buf.push_multipart().unwrap();
mp.push(b"data that will be discarded").unwrap();
mp.cancel();
assert_eq!(buf.count(), original_count);
assert_eq!(buf.bytes_unused(), original_unused);
}
#[test]
fn multipart_normal_no_room_for_start() {
let mut buf = BytearrayRingbuffer::<24>::new();
buf.push(&[0u8; 16]).unwrap(); assert_eq!(buf.bytes_unused(), 0);
assert!(buf.push_multipart().is_err());
}
#[test]
fn multipart_force_drops_old_packets() {
let mut buf = BytearrayRingbuffer::<24>::new();
buf.push(b"AA").unwrap(); buf.push(b"BB").unwrap(); assert_eq!(buf.count(), 2);
let mut mp = buf.push_multipart_force();
mp.push(b"hello world").unwrap(); drop(mp);
assert_eq!(buf.count(), 1);
let (a, b) = buf.pop_front().unwrap();
assert_eq!(collect(a, b), b"hello world");
}
#[test]
fn multipart_force_cancel_drops_are_permanent() {
let mut buf = BytearrayRingbuffer::<24>::new();
buf.push(b"AA").unwrap();
buf.push(b"BB").unwrap();
let count_before = buf.count();
let mut mp = buf.push_multipart_force();
mp.push(b"hello world").unwrap(); mp.cancel();
assert!(buf.count() < count_before);
assert_eq!(buf.count(), 0);
}
#[test]
fn multipart_push_after_multipart() {
let mut buf = BytearrayRingbuffer::<64>::new();
{
let mut mp = buf.push_multipart().unwrap();
mp.push(b"first").unwrap();
}
buf.push(b"second").unwrap();
assert_eq!(buf.count(), 2);
let (a, b) = buf.nth(0).unwrap();
assert_eq!(collect(a, b), b"first");
let (a, b) = buf.nth(1).unwrap();
assert_eq!(collect(a, b), b"second");
}
#[test]
fn multipart_force_max_payload() {
const N: usize = 32;
let mut buf = BytearrayRingbuffer::<N>::new();
buf.push(b"old").unwrap();
let payload: Vec<u8> = (0..((N - 8) as u8)).collect();
let mut mp = buf.push_multipart_force();
mp.push(&payload).unwrap();
drop(mp);
assert_eq!(buf.count(), 1);
let (a, b) = buf.pop_front().unwrap();
assert_eq!(collect(a, b), payload);
}
#[test]
fn multipart_wraparound_all_offsets() {
const N: usize = 48;
for offset in 0..N {
let mut buf = BytearrayRingbuffer::<N>::new();
buf.head = offset;
buf.tail = offset;
buf.push(b"prefix").unwrap();
let mut mp = buf.push_multipart().unwrap();
mp.push(b"foo").unwrap();
mp.push(b"bar").unwrap();
drop(mp);
buf.push(b"suffix").unwrap();
assert_eq!(buf.count(), 3);
let (a, b) = buf.nth(0).unwrap();
assert_eq!(collect(a, b), b"prefix");
let (a, b) = buf.nth(1).unwrap();
assert_eq!(collect(a, b), b"foobar");
let (a, b) = buf.nth(2).unwrap();
assert_eq!(collect(a, b), b"suffix");
}
}
}