#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
use core::cell::UnsafeCell;
bitfield::bitfield! {
#[derive(Copy, Clone, Eq, PartialEq)]
#[repr(transparent)]
struct BufferStatus(u16);
impl Debug;
#[inline]
swap_pending, set_swap_pending: 0;
#[inline]
back_locked, set_back_locked: 1;
#[inline]
front_locked, set_front_locked: 2;
#[inline]
u8, front_index, set_front_index: 4, 3;
#[inline]
u8, work_index, set_work_index: 6, 5;
#[inline]
u8, pending_index, set_pending_index: 8, 7;
}
impl Default for BufferStatus {
#[inline]
fn default() -> Self {
let mut status = Self(0);
status.set_work_index(1);
status.set_pending_index(2);
status
}
}
#[repr(transparent)]
#[derive(Debug)]
struct AtomicStatus(core::sync::atomic::AtomicU16);
impl Default for AtomicStatus {
#[inline]
fn default() -> Self {
Self(core::sync::atomic::AtomicU16::new(
BufferStatus::default().0,
))
}
}
impl AtomicStatus {
#[inline]
fn fetch_update<F>(&self, mut f: F) -> Result<BufferStatus, BufferStatus>
where
F: FnMut(BufferStatus) -> Option<BufferStatus>,
{
use core::sync::atomic::Ordering;
self.0
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |s| {
f(BufferStatus(s)).map(|s| s.0)
})
.map(BufferStatus)
.map_err(BufferStatus)
}
}
#[derive(Debug)]
pub struct AtomicTripleBuffer<T> {
buffers: [UnsafeCell<T>; 3],
status: AtomicStatus,
}
unsafe impl<T: Send> Send for AtomicTripleBuffer<T> {}
unsafe impl<T: Send> Sync for AtomicTripleBuffer<T> {}
impl<T: Default> Default for AtomicTripleBuffer<T> {
#[inline]
fn default() -> Self {
Self {
buffers: [Default::default(), Default::default(), Default::default()],
status: Default::default(),
}
}
}
#[derive(Debug)]
pub enum TBLockError {
AlreadyLocked,
}
#[cfg(feature = "std")]
impl std::error::Error for TBLockError {}
impl core::fmt::Display for TBLockError {
fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
fmt.write_str("buffer already locked")
}
}
impl<T> AtomicTripleBuffer<T> {
pub fn new(init: T) -> Self
where
T: Clone,
{
Self {
buffers: [
UnsafeCell::new(init.clone()),
UnsafeCell::new(init.clone()),
UnsafeCell::new(init),
],
status: Default::default(),
}
}
pub fn front_buffer(&self) -> Result<TBFrontGuard<'_, T>, TBLockError> {
let mut front_index = 0;
self.status
.fetch_update(|mut status| {
if status.front_locked() {
return None;
}
status.set_front_locked(true);
if status.swap_pending() {
status.set_swap_pending(false);
front_index = status.pending_index();
status.set_pending_index(status.front_index());
status.set_front_index(front_index);
} else {
front_index = status.front_index();
}
Some(status)
})
.map_err(|_| TBLockError::AlreadyLocked)?;
Ok(TBFrontGuard {
cell: &self.buffers[front_index as usize],
status: &self.status,
})
}
pub fn back_buffers(&self) -> Result<TBBackGuard<'_, T>, TBLockError> {
let mut locked_status = BufferStatus::default();
self.status
.fetch_update(|mut status| {
if status.back_locked() {
return None;
}
status.set_back_locked(true);
locked_status = status;
Some(status)
})
.map_err(|_| TBLockError::AlreadyLocked)?;
Ok(TBBackGuard {
bufs: self,
locked_status,
})
}
}
#[derive(Debug)]
pub struct TBFrontGuard<'a, T> {
cell: &'a UnsafeCell<T>,
status: &'a AtomicStatus,
}
impl<T> core::ops::Deref for TBFrontGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.cell.get() }
}
}
impl<T> core::ops::DerefMut for TBFrontGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.cell.get() }
}
}
impl<T: core::fmt::Display> core::fmt::Display for TBFrontGuard<'_, T> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
(**self).fmt(f)
}
}
impl<T> Drop for TBFrontGuard<'_, T> {
fn drop(&mut self) {
self.status
.fetch_update(|mut status| {
status.set_front_locked(false);
Some(status)
})
.ok();
}
}
unsafe impl<T: Send> Send for TBFrontGuard<'_, T> {}
unsafe impl<T: Send + Sync> Sync for TBFrontGuard<'_, T> {}
#[derive(Debug)]
pub struct TBBackGuard<'a, T> {
bufs: &'a AtomicTripleBuffer<T>,
locked_status: BufferStatus,
}
impl<T> TBBackGuard<'_, T> {
pub fn back(&self) -> &T {
let index = self.locked_status.work_index() as usize;
unsafe { &*self.bufs.buffers[index].get() }
}
pub fn back_mut(&mut self) -> &mut T {
let index = self.locked_status.work_index() as usize;
unsafe { &mut *self.bufs.buffers[index].get() }
}
pub fn pending(&self) -> Option<&T> {
if self.locked_status.swap_pending() {
return None;
}
let index = self.locked_status.pending_index() as usize;
Some(unsafe { &*self.bufs.buffers[index].get() })
}
pub fn pending_mut(&mut self) -> Option<&mut T> {
if self.locked_status.swap_pending() {
return None;
}
let index = self.locked_status.pending_index() as usize;
Some(unsafe { &mut *self.bufs.buffers[index].get() })
}
pub fn swap(self) {
self.bufs
.status
.fetch_update(|mut status| {
status.set_back_locked(false);
status.set_swap_pending(true);
let pending_index = status.work_index();
status.set_work_index(status.pending_index());
status.set_pending_index(pending_index);
Some(status)
})
.ok();
core::mem::forget(self);
}
}
impl<T: core::fmt::Display> core::fmt::Display for TBBackGuard<'_, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.back().fmt(f)
}
}
impl<T> Drop for TBBackGuard<'_, T> {
fn drop(&mut self) {
self.bufs
.status
.fetch_update(|mut status| {
status.set_back_locked(false);
Some(status)
})
.ok();
}
}
unsafe impl<T: Send> Send for TBBackGuard<'_, T> {}
unsafe impl<T: Send + Sync> Sync for TBBackGuard<'_, T> {}
#[cfg(test)]
mod tests {
use super::AtomicTripleBuffer;
use std::sync::Arc;
#[test]
fn basic() {
#[derive(Clone, Default)]
struct Data {
a: i32,
b: i32,
}
let buf = Arc::new(AtomicTripleBuffer::<Data>::default());
let b = buf.clone();
let thread = std::thread::spawn(move || {
let mut prev = Data::default();
loop {
let front = b.front_buffer().unwrap();
assert!(front.a == front.b);
assert!(front.a >= prev.a && front.b >= prev.b);
if front.a >= 10000 {
break;
}
prev = (*front).clone();
}
});
let mut data = Data::default();
for _ in 0..10000 {
data.a += 1;
data.b += 1;
let mut bufs = buf.back_buffers().unwrap();
*bufs.back_mut() = data.clone();
bufs.swap();
}
thread.join().unwrap();
}
}