1use std::{
2 collections::{hash_map::DefaultHasher, HashMap, VecDeque},
3 hash::{self, Hasher},
4 sync::{Arc, RwLock},
5};
6
7pub struct CncrKLtdRing<K, V>
8where
9 K: Eq + hash::Hash,
10{
11 shards: Arc<Vec<Shard<K, V>>>,
12}
13
14impl<K, V> CncrKLtdRing<K, V>
15where
16 K: Eq + hash::Hash,
17{
18 #[must_use]
19 pub fn new(shard_count: usize, ring_size: usize) -> Self {
20 let shards = (0..shard_count)
21 .map(|_| Shard::new(ring_size))
22 .collect::<Vec<_>>();
23 Self {
24 shards: Arc::new(shards),
25 }
26 }
27
28 #[must_use]
29 pub fn take_ring(&self, key: &K) -> Option<VecDeque<V>> {
30 let index = self.shard_index(key);
31 self.shards[index].take_ring(key)
32 }
33
34 pub fn push(&self, key: K, value: V) {
35 let index = self.shard_index(&key);
36 self.shards[index].push(key, value);
37 }
38
39 fn shard_index(&self, key: &K) -> usize {
40 let mut hasher = DefaultHasher::new();
41 key.hash(&mut hasher);
42 hasher.finish() as usize % self.shards.len()
43 }
44}
45
46impl<K, V> CncrKLtdRing<K, V>
47where
48 K: Eq + hash::Hash,
49 V: Clone,
50{
51 #[must_use]
52 pub fn clone_batch_last(&self, key: &K, count: usize) -> Option<Vec<V>> {
53 let index = self.shard_index(key);
54 self.shards[index].clone_batch_last(key, count)
55 }
56}
57
58struct Shard<K, V>
59where
60 K: Eq + hash::Hash,
61{
62 map: RwLock<HashMap<K, VecDeque<V>>>,
63 ring_size: usize,
64}
65
66impl<K, V> Shard<K, V>
67where
68 K: Eq + hash::Hash,
69{
70 #[must_use]
71 fn new(ring_size: usize) -> Self {
72 Self {
73 map: RwLock::new(HashMap::new()),
74 ring_size,
75 }
76 }
77
78 #[must_use]
79 fn take_ring(&self, key: &K) -> Option<VecDeque<V>> {
80 let mut map = self.map.write().unwrap();
81 let value = map
83 .get_mut(key)
84 .map(|v| std::mem::replace(v, VecDeque::new()));
85 value
86 }
87
88 fn push(&self, key: K, value: V) {
89 let mut map = self.map.write().unwrap();
90 let ring = map
91 .entry(key)
92 .or_insert_with(|| VecDeque::with_capacity(self.ring_size));
93 if ring.len() == self.ring_size {
94 ring.pop_front();
95 }
96 ring.push_back(value);
97 }
98}
99
100impl<K, V> Shard<K, V>
101where
102 K: Eq + hash::Hash,
103 V: Clone,
104{
105 #[must_use]
106 fn clone_batch_last(&self, key: &K, count: usize) -> Option<Vec<V>> {
107 let map = self.map.read().unwrap();
108 let value = map
109 .get(key)
110 .map(|v| v.iter().rev().take(count).cloned().collect::<Vec<_>>());
111 value
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test() {
121 let ring_map = CncrKLtdRing::new(4, 3);
122
123 let ring = ring_map.take_ring(&"a");
124 assert!(ring.is_none());
125
126 let ring = ring_map.take_ring(&"a");
127 assert!(ring.is_none());
128
129 ring_map.push("a", 1);
130 ring_map.push("a", 2);
131 ring_map.push("a", 3);
132 ring_map.push("a", 4);
133 let clones = ring_map.clone_batch_last(&"a", 2);
134 assert_eq!(clones, Some(vec![4, 3]));
135 let clones = ring_map.clone_batch_last(&"a", 4);
136 assert_eq!(clones, Some(vec![4, 3, 2]));
137 let ring = ring_map.take_ring(&"a");
138 let ring = ring.unwrap();
139 assert_eq!(ring.len(), 3);
140 assert_eq!(ring[0], 2);
141 assert_eq!(ring[1], 3);
142 assert_eq!(ring[2], 4);
143
144 let ring = ring_map.take_ring(&"a");
145 let ring = ring.unwrap();
146 assert_eq!(ring.len(), 0);
147 }
148
149 #[test]
150 fn test_concurrency_same_key() {
151 let ring_map = CncrKLtdRing::new(4, 3);
152 let ring_map = Arc::new(ring_map);
153 let mut threads = Vec::new();
154 for i in 0..100 {
155 let ring_map = ring_map.clone();
156 let thread = std::thread::spawn(move || {
157 ring_map.push("a", i);
158 });
159 threads.push(thread);
160 }
161 for thread in threads {
162 thread.join().unwrap();
163 }
164 let ring = ring_map.take_ring(&"a");
165 let ring = ring.unwrap();
166 assert_eq!(ring.len(), 3);
167 }
168
169 #[test]
170 fn test_concurrency_different_keys() {
171 let ring_map = CncrKLtdRing::new(4, 3);
172 let ring_map = Arc::new(ring_map);
173 let mut threads = Vec::new();
174 for key in 0..100 {
175 let ring_map = Arc::clone(&ring_map);
176 let thread = std::thread::spawn(move || {
177 for value in 0..100 {
178 ring_map.push(key, value);
179 }
180 });
181 threads.push(thread);
182 }
183 for thread in threads {
184 thread.join().unwrap();
185 }
186 for i in 0..100 {
187 let ring = ring_map.take_ring(&i);
188 let ring = ring.unwrap();
189 assert_eq!(ring.len(), 3);
190 assert_eq!(ring[0], 97);
191 assert_eq!(ring[1], 98);
192 assert_eq!(ring[2], 99);
193 }
194 }
195}