use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BarrierToken(pub u64);
pub trait AsyncCopy {
unsafe fn issue(&mut self, src: *const u8, dst: *mut u8, bytes: usize) -> BarrierToken;
fn wait(&mut self, token: BarrierToken);
}
pub struct SyncCopy {
counter: AtomicU64,
}
impl SyncCopy {
pub const fn new() -> Self {
Self {
counter: AtomicU64::new(0),
}
}
}
impl Default for SyncCopy {
fn default() -> Self {
Self::new()
}
}
impl AsyncCopy for SyncCopy {
unsafe fn issue(&mut self, src: *const u8, dst: *mut u8, bytes: usize) -> BarrierToken {
unsafe {
std::ptr::copy_nonoverlapping(src, dst, bytes);
}
BarrierToken(self.counter.fetch_add(1, Ordering::Relaxed))
}
fn wait(&mut self, _token: BarrierToken) {
}
}
#[derive(Debug, Clone)]
pub struct DoubleBuffer<T> {
buffers: [T; 2],
active: usize,
}
impl<T> DoubleBuffer<T> {
pub fn new(a: T, b: T) -> Self {
Self {
buffers: [a, b],
active: 0,
}
}
pub fn current(&self) -> &T {
&self.buffers[self.active]
}
pub fn current_mut(&mut self) -> &mut T {
&mut self.buffers[self.active]
}
pub fn next(&self) -> &T {
&self.buffers[1 - self.active]
}
pub fn next_mut(&mut self) -> &mut T {
&mut self.buffers[1 - self.active]
}
pub fn swap(&mut self) {
self.active = 1 - self.active;
}
pub fn pair(&self) -> (&T, &T) {
(&self.buffers[0], &self.buffers[1])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn double_buffer_swap_round_trip() {
let mut db = DoubleBuffer::new(vec![1u8; 4], vec![2u8; 4]);
assert_eq!(db.current(), &vec![1u8; 4]);
db.swap();
assert_eq!(db.current(), &vec![2u8; 4]);
db.swap();
assert_eq!(db.current(), &vec![1u8; 4]);
}
#[test]
fn sync_copy_round_trips_data() {
let src = [1u8, 2, 3, 4];
let mut dst = [0u8; 4];
let mut engine = SyncCopy::new();
let token = unsafe { engine.issue(src.as_ptr(), dst.as_mut_ptr(), 4) };
engine.wait(token);
assert_eq!(dst, src);
}
#[test]
fn pipelined_pattern_through_double_buffer() {
let source: Vec<u8> = (0..16u8).collect();
let tile_bytes = 4;
let mut db = DoubleBuffer::new(vec![0u8; tile_bytes], vec![0u8; tile_bytes]);
let mut engine = SyncCopy::new();
let t0 =
unsafe { engine.issue(source.as_ptr(), db.current_mut().as_mut_ptr(), tile_bytes) };
engine.wait(t0);
let mut total: u64 = 0;
let mut tile_idx = 1usize;
while tile_idx * tile_bytes < source.len() {
let t = unsafe {
engine.issue(
source.as_ptr().add(tile_idx * tile_bytes),
db.next_mut().as_mut_ptr(),
tile_bytes,
)
};
total += db.current().iter().map(|&b| b as u64).sum::<u64>();
engine.wait(t);
db.swap();
tile_idx += 1;
}
total += db.current().iter().map(|&b| b as u64).sum::<u64>();
let expected: u64 = (0..16u64).sum();
assert_eq!(total, expected);
}
}