lightning/
linked_map.rs

1// A concurrent linked hash map, fast and lock-free on iterate
2
3use crate::map::{Map, ObjectMap};
4use crate::spin::SpinLock;
5use std::ops::Deref;
6use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
7use std::sync::atomic::{fence, AtomicUsize};
8use std::sync::Arc;
9
10const NONE_KEY: usize = !0 >> 1;
11
12pub type NodeRef<T> = Arc<Node<T>>;
13
14pub struct Node<T> {
15    // Prev and next node keys
16    lock: SpinLock<()>,
17    prev: AtomicUsize,
18    next: AtomicUsize,
19    obj: T,
20}
21
22pub struct LinkedObjectMap<T> {
23    map: ObjectMap<NodeRef<T>>,
24    head: AtomicUsize,
25    tail: AtomicUsize,
26}
27
28impl<T> LinkedObjectMap<T> {
29    pub fn with_capacity(cap: usize) -> Self {
30        LinkedObjectMap {
31            map: ObjectMap::with_capacity(cap),
32            head: AtomicUsize::new(NONE_KEY),
33            tail: AtomicUsize::new(NONE_KEY),
34        }
35    }
36
37    pub fn insert_front(&self, key: &usize, value: T) {
38        debug_assert_ne!(*key, NONE_KEY);
39        let backoff = crossbeam_utils::Backoff::new();
40        let new_front = Node::new(value, NONE_KEY, NONE_KEY);
41        if let Some(_) = self.map.insert(key, new_front.clone()) {
42            return;
43        }
44        let _new_guard = new_front.lock.lock();
45        loop {
46            let front = self.head.load(Acquire);
47            let front_node = self.map.get(&front);
48            let _front_guard = front_node.as_ref().map(|n| n.lock.lock());
49            if let Some(ref front_node) = front_node {
50                if front_node.get_prev() != NONE_KEY {
51                    backoff.spin();
52                    continue;
53                }
54            } else if front != NONE_KEY {
55                // Inconsistent with map, will spin wait
56                backoff.spin();
57                continue;
58            }
59            new_front.set_next(front);
60            if self.head.compare_and_swap(front, *key, AcqRel) == front {
61                if let Some(ref front_node) = front_node {
62                    front_node.prev.store(*key, Release);
63                } else {
64                    debug_assert_eq!(front, NONE_KEY);
65                    self.tail.compare_and_swap(NONE_KEY, *key, AcqRel);
66                }
67                break;
68            } else {
69                backoff.spin();
70            }
71        }
72    }
73
74    pub fn insert_back(&self, key: &usize, value: T) {
75        debug_assert_ne!(*key, NONE_KEY);
76        let backoff = crossbeam_utils::Backoff::new();
77        let new_back = Node::new(value, NONE_KEY, NONE_KEY);
78        let _new_guard = new_back.lock.lock();
79        if let Some(_) = self.map.insert(key, new_back.clone()) {
80            return;
81        }
82        loop {
83            let back = self.tail.load(Acquire);
84            let back_node = self.map.get(&back);
85            let _back_guard = back_node.as_ref().map(|n| n.lock.lock());
86            if let Some(ref back_node) = back_node {
87                if back_node.get_next() != NONE_KEY {
88                    backoff.spin();
89                    continue;
90                }
91            } else if back != NONE_KEY {
92                backoff.spin();
93                continue;
94            }
95            new_back.set_prev(back);
96            if self.tail.compare_and_swap(back, *key, AcqRel) == back {
97                if let Some(ref back_node) = back_node {
98                    back_node.next.store(*key, Release);
99                } else {
100                    debug_assert_eq!(back, NONE_KEY);
101                    self.head.compare_and_swap(NONE_KEY, *key, AcqRel);
102                }
103                break;
104            } else {
105                backoff.spin();
106            }
107        }
108    }
109
110    pub fn get(&self, key: &usize) -> Option<NodeRef<T>> {
111        self.map.get(key)
112    }
113
114    pub fn remove(&self, key: &usize) -> Option<NodeRef<T>> {
115        let val = self.map.get(key);
116        if let Some(val_node) = val {
117            self.remove_node(*key, val_node);
118            return self.map.remove(key);
119        } else {
120            return val;
121        }
122    }
123
124    fn remove_node(&self, key: usize, val_node: NodeRef<T>) {
125        let backoff = crossbeam_utils::Backoff::new();
126        loop {
127            let prev = val_node.get_prev();
128            let next = val_node.get_next();
129            let prev_node = self.map.get(&prev);
130            let next_node = self.map.get(&next);
131            if (prev != NONE_KEY && prev_node.is_none())
132                || (next != NONE_KEY && next_node.is_none())
133            {
134                backoff.spin();
135                continue;
136            }
137            // Lock 3 nodes, from left to right to avoid dead lock
138            let _prev_guard = prev_node.as_ref().map(|n| n.lock.lock());
139            let _self_guard = val_node.lock.lock();
140            let _next_guard = next_node.as_ref().map(|n| n.lock.lock());
141            // Validate 3 nodes, retry on failure
142            if {
143                prev_node
144                    .as_ref()
145                    .map(|n| n.get_next() != key)
146                    .unwrap_or(false)
147                    | (val_node.get_prev() != prev)
148                    | (val_node.get_next() != next)
149                    | next_node
150                        .as_ref()
151                        .map(|n| n.get_prev() != key)
152                        .unwrap_or(false)
153            } {
154                backoff.spin();
155                continue;
156            }
157            // Bacause all the nodes we are about to modify are locked, we shall use store
158            // instead of CAS
159            prev_node.as_ref().map(|n| n.set_next(next));
160            next_node.as_ref().map(|n| n.set_prev(prev));
161            if prev_node.is_none() {
162                debug_assert_eq!(self.head.load(Acquire), key);
163                self.head.store(next, Release);
164            }
165            if next_node.is_none() {
166                debug_assert_eq!(self.tail.load(Acquire), key);
167                self.tail.store(prev, Release);
168            }
169            return;
170        }
171    }
172
173    pub fn len(&self) -> usize {
174        self.map.len()
175    }
176
177    pub fn contains_key(&self, key: &usize) -> bool {
178        self.map.contains_key(key)
179    }
180
181    pub fn all_pairs(&self) -> Vec<(usize, NodeRef<T>)> {
182        let mut res = vec![];
183        let mut node_key = self.head.load(Acquire);
184        loop {
185            if let Some(node) = self.map.get(&node_key) {
186                let new_node_key = node.get_next();
187                res.push((node_key, node));
188                node_key = new_node_key;
189            } else if node_key == NONE_KEY {
190                break;
191            } else {
192                unreachable!();
193            }
194        }
195        res
196    }
197
198    pub fn all_keys(&self) -> Vec<usize> {
199        let mut res = vec![];
200        let mut node_key = self.head.load(Acquire);
201        loop {
202            if let Some(node) = self.map.get(&node_key) {
203                res.push(node_key);
204                node_key = node.get_next();
205            } else if node_key == NONE_KEY {
206                break;
207            } else {
208                unreachable!();
209            }
210        }
211        res
212    }
213
214    pub fn all_values(&self) -> Vec<NodeRef<T>> {
215        let mut res = vec![];
216        let mut node_key = self.head.load(Acquire);
217        loop {
218            if let Some(node) = self.map.get(&node_key) {
219                node_key = node.get_next();
220                res.push(node);
221            } else if node_key == NONE_KEY {
222                break;
223            } else {
224                unreachable!();
225            }
226        }
227        res
228    }
229
230    pub fn iter(&self) -> LinkedMapIter<T> {
231        loop {
232            
233        }
234    }
235}
236
237pub struct LinkedMapIter<'a, T> {
238    node: Arc<Node<T>>,
239    map: &'a LinkedObjectMap<T>
240}
241
242impl<T> Node<T> {
243    pub fn new(obj: T, prev: usize, next: usize) -> NodeRef<T> {
244        Arc::new(Self {
245            obj,
246            lock: SpinLock::new(()),
247            prev: AtomicUsize::new(prev),
248            next: AtomicUsize::new(next),
249        })
250    }
251
252    fn get_next(&self) -> usize {
253        self.next.load(Acquire)
254    }
255
256    fn get_prev(&self) -> usize {
257        self.prev.load(Acquire)
258    }
259
260    fn set_next(&self, new: usize) {
261        self.next.store(new, Release)
262    }
263
264    fn set_prev(&self, new: usize) {
265        self.prev.store(new, Release)
266    }
267}
268
269impl<T> Deref for Node<T> {
270    type Target = T;
271
272    fn deref(&self) -> &Self::Target {
273        &self.obj
274    }
275}
276
277#[cfg(test)]
278mod test {
279    use super::*;
280    use std::{collections::HashSet, thread};
281
282    #[test]
283    pub fn linked_map_serial() {
284        let map = LinkedObjectMap::with_capacity(16);
285        for i in 0..1024 {
286            map.insert_front(&i, i);
287        }
288        for i in 1024..2048 {
289            map.insert_back(&i, i);
290        }
291    }
292
293    #[test]
294    pub fn linked_map_insertions() {
295        let _ = env_logger::try_init();
296        let linked_map = Arc::new(LinkedObjectMap::with_capacity(16));
297        let num_threads = num_cpus::get();
298        let mut threads = vec![];
299        let num_data = 999;
300        for i in 0..num_threads {
301            let map = linked_map.clone();
302            threads.push(thread::spawn(move || {
303                for j in 0..num_data {
304                    let num = i * 1000 + j;
305                    debug!("Insert {}", num);
306                    if j % 2 == 1 {
307                        map.insert_back(&num, num);
308                    } else {
309                        map.insert_front(&num, num);
310                    }
311                }
312                map.all_keys();
313                map.all_values();
314                map.all_pairs();
315            }));
316        }
317        info!("Waiting for threads to finish");
318        for t in threads {
319            t.join().unwrap();
320        }
321        let mut num_set = HashSet::new();
322        for (key, node) in linked_map.all_pairs() {
323            let value = **node;
324            assert_eq!(key, value);
325            num_set.insert(key);
326        }
327        assert_eq!(num_set.len(), num_threads * num_data);
328    }
329}