stretto/
store.rs

1use crate::ttl::{ExpirationMap, Time};
2use crate::utils::{change_lifetime_const, SharedValue, ValueRef, ValueRefMut};
3use crate::{CacheError, DefaultUpdateValidator, Item as CrateItem, UpdateValidator};
4use parking_lot::RwLock;
5use std::collections::hash_map::RandomState;
6use std::collections::HashMap;
7use std::fmt::{Debug, Formatter};
8use std::hash::BuildHasher;
9use std::mem;
10use std::sync::Arc;
11
12const NUM_OF_SHARDS: usize = 256;
13
14pub(crate) struct StoreItem<V> {
15    pub(crate) key: u64,
16    pub(crate) conflict: u64,
17    pub(crate) value: SharedValue<V>,
18    pub(crate) expiration: Time,
19}
20
21impl<V> Debug for StoreItem<V> {
22    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("StoreItem")
24            .field("key", &self.key)
25            .field("conflict", &self.conflict)
26            .field("expiration", &self.expiration)
27            .finish()
28    }
29}
30
31type Shards<V, SS> = Box<[RwLock<HashMap<u64, StoreItem<V>, SS>>; NUM_OF_SHARDS]>;
32
33pub(crate) struct ShardedMap<V, U = DefaultUpdateValidator<V>, SS = RandomState, ES = RandomState> {
34    shards: Shards<V, SS>,
35    em: ExpirationMap<ES>,
36    store_item_size: usize,
37    validator: U,
38}
39
40impl<V: Send + Sync + 'static> ShardedMap<V> {
41    #[allow(dead_code)]
42    pub fn new() -> Self {
43        Self::with_validator(ExpirationMap::new(), DefaultUpdateValidator::default())
44    }
45}
46
47impl<V: Send + Sync + 'static, U: UpdateValidator<Value = V>> ShardedMap<V, U> {
48    #[allow(dead_code)]
49    pub fn with_validator(em: ExpirationMap<RandomState>, validator: U) -> Self {
50        let shards = Box::new(
51            (0..NUM_OF_SHARDS)
52                .map(|_| RwLock::new(HashMap::new()))
53                .collect::<Vec<_>>()
54                .try_into()
55                .unwrap(),
56        );
57
58        let size = mem::size_of::<StoreItem<V>>();
59        Self {
60            shards,
61            em,
62            store_item_size: size,
63            validator,
64        }
65    }
66}
67
68impl<
69        V: Send + Sync + 'static,
70        U: UpdateValidator<Value = V>,
71        SS: BuildHasher + Clone + 'static,
72        ES: BuildHasher + Clone + 'static,
73    > ShardedMap<V, U, SS, ES>
74{
75    pub fn with_validator_and_hasher(em: ExpirationMap<ES>, validator: U, hasher: SS) -> Self {
76        let shards = Box::new(
77            (0..NUM_OF_SHARDS)
78                .map(|_| RwLock::new(HashMap::with_hasher(hasher.clone())))
79                .collect::<Vec<_>>()
80                .try_into()
81                .unwrap(),
82        );
83
84        let size = mem::size_of::<StoreItem<V>>();
85        Self {
86            shards,
87            em,
88            store_item_size: size,
89            validator,
90        }
91    }
92
93    pub fn get(&self, key: &u64, conflict: u64) -> Option<ValueRef<'_, V, SS>> {
94        let data = self.shards[(*key as usize) % NUM_OF_SHARDS].read();
95
96        if let Some(item) = data.get(key) {
97            if conflict != 0 && (conflict != item.conflict) {
98                return None;
99            }
100
101            // Handle expired items
102            if !item.expiration.is_zero() && item.expiration.is_expired() {
103                return None;
104            }
105
106            unsafe {
107                let vptr = change_lifetime_const(item);
108                Some(ValueRef::new(data, vptr))
109            }
110        } else {
111            None
112        }
113    }
114
115    pub fn get_mut(&self, key: &u64, conflict: u64) -> Option<ValueRefMut<'_, V, SS>> {
116        let data = self.shards[(*key as usize) % NUM_OF_SHARDS].write();
117
118        if let Some(item) = data.get(key) {
119            if conflict != 0 && (conflict != item.conflict) {
120                return None;
121            }
122
123            // Handle expired items
124            if !item.expiration.is_zero() && item.expiration.is_expired() {
125                return None;
126            }
127
128            unsafe {
129                let vptr = &mut *item.value.as_ptr();
130                Some(ValueRefMut::new(data, vptr))
131            }
132        } else {
133            None
134        }
135    }
136
137    pub fn try_insert(
138        &self,
139        key: u64,
140        val: V,
141        conflict: u64,
142        expiration: Time,
143    ) -> Result<(), CacheError> {
144        let mut data = self.shards[(key as usize) % NUM_OF_SHARDS].write();
145
146        match data.get(&key) {
147            None => {
148                // The value is not in the map already. There's no need to return anything.
149                // Simply add the expiration map.
150                self.em.try_insert(key, conflict, expiration)?;
151            }
152            Some(sitem) => {
153                // The item existed already. We need to check the conflict key and reject the
154                // update if they do not match. Only after that the expiration map is updated.
155                if conflict != 0 && (conflict != sitem.conflict) {
156                    return Ok(());
157                }
158
159                if !self.validator.should_update(sitem.value.get(), &val) {
160                    return Ok(());
161                }
162
163                self.em
164                    .try_update(key, conflict, sitem.expiration, expiration)?;
165            }
166        }
167
168        data.insert(
169            key,
170            StoreItem {
171                key,
172                conflict,
173                value: SharedValue::new(val),
174                expiration,
175            },
176        );
177
178        Ok(())
179    }
180
181    pub fn try_update(
182        &self,
183        key: u64,
184        mut val: V,
185        conflict: u64,
186        expiration: Time,
187    ) -> Result<UpdateResult<V>, CacheError> {
188        let mut data = self.shards[(key as usize) % NUM_OF_SHARDS].write();
189        match data.get_mut(&key) {
190            None => Ok(UpdateResult::NotExist(val)),
191            Some(item) => {
192                if conflict != 0 && (conflict != item.conflict) {
193                    return Ok(UpdateResult::Conflict(val));
194                }
195
196                if !self.validator.should_update(item.value.get(), &val) {
197                    return Ok(UpdateResult::Reject(val));
198                }
199
200                self.em
201                    .try_update(key, conflict, item.expiration, expiration)?;
202                mem::swap(&mut val, item.value.get_mut());
203                item.expiration = expiration;
204                Ok(UpdateResult::Update(val))
205            }
206        }
207    }
208
209    pub fn len(&self) -> usize {
210        self.shards.iter().map(|l| l.read().len()).sum()
211    }
212
213    pub fn try_remove(&self, key: &u64, conflict: u64) -> Result<Option<StoreItem<V>>, CacheError> {
214        let mut data = self.shards[(*key as usize) % NUM_OF_SHARDS].write();
215
216        match data.get(key) {
217            None => Ok(None),
218            Some(item) => {
219                if conflict != 0 && (conflict != item.conflict) {
220                    return Ok(None);
221                }
222
223                if !item.expiration.is_zero() {
224                    self.em.try_remove(key, item.expiration)?;
225                }
226
227                Ok(data.remove(key))
228            }
229        }
230    }
231
232    pub fn expiration(&self, key: &u64) -> Option<Time> {
233        self.shards[((*key) as usize) % NUM_OF_SHARDS]
234            .read()
235            .get(key)
236            .map(|val| val.expiration)
237    }
238
239    #[cfg(feature = "sync")]
240    pub fn try_cleanup<PS: BuildHasher + Clone + 'static>(
241        &self,
242        policy: Arc<crate::policy::LFUPolicy<PS>>,
243    ) -> Result<Vec<CrateItem<V>>, CacheError> {
244        let now = Time::now();
245        Ok(self
246            .em
247            .try_cleanup(now)?
248            .map_or(Vec::with_capacity(0), |m| {
249                m.iter()
250                    // Sanity check. Verify that the store agrees that this key is expired.
251                    .filter_map(|(k, v)| {
252                        self.expiration(k)
253                            .and_then(|t| {
254                                if t.is_expired() {
255                                    let cost = policy.cost(k);
256                                    policy.remove(k);
257                                    self.try_remove(k, *v)
258                                        .map(|maybe_sitem| {
259                                            maybe_sitem.map(|sitem| CrateItem {
260                                                val: Some(sitem.value.into_inner()),
261                                                index: sitem.key,
262                                                conflict: sitem.conflict,
263                                                cost,
264                                                exp: t,
265                                            })
266                                        })
267                                        .ok()
268                                } else {
269                                    None
270                                }
271                            })
272                            .flatten()
273                    })
274                    .collect()
275            }))
276    }
277
278    #[cfg(feature = "async")]
279    pub fn try_cleanup_async<PS: BuildHasher + Clone + 'static>(
280        &self,
281        policy: Arc<crate::policy::AsyncLFUPolicy<PS>>,
282    ) -> Result<Vec<CrateItem<V>>, CacheError> {
283        let now = Time::now();
284        let items = self.em.try_cleanup(now)?;
285
286        let mut removed_items = Vec::new();
287        if let Some(items) = items {
288            for (k, v) in items.iter() {
289                let expiration = self.expiration(k);
290                if let Some(t) = expiration {
291                    if t.is_expired() {
292                        let cost = policy.cost(k);
293                        policy.remove(k);
294                        let removed_item = self.try_remove(k, *v)?;
295                        if let Some(sitem) = removed_item {
296                            removed_items.push(CrateItem {
297                                val: Some(sitem.value.into_inner()),
298                                index: sitem.key,
299                                conflict: sitem.conflict,
300                                cost,
301                                exp: t,
302                            })
303                        }
304                    }
305                }
306            }
307        }
308
309        Ok(removed_items)
310    }
311
312    pub fn clear(&self) {
313        // TODO: item call back
314        self.shards.iter().for_each(|shard| shard.write().clear());
315    }
316
317    pub fn hasher(&self) -> ES {
318        self.em.hasher()
319    }
320
321    pub fn item_size(&self) -> usize {
322        self.store_item_size
323    }
324}
325
326unsafe impl<V: Send + Sync + 'static, U: UpdateValidator<Value = V>, SS: BuildHasher, ES: BuildHasher>
327    Send for ShardedMap<V, U, SS, ES>
328{
329}
330unsafe impl<V: Send + Sync + 'static, U: UpdateValidator<Value = V>, SS: BuildHasher, ES: BuildHasher>
331    Sync for ShardedMap<V, U, SS, ES>
332{
333}
334
335pub(crate) enum UpdateResult<V: Send + Sync + 'static> {
336    NotExist(V),
337    Reject(V),
338    Conflict(V),
339    Update(V),
340}
341
342#[cfg(test)]
343impl<V: Send + Sync + 'static> UpdateResult<V> {
344    fn into_inner(self) -> V {
345        match self {
346            UpdateResult::NotExist(v) => v,
347            UpdateResult::Reject(v) => v,
348            UpdateResult::Conflict(v) => v,
349            UpdateResult::Update(v) => v,
350        }
351    }
352}
353
354#[cfg(test)]
355mod test {
356    use crate::store::{ShardedMap, StoreItem};
357    use crate::ttl::Time;
358    use crate::utils::SharedValue;
359    use std::sync::Arc;
360    use std::time::Duration;
361
362    #[test]
363    fn test_store_item_debug() {
364        let item = StoreItem {
365            key: 0,
366            conflict: 0,
367            value: SharedValue::new(3),
368            expiration: Time::now(),
369        };
370
371        eprintln!("{:?}", item);
372    }
373
374    #[test]
375    fn test_store() {
376        let _s: ShardedMap<u64> = ShardedMap::new();
377    }
378
379    #[test]
380    fn test_store_set_get() {
381        let s: ShardedMap<u64> = ShardedMap::new();
382
383        s.try_insert(1, 2, 0, Time::now()).unwrap();
384        let val = s.get(&1, 0).unwrap();
385        assert_eq!(&2, val.value());
386        val.release();
387
388        let mut val = s.get_mut(&1, 0).unwrap();
389        *val.value_mut() = 3;
390        val.release();
391
392        let v = s.get(&1, 0).unwrap();
393        assert_eq!(&3, v.value());
394    }
395
396    #[test]
397    fn test_concurrent_get_insert() {
398        let s = Arc::new(ShardedMap::new());
399        let s1 = s.clone();
400
401        std::thread::spawn(move || {
402            s.try_insert(1, 2, 0, Time::now()).unwrap();
403        });
404
405        loop {
406            match s1.get(&1, 0) {
407                None => continue,
408                Some(val) => {
409                    assert_eq!(val.read(), 2);
410                    break;
411                }
412            }
413        }
414    }
415
416    #[test]
417    fn test_concurrent_get_mut_insert() {
418        let s = Arc::new(ShardedMap::new());
419        let s1 = s.clone();
420
421        std::thread::spawn(move || {
422            s.try_insert(1, 2, 0, Time::now()).unwrap();
423            loop {
424                match s.get(&1, 0) {
425                    None => continue,
426                    Some(val) => {
427                        let val = val.read();
428                        if val == 2 {
429                            continue;
430                        } else if val == 7 {
431                            break;
432                        } else {
433                            panic!("get wrong value")
434                        }
435                    }
436                }
437            }
438        });
439
440        loop {
441            match s1.get(&1, 0) {
442                None => continue,
443                Some(val) => {
444                    assert_eq!(val.read(), 2);
445                    break;
446                }
447            }
448        }
449
450        s1.get_mut(&1, 0).unwrap().write(7);
451    }
452
453    #[test]
454    fn test_store_remove() {
455        let s: ShardedMap<u64> = ShardedMap::new();
456
457        s.try_insert(1, 2, 0, Time::now()).unwrap();
458        assert_eq!(s.try_remove(&1, 0).unwrap().unwrap().value.into_inner(), 2);
459        let v = s.get(&1, 0);
460        assert!(v.is_none());
461        assert!(s.try_remove(&2, 0).unwrap().is_none());
462    }
463
464    #[test]
465    fn test_store_update() {
466        let s = ShardedMap::new();
467        s.try_insert(1, 1, 0, Time::now()).unwrap();
468        let v = s.try_update(1, 2, 0, Time::now()).unwrap();
469        assert_eq!(v.into_inner(), 1);
470
471        assert_eq!(s.get(&1, 0).unwrap().read(), 2);
472
473        let v = s.try_update(1, 3, 0, Time::now()).unwrap();
474        assert_eq!(v.into_inner(), 2);
475
476        assert_eq!(s.get(&1, 0).unwrap().read(), 3);
477
478        let v = s.try_update(2, 2, 0, Time::now()).unwrap();
479        assert_eq!(v.into_inner(), 2);
480        let v = s.get(&2, 0);
481        assert!(v.is_none());
482    }
483
484    #[test]
485    fn test_store_expiration() {
486        let exp = Time::now_with_expiration(Duration::from_secs(1));
487        let s = ShardedMap::new();
488        s.try_insert(1, 1, 0, exp).unwrap();
489
490        assert_eq!(s.get(&1, 0).unwrap().read(), 1);
491
492        let ttl = s.expiration(&1);
493        assert_eq!(exp, ttl.unwrap());
494
495        s.try_remove(&1, 0).unwrap();
496        assert!(s.get(&1, 0).is_none());
497        let ttl = s.expiration(&1);
498        assert!(ttl.is_none());
499
500        assert!(s.expiration(&4340958203495).is_none());
501    }
502
503    #[test]
504    fn test_store_collision() {
505        let s = ShardedMap::new();
506        let mut data1 = s.shards[1].write();
507        data1.insert(
508            1,
509            StoreItem {
510                key: 1,
511                conflict: 0,
512                value: SharedValue::new(1),
513                expiration: Time::now(),
514            },
515        );
516        drop(data1);
517        assert!(s.get(&1, 1).is_none());
518
519        s.try_insert(1, 2, 1, Time::now()).unwrap();
520        assert_ne!(s.get(&1, 0).unwrap().read(), 2);
521
522        let v = s.try_update(1, 2, 1, Time::now()).unwrap();
523        assert_eq!(v.into_inner(), 2);
524        assert_ne!(s.get(&1, 0).unwrap().read(), 2);
525
526        assert!(s.try_remove(&1, 1).unwrap().is_none());
527        assert_eq!(s.get(&1, 0).unwrap().read(), 1);
528    }
529}