use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use nix::sys::eventfd::{eventfd, EfdFlags};
use std::os::unix::io::{IntoRawFd, RawFd};
use vil_types::Descriptor;
use crate::traits::QueueBackend;
#[repr(align(128))]
struct CachePadded<T> {
value: T,
}
impl<T> CachePadded<T> {
fn new(value: T) -> Self {
Self { value }
}
}
pub struct SpscRingBuffer<T> {
buffer: Box<[UnsafeCell<MaybeUninit<T>>]>,
capacity: usize,
mask: usize,
tail: CachePadded<AtomicUsize>,
head: CachePadded<AtomicUsize>,
}
unsafe impl<T: Send> Send for SpscRingBuffer<T> {}
unsafe impl<T: Send> Sync for SpscRingBuffer<T> {}
impl<T> SpscRingBuffer<T> {
pub fn new(min_capacity: usize) -> Self {
assert!(min_capacity > 0, "SPSC capacity must be > 0");
let capacity = min_capacity.next_power_of_two();
let mask = capacity - 1;
let mut buffer = Vec::with_capacity(capacity);
for _ in 0..capacity {
buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
}
Self {
buffer: buffer.into_boxed_slice(),
capacity,
mask,
tail: CachePadded::new(AtomicUsize::new(0)),
head: CachePadded::new(AtomicUsize::new(0)),
}
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn try_push(&self, item: T) -> Result<(), T> {
let tail = self.tail.value.load(Ordering::Relaxed);
let head = self.head.value.load(Ordering::Acquire);
if tail.wrapping_sub(head) >= self.capacity {
return Err(item);
}
let idx = tail & self.mask;
unsafe {
(*self.buffer[idx].get()).write(item);
}
self.tail
.value
.store(tail.wrapping_add(1), Ordering::Release);
Ok(())
}
pub fn push(&self, item: T) {
let mut item = item;
loop {
match self.try_push(item) {
Ok(()) => return,
Err(returned) => {
item = returned;
std::hint::spin_loop();
}
}
}
}
pub fn try_pop(&self) -> Option<T> {
let head = self.head.value.load(Ordering::Relaxed);
let tail = self.tail.value.load(Ordering::Acquire);
if head == tail {
return None;
}
let idx = head & self.mask;
let item = unsafe { (*self.buffer[idx].get()).assume_init_read() };
self.head
.value
.store(head.wrapping_add(1), Ordering::Release);
Some(item)
}
pub fn len(&self) -> usize {
let tail = self.tail.value.load(Ordering::Acquire);
let head = self.head.value.load(Ordering::Acquire);
tail.wrapping_sub(head)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_full(&self) -> bool {
self.len() >= self.capacity
}
}
impl<T> Drop for SpscRingBuffer<T> {
fn drop(&mut self) {
let head = *self.head.value.get_mut();
let tail = *self.tail.value.get_mut();
for i in head..tail {
let idx = i & self.mask;
unsafe {
self.buffer[idx].get_mut().assume_init_drop();
}
}
}
}
#[derive(Clone)]
pub struct SpscQueue {
inner: Arc<SpscRingBuffer<Descriptor>>,
}
impl SpscQueue {
#[doc(alias = "vil_keep")]
pub fn new(min_capacity: usize) -> Self {
Self {
inner: Arc::new(SpscRingBuffer::new(min_capacity)),
}
}
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
pub fn is_full(&self) -> bool {
self.inner.is_full()
}
pub fn try_push(&self, descriptor: Descriptor) -> Result<(), Descriptor> {
self.inner.try_push(descriptor)
}
}
impl QueueBackend for SpscQueue {
fn push(&self, descriptor: Descriptor) {
self.inner.push(descriptor);
}
fn try_pop(&self) -> Option<Descriptor> {
self.inner.try_pop()
}
fn len(&self) -> usize {
self.inner.len()
}
}
#[repr(C, align(128))]
pub struct ShmSpscLayout<T> {
pub capacity: usize,
pub mask: usize,
pub tail: AtomicUsize, pub head: AtomicUsize, _phantom: std::marker::PhantomData<T>,
}
#[derive(Clone)]
pub struct ShmSpscQueue {
layout: *mut ShmSpscLayout<Descriptor>,
buffer: *mut UnsafeCell<MaybeUninit<Descriptor>>,
signal_fd: Option<RawFd>,
}
unsafe impl Send for ShmSpscQueue {}
unsafe impl Sync for ShmSpscQueue {}
impl ShmSpscQueue {
pub fn create_eventfd() -> std::io::Result<RawFd> {
eventfd(0, EfdFlags::EFD_CLOEXEC | EfdFlags::EFD_NONBLOCK)
.map(|fd| fd.into_raw_fd())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))
}
pub unsafe fn from_raw_parts(
base_ptr: *mut u8,
_capacity: usize,
signal_fd: Option<RawFd>,
) -> Self {
let layout = base_ptr as *mut ShmSpscLayout<Descriptor>;
let buffer = base_ptr.add(std::mem::size_of::<ShmSpscLayout<Descriptor>>())
as *mut UnsafeCell<MaybeUninit<Descriptor>>;
Self {
layout,
buffer,
signal_fd,
}
}
pub fn signal(&self) {
if let Some(fd) = self.signal_fd {
let buf = 1u64.to_ne_bytes();
let _ = nix::unistd::write(fd, &buf);
}
}
pub fn wait(&self) {
if let Some(fd) = self.signal_fd {
let mut buf = [0u8; 8];
let _ = nix::unistd::read(fd, &mut buf);
}
}
pub fn try_push(&self, item: Descriptor) -> Result<(), Descriptor> {
unsafe {
let tail = (*self.layout).tail.load(Ordering::Relaxed);
let head = (*self.layout).head.load(Ordering::Acquire);
let capacity = (*self.layout).capacity;
let mask = (*self.layout).mask;
if tail.wrapping_sub(head) >= capacity {
return Err(item);
}
let idx = tail & mask;
let slot_ptr = self.buffer.add(idx);
(*(*slot_ptr).get()).write(item);
(*self.layout)
.tail
.store(tail.wrapping_add(1), Ordering::Release);
self.signal();
Ok(())
}
}
pub fn try_pop(&self) -> Option<Descriptor> {
unsafe {
let head = (*self.layout).head.load(Ordering::Relaxed);
let tail = (*self.layout).tail.load(Ordering::Acquire);
let mask = (*self.layout).mask;
if head == tail {
return None;
}
let idx = head & mask;
let slot_ptr = self.buffer.add(idx);
let item = (*(*slot_ptr).get()).assume_init_read();
(*self.layout)
.head
.store(head.wrapping_add(1), Ordering::Release);
Some(item)
}
}
}
impl QueueBackend for ShmSpscQueue {
fn push(&self, descriptor: Descriptor) {
let mut item = descriptor;
loop {
match self.try_push(item) {
Ok(()) => return,
Err(returned) => {
item = returned;
std::hint::spin_loop();
}
}
}
}
fn try_pop(&self) -> Option<Descriptor> {
self.try_pop()
}
fn len(&self) -> usize {
unsafe {
let tail = (*self.layout).tail.load(Ordering::Acquire);
let head = (*self.layout).head.load(Ordering::Acquire);
tail.wrapping_sub(head)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use vil_types::{HostId, PortId, SampleId};
fn make_desc(id: u64) -> Descriptor {
Descriptor {
sample_id: SampleId(id),
origin_host: HostId(0),
origin_port: PortId(1),
lineage_id: id * 10,
publish_ts: 0,
}
}
#[test]
fn test_capacity_rounds_up() {
let rb = SpscRingBuffer::<u64>::new(3);
assert_eq!(rb.capacity(), 4);
let rb = SpscRingBuffer::<u64>::new(8);
assert_eq!(rb.capacity(), 8);
let rb = SpscRingBuffer::<u64>::new(1);
assert_eq!(rb.capacity(), 1);
}
#[test]
#[should_panic(expected = "SPSC capacity must be > 0")]
fn test_zero_capacity_panics() {
let _ = SpscRingBuffer::<u64>::new(0);
}
#[test]
fn test_push_pop_single() {
let rb = SpscRingBuffer::new(4);
assert!(rb.is_empty());
rb.try_push(42u64).unwrap();
assert_eq!(rb.len(), 1);
assert!(!rb.is_empty());
let val = rb.try_pop().unwrap();
assert_eq!(val, 42);
assert!(rb.is_empty());
}
#[test]
fn test_fifo_ordering() {
let rb = SpscRingBuffer::new(8);
for i in 0..8 {
rb.try_push(i as u64).unwrap();
}
assert!(rb.is_full());
for i in 0..8 {
assert_eq!(rb.try_pop().unwrap(), i as u64);
}
assert!(rb.is_empty());
}
#[test]
fn test_full_returns_err() {
let rb = SpscRingBuffer::new(2);
rb.try_push(1u64).unwrap();
rb.try_push(2u64).unwrap();
let result = rb.try_push(3u64);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), 3);
}
#[test]
fn test_empty_returns_none() {
let rb = SpscRingBuffer::<u64>::new(4);
assert!(rb.try_pop().is_none());
}
#[test]
fn test_wrap_around() {
let rb = SpscRingBuffer::new(4);
for round in 0..5 {
for i in 0..4 {
rb.try_push(round * 10 + i).unwrap();
}
for i in 0..4 {
assert_eq!(rb.try_pop().unwrap(), round * 10 + i);
}
assert!(rb.is_empty());
}
}
#[test]
fn test_interleaved_push_pop() {
let rb = SpscRingBuffer::new(4);
rb.try_push(1u64).unwrap();
rb.try_push(2u64).unwrap();
assert_eq!(rb.try_pop().unwrap(), 1);
rb.try_push(3u64).unwrap();
assert_eq!(rb.try_pop().unwrap(), 2);
assert_eq!(rb.try_pop().unwrap(), 3);
assert!(rb.is_empty());
}
#[test]
fn test_drop_drains_remaining() {
use std::sync::atomic::AtomicUsize;
static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
struct Counted(#[allow(dead_code)] u64);
impl Drop for Counted {
fn drop(&mut self) {
DROP_COUNT.fetch_add(1, Ordering::Relaxed);
}
}
DROP_COUNT.store(0, Ordering::Relaxed);
{
let rb = SpscRingBuffer::new(8);
rb.try_push(Counted(1)).unwrap();
rb.try_push(Counted(2)).unwrap();
rb.try_push(Counted(3)).unwrap();
}
assert_eq!(DROP_COUNT.load(Ordering::Relaxed), 3);
}
#[test]
fn test_spsc_queue_descriptor() {
let q = SpscQueue::new(64);
assert!(q.is_empty());
assert_eq!(q.capacity(), 64);
q.push(make_desc(1));
q.push(make_desc(2));
assert_eq!(q.len(), 2);
let d1 = q.try_pop().unwrap();
assert_eq!(d1.sample_id, SampleId(1));
let d2 = q.try_pop().unwrap();
assert_eq!(d2.sample_id, SampleId(2));
assert!(q.is_empty());
}
#[test]
fn test_spsc_queue_backend_trait() {
let q: Box<dyn QueueBackend> = Box::new(SpscQueue::new(16));
q.push(make_desc(42));
assert_eq!(q.len(), 1);
let d = q.try_pop().unwrap();
assert_eq!(d.sample_id, SampleId(42));
}
#[test]
fn test_spsc_queue_full() {
let q = SpscQueue::new(2); assert!(q.try_push(make_desc(1)).is_ok());
assert!(q.try_push(make_desc(2)).is_ok());
assert!(q.try_push(make_desc(3)).is_err());
assert!(q.is_full());
}
#[test]
fn test_cross_thread_stress() {
let count = 100_000u64;
let rb = Arc::new(SpscRingBuffer::new(1024));
let producer_rb = rb.clone();
let producer = thread::spawn(move || {
for i in 0..count {
producer_rb.push(i);
}
});
let consumer_rb = rb.clone();
let consumer = thread::spawn(move || {
let mut received = Vec::with_capacity(count as usize);
while received.len() < count as usize {
if let Some(val) = consumer_rb.try_pop() {
received.push(val);
} else {
std::hint::spin_loop();
}
}
received
});
producer.join().unwrap();
let received = consumer.join().unwrap();
assert_eq!(received.len(), count as usize);
for (i, val) in received.iter().enumerate() {
assert_eq!(*val, i as u64, "FIFO violation at index {}", i);
}
}
#[test]
fn test_cross_thread_descriptor_stress() {
let count = 50_000u64;
let q = SpscQueue::new(512);
let producer_q = q.clone();
let producer = thread::spawn(move || {
for i in 0..count {
producer_q.push(make_desc(i));
}
});
let consumer_q = q.clone();
let consumer = thread::spawn(move || {
let mut received = Vec::with_capacity(count as usize);
while received.len() < count as usize {
if let Some(d) = consumer_q.try_pop() {
received.push(d);
} else {
std::hint::spin_loop();
}
}
received
});
producer.join().unwrap();
let received = consumer.join().unwrap();
assert_eq!(received.len(), count as usize);
for (i, d) in received.iter().enumerate() {
assert_eq!(d.sample_id, SampleId(i as u64), "FIFO violation at {}", i);
}
}
#[test]
fn test_shm_spsc_queue_basic() {
let capacity = 16;
let size = std::mem::size_of::<ShmSpscLayout<Descriptor>>()
+ capacity * std::mem::size_of::<UnsafeCell<MaybeUninit<Descriptor>>>();
let mut storage = vec![0u8; size + 256];
let raw_ptr = storage.as_mut_ptr();
let base_ptr = unsafe {
let offset = (128 - (raw_ptr as usize % 128)) % 128;
raw_ptr.add(offset)
};
unsafe {
let layout = base_ptr as *mut ShmSpscLayout<Descriptor>;
(*layout).capacity = capacity;
(*layout).mask = capacity - 1;
(*layout).tail.store(0, Ordering::Release);
(*layout).head.store(0, Ordering::Release);
let efd = ShmSpscQueue::create_eventfd().unwrap();
let q1 = ShmSpscQueue::from_raw_parts(base_ptr, capacity, Some(efd));
let q2 = ShmSpscQueue::from_raw_parts(base_ptr, capacity, Some(efd));
let desc = make_desc(123);
q1.try_push(desc).unwrap();
assert_eq!(q2.len(), 1);
let popped = q2.try_pop().unwrap();
assert_eq!(popped.sample_id, SampleId(123));
nix::unistd::close(efd).unwrap();
}
}
}