use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PersistentThreadMode {
#[default]
Auto,
Force,
Disable,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct WorkItem {
pub input_offset: u32,
pub input_len: u32,
pub rule_set_id: u32,
pub correlation: u32,
}
#[derive(Debug)]
pub struct RingAtomics {
pub head: AtomicU64,
pub tail: AtomicU64,
pub ready: Vec<AtomicU64>,
pub done: Vec<AtomicU32>,
}
impl RingAtomics {
fn new(ring_size: u32) -> Self {
Self {
head: AtomicU64::new(0),
tail: AtomicU64::new(0),
ready: (0..ring_size).map(|_| AtomicU64::new(0)).collect(),
done: (0..ring_size).map(|_| AtomicU32::new(0)).collect(),
}
}
}
#[derive(Debug)]
struct WorkSlot {
lo: AtomicU64,
hi: AtomicU64,
}
impl WorkSlot {
fn new(item: WorkItem) -> Self {
let (lo, hi) = pack_work_item(item);
Self {
lo: AtomicU64::new(lo),
hi: AtomicU64::new(hi),
}
}
fn store(&self, item: WorkItem) {
let (lo, hi) = pack_work_item(item);
self.lo.store(lo, Ordering::Relaxed);
self.hi.store(hi, Ordering::Relaxed);
}
fn load(&self) -> WorkItem {
unpack_work_item(
self.lo.load(Ordering::Relaxed),
self.hi.load(Ordering::Relaxed),
)
}
}
fn pack_work_item(item: WorkItem) -> (u64, u64) {
(
u64::from(item.input_offset) | (u64::from(item.input_len) << 32),
u64::from(item.rule_set_id) | (u64::from(item.correlation) << 32),
)
}
fn unpack_work_item(lo: u64, hi: u64) -> WorkItem {
WorkItem {
input_offset: lo as u32,
input_len: (lo >> 32) as u32,
rule_set_id: hi as u32,
correlation: (hi >> 32) as u32,
}
}
#[derive(Debug)]
pub struct PersistentEngine {
slots: Vec<WorkSlot>,
atomics: RingAtomics,
ring_size: u32,
}
impl PersistentEngine {
pub fn new(ring_size: u32) -> Self {
let ring_size = ring_size
.checked_next_power_of_two()
.filter(|&size| size > 0)
.unwrap_or(1);
Self::with_valid_ring_size(ring_size)
}
pub fn try_new(ring_size: u32) -> Result<Self, String> {
if ring_size.is_power_of_two() && ring_size > 0 {
Ok(Self::with_valid_ring_size(ring_size))
} else {
Err(format!(
"Fix: ring_size must be a nonzero power of two, got {ring_size}."
))
}
}
fn with_valid_ring_size(ring_size: u32) -> Self {
let zero = WorkItem {
input_offset: 0,
input_len: 0,
rule_set_id: 0,
correlation: 0,
};
let slots = (0..ring_size).map(|_| WorkSlot::new(zero)).collect();
Self {
slots,
atomics: RingAtomics::new(ring_size),
ring_size,
}
}
pub fn ring_size(&self) -> u32 {
self.ring_size
}
pub fn enqueue(&self, item: WorkItem) -> Result<u32, QueueFull> {
loop {
let head = self.atomics.head.load(Ordering::Acquire);
let tail = self.atomics.tail.load(Ordering::Acquire);
if head.wrapping_sub(tail) >= u64::from(self.ring_size) {
return Err(QueueFull);
}
match self.atomics.head.compare_exchange(
head,
head.wrapping_add(1),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let slot_idx = (head as u32) & (self.ring_size - 1);
let slot_offset = slot_idx as usize;
let Some(slot) = self.slots.get(slot_offset) else {
return Err(QueueFull);
};
slot.store(item);
self.atomics.done[slot_offset].store(0, Ordering::Release);
self.atomics.ready[slot_offset].store(head.wrapping_add(1), Ordering::Release);
return Ok(slot_idx);
}
Err(_) => continue,
}
}
}
pub fn claim(&self) -> Option<WorkItem> {
loop {
let head = self.atomics.head.load(Ordering::Acquire);
let tail = self.atomics.tail.load(Ordering::Acquire);
if tail >= head {
return None;
}
match self.atomics.tail.compare_exchange(
tail,
tail.wrapping_add(1),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let slot_idx = (tail as u32) & (self.ring_size - 1);
let slot_offset = slot_idx as usize;
let published = tail.wrapping_add(1);
while self.atomics.ready[slot_offset].load(Ordering::Acquire) != published {
std::hint::spin_loop();
}
let slot = self.slots.get(slot_offset)?;
return Some(slot.load());
}
Err(_) => continue,
}
}
}
pub fn mark_done(&self, slot_idx: u32) {
if let Some(done) = self.atomics.done.get(slot_idx as usize) {
done.store(1, Ordering::Release);
}
}
pub fn is_done(&self, slot_idx: u32) -> bool {
self.atomics
.done
.get(slot_idx as usize)
.is_some_and(|done| done.load(Ordering::Acquire) != 0)
}
pub fn in_flight(&self) -> u32 {
self.atomics
.head
.load(Ordering::Acquire)
.wrapping_sub(self.atomics.tail.load(Ordering::Acquire))
.min(u64::from(u32::MAX)) as u32
}
pub fn head(&self) -> u32 {
self.atomics.head.load(Ordering::Acquire) as u32
}
pub fn tail(&self) -> u32 {
self.atomics.tail.load(Ordering::Acquire) as u32
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct QueueFull;
impl std::fmt::Display for QueueFull {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("persistent engine ring buffer is full")
}
}
impl std::error::Error for QueueFull {}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
fn item(i: u32) -> WorkItem {
WorkItem {
input_offset: i * 1024,
input_len: 1024,
rule_set_id: 0,
correlation: i,
}
}
#[test]
fn invalid_ring_size_has_explicit_error_api() {
let err = PersistentEngine::try_new(7).unwrap_err();
assert!(err.contains("Fix:"));
assert!(PersistentEngine::try_new(0).is_err());
}
#[test]
fn infallible_constructor_normalizes_ring_size() {
assert_eq!(PersistentEngine::new(7).ring_size(), 8);
assert_eq!(PersistentEngine::new(0).ring_size(), 1);
}
#[test]
fn enqueue_claim_fifo_single_thread() {
let eng = PersistentEngine::new(8);
for i in 0..8 {
assert_eq!(eng.enqueue(item(i)).unwrap(), i);
}
for i in 0..8 {
assert_eq!(eng.claim().unwrap().correlation, i);
}
assert!(eng.claim().is_none());
}
#[test]
fn queue_full_on_overflow() {
let eng = PersistentEngine::new(4);
for i in 0..4 {
eng.enqueue(item(i)).unwrap();
}
assert_eq!(eng.enqueue(item(99)), Err(QueueFull));
}
#[test]
fn space_reclaims_after_claim() {
let eng = PersistentEngine::new(4);
for i in 0..4 {
eng.enqueue(item(i)).unwrap();
}
assert!(eng.enqueue(item(99)).is_err());
let _ = eng.claim().unwrap();
assert!(eng.enqueue(item(99)).is_ok());
}
#[test]
fn in_flight_tracks_correctly() {
let eng = PersistentEngine::new(16);
assert_eq!(eng.in_flight(), 0);
for i in 0..5 {
eng.enqueue(item(i)).unwrap();
}
assert_eq!(eng.in_flight(), 5);
eng.claim().unwrap();
eng.claim().unwrap();
assert_eq!(eng.in_flight(), 3);
}
#[test]
fn done_marker_flows_through() {
let eng = PersistentEngine::new(4);
let slot = eng.enqueue(item(1)).unwrap();
assert!(!eng.is_done(slot));
let _ = eng.claim().unwrap();
eng.mark_done(slot);
assert!(eng.is_done(slot));
}
#[test]
fn mark_done_never_panics_on_invalid_slot() {
let eng = PersistentEngine::new(4);
eng.mark_done(u32::MAX);
assert!(!eng.is_done(u32::MAX));
}
#[test]
fn multi_producer_single_consumer_no_item_lost() {
let eng = Arc::new(PersistentEngine::new(128));
let producers = 4;
let items_per_producer = 16;
let mut handles = Vec::new();
for p in 0..producers {
let eng = Arc::clone(&eng);
handles.push(thread::spawn(move || {
for i in 0..items_per_producer {
let corr = (p * 1000 + i) as u32;
loop {
if eng.enqueue(item(corr)).is_ok() {
break;
}
thread::yield_now();
}
}
}));
}
let consumer_eng = Arc::clone(&eng);
let consumer = thread::spawn(move || {
let total = (producers * items_per_producer) as usize;
let mut seen = Vec::with_capacity(total);
while seen.len() < total {
if let Some(it) = consumer_eng.claim() {
seen.push(it.correlation);
} else {
thread::yield_now();
}
}
seen
});
for h in handles {
h.join().unwrap();
}
let seen = consumer.join().unwrap();
let mut sorted = seen.clone();
sorted.sort();
sorted.dedup();
assert_eq!(sorted.len(), seen.len(), "duplicate items consumed");
for p in 0..producers {
for i in 0..items_per_producer {
let expected = (p * 1000 + i) as u32;
assert!(
seen.contains(&expected),
"missing correlation id {expected}"
);
}
}
}
#[test]
fn wrap_around_works_for_large_throughput() {
let eng = PersistentEngine::new(16);
let passes = 10;
for p in 0..passes {
for i in 0..16 {
let corr = (p * 1000 + i) as u32;
assert!(eng.enqueue(item(corr)).is_ok());
}
for i in 0..16 {
let corr = (p * 1000 + i) as u32;
assert_eq!(eng.claim().unwrap().correlation, corr);
}
}
assert_eq!(eng.head(), (passes * 16) as u32);
assert_eq!(eng.tail(), (passes * 16) as u32);
assert_eq!(eng.in_flight(), 0);
}
#[test]
fn multi_consumer_no_double_claim() {
let eng = Arc::new(PersistentEngine::new(128));
let total = 100_u32;
for i in 0..total {
eng.enqueue(item(i)).unwrap();
}
let consumers = 4;
let mut handles = Vec::new();
let shared_consumed = Arc::new(std::sync::Mutex::new(Vec::new()));
for _ in 0..consumers {
let eng = Arc::clone(&eng);
let out = Arc::clone(&shared_consumed);
handles.push(thread::spawn(move || {
let mut local = Vec::new();
while let Some(it) = eng.claim() {
local.push(it.correlation);
}
out.lock().unwrap().extend(local);
}));
}
for h in handles {
h.join().unwrap();
}
let mut consumed = Arc::try_unwrap(shared_consumed)
.unwrap()
.into_inner()
.unwrap();
consumed.sort();
assert_eq!(consumed.len(), total as usize);
for (i, c) in consumed.iter().enumerate() {
assert_eq!(*c, i as u32, "duplicated or missing item at idx {i}");
}
}
#[test]
fn queue_full_error_display_is_useful() {
let s = format!("{QueueFull}");
assert!(s.contains("ring buffer"));
}
}