1use crate::map::{Map, ObjectMap};
4use crate::spin::SpinLock;
5use std::ops::Deref;
6use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
7use std::sync::atomic::{fence, AtomicUsize};
8use std::sync::Arc;
9
10const NONE_KEY: usize = !0 >> 1;
11
12pub type NodeRef<T> = Arc<Node<T>>;
13
14pub struct Node<T> {
15 lock: SpinLock<()>,
17 prev: AtomicUsize,
18 next: AtomicUsize,
19 obj: T,
20}
21
22pub struct LinkedObjectMap<T> {
23 map: ObjectMap<NodeRef<T>>,
24 head: AtomicUsize,
25 tail: AtomicUsize,
26}
27
28impl<T> LinkedObjectMap<T> {
29 pub fn with_capacity(cap: usize) -> Self {
30 LinkedObjectMap {
31 map: ObjectMap::with_capacity(cap),
32 head: AtomicUsize::new(NONE_KEY),
33 tail: AtomicUsize::new(NONE_KEY),
34 }
35 }
36
37 pub fn insert_front(&self, key: &usize, value: T) {
38 debug_assert_ne!(*key, NONE_KEY);
39 let backoff = crossbeam_utils::Backoff::new();
40 let new_front = Node::new(value, NONE_KEY, NONE_KEY);
41 if let Some(_) = self.map.insert(key, new_front.clone()) {
42 return;
43 }
44 let _new_guard = new_front.lock.lock();
45 loop {
46 let front = self.head.load(Acquire);
47 let front_node = self.map.get(&front);
48 let _front_guard = front_node.as_ref().map(|n| n.lock.lock());
49 if let Some(ref front_node) = front_node {
50 if front_node.get_prev() != NONE_KEY {
51 backoff.spin();
52 continue;
53 }
54 } else if front != NONE_KEY {
55 backoff.spin();
57 continue;
58 }
59 new_front.set_next(front);
60 if self.head.compare_and_swap(front, *key, AcqRel) == front {
61 if let Some(ref front_node) = front_node {
62 front_node.prev.store(*key, Release);
63 } else {
64 debug_assert_eq!(front, NONE_KEY);
65 self.tail.compare_and_swap(NONE_KEY, *key, AcqRel);
66 }
67 break;
68 } else {
69 backoff.spin();
70 }
71 }
72 }
73
74 pub fn insert_back(&self, key: &usize, value: T) {
75 debug_assert_ne!(*key, NONE_KEY);
76 let backoff = crossbeam_utils::Backoff::new();
77 let new_back = Node::new(value, NONE_KEY, NONE_KEY);
78 let _new_guard = new_back.lock.lock();
79 if let Some(_) = self.map.insert(key, new_back.clone()) {
80 return;
81 }
82 loop {
83 let back = self.tail.load(Acquire);
84 let back_node = self.map.get(&back);
85 let _back_guard = back_node.as_ref().map(|n| n.lock.lock());
86 if let Some(ref back_node) = back_node {
87 if back_node.get_next() != NONE_KEY {
88 backoff.spin();
89 continue;
90 }
91 } else if back != NONE_KEY {
92 backoff.spin();
93 continue;
94 }
95 new_back.set_prev(back);
96 if self.tail.compare_and_swap(back, *key, AcqRel) == back {
97 if let Some(ref back_node) = back_node {
98 back_node.next.store(*key, Release);
99 } else {
100 debug_assert_eq!(back, NONE_KEY);
101 self.head.compare_and_swap(NONE_KEY, *key, AcqRel);
102 }
103 break;
104 } else {
105 backoff.spin();
106 }
107 }
108 }
109
110 pub fn get(&self, key: &usize) -> Option<NodeRef<T>> {
111 self.map.get(key)
112 }
113
114 pub fn remove(&self, key: &usize) -> Option<NodeRef<T>> {
115 let val = self.map.get(key);
116 if let Some(val_node) = val {
117 self.remove_node(*key, val_node);
118 return self.map.remove(key);
119 } else {
120 return val;
121 }
122 }
123
124 fn remove_node(&self, key: usize, val_node: NodeRef<T>) {
125 let backoff = crossbeam_utils::Backoff::new();
126 loop {
127 let prev = val_node.get_prev();
128 let next = val_node.get_next();
129 let prev_node = self.map.get(&prev);
130 let next_node = self.map.get(&next);
131 if (prev != NONE_KEY && prev_node.is_none())
132 || (next != NONE_KEY && next_node.is_none())
133 {
134 backoff.spin();
135 continue;
136 }
137 let _prev_guard = prev_node.as_ref().map(|n| n.lock.lock());
139 let _self_guard = val_node.lock.lock();
140 let _next_guard = next_node.as_ref().map(|n| n.lock.lock());
141 if {
143 prev_node
144 .as_ref()
145 .map(|n| n.get_next() != key)
146 .unwrap_or(false)
147 | (val_node.get_prev() != prev)
148 | (val_node.get_next() != next)
149 | next_node
150 .as_ref()
151 .map(|n| n.get_prev() != key)
152 .unwrap_or(false)
153 } {
154 backoff.spin();
155 continue;
156 }
157 prev_node.as_ref().map(|n| n.set_next(next));
160 next_node.as_ref().map(|n| n.set_prev(prev));
161 if prev_node.is_none() {
162 debug_assert_eq!(self.head.load(Acquire), key);
163 self.head.store(next, Release);
164 }
165 if next_node.is_none() {
166 debug_assert_eq!(self.tail.load(Acquire), key);
167 self.tail.store(prev, Release);
168 }
169 return;
170 }
171 }
172
173 pub fn len(&self) -> usize {
174 self.map.len()
175 }
176
177 pub fn contains_key(&self, key: &usize) -> bool {
178 self.map.contains_key(key)
179 }
180
181 pub fn all_pairs(&self) -> Vec<(usize, NodeRef<T>)> {
182 let mut res = vec![];
183 let mut node_key = self.head.load(Acquire);
184 loop {
185 if let Some(node) = self.map.get(&node_key) {
186 let new_node_key = node.get_next();
187 res.push((node_key, node));
188 node_key = new_node_key;
189 } else if node_key == NONE_KEY {
190 break;
191 } else {
192 unreachable!();
193 }
194 }
195 res
196 }
197
198 pub fn all_keys(&self) -> Vec<usize> {
199 let mut res = vec![];
200 let mut node_key = self.head.load(Acquire);
201 loop {
202 if let Some(node) = self.map.get(&node_key) {
203 res.push(node_key);
204 node_key = node.get_next();
205 } else if node_key == NONE_KEY {
206 break;
207 } else {
208 unreachable!();
209 }
210 }
211 res
212 }
213
214 pub fn all_values(&self) -> Vec<NodeRef<T>> {
215 let mut res = vec![];
216 let mut node_key = self.head.load(Acquire);
217 loop {
218 if let Some(node) = self.map.get(&node_key) {
219 node_key = node.get_next();
220 res.push(node);
221 } else if node_key == NONE_KEY {
222 break;
223 } else {
224 unreachable!();
225 }
226 }
227 res
228 }
229
230 pub fn iter(&self) -> LinkedMapIter<T> {
231 loop {
232
233 }
234 }
235}
236
237pub struct LinkedMapIter<'a, T> {
238 node: Arc<Node<T>>,
239 map: &'a LinkedObjectMap<T>
240}
241
242impl<T> Node<T> {
243 pub fn new(obj: T, prev: usize, next: usize) -> NodeRef<T> {
244 Arc::new(Self {
245 obj,
246 lock: SpinLock::new(()),
247 prev: AtomicUsize::new(prev),
248 next: AtomicUsize::new(next),
249 })
250 }
251
252 fn get_next(&self) -> usize {
253 self.next.load(Acquire)
254 }
255
256 fn get_prev(&self) -> usize {
257 self.prev.load(Acquire)
258 }
259
260 fn set_next(&self, new: usize) {
261 self.next.store(new, Release)
262 }
263
264 fn set_prev(&self, new: usize) {
265 self.prev.store(new, Release)
266 }
267}
268
269impl<T> Deref for Node<T> {
270 type Target = T;
271
272 fn deref(&self) -> &Self::Target {
273 &self.obj
274 }
275}
276
277#[cfg(test)]
278mod test {
279 use super::*;
280 use std::{collections::HashSet, thread};
281
282 #[test]
283 pub fn linked_map_serial() {
284 let map = LinkedObjectMap::with_capacity(16);
285 for i in 0..1024 {
286 map.insert_front(&i, i);
287 }
288 for i in 1024..2048 {
289 map.insert_back(&i, i);
290 }
291 }
292
293 #[test]
294 pub fn linked_map_insertions() {
295 let _ = env_logger::try_init();
296 let linked_map = Arc::new(LinkedObjectMap::with_capacity(16));
297 let num_threads = num_cpus::get();
298 let mut threads = vec![];
299 let num_data = 999;
300 for i in 0..num_threads {
301 let map = linked_map.clone();
302 threads.push(thread::spawn(move || {
303 for j in 0..num_data {
304 let num = i * 1000 + j;
305 debug!("Insert {}", num);
306 if j % 2 == 1 {
307 map.insert_back(&num, num);
308 } else {
309 map.insert_front(&num, num);
310 }
311 }
312 map.all_keys();
313 map.all_values();
314 map.all_pairs();
315 }));
316 }
317 info!("Waiting for threads to finish");
318 for t in threads {
319 t.join().unwrap();
320 }
321 let mut num_set = HashSet::new();
322 for (key, node) in linked_map.all_pairs() {
323 let value = **node;
324 assert_eq!(key, value);
325 num_set.insert(key);
326 }
327 assert_eq!(num_set.len(), num_threads * num_data);
328 }
329}