1use crate::ttl::{ExpirationMap, Time};
2use crate::utils::{change_lifetime_const, SharedValue, ValueRef, ValueRefMut};
3use crate::{CacheError, DefaultUpdateValidator, Item as CrateItem, UpdateValidator};
4use parking_lot::RwLock;
5use std::collections::hash_map::RandomState;
6use std::collections::HashMap;
7use std::fmt::{Debug, Formatter};
8use std::hash::BuildHasher;
9use std::mem;
10use std::sync::Arc;
11
12const NUM_OF_SHARDS: usize = 256;
13
14pub(crate) struct StoreItem<V> {
15 pub(crate) key: u64,
16 pub(crate) conflict: u64,
17 pub(crate) value: SharedValue<V>,
18 pub(crate) expiration: Time,
19}
20
21impl<V> Debug for StoreItem<V> {
22 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("StoreItem")
24 .field("key", &self.key)
25 .field("conflict", &self.conflict)
26 .field("expiration", &self.expiration)
27 .finish()
28 }
29}
30
31type Shards<V, SS> = Box<[RwLock<HashMap<u64, StoreItem<V>, SS>>; NUM_OF_SHARDS]>;
32
33pub(crate) struct ShardedMap<V, U = DefaultUpdateValidator<V>, SS = RandomState, ES = RandomState> {
34 shards: Shards<V, SS>,
35 em: ExpirationMap<ES>,
36 store_item_size: usize,
37 validator: U,
38}
39
40impl<V: Send + Sync + 'static> ShardedMap<V> {
41 #[allow(dead_code)]
42 pub fn new() -> Self {
43 Self::with_validator(ExpirationMap::new(), DefaultUpdateValidator::default())
44 }
45}
46
47impl<V: Send + Sync + 'static, U: UpdateValidator<Value = V>> ShardedMap<V, U> {
48 #[allow(dead_code)]
49 pub fn with_validator(em: ExpirationMap<RandomState>, validator: U) -> Self {
50 let shards = Box::new(
51 (0..NUM_OF_SHARDS)
52 .map(|_| RwLock::new(HashMap::new()))
53 .collect::<Vec<_>>()
54 .try_into()
55 .unwrap(),
56 );
57
58 let size = mem::size_of::<StoreItem<V>>();
59 Self {
60 shards,
61 em,
62 store_item_size: size,
63 validator,
64 }
65 }
66}
67
68impl<
69 V: Send + Sync + 'static,
70 U: UpdateValidator<Value = V>,
71 SS: BuildHasher + Clone + 'static,
72 ES: BuildHasher + Clone + 'static,
73 > ShardedMap<V, U, SS, ES>
74{
75 pub fn with_validator_and_hasher(em: ExpirationMap<ES>, validator: U, hasher: SS) -> Self {
76 let shards = Box::new(
77 (0..NUM_OF_SHARDS)
78 .map(|_| RwLock::new(HashMap::with_hasher(hasher.clone())))
79 .collect::<Vec<_>>()
80 .try_into()
81 .unwrap(),
82 );
83
84 let size = mem::size_of::<StoreItem<V>>();
85 Self {
86 shards,
87 em,
88 store_item_size: size,
89 validator,
90 }
91 }
92
93 pub fn get(&self, key: &u64, conflict: u64) -> Option<ValueRef<'_, V, SS>> {
94 let data = self.shards[(*key as usize) % NUM_OF_SHARDS].read();
95
96 if let Some(item) = data.get(key) {
97 if conflict != 0 && (conflict != item.conflict) {
98 return None;
99 }
100
101 if !item.expiration.is_zero() && item.expiration.is_expired() {
103 return None;
104 }
105
106 unsafe {
107 let vptr = change_lifetime_const(item);
108 Some(ValueRef::new(data, vptr))
109 }
110 } else {
111 None
112 }
113 }
114
115 pub fn get_mut(&self, key: &u64, conflict: u64) -> Option<ValueRefMut<'_, V, SS>> {
116 let data = self.shards[(*key as usize) % NUM_OF_SHARDS].write();
117
118 if let Some(item) = data.get(key) {
119 if conflict != 0 && (conflict != item.conflict) {
120 return None;
121 }
122
123 if !item.expiration.is_zero() && item.expiration.is_expired() {
125 return None;
126 }
127
128 unsafe {
129 let vptr = &mut *item.value.as_ptr();
130 Some(ValueRefMut::new(data, vptr))
131 }
132 } else {
133 None
134 }
135 }
136
137 pub fn try_insert(
138 &self,
139 key: u64,
140 val: V,
141 conflict: u64,
142 expiration: Time,
143 ) -> Result<(), CacheError> {
144 let mut data = self.shards[(key as usize) % NUM_OF_SHARDS].write();
145
146 match data.get(&key) {
147 None => {
148 self.em.try_insert(key, conflict, expiration)?;
151 }
152 Some(sitem) => {
153 if conflict != 0 && (conflict != sitem.conflict) {
156 return Ok(());
157 }
158
159 if !self.validator.should_update(sitem.value.get(), &val) {
160 return Ok(());
161 }
162
163 self.em
164 .try_update(key, conflict, sitem.expiration, expiration)?;
165 }
166 }
167
168 data.insert(
169 key,
170 StoreItem {
171 key,
172 conflict,
173 value: SharedValue::new(val),
174 expiration,
175 },
176 );
177
178 Ok(())
179 }
180
181 pub fn try_update(
182 &self,
183 key: u64,
184 mut val: V,
185 conflict: u64,
186 expiration: Time,
187 ) -> Result<UpdateResult<V>, CacheError> {
188 let mut data = self.shards[(key as usize) % NUM_OF_SHARDS].write();
189 match data.get_mut(&key) {
190 None => Ok(UpdateResult::NotExist(val)),
191 Some(item) => {
192 if conflict != 0 && (conflict != item.conflict) {
193 return Ok(UpdateResult::Conflict(val));
194 }
195
196 if !self.validator.should_update(item.value.get(), &val) {
197 return Ok(UpdateResult::Reject(val));
198 }
199
200 self.em
201 .try_update(key, conflict, item.expiration, expiration)?;
202 mem::swap(&mut val, item.value.get_mut());
203 item.expiration = expiration;
204 Ok(UpdateResult::Update(val))
205 }
206 }
207 }
208
209 pub fn len(&self) -> usize {
210 self.shards.iter().map(|l| l.read().len()).sum()
211 }
212
213 pub fn try_remove(&self, key: &u64, conflict: u64) -> Result<Option<StoreItem<V>>, CacheError> {
214 let mut data = self.shards[(*key as usize) % NUM_OF_SHARDS].write();
215
216 match data.get(key) {
217 None => Ok(None),
218 Some(item) => {
219 if conflict != 0 && (conflict != item.conflict) {
220 return Ok(None);
221 }
222
223 if !item.expiration.is_zero() {
224 self.em.try_remove(key, item.expiration)?;
225 }
226
227 Ok(data.remove(key))
228 }
229 }
230 }
231
232 pub fn expiration(&self, key: &u64) -> Option<Time> {
233 self.shards[((*key) as usize) % NUM_OF_SHARDS]
234 .read()
235 .get(key)
236 .map(|val| val.expiration)
237 }
238
239 #[cfg(feature = "sync")]
240 pub fn try_cleanup<PS: BuildHasher + Clone + 'static>(
241 &self,
242 policy: Arc<crate::policy::LFUPolicy<PS>>,
243 ) -> Result<Vec<CrateItem<V>>, CacheError> {
244 let now = Time::now();
245 Ok(self
246 .em
247 .try_cleanup(now)?
248 .map_or(Vec::with_capacity(0), |m| {
249 m.iter()
250 .filter_map(|(k, v)| {
252 self.expiration(k)
253 .and_then(|t| {
254 if t.is_expired() {
255 let cost = policy.cost(k);
256 policy.remove(k);
257 self.try_remove(k, *v)
258 .map(|maybe_sitem| {
259 maybe_sitem.map(|sitem| CrateItem {
260 val: Some(sitem.value.into_inner()),
261 index: sitem.key,
262 conflict: sitem.conflict,
263 cost,
264 exp: t,
265 })
266 })
267 .ok()
268 } else {
269 None
270 }
271 })
272 .flatten()
273 })
274 .collect()
275 }))
276 }
277
278 #[cfg(feature = "async")]
279 pub fn try_cleanup_async<PS: BuildHasher + Clone + 'static>(
280 &self,
281 policy: Arc<crate::policy::AsyncLFUPolicy<PS>>,
282 ) -> Result<Vec<CrateItem<V>>, CacheError> {
283 let now = Time::now();
284 let items = self.em.try_cleanup(now)?;
285
286 let mut removed_items = Vec::new();
287 if let Some(items) = items {
288 for (k, v) in items.iter() {
289 let expiration = self.expiration(k);
290 if let Some(t) = expiration {
291 if t.is_expired() {
292 let cost = policy.cost(k);
293 policy.remove(k);
294 let removed_item = self.try_remove(k, *v)?;
295 if let Some(sitem) = removed_item {
296 removed_items.push(CrateItem {
297 val: Some(sitem.value.into_inner()),
298 index: sitem.key,
299 conflict: sitem.conflict,
300 cost,
301 exp: t,
302 })
303 }
304 }
305 }
306 }
307 }
308
309 Ok(removed_items)
310 }
311
312 pub fn clear(&self) {
313 self.shards.iter().for_each(|shard| shard.write().clear());
315 }
316
317 pub fn hasher(&self) -> ES {
318 self.em.hasher()
319 }
320
321 pub fn item_size(&self) -> usize {
322 self.store_item_size
323 }
324}
325
326unsafe impl<V: Send + Sync + 'static, U: UpdateValidator<Value = V>, SS: BuildHasher, ES: BuildHasher>
327 Send for ShardedMap<V, U, SS, ES>
328{
329}
330unsafe impl<V: Send + Sync + 'static, U: UpdateValidator<Value = V>, SS: BuildHasher, ES: BuildHasher>
331 Sync for ShardedMap<V, U, SS, ES>
332{
333}
334
335pub(crate) enum UpdateResult<V: Send + Sync + 'static> {
336 NotExist(V),
337 Reject(V),
338 Conflict(V),
339 Update(V),
340}
341
342#[cfg(test)]
343impl<V: Send + Sync + 'static> UpdateResult<V> {
344 fn into_inner(self) -> V {
345 match self {
346 UpdateResult::NotExist(v) => v,
347 UpdateResult::Reject(v) => v,
348 UpdateResult::Conflict(v) => v,
349 UpdateResult::Update(v) => v,
350 }
351 }
352}
353
354#[cfg(test)]
355mod test {
356 use crate::store::{ShardedMap, StoreItem};
357 use crate::ttl::Time;
358 use crate::utils::SharedValue;
359 use std::sync::Arc;
360 use std::time::Duration;
361
362 #[test]
363 fn test_store_item_debug() {
364 let item = StoreItem {
365 key: 0,
366 conflict: 0,
367 value: SharedValue::new(3),
368 expiration: Time::now(),
369 };
370
371 eprintln!("{:?}", item);
372 }
373
374 #[test]
375 fn test_store() {
376 let _s: ShardedMap<u64> = ShardedMap::new();
377 }
378
379 #[test]
380 fn test_store_set_get() {
381 let s: ShardedMap<u64> = ShardedMap::new();
382
383 s.try_insert(1, 2, 0, Time::now()).unwrap();
384 let val = s.get(&1, 0).unwrap();
385 assert_eq!(&2, val.value());
386 val.release();
387
388 let mut val = s.get_mut(&1, 0).unwrap();
389 *val.value_mut() = 3;
390 val.release();
391
392 let v = s.get(&1, 0).unwrap();
393 assert_eq!(&3, v.value());
394 }
395
396 #[test]
397 fn test_concurrent_get_insert() {
398 let s = Arc::new(ShardedMap::new());
399 let s1 = s.clone();
400
401 std::thread::spawn(move || {
402 s.try_insert(1, 2, 0, Time::now()).unwrap();
403 });
404
405 loop {
406 match s1.get(&1, 0) {
407 None => continue,
408 Some(val) => {
409 assert_eq!(val.read(), 2);
410 break;
411 }
412 }
413 }
414 }
415
416 #[test]
417 fn test_concurrent_get_mut_insert() {
418 let s = Arc::new(ShardedMap::new());
419 let s1 = s.clone();
420
421 std::thread::spawn(move || {
422 s.try_insert(1, 2, 0, Time::now()).unwrap();
423 loop {
424 match s.get(&1, 0) {
425 None => continue,
426 Some(val) => {
427 let val = val.read();
428 if val == 2 {
429 continue;
430 } else if val == 7 {
431 break;
432 } else {
433 panic!("get wrong value")
434 }
435 }
436 }
437 }
438 });
439
440 loop {
441 match s1.get(&1, 0) {
442 None => continue,
443 Some(val) => {
444 assert_eq!(val.read(), 2);
445 break;
446 }
447 }
448 }
449
450 s1.get_mut(&1, 0).unwrap().write(7);
451 }
452
453 #[test]
454 fn test_store_remove() {
455 let s: ShardedMap<u64> = ShardedMap::new();
456
457 s.try_insert(1, 2, 0, Time::now()).unwrap();
458 assert_eq!(s.try_remove(&1, 0).unwrap().unwrap().value.into_inner(), 2);
459 let v = s.get(&1, 0);
460 assert!(v.is_none());
461 assert!(s.try_remove(&2, 0).unwrap().is_none());
462 }
463
464 #[test]
465 fn test_store_update() {
466 let s = ShardedMap::new();
467 s.try_insert(1, 1, 0, Time::now()).unwrap();
468 let v = s.try_update(1, 2, 0, Time::now()).unwrap();
469 assert_eq!(v.into_inner(), 1);
470
471 assert_eq!(s.get(&1, 0).unwrap().read(), 2);
472
473 let v = s.try_update(1, 3, 0, Time::now()).unwrap();
474 assert_eq!(v.into_inner(), 2);
475
476 assert_eq!(s.get(&1, 0).unwrap().read(), 3);
477
478 let v = s.try_update(2, 2, 0, Time::now()).unwrap();
479 assert_eq!(v.into_inner(), 2);
480 let v = s.get(&2, 0);
481 assert!(v.is_none());
482 }
483
484 #[test]
485 fn test_store_expiration() {
486 let exp = Time::now_with_expiration(Duration::from_secs(1));
487 let s = ShardedMap::new();
488 s.try_insert(1, 1, 0, exp).unwrap();
489
490 assert_eq!(s.get(&1, 0).unwrap().read(), 1);
491
492 let ttl = s.expiration(&1);
493 assert_eq!(exp, ttl.unwrap());
494
495 s.try_remove(&1, 0).unwrap();
496 assert!(s.get(&1, 0).is_none());
497 let ttl = s.expiration(&1);
498 assert!(ttl.is_none());
499
500 assert!(s.expiration(&4340958203495).is_none());
501 }
502
503 #[test]
504 fn test_store_collision() {
505 let s = ShardedMap::new();
506 let mut data1 = s.shards[1].write();
507 data1.insert(
508 1,
509 StoreItem {
510 key: 1,
511 conflict: 0,
512 value: SharedValue::new(1),
513 expiration: Time::now(),
514 },
515 );
516 drop(data1);
517 assert!(s.get(&1, 1).is_none());
518
519 s.try_insert(1, 2, 1, Time::now()).unwrap();
520 assert_ne!(s.get(&1, 0).unwrap().read(), 2);
521
522 let v = s.try_update(1, 2, 1, Time::now()).unwrap();
523 assert_eq!(v.into_inner(), 2);
524 assert_ne!(s.get(&1, 0).unwrap().read(), 2);
525
526 assert!(s.try_remove(&1, 1).unwrap().is_none());
527 assert_eq!(s.get(&1, 0).unwrap().read(), 1);
528 }
529}