Skip to main content

ds_ext/
queue.rs

1//! A linked hash map ordered by insertion which can be reordered by swapping,
2//! useful as a simple priority queue (e.g. an LFU or LRU cache).
3
4use std::borrow::Borrow;
5use std::cell::{Ref, RefCell, RefMut};
6use std::collections::HashMap;
7use std::hash::Hash;
8use std::sync::Arc;
9use std::{fmt, mem};
10
11struct ItemState<K> {
12    prev: Option<Arc<K>>,
13    next: Option<Arc<K>>,
14}
15
16struct Item<K, V> {
17    key: Arc<K>,
18    value: V,
19    state: RefCell<ItemState<K>>,
20}
21
22impl<K, V> Item<K, V> {
23    #[inline]
24    fn state(&self) -> Ref<'_, ItemState<K>> {
25        self.state.borrow()
26    }
27
28    #[inline]
29    fn state_mut(&self) -> RefMut<'_, ItemState<K>> {
30        self.state.borrow_mut()
31    }
32}
33
34type Inner<K, V> = HashMap<Arc<K>, Item<K, V>>;
35
36/// An iterator over the contents of a [`LinkedHashMap`]
37pub struct IntoIter<K, V> {
38    queue: LinkedHashMap<K, V>,
39}
40
41impl<K: Eq + Hash + fmt::Debug, V> Iterator for IntoIter<K, V> {
42    type Item = (K, V);
43
44    fn next(&mut self) -> Option<Self::Item> {
45        self.queue.pop_first_entry()
46    }
47
48    fn size_hint(&self) -> (usize, Option<usize>) {
49        (self.queue.len(), Some(self.queue.len()))
50    }
51}
52
53impl<K: Eq + Hash + fmt::Debug, V> DoubleEndedIterator for IntoIter<K, V> {
54    fn next_back(&mut self) -> Option<Self::Item> {
55        self.queue.pop_last_entry()
56    }
57}
58
59/// An iterator over the entries in a [`LinkedHashMap`]
60pub struct Iter<'a, K, V> {
61    list: &'a Inner<K, V>,
62    next: Option<Arc<K>>,
63    last: Option<Arc<K>>,
64    size: usize,
65}
66
67impl<'a, K: Eq + Hash, V> Iterator for Iter<'a, K, V> {
68    type Item = (&'a K, &'a V);
69
70    fn next(&mut self) -> Option<Self::Item> {
71        let next = self.next.take()?;
72        let (key, item) = self.list.get_key_value(&*next).expect("next");
73
74        if self.last == Some(next) {
75            self.next = None;
76            self.last = None;
77        } else {
78            self.next = item.state().next.clone();
79        }
80
81        self.size -= 1;
82
83        Some((key, &item.value))
84    }
85
86    fn size_hint(&self) -> (usize, Option<usize>) {
87        (self.size, Some(self.size))
88    }
89}
90
91impl<'a, K: Eq + Hash, V> DoubleEndedIterator for Iter<'a, K, V> {
92    fn next_back(&mut self) -> Option<Self::Item> {
93        let last = self.last.take()?;
94        let (key, item) = self.list.get_key_value(&*last).expect("next");
95
96        if self.next == Some(last) {
97            self.next = None;
98            self.last = None;
99        } else {
100            self.last = item.state().prev.clone();
101        }
102
103        self.size -= 1;
104
105        Some((key, &item.value))
106    }
107}
108
109/// An iterator over the keys in a [`LinkedHashMap`]
110pub struct Keys<'a, K, V> {
111    inner: Iter<'a, K, V>,
112}
113
114impl<'a, K: Hash + Eq, V> Iterator for Keys<'a, K, V> {
115    type Item = &'a K;
116
117    fn next(&mut self) -> Option<Self::Item> {
118        self.inner.next().map(|(key, _value)| key)
119    }
120
121    fn size_hint(&self) -> (usize, Option<usize>) {
122        self.inner.size_hint()
123    }
124}
125
126impl<'a, K: Hash + Eq, V> DoubleEndedIterator for Keys<'a, K, V> {
127    fn next_back(&mut self) -> Option<Self::Item> {
128        self.inner.next_back().map(|(key, _value)| key)
129    }
130}
131
132/// An iterator over the values in a [`LinkedHashMap`]
133pub struct Values<'a, K, V> {
134    inner: Iter<'a, K, V>,
135}
136
137impl<'a, K: Eq + Hash, V> Iterator for Values<'a, K, V> {
138    type Item = &'a V;
139
140    fn next(&mut self) -> Option<Self::Item> {
141        self.inner.next().map(|(_key, value)| value)
142    }
143
144    fn size_hint(&self) -> (usize, Option<usize>) {
145        self.inner.size_hint()
146    }
147}
148
149impl<'a, K: Eq + Hash, V> DoubleEndedIterator for Values<'a, K, V> {
150    fn next_back(&mut self) -> Option<Self::Item> {
151        self.inner.next_back().map(|(_key, value)| value)
152    }
153}
154
155/// A hash map in insertion order which can be reordered using [`Self::bump`] and [`Self::swap`].
156pub struct LinkedHashMap<K, V> {
157    list: Inner<K, V>,
158    head: Option<Arc<K>>,
159    tail: Option<Arc<K>>,
160}
161
162impl<K: Clone + Eq + Hash, V: Clone> Clone for LinkedHashMap<K, V> {
163    fn clone(&self) -> Self {
164        let mut other = Self::with_capacity(self.list.capacity());
165
166        for (key, item) in &self.list {
167            let key = K::clone(&**key);
168            let value = V::clone(&item.value);
169            other.insert(key, value);
170        }
171
172        other
173    }
174}
175
176impl<K: Eq + Hash, V> LinkedHashMap<K, V> {
177    /// Construct a new [`LinkedHashMap`].
178    pub fn new() -> Self {
179        Self {
180            list: HashMap::new(),
181            head: None,
182            tail: None,
183        }
184    }
185
186    /// Construct a new [`LinkedHashMap`] with the given `capacity`.
187    pub fn with_capacity(capacity: usize) -> Self {
188        Self {
189            list: HashMap::with_capacity(capacity),
190            head: None,
191            tail: None,
192        }
193    }
194
195    /// If `key` is present, increase its priority by one and return `true`.
196    pub fn bump(&mut self, key: &K) -> bool {
197        let item = if let Some(item) = self.list.get(key) {
198            item
199        } else {
200            return false;
201        };
202
203        let mut item_state = item.state_mut();
204
205        if item_state.prev.is_none() {
206            // can't bump the first item
207            return true;
208        } else if item_state.next.is_none() && item_state.prev.is_some() {
209            // bump the last item
210
211            let prev_key = item_state.prev.as_ref().expect("prev key").clone();
212            let mut prev = self.list.get::<K>(&prev_key).expect("prev").state_mut();
213
214            mem::swap(&mut prev.next, &mut item_state.next); // set prev.next
215            mem::swap(&mut item_state.prev, &mut prev.prev); // set item.prev
216            mem::swap(&mut item_state.next, &mut prev.prev); // set item.next & prev.prev
217
218            self.tail = Some(prev_key)
219        } else {
220            // bump an item in the middle
221
222            let prev_key = item_state.prev.as_ref().expect("previous key").clone();
223            let mut prev = self.list.get::<K>(&prev_key).expect("prev").state_mut();
224
225            let next_key = item_state.next.as_ref().expect("next key").clone();
226            let mut next = self.list.get::<K>(&next_key).expect("next").state_mut();
227
228            mem::swap(&mut next.prev, &mut item_state.prev); // set next.prev
229            mem::swap(&mut item_state.prev, &mut prev.prev); // set item.prev
230            mem::swap(&mut prev.next, &mut item_state.next); // set prev.next
231
232            item_state.next = Some(prev_key);
233        }
234
235        if let Some(prev_key) = &item_state.prev {
236            let mut prev = self.list.get::<K>(prev_key).expect("prev").state_mut();
237            prev.next = Some(item.key.clone());
238        } else {
239            self.head = Some(item.key.clone());
240        }
241
242        std::mem::drop(item_state);
243
244        true
245    }
246
247    /// Remove all entries from this [`LinkedHashMap`].
248    pub fn clear(&mut self) {
249        self.list.clear();
250        self.head = None;
251        self.tail = None;
252    }
253
254    /// Return `true` if there is an entry for the given `key` in this [`LinkedHashMap`].
255    pub fn contains_key<Q>(&self, key: &Q) -> bool
256    where
257        Arc<K>: Borrow<Q>,
258        Q: Eq + Hash + ?Sized,
259    {
260        self.list.contains_key(key)
261    }
262
263    /// Consume the `iter` and insert all its elements into this [`LinkedHashMap`].
264    pub fn extend<I: IntoIterator<Item = (K, V)>>(&mut self, iter: I) {
265        for (key, value) in iter {
266            self.insert(key, value);
267        }
268    }
269
270    /// Borrow the value at the given `key`, if present.
271    pub fn get<Q>(&self, key: &Q) -> Option<&V>
272    where
273        Arc<K>: Borrow<Q>,
274        Q: Eq + Hash + ?Sized,
275    {
276        self.list.get(key).map(|item| &item.value)
277    }
278
279    /// Borrow the entry at the given `key`, if present.
280    pub fn get_key_value<Q>(&self, key: &Q) -> Option<(&K, &V)>
281    where
282        Arc<K>: Borrow<Q>,
283        Q: Eq + Hash + ?Sized,
284    {
285        self.list
286            .get_key_value(key)
287            .map(|(key, item)| (&**key, &item.value))
288    }
289
290    /// Borrow the value at the given `key` mutably, if present.
291    pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
292    where
293        Arc<K>: Borrow<Q>,
294        Q: Eq + Hash + ?Sized,
295    {
296        self.list.get_mut(key).map(|item| &mut item.value)
297    }
298
299    /// Insert a new `value` at `key` and return the previous value, if any.
300    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
301        let old_value = self.remove(&key);
302
303        let key = Arc::new(key);
304        let mut next = Some(key.clone());
305        mem::swap(&mut self.head, &mut next);
306
307        if let Some(prev_key) = &next {
308            let mut prev = self.list.get::<K>(prev_key).expect("prev").state_mut();
309            prev.prev = Some(key.clone());
310        } else {
311            debug_assert!(self.tail.is_none());
312            self.tail = Some(key.clone());
313        }
314
315        let item = Item {
316            key: key.clone(),
317            value,
318            state: RefCell::new(ItemState { prev: None, next }),
319        };
320
321        assert!(self.list.insert(key, item).is_none());
322
323        old_value
324    }
325
326    /// Construct an iterator over the entries in this [`LinkedHashMap`].
327    pub fn iter(&self) -> Iter<'_, K, V> {
328        Iter {
329            list: &self.list,
330            next: self.head.clone(),
331            last: self.tail.clone(),
332            size: self.len(),
333        }
334    }
335
336    /// Return `true` if this [`LinkedHashMap`] is empty.
337    pub fn is_empty(&self) -> bool {
338        self.list.is_empty()
339    }
340
341    /// Construct an iterator over keys of this [`LinkedHashMap`].
342    pub fn keys(&self) -> Keys<'_, K, V> {
343        Keys { inner: self.iter() }
344    }
345
346    /// Return the size of this [`LinkedHashMap`].
347    pub fn len(&self) -> usize {
348        self.list.len()
349    }
350
351    /// Remove and return the first value in this [`LinkedHashMap`].
352    pub fn pop_first(&mut self) -> Option<V> {
353        let head = self.head.as_ref()?;
354        let item = self.list.remove(head).expect("head");
355
356        Some(self.remove_inner(item))
357    }
358
359    /// Remove and return the first entry in this [`LinkedHashMap`].
360    pub fn pop_first_entry(&mut self) -> Option<(K, V)>
361    where
362        K: fmt::Debug,
363    {
364        let head = self.head.as_ref()?;
365        let (key, item) = self
366            .list
367            .remove_entry(head)
368            .expect("head");
369
370        let value = self.remove_inner(item);
371        let key = Arc::try_unwrap(key).expect("key");
372        Some((key, value))
373    }
374
375    /// Remove and return the last value in this [`LinkedHashMap`].
376    pub fn pop_last(&mut self) -> Option<V> {
377        let tail = self.tail.as_ref()?;
378        let item = self.list.remove(tail).expect("tail");
379
380        Some(self.remove_inner(item))
381    }
382
383    /// Remove and return the last entry in this [`LinkedHashMap`].
384    pub fn pop_last_entry(&mut self) -> Option<(K, V)>
385    where
386        K: fmt::Debug,
387    {
388        let tail = self.tail.as_ref()?;
389        let (key, item) = self
390            .list
391            .remove_entry(tail)
392            .expect("tail");
393
394        let value = self.remove_inner(item);
395        let key = Arc::try_unwrap(key).expect("key");
396        Some((key, value))
397    }
398
399    fn remove_inner(&mut self, item: Item<K, V>) -> V {
400        let mut item_state = item.state_mut();
401
402        if item_state.prev.is_none() && item_state.next.is_none() {
403            // there was only one item and now the map is empty
404            self.head = None;
405            self.tail = None;
406        } else if item_state.prev.is_none() {
407            // the first item has been removed
408            self.head = item_state.next.clone();
409
410            let next_key = self.head.as_ref().expect("next key");
411            let mut next = self.list.get::<K>(next_key).expect("next").state_mut();
412
413            mem::swap(&mut next.prev, &mut item_state.prev);
414        } else if item_state.next.is_none() {
415            // the last item has been removed
416            self.tail = item_state.prev.clone();
417
418            let prev_key = self.tail.as_ref().expect("previous key");
419            let mut prev = self.list.get::<K>(prev_key).expect("prev").state_mut();
420
421            mem::swap(&mut prev.next, &mut item_state.next);
422        } else {
423            // an item in the middle has been removed
424            let prev_key = item_state.prev.as_ref().expect("previous key");
425            let mut prev = self.list.get::<K>(prev_key).expect("prev").state_mut();
426
427            let next_key = item_state.next.as_ref().expect("next key");
428            let mut next = self.list.get::<K>(next_key).expect("next item").state_mut();
429
430            mem::swap(&mut next.prev, &mut item_state.prev);
431            mem::swap(&mut prev.next, &mut item_state.next);
432        }
433
434        std::mem::drop(item_state);
435
436        item.value
437    }
438
439    /// Remove an entry from this [`LinkedHashMap`] and return its value, if present.
440    pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
441    where
442        Arc<K>: Borrow<Q>,
443        Q: Hash + Eq + ?Sized,
444    {
445        let item = self.list.remove(key)?;
446        Some(self.remove_inner(item))
447    }
448
449    /// Remove and return an entry from this [`LinkedHashMap`], if present.
450    pub fn remove_entry<Q>(&mut self, key: &Q) -> Option<(K, V)>
451    where
452        K: fmt::Debug,
453        Arc<K>: Borrow<Q>,
454        Q: Hash + Eq + ?Sized,
455    {
456        let (key, item) = self.list.remove_entry(key)?;
457        let key = Arc::try_unwrap(key).expect("key");
458        Some((key, self.remove_inner(item)))
459    }
460
461    /// Swap the position of two keys in this [`LinkedHashMap`].
462    /// Returns `true` if both keys are present.
463    pub fn swap<Q>(&mut self, l: &Q, r: &Q) -> bool
464    where
465        Arc<K>: Borrow<Q>,
466        Q: Hash + Eq + ?Sized,
467    {
468        if l == r {
469            return self.contains_key(l) && self.contains_key(r);
470        }
471
472        let (l_key, l_item) = if let Some(entry) = self.list.get_key_value(l) {
473            entry
474        } else {
475            return false;
476        };
477
478        let (r_key, r_item) = if let Some(entry) = self.list.get_key_value(r) {
479            entry
480        } else {
481            return false;
482        };
483
484        if l_item.state().next.as_ref() == Some(r_key) {
485            let key = r_key.clone();
486            return self.bump(&key);
487        } else if r_item.state().next.as_ref() == Some(l_key) {
488            let key = l_key.clone();
489            return self.bump(&key);
490        } else {
491            let mut l_state = l_item.state_mut();
492            let mut r_state = r_item.state_mut();
493            mem::swap(&mut *l_state, &mut *r_state);
494        }
495
496        if self.head.as_ref() == Some(l_key) {
497            self.head = Some(r_key.clone());
498        } else if self.head.as_ref() == Some(r_key) {
499            self.head = Some(l_key.clone());
500        }
501
502        if self.tail.as_ref() == Some(l_key) {
503            self.tail = Some(r_key.clone());
504        } else if self.tail.as_ref() == Some(r_key) {
505            self.tail = Some(l_key.clone());
506        }
507
508        true
509    }
510
511    /// Construct an iterator over the values in this [`LinkedHashMap`].
512    pub fn values(&self) -> Values<'_, K, V> {
513        Values { inner: self.iter() }
514    }
515}
516
517impl<K: Eq + Hash, V> Default for LinkedHashMap<K, V> {
518    fn default() -> Self {
519        Self::new()
520    }
521}
522
523impl<K: Eq + Hash, V> FromIterator<(K, V)> for LinkedHashMap<K, V> {
524    fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
525        let iter = iter.into_iter();
526        let mut map = match iter.size_hint() {
527            (_, Some(max)) => Self::with_capacity(max),
528            (min, None) if min > 0 => Self::with_capacity(min),
529            _ => Self::new(),
530        };
531
532        map.extend(iter);
533        map
534    }
535}
536
537impl<K: Eq + Hash + fmt::Debug, V> IntoIterator for LinkedHashMap<K, V> {
538    type Item = (K, V);
539    type IntoIter = IntoIter<K, V>;
540
541    fn into_iter(self) -> Self::IntoIter {
542        IntoIter { queue: self }
543    }
544}
545
546impl<'a, K: Eq + Hash, V> IntoIterator for &'a LinkedHashMap<K, V> {
547    type Item = (&'a K, &'a V);
548    type IntoIter = Iter<'a, K, V>;
549
550    fn into_iter(self) -> Self::IntoIter {
551        self.iter()
552    }
553}
554
555impl<K: Eq + Hash + fmt::Debug, V: fmt::Debug> fmt::Debug for LinkedHashMap<K, V> {
556    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
557        f.write_str("{")?;
558
559        for (key, value) in self.iter() {
560            write!(f, "{:?}: {:?}, ", key, value)?;
561        }
562
563        f.write_str("}")
564    }
565}
566
567#[allow(dead_code)]
568fn validate<K: Eq + Hash + fmt::Debug, V>(queue: &LinkedHashMap<K, V>) {
569    if queue.list.is_empty() {
570        assert!(queue.head.is_none(), "head is {:?}", queue.head);
571        assert!(queue.tail.is_none(), "tail is {:?}", queue.tail);
572    } else {
573        let first_key = queue.head.as_ref().expect("first key");
574        let first = queue.list.get::<K>(first_key).expect("first item");
575        assert_eq!(first.state().prev, None);
576
577        let last_key = queue.tail.as_ref().expect("last key");
578        let last = queue.list.get::<K>(last_key).expect("last item");
579        assert_eq!(last.state().next, None);
580    }
581
582    let mut size = 0;
583    let mut last = None;
584    let mut next = queue.head.clone();
585    while let Some(key) = next {
586        let item = queue.list.get::<K>(&key).expect("item");
587
588        let item_state = item.state.borrow();
589        assert_ne!(item_state.prev.as_ref(), Some(&key));
590        assert_ne!(item_state.next.as_ref(), Some(&key));
591
592        let prev_key = item_state.prev.as_ref();
593        assert_eq!(last.as_ref(), prev_key);
594
595        last = Some(key);
596        next = item.state.borrow().next.clone();
597        size += 1;
598    }
599
600    assert_eq!(size, queue.len());
601}
602
603#[allow(dead_code)]
604fn print_debug<K: fmt::Debug + Eq + Hash, V>(queue: &LinkedHashMap<K, V>) {
605    let mut next = queue.head.clone();
606
607    if next.is_some() {
608        println!();
609    }
610
611    while let Some(next_key) = next {
612        let item = queue.list.get::<K>(&next_key).expect("item").state();
613
614        if let Some(prev_key) = item.prev.as_ref() {
615            print!("{:?}-", prev_key);
616        }
617
618        print!("{:?}", next_key);
619
620        next = item.next.clone();
621        if let Some(next_key) = &next {
622            print!("-{:?}", next_key);
623        }
624
625        println!();
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use rand::rngs::StdRng;
632    use rand::{Rng, SeedableRng};
633
634    use super::*;
635
636    #[test]
637    fn test_order() {
638        let mut queue = LinkedHashMap::new();
639        let expected: Vec<i32> = (0..10).collect();
640
641        for i in expected.iter() {
642            queue.insert(*i, i.to_string());
643            validate(&queue);
644        }
645
646        assert_eq!(queue.len(), expected.len());
647
648        let mut actual = Vec::with_capacity(expected.len());
649        for (i, s) in queue.iter() {
650            assert_eq!(&i.to_string(), s);
651            actual.push(i);
652        }
653
654        assert_eq!(actual.len(), expected.len());
655        assert!(actual
656            .iter()
657            .zip(expected.into_iter().rev())
658            .all(|(l, r)| **l == r))
659    }
660
661    #[test]
662    fn test_access() {
663        let mut queue = LinkedHashMap::new();
664        validate(&queue);
665
666        let mut rng = rand::rng();
667        for _ in 1..100_000 {
668            let i: i32 = rng.random_range(0..1000);
669            queue.insert(i, i.to_string());
670            validate(&queue);
671
672            let mut size = 0;
673            for _ in queue.iter() {
674                size += 1;
675            }
676
677            assert_eq!(queue.len(), size);
678            assert!(!queue.is_empty());
679
680            while !queue.is_empty() {
681                let i: i32 = rng.random_range(0..queue.len() as i32);
682                queue.bump(&i);
683                validate(&queue);
684
685                if queue.pop_first().is_some() {
686                    validate(&queue);
687                    size -= 1;
688                }
689
690                if !queue.is_empty() {
691                    let i: i32 = rng.random_range(0..**queue.tail.as_ref().expect("tail"));
692                    queue.bump(&i);
693                    validate(&queue);
694                }
695
696                if queue.pop_last().is_some() {
697                    validate(&queue);
698                    size -= 1;
699                }
700
701                assert_eq!(queue.len(), size);
702            }
703
704            assert_eq!(queue.len(), 0);
705        }
706    }
707
708    #[test]
709    fn test_random_ops_invariants() {
710        let mut rng = StdRng::seed_from_u64(0x_51a3_2b7d);
711        let mut queue = LinkedHashMap::new();
712        let mut live = std::collections::HashSet::new();
713
714        for _ in 0..10_000 {
715            let key: i32 = rng.random_range(0..500);
716            let action: u8 = rng.random_range(0..4);
717
718            match action {
719                0 => {
720                    queue.insert(key, key.to_string());
721                    live.insert(key);
722                }
723                1 => {
724                    queue.remove(&key);
725                    live.remove(&key);
726                }
727                2 => {
728                    queue.bump(&key);
729                }
730                _ => {
731                    if rng.random() {
732                        if let Some(value) = queue.pop_first() {
733                            let key: i32 = value.parse().expect("key");
734                            live.remove(&key);
735                        }
736                    } else if let Some(value) = queue.pop_last() {
737                        let key: i32 = value.parse().expect("key");
738                        live.remove(&key);
739                    }
740                }
741            }
742
743            validate(&queue);
744            assert_eq!(queue.len(), live.len());
745            for key in &live {
746                assert!(queue.get(key).is_some());
747            }
748        }
749    }
750
751    #[test]
752    fn test_bump_head_tail_middle() {
753        let mut queue = LinkedHashMap::new();
754        queue.insert(1, "one".to_string());
755        queue.insert(2, "two".to_string());
756        queue.insert(3, "three".to_string());
757
758        // bump head: no change
759        assert!(queue.bump(&3));
760        assert_eq!(queue.iter().map(|(k, _)| *k).collect::<Vec<_>>(), vec![3, 2, 1]);
761
762        // bump tail: move up one position
763        assert!(queue.bump(&1));
764        assert_eq!(queue.iter().map(|(k, _)| *k).collect::<Vec<_>>(), vec![3, 1, 2]);
765
766        // bump middle: move up one position
767        assert!(queue.bump(&1));
768        assert_eq!(queue.iter().map(|(k, _)| *k).collect::<Vec<_>>(), vec![1, 3, 2]);
769    }
770}