use alloc::vec::Vec;
use core::cell::Cell;
use core::default::Default;
use crossbeam_utils::CachePadded;
#[cfg(not(loom))]
pub(crate) const MAX_PENDING_OPS: usize = 32;
#[cfg(loom)]
pub(crate) const MAX_PENDING_OPS: usize = 1;
const_assert!(MAX_PENDING_OPS >= 1 && (MAX_PENDING_OPS & (MAX_PENDING_OPS - 1) == 0));
type PendingOperation<T, R> = Cell<(Option<T>, Option<R>)>;
#[repr(align(64))]
pub(crate) struct Context<T, R>
where
T: Sized + Clone,
R: Sized + Clone,
{
batch: [CachePadded<PendingOperation<T, R>>; MAX_PENDING_OPS],
pub tail: CachePadded<Cell<usize>>,
pub head: CachePadded<Cell<usize>>,
pub comb: CachePadded<Cell<usize>>,
}
impl<T, R> Default for Context<T, R>
where
T: Sized + Clone,
R: Sized + Clone,
{
fn default() -> Context<T, R> {
let mut batch: [CachePadded<PendingOperation<T, R>>; MAX_PENDING_OPS] =
unsafe { ::core::mem::MaybeUninit::zeroed().assume_init() };
for elem in &mut batch[..] {
*elem = CachePadded::new(Cell::new((None, None)));
}
Context {
batch,
tail: CachePadded::new(Cell::new(Default::default())),
head: CachePadded::new(Cell::new(Default::default())),
comb: CachePadded::new(Cell::new(Default::default())),
}
}
}
impl<T, R> Context<T, R>
where
T: Sized + Clone,
R: Sized + Clone,
{
#[inline(always)]
pub(crate) fn enqueue(&self, op: T) -> bool {
let t = self.tail.get();
let h = self.head.get();
if t - h == MAX_PENDING_OPS {
return false;
};
let e = self.batch[self.index(t)].as_ptr();
unsafe { (*e).0 = Some(op) };
self.tail.set(t + 1);
true
}
#[inline(always)]
pub(crate) fn enqueue_resps(&self, responses: &[R]) {
let h = self.comb.get();
let n = responses.len();
if n == 0 {
return;
};
for (i, response) in responses.iter().enumerate().take(n) {
let e = self.batch[self.index(h + i)].as_ptr();
unsafe {
(*e).1 = Some(response.clone());
}
}
self.comb.set(h + n);
}
#[inline(always)]
pub(crate) fn ops(&self, buffer: &mut Vec<T>) -> usize {
let mut h = self.comb.get();
let t = self.tail.get();
if h == t {
return 0;
};
if h > t {
panic!("Combiner Head of thread-local batch has advanced beyond tail!");
}
let mut n = 0;
loop {
if h == t {
break;
};
unsafe {
buffer.push(
(*self.batch[self.index(h)].as_ptr())
.0
.as_ref()
.unwrap()
.clone(),
);
}
h += 1;
n += 1;
}
n
}
#[inline(always)]
pub(crate) fn res(&self) -> Option<R> {
let s = self.head.get();
let f = self.comb.get();
if s == f {
return None;
};
if s > f {
panic!("Head of thread-local batch has advanced beyond combiner offset!");
}
self.head.set(s + 1);
unsafe { (*self.batch[self.index(s)].as_ptr()).1.clone() }
}
#[inline(always)]
pub(crate) fn batch_size() -> usize {
MAX_PENDING_OPS
}
#[inline(always)]
fn index(&self, logical: usize) -> usize {
logical & (MAX_PENDING_OPS - 1)
}
}
#[cfg(test)]
mod test {
use super::*;
use std::vec;
#[test]
fn test_context_create_default() {
let c = Context::<u64, Result<u64, ()>>::default();
assert_eq!(c.batch.len(), MAX_PENDING_OPS);
assert_eq!(c.tail.get(), 0);
assert_eq!(c.head.get(), 0);
assert_eq!(c.comb.get(), 0);
}
#[test]
fn test_context_enqueue() {
let c = Context::<u64, Result<u64, ()>>::default();
assert!(c.enqueue(121));
unsafe { assert_eq!((*c.batch[0].as_ptr()).0, Some(121)) };
assert_eq!(c.tail.take(), 1);
assert_eq!(c.head.take(), 0);
assert_eq!(c.comb.take(), 0);
}
#[test]
fn test_context_enqueue_full() {
let c = Context::<u64, Result<u64, ()>>::default();
c.tail.set(MAX_PENDING_OPS);
assert!(!c.enqueue(100));
assert_eq!(c.tail.get(), MAX_PENDING_OPS);
assert_eq!(c.head.get(), 0);
assert_eq!(c.comb.get(), 0);
}
#[test]
fn test_context_enqueue_resps() {
let c = Context::<u64, Result<u64, ()>>::default();
let r = [Ok(11), Ok(12), Ok(13), Ok(14)];
c.tail.set(16);
c.comb.set(12);
c.enqueue_resps(&r);
assert_eq!(c.tail.get(), 16);
assert_eq!(c.head.get(), 0);
assert_eq!(c.comb.get(), 16);
assert_eq!(c.batch[12].get().1, Some(r[0]));
assert_eq!(c.batch[13].get().1, Some(r[1]));
assert_eq!(c.batch[14].get().1, Some(r[2]));
assert_eq!(c.batch[15].get().1, Some(r[3]));
}
#[test]
fn test_context_enqueue_resps_empty() {
let c = Context::<u64, Result<u64, ()>>::default();
let r = [];
c.tail.set(16);
c.comb.set(12);
c.enqueue_resps(&r);
assert_eq!(c.tail.get(), 16);
assert_eq!(c.head.get(), 0);
assert_eq!(c.comb.get(), 12);
assert_eq!(c.batch[12].get().1, None);
}
#[test]
fn test_context_ops() {
let c = Context::<usize, usize>::default();
let mut o = vec![];
for idx in 0..MAX_PENDING_OPS / 2 {
assert!(c.enqueue(idx * idx))
}
assert_eq!(c.ops(&mut o), MAX_PENDING_OPS / 2);
assert_eq!(o.len(), MAX_PENDING_OPS / 2);
assert_eq!(c.tail.get(), MAX_PENDING_OPS / 2);
assert_eq!(c.head.get(), 0);
assert_eq!(c.comb.get(), 0);
for idx in 0..MAX_PENDING_OPS / 2 {
assert_eq!(o[idx], idx * idx)
}
}
#[test]
fn test_context_ops_empty() {
let c = Context::<usize, usize>::default();
let mut o = vec![];
c.tail.set(8);
c.comb.set(8);
assert_eq!(c.ops(&mut o), 0);
assert_eq!(o.len(), 0);
assert_eq!(c.tail.get(), 8);
assert_eq!(c.head.get(), 0);
assert_eq!(c.comb.get(), 8);
}
#[test]
#[should_panic]
fn test_context_ops_panic() {
let c = Context::<usize, usize>::default();
let mut o = vec![];
c.tail.set(6);
c.comb.set(9);
assert_eq!(c.ops(&mut o), 0);
}
#[test]
fn test_context_res() {
let c = Context::<u64, Result<u64, ()>>::default();
let r = [Ok(11), Ok(12), Ok(13), Ok(14)];
c.tail.set(16);
c.enqueue_resps(&r);
assert_eq!(c.tail.get(), 16);
assert_eq!(c.comb.get(), 4);
assert_eq!(c.res(), Some(r[0]));
assert_eq!(c.head.get(), 1);
assert_eq!(c.res(), Some(r[1]));
assert_eq!(c.head.get(), 2);
assert_eq!(c.res(), Some(r[2]));
assert_eq!(c.head.get(), 3);
assert_eq!(c.res(), Some(r[3]));
assert_eq!(c.head.get(), 4);
}
#[test]
fn test_context_res_empty() {
let c = Context::<usize, usize>::default();
c.tail.set(8);
assert_eq!(c.tail.get(), 8);
assert_eq!(c.head.get(), 0);
assert_eq!(c.comb.get(), 0);
assert_eq!(c.res(), None);
}
#[test]
#[should_panic]
fn test_context_res_panic() {
let c = Context::<usize, usize>::default();
c.tail.set(8);
c.comb.set(4);
c.head.set(6);
assert_eq!(c.res(), None);
}
#[test]
fn test_context_batch_size() {
assert_eq!(Context::<usize, usize>::batch_size(), MAX_PENDING_OPS);
}
#[test]
fn test_index() {
let c = Context::<u64, Result<u64, ()>>::default();
assert_eq!(c.index(100), 100 % MAX_PENDING_OPS);
}
}