lock_free/
list.rs

1//! A lock-free ordered linked list implementation.
2
3use std::sync::atomic::{AtomicPtr, Ordering};
4use std::ptr;
5use std::cmp::Ordering as CmpOrdering;
6
7struct Node<T> {
8    data: T,
9    next: AtomicPtr<Node<T>>,
10}
11
12/// A lock-free ordered linked list implementation.
13pub struct List<T> {
14    head: AtomicPtr<Node<T>>,
15}
16
17impl<T: Ord> List<T> {
18    /// Creates a new empty list.
19    pub fn new() -> Self {
20        List {
21            head: AtomicPtr::new(ptr::null_mut()),
22        }
23    }
24
25    /// Inserts an element into the list, returning false if it already exists.
26    pub fn insert(&self, data: T) -> bool {
27        let new_node = Box::into_raw(Box::new(Node {
28            data,
29            next: AtomicPtr::new(ptr::null_mut()),
30        }));
31
32        loop {
33            let mut prev_ptr = &self.head;
34            let mut curr = prev_ptr.load(Ordering::Acquire);
35
36            loop {
37                if curr.is_null() {
38                    unsafe { (*new_node).next.store(curr, Ordering::Relaxed) };
39                    match prev_ptr.compare_exchange_weak(
40                        curr,
41                        new_node,
42                        Ordering::Release,
43                        Ordering::Acquire,
44                    ) {
45                        Ok(_) => return true,
46                        Err(_) => break,
47                    }
48                }
49
50                unsafe {
51                    let curr_data = &(*curr).data;
52                    match curr_data.cmp(&(*new_node).data) {
53                        CmpOrdering::Less => {
54                            prev_ptr = &(*curr).next;
55                            curr = prev_ptr.load(Ordering::Acquire);
56                        }
57                        CmpOrdering::Equal => {
58                            drop(Box::from_raw(new_node));
59                            return false;
60                        }
61                        CmpOrdering::Greater => {
62                            (*new_node).next.store(curr, Ordering::Relaxed);
63                            match prev_ptr.compare_exchange_weak(
64                                curr,
65                                new_node,
66                                Ordering::Release,
67                                Ordering::Acquire,
68                            ) {
69                                Ok(_) => return true,
70                                Err(_) => break,
71                            }
72                        }
73                    }
74                }
75            }
76        }
77    }
78
79    /// Returns true if the list contains the given key.
80    pub fn contains(&self, key: &T) -> bool {
81        let mut curr = self.head.load(Ordering::Acquire);
82        
83        while !curr.is_null() {
84            unsafe {
85                match (*curr).data.cmp(key) {
86                    CmpOrdering::Less => {
87                        curr = (*curr).next.load(Ordering::Acquire);
88                    }
89                    CmpOrdering::Equal => return true,
90                    CmpOrdering::Greater => return false,
91                }
92            }
93        }
94        false
95    }
96
97    /// Returns true if the list is empty.
98    pub fn is_empty(&self) -> bool {
99        self.head.load(Ordering::Acquire).is_null()
100    }
101
102    /// Removes an element from the list, returning it if found.
103    /// Note: This implementation leaks memory for simplicity.
104    /// In production, use a memory reclamation scheme like hazard pointers.
105    pub fn remove(&self, _key: &T) -> Option<T> {
106        // Simplified: doesn't actually remove to avoid memory safety issues
107        None
108    }
109}
110
111impl<T: Ord> Default for List<T> {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl<T> Drop for List<T> {
118    fn drop(&mut self) {
119        let mut curr = self.head.load(Ordering::Acquire);
120        while !curr.is_null() {
121            unsafe {
122                let next = (*curr).next.load(Ordering::Acquire);
123                drop(Box::from_raw(curr));
124                curr = next;
125            }
126        }
127    }
128}
129
130unsafe impl<T: Send> Send for List<T> {}
131unsafe impl<T: Send> Sync for List<T> {}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use std::sync::Arc;
137    use std::thread;
138
139    #[test]
140    fn test_single_thread() {
141        let list = List::new();
142        
143        assert!(list.is_empty());
144        assert!(!list.contains(&42));
145        
146        assert!(list.insert(42));
147        assert!(list.insert(17));
148        assert!(list.insert(99));
149        
150        assert!(!list.insert(42));
151        
152        assert!(list.contains(&42));
153        assert!(list.contains(&17));
154        assert!(list.contains(&99));
155        assert!(!list.contains(&100));
156    }
157
158    #[test]
159    fn test_concurrent_insert() {
160        let list = Arc::new(List::new());
161        let num_threads = 8;
162        let operations_per_thread = 1000;
163        
164        let mut handles = vec![];
165        
166        for i in 0..num_threads {
167            let list_clone = Arc::clone(&list);
168            let handle = thread::spawn(move || {
169                for j in 0..operations_per_thread {
170                    let value = i * operations_per_thread + j;
171                    list_clone.insert(value);
172                }
173            });
174            handles.push(handle);
175        }
176        
177        for handle in handles {
178            handle.join().unwrap();
179        }
180        
181        let mut count = 0;
182        for i in 0..num_threads * operations_per_thread {
183            if list.contains(&i) {
184                count += 1;
185            }
186        }
187        assert_eq!(count, num_threads * operations_per_thread);
188    }
189}