1use core::hash::Hash;
4use std::collections::{HashMap, VecDeque};
5use std::sync::Mutex;
6
7use crate::cache::Cache;
8use crate::error::CacheError;
9use crate::util::MutexExt;
10
11pub struct SizedCache<K, V> {
54 max_weight: usize,
55 weigher: fn(&V) -> usize,
56 inner: Mutex<Inner<K, V>>,
57}
58
59struct Entry<V> {
60 value: V,
61 weight: usize,
62}
63
64struct Inner<K, V> {
65 map: HashMap<K, Entry<V>>,
66 order: VecDeque<K>,
68 total_weight: usize,
69}
70
71impl<K, V> SizedCache<K, V>
72where
73 K: Eq + Hash + Clone,
74 V: Clone,
75{
76 pub fn new(max_weight: usize, weigher: fn(&V) -> usize) -> Result<Self, CacheError> {
91 if max_weight == 0 {
92 return Err(CacheError::InvalidCapacity);
93 }
94 Ok(Self {
95 max_weight,
96 weigher,
97 inner: Mutex::new(Inner {
98 map: HashMap::new(),
99 order: VecDeque::new(),
100 total_weight: 0,
101 }),
102 })
103 }
104
105 pub fn max_weight(&self) -> usize {
107 self.max_weight
108 }
109
110 pub fn total_weight(&self) -> usize {
112 self.inner.lock_recover().total_weight
113 }
114}
115
116impl<K, V> Cache<K, V> for SizedCache<K, V>
117where
118 K: Eq + Hash + Clone,
119 V: Clone,
120{
121 fn get(&self, key: &K) -> Option<V> {
122 let mut inner = self.inner.lock_recover();
123 let value = inner.map.get(key)?.value.clone();
124 promote(&mut inner.order, key);
125 Some(value)
126 }
127
128 fn insert(&self, key: K, value: V) -> Option<V> {
129 let new_weight = (self.weigher)(&value);
130 if new_weight > self.max_weight {
133 return None;
134 }
135
136 let mut inner = self.inner.lock_recover();
137
138 if let Some(existing) = inner.map.get_mut(&key) {
140 let old_value = core::mem::replace(&mut existing.value, value);
141 let old_weight = existing.weight;
142 existing.weight = new_weight;
143 inner.total_weight = inner
146 .total_weight
147 .saturating_add(new_weight)
148 .saturating_sub(old_weight);
149 promote(&mut inner.order, &key);
150 evict_until_fits(&mut inner, self.max_weight);
151 return Some(old_value);
152 }
153
154 let projected_total = inner.total_weight.saturating_add(new_weight);
156 if projected_total > self.max_weight {
157 evict_until_fits_for_new(&mut inner, self.max_weight, new_weight);
158 }
159 inner.order.push_front(key.clone());
160 let _ = inner.map.insert(
161 key,
162 Entry {
163 value,
164 weight: new_weight,
165 },
166 );
167 inner.total_weight = inner.total_weight.saturating_add(new_weight);
168 None
169 }
170
171 fn remove(&self, key: &K) -> Option<V> {
172 let mut inner = self.inner.lock_recover();
173 let entry = inner.map.remove(key)?;
174 inner.total_weight = inner.total_weight.saturating_sub(entry.weight);
175 if let Some(pos) = inner.order.iter().position(|k| k == key) {
176 let _ = inner.order.remove(pos);
177 }
178 Some(entry.value)
179 }
180
181 fn contains_key(&self, key: &K) -> bool {
182 self.inner.lock_recover().map.contains_key(key)
183 }
184
185 fn len(&self) -> usize {
186 self.inner.lock_recover().map.len()
187 }
188
189 fn clear(&self) {
190 let mut inner = self.inner.lock_recover();
191 inner.map.clear();
192 inner.order.clear();
193 inner.total_weight = 0;
194 }
195
196 fn capacity(&self) -> usize {
199 self.max_weight
200 }
201}
202
203fn promote<K: Eq>(order: &mut VecDeque<K>, key: &K) {
204 if let Some(pos) = order.iter().position(|k| k == key) {
205 if let Some(k) = order.remove(pos) {
206 order.push_front(k);
207 }
208 }
209}
210
211fn evict_until_fits<K, V>(inner: &mut Inner<K, V>, max_weight: usize)
214where
215 K: Eq + Hash,
216{
217 while inner.total_weight > max_weight {
218 let Some(victim_key) = inner.order.pop_back() else {
219 break;
220 };
221 if let Some(victim) = inner.map.remove(&victim_key) {
222 inner.total_weight = inner.total_weight.saturating_sub(victim.weight);
223 }
224 }
225}
226
227fn evict_until_fits_for_new<K, V>(inner: &mut Inner<K, V>, max_weight: usize, incoming: usize)
231where
232 K: Eq + Hash,
233{
234 while inner.total_weight.saturating_add(incoming) > max_weight {
235 let Some(victim_key) = inner.order.pop_back() else {
236 break;
237 };
238 if let Some(victim) = inner.map.remove(&victim_key) {
239 inner.total_weight = inner.total_weight.saturating_sub(victim.weight);
240 }
241 }
242}