use crate::metrics::{AtomicMetrics, MetricsCollector};
use crate::util::CachePadded;
use core::sync::atomic::{AtomicPtr, Ordering};
use std::sync::atomic::{AtomicUsize, Ordering as StdOrdering};
#[cfg(feature = "std")]
use std::boxed::Box;
#[cfg(feature = "std")]
use std::vec::Vec;
#[derive(Debug)]
struct Node<T> {
data: T,
next: AtomicPtr<Node<T>>,
}
#[derive(Debug)]
pub struct LockFreeStack<T> {
head: CachePadded<AtomicPtr<Node<T>>>,
metrics: AtomicMetrics,
metrics_enabled: AtomicUsize,
}
impl<T> LockFreeStack<T> {
pub fn new() -> Self {
Self {
head: CachePadded::new(AtomicPtr::new(std::ptr::null_mut())),
metrics: AtomicMetrics::default(),
metrics_enabled: AtomicUsize::new(1), }
}
pub fn push(&self, value: T) {
#[cfg(feature = "std")]
let start = std::time::Instant::now();
let node = Box::into_raw(Box::new(Node {
data: value,
next: AtomicPtr::new(std::ptr::null_mut()),
}));
loop {
let head = self.head.get().load(Ordering::Acquire);
unsafe {
(*node).next.store(head, Ordering::Relaxed);
}
match self.head.get().compare_exchange_weak(
head,
node,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => {
#[cfg(feature = "std")]
self.metrics.record_success(start.elapsed());
break;
}
Err(_) => {
#[cfg(feature = "std")]
self.metrics.record_contention();
}
}
}
}
pub fn pop(&self) -> Option<T> {
#[cfg(feature = "std")]
let start = std::time::Instant::now();
loop {
let head = self.head.get().load(Ordering::Acquire);
if head.is_null() {
#[cfg(feature = "std")]
self.metrics.record_failure();
return None;
}
let next = unsafe { (*head).next.load(Ordering::Relaxed) };
match self.head.get().compare_exchange_weak(
head,
next,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => {
let data = unsafe {
let node = Box::from_raw(head);
node.data
};
#[cfg(feature = "std")]
self.metrics.record_success(start.elapsed());
return Some(data);
}
Err(_) => {
#[cfg(feature = "std")]
self.metrics.record_contention();
}
}
}
}
pub fn is_empty(&self) -> bool {
self.head.get().load(Ordering::Acquire).is_null()
}
pub fn len(&self) -> usize {
let mut count = 0;
let mut current = self.head.get().load(Ordering::Acquire);
while !current.is_null() {
count += 1;
current = unsafe { (*current).next.load(Ordering::Relaxed) };
}
count
}
pub fn pop_batch(&self, max_count: usize) -> Vec<T> {
let mut result = Vec::with_capacity(max_count);
for _ in 0..max_count {
if let Some(value) = self.pop() {
result.push(value);
} else {
break;
}
}
result
}
pub fn push_batch<I>(&self, values: I)
where
I: IntoIterator<Item = T>,
{
for value in values {
self.push(value);
}
}
}
impl<T> Default for LockFreeStack<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Drop for LockFreeStack<T> {
fn drop(&mut self) {
let mut current = self.head.get().load(Ordering::Acquire);
while !current.is_null() {
let next = unsafe { (*current).next.load(Ordering::Relaxed) };
unsafe {
drop(Box::from_raw(current));
}
current = next;
}
}
}
#[cfg(feature = "std")]
impl<T> MetricsCollector for LockFreeStack<T> {
fn metrics(&self) -> crate::metrics::PerformanceMetrics {
self.metrics.snapshot()
}
fn reset_metrics(&self) {
self.metrics.reset();
}
fn set_metrics_enabled(&self, enabled: bool) {
self.metrics_enabled
.store(enabled as usize, StdOrdering::Relaxed);
}
fn is_metrics_enabled(&self) -> bool {
self.metrics_enabled.load(StdOrdering::Relaxed) != 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::vec;
#[test]
fn test_basic_operations() {
let stack = LockFreeStack::new();
assert!(stack.is_empty());
assert_eq!(stack.len(), 0);
assert_eq!(stack.pop(), None);
stack.push(1);
stack.push(2);
stack.push(3);
assert!(!stack.is_empty());
assert_eq!(stack.len(), 3);
assert_eq!(stack.pop(), Some(3));
assert_eq!(stack.pop(), Some(2));
assert_eq!(stack.pop(), Some(1));
assert_eq!(stack.pop(), None);
assert!(stack.is_empty());
assert_eq!(stack.len(), 0);
}
#[test]
fn test_batch_operations() {
let stack = LockFreeStack::new();
stack.push_batch(vec![1, 2, 3, 4, 5]);
assert_eq!(stack.len(), 5);
let elements = stack.pop_batch(3);
assert_eq!(elements.len(), 3);
assert_eq!(stack.len(), 2);
assert_eq!(stack.pop(), Some(2));
assert_eq!(stack.pop(), Some(1));
assert_eq!(stack.pop(), None);
}
#[test]
fn test_concurrent_operations() {
let stack = Arc::new(LockFreeStack::new());
let mut handles = vec![];
for i in 0..4 {
let stack_clone = Arc::clone(&stack);
let handle = thread::spawn(move || {
for j in 0..100 {
stack_clone.push(i * 100 + j);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let mut count = 0;
while stack.pop().is_some() {
count += 1;
}
assert_eq!(count, 400);
assert!(stack.is_empty());
}
#[test]
fn test_producer_consumer() {
let stack = Arc::new(LockFreeStack::new());
let _handles: Vec<std::thread::JoinHandle<()>> = vec![];
let producer_stack = Arc::clone(&stack);
let producer = thread::spawn(move || {
for i in 0..1000 {
producer_stack.push(i);
}
});
let consumer_stack = Arc::clone(&stack);
let consumer = thread::spawn(move || {
let mut sum = 0;
let mut count = 0;
while count < 1000 {
if let Some(value) = consumer_stack.pop() {
sum += value;
count += 1;
}
thread::yield_now();
}
sum
});
producer.join().unwrap();
let result = consumer.join().unwrap();
assert_eq!(result, 499500);
}
#[cfg(feature = "std")]
#[test]
fn test_metrics() {
let stack = LockFreeStack::new();
stack.push(1);
stack.push(2);
stack.push(3);
let _ = stack.pop();
let _ = stack.pop();
let _ = stack.pop();
let _ = stack.pop();
let metrics = stack.metrics();
assert_eq!(metrics.total_operations, 7);
assert_eq!(metrics.successful_operations, 6);
assert_eq!(metrics.failed_operations, 1);
assert!(metrics.success_rate() > 80.0);
stack.set_metrics_enabled(false);
assert!(!stack.is_metrics_enabled());
stack.reset_metrics();
let reset_metrics = stack.metrics();
assert_eq!(reset_metrics.total_operations, 0);
}
}