use core::cmp::min;
use core::num::Wrapping;
use portable_atomic::{AtomicU32, Ordering};
use ringbuf::Rb;
use crate::util::write_from_slices;
#[derive(Default)]
pub struct Buffer<const N: usize> {
ring: try_lock::TryLock<CountedRing<N>>,
lost: AtomicU32,
}
#[derive(Default)]
struct CountedRing<const N: usize> {
buf: ringbuf::ring_buffer::LocalRb<u8, [core::mem::MaybeUninit<u8>; N]>,
total_written: Wrapping<u32>,
}
#[derive(Debug, PartialEq)]
pub struct BufferUnavailable;
#[derive(Debug, PartialEq)]
pub enum ReadErr {
BufferUnavailable,
DataUnavailable,
}
impl<const N: usize> Buffer<N> {
pub fn read_earliest(
&self,
outbuf: &mut [u8],
) -> Result<(Wrapping<u32>, usize), BufferUnavailable> {
let Some(ring) = self.ring.try_lock() else {
return Err(BufferUnavailable);
};
let (first, second) = ring.buf.as_slices();
write_from_slices(outbuf, first, second);
Ok((
ring.total_written - Wrapping(ring.buf.len() as _),
ring.buf.len(),
))
}
pub fn read_from_cursor(
&self,
cursor: Wrapping<u32>,
outbuf: &mut [u8],
) -> Result<usize, ReadErr> {
let Some(ring) = self.ring.try_lock() else {
return Err(ReadErr::BufferUnavailable);
};
let (mut first, mut second) = ring.buf.as_slices();
let cursor_in_firstsecond = cursor - ring.total_written + Wrapping(ring.buf.len() as _);
let mut cursor_in_firstsecond = cursor_in_firstsecond.0 as usize;
if cursor_in_firstsecond > first.len() + second.len() {
return Err(ReadErr::DataUnavailable);
}
let remove_from_first = min(first.len(), cursor_in_firstsecond);
cursor_in_firstsecond -= remove_from_first;
first = &first[remove_from_first..];
let remove_from_second = min(second.len(), cursor_in_firstsecond);
cursor_in_firstsecond -= remove_from_second;
second = &second[remove_from_second..];
debug_assert!(cursor_in_firstsecond == 0);
let len_from_cursor = first.len() + second.len();
write_from_slices(outbuf, first, second);
Ok(len_from_cursor)
}
pub fn write(&self, data: &[u8]) {
if let Some(mut ring) = self.ring.try_lock() {
let lost_before_us = self.lost.swap(0, Ordering::Relaxed);
if lost_before_us != 0 {
ring.total_written += Wrapping(lost_before_us);
ring.buf.clear();
}
ring.buf.push_slice_overwrite(data);
ring.total_written += Wrapping(data.len() as _);
} else {
self.lost.fetch_add(data.len() as _, Ordering::Relaxed);
}
}
}
#[test]
fn test_buffer_readwrite() {
const N: usize = 1024;
let b: Buffer<N> = Default::default();
let mut outbuf = [0; 4];
assert_eq!(b.read_earliest(&mut outbuf), Ok((Wrapping(0), 0)));
assert_eq!(b.read_from_cursor(Wrapping(0), &mut outbuf), Ok(0));
assert_eq!(
b.read_from_cursor(Wrapping(10), &mut outbuf),
Err(ReadErr::DataUnavailable)
);
b.write(b"Hello");
assert_eq!(b.read_earliest(&mut outbuf), Ok((Wrapping(0), 5)));
assert_eq!(&outbuf, "Hell".as_bytes());
outbuf[0] = 0;
assert_eq!(b.read_earliest(&mut outbuf), Ok((Wrapping(0), 5)));
assert_eq!(&outbuf, "Hell".as_bytes());
outbuf[0] = 0;
assert_eq!(b.read_from_cursor(Wrapping(2), &mut outbuf), Ok(3));
assert_eq!(&outbuf[..3], "llo".as_bytes());
b.write(b" World!");
const HWLEN: usize = "Hello World!".len();
outbuf[0] = 0;
assert_eq!(b.read_earliest(&mut outbuf), Ok((Wrapping(0), HWLEN)));
assert_eq!(&outbuf, "Hell".as_bytes());
let erase_h = [0; N - HWLEN + 1];
b.write(&erase_h);
assert_eq!(b.read_earliest(&mut outbuf), Ok((Wrapping(1), N)));
assert_eq!(&outbuf, "ello".as_bytes());
}
#[test]
fn test_buffer_collisions() {
const N: usize = 1024;
let b: Buffer<N> = Default::default();
let mut outbuf = [0; 4];
const OFFSET: usize = N - 6; b.write(&[0; OFFSET]);
b.write(b"1234");
let locked = b.ring.try_lock();
b.write(b"5678");
drop(locked);
assert_eq!(b.read_earliest(&mut outbuf), Ok((Wrapping(0), OFFSET + 4)));
assert_eq!(
b.read_from_cursor(Wrapping(OFFSET as _), &mut outbuf),
Ok(4)
);
assert_eq!(&outbuf, b"1234");
b.write(b"90ab");
assert_eq!(
b.read_earliest(&mut outbuf),
Ok((Wrapping(OFFSET as u32 + 8), 4))
);
assert_eq!(&outbuf, b"90ab");
}