1use derive_more::Debug;
2use std::borrow::Borrow;
3use std::collections::hash_map::*;
4use std::collections::HashMap;
5use std::collections::TryReserveError;
6use std::hash::Hash;
7use std::iter::{FromIterator, IntoIterator};
8use std::ops::{Index, IndexMut};
9
10use crate::DefaultFn;
11
12#[derive(Clone, Debug)]
14#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
15pub struct DefaultHashMap<K: Eq + Hash, V> {
16 map: HashMap<K, V>,
17 default: V,
18 #[debug(skip)]
19 #[cfg_attr(feature = "with-serde", serde(skip))]
20 default_fn: Box<dyn DefaultFn<V>>,
21}
22
23impl<K: Eq + Hash, V: PartialEq> PartialEq for DefaultHashMap<K, V> {
24 fn eq(&self, other: &Self) -> bool {
25 self.map == other.map && self.default == other.default
26 }
27}
28
29impl<K: Eq + Hash, V: Eq> Eq for DefaultHashMap<K, V> {}
30
31impl<K: Eq + Hash, V: Default> DefaultHashMap<K, V> {
32 pub fn new() -> DefaultHashMap<K, V> {
37 DefaultHashMap {
38 map: HashMap::default(),
39 default_fn: Box::new(|| V::default()),
40 default: V::default(),
41 }
42 }
43}
44
45impl<K: Eq + Hash, V: Default> Default for DefaultHashMap<K, V> {
46 fn default() -> DefaultHashMap<K, V> {
48 DefaultHashMap::new()
49 }
50}
51
52impl<K: Eq + Hash, V: Default> From<HashMap<K, V>> for DefaultHashMap<K, V> {
53 fn from(map: HashMap<K, V>) -> DefaultHashMap<K, V> {
59 DefaultHashMap {
60 map,
61 default_fn: Box::new(|| V::default()),
62 default: V::default(),
63 }
64 }
65}
66
67impl<K: Eq + Hash, V> From<DefaultHashMap<K, V>> for HashMap<K, V> {
68 fn from(default_map: DefaultHashMap<K, V>) -> HashMap<K, V> {
71 default_map.map
72 }
73}
74
75impl<K: Eq + Hash, V: Clone + 'static> DefaultHashMap<K, V> {
76 pub fn with_default(default: V) -> DefaultHashMap<K, V> {
80 DefaultHashMap {
81 map: HashMap::new(),
82 default: default.clone(),
83 default_fn: Box::new(move || default.clone()),
84 }
85 }
86
87 pub fn from_map_with_default(map: HashMap<K, V>, default: V) -> DefaultHashMap<K, V> {
91 DefaultHashMap {
92 map,
93 default: default.clone(),
94 default_fn: Box::new(move || default.clone()),
95 }
96 }
97
98 pub fn set_default(&mut self, new_default: V) {
100 self.default = new_default.clone();
101 self.default_fn = Box::new(move || new_default.clone());
102 }
103}
104
105impl<K: Eq + Hash, V> DefaultHashMap<K, V> {
106 pub fn get<Q, QB: Borrow<Q>>(&self, key: QB) -> &V
111 where
112 K: Borrow<Q>,
113 Q: ?Sized + Hash + Eq,
114 {
115 self.map.get(key.borrow()).unwrap_or(&self.default)
116 }
117
118 pub fn get_default(&self) -> V {
124 self.default_fn.call()
125 }
126
127 pub fn with_fn(default_fn: impl DefaultFn<V> + 'static) -> DefaultHashMap<K, V> {
131 DefaultHashMap {
132 map: HashMap::new(),
133 default: default_fn.call(),
134 default_fn: Box::new(default_fn),
135 }
136 }
137
138 pub fn from_map_with_fn(
142 map: HashMap<K, V>,
143 default_fn: impl DefaultFn<V> + 'static,
144 ) -> DefaultHashMap<K, V> {
145 DefaultHashMap {
146 map,
147 default: default_fn.call(),
148 default_fn: Box::new(default_fn),
149 }
150 }
151}
152
153impl<K: Eq + Hash, V> DefaultHashMap<K, V> {
154 pub fn get_mut(&mut self, key: K) -> &mut V {
160 let entry = self.map.entry(key);
161 match entry {
162 Entry::Occupied(occupied) => occupied.into_mut(),
163 Entry::Vacant(vacant) => vacant.insert(self.default_fn.call()),
164 }
165 }
166}
167
168impl<K: Eq + Hash, KB: Borrow<K>, V> Index<KB> for DefaultHashMap<K, V> {
171 type Output = V;
172
173 fn index(&self, index: KB) -> &V {
174 self.get(index)
175 }
176}
177
178impl<K: Eq + Hash, V> IndexMut<K> for DefaultHashMap<K, V> {
181 #[inline]
182 fn index_mut(&mut self, index: K) -> &mut V {
183 self.get_mut(index)
184 }
185}
186
187impl<K: Eq + Hash, V> DefaultHashMap<K, V> {
192 pub fn capacity(&self) -> usize {
193 self.map.capacity()
194 }
195 #[inline]
196 pub fn keys(&self) -> Keys<K, V> {
197 self.map.keys()
198 }
199 #[inline]
200 pub fn into_keys(self) -> IntoKeys<K, V> {
201 self.map.into_keys()
202 }
203 #[inline]
204 pub fn values(&self) -> Values<K, V> {
205 self.map.values()
206 }
207 #[inline]
208 pub fn values_mut(&mut self) -> ValuesMut<K, V> {
209 self.map.values_mut()
210 }
211 #[inline]
212 pub fn into_values(self) -> IntoValues<K, V> {
213 self.map.into_values()
214 }
215 #[inline]
216 pub fn iter(&self) -> Iter<K, V> {
217 self.map.iter()
218 }
219 #[inline]
220 pub fn iter_mut(&mut self) -> IterMut<K, V> {
221 self.map.iter_mut()
222 }
223 #[inline]
224 pub fn len(&self) -> usize {
225 self.map.len()
226 }
227 #[inline]
228 pub fn is_empty(&self) -> bool {
229 self.map.is_empty()
230 }
231 #[inline]
232 pub fn drain(&mut self) -> Drain<K, V> {
233 self.map.drain()
234 }
235 #[inline]
236 pub fn retain<RF>(&mut self, f: RF)
237 where
238 RF: FnMut(&K, &mut V) -> bool,
239 {
240 self.map.retain(f)
241 }
242 #[inline]
243 pub fn clear(&mut self) {
244 self.map.clear()
245 }
246 #[inline]
247 pub fn reserve(&mut self, additional: usize) {
248 self.map.reserve(additional)
249 }
250 #[inline]
251 pub fn try_reserve(&mut self, additional: usize) -> Result<(), TryReserveError> {
252 self.map.try_reserve(additional)
253 }
254 #[inline]
255 pub fn shrink_to_fit(&mut self) {
256 self.map.shrink_to_fit()
257 }
258 #[inline]
259 pub fn shrink_to(&mut self, min_capacity: usize) {
260 self.map.shrink_to(min_capacity);
261 }
262 #[inline]
263 pub fn entry(&mut self, key: K) -> Entry<K, V> {
264 self.map.entry(key)
265 }
266
267 #[inline]
268 pub fn insert(&mut self, k: K, v: V) -> Option<V> {
269 self.map.insert(k, v)
270 }
271 #[inline]
272 pub fn contains_key<Q>(&self, k: &Q) -> bool
273 where
274 K: Borrow<Q>,
275 Q: ?Sized + Hash + Eq,
276 {
277 self.map.contains_key(k)
278 }
279 #[inline]
280 pub fn remove<Q>(&mut self, k: &Q) -> Option<V>
281 where
282 K: Borrow<Q>,
283 Q: ?Sized + Hash + Eq,
284 {
285 self.map.remove(k)
286 }
287 #[inline]
288 pub fn remove_entry<Q: ?Sized>(&mut self, k: &Q) -> Option<(K, V)>
289 where
290 K: Borrow<Q>,
291 Q: Hash + Eq,
292 {
293 self.map.remove_entry(k)
294 }
295}
296impl<K: Eq + Hash, V: Default> FromIterator<(K, V)> for DefaultHashMap<K, V> {
299 fn from_iter<I>(iter: I) -> Self
300 where
301 I: IntoIterator<Item = (K, V)>,
302 {
303 Self {
304 map: HashMap::from_iter(iter),
305 default: V::default(),
306 default_fn: Box::new(|| V::default()),
307 }
308 }
309}
310
311#[macro_export]
339macro_rules! defaulthashmap {
340 (@single $($x:tt)*) => (());
341 (@count $($rest:expr),*) => (<[()]>::len(&[$(defaulthashmap!(@single $rest)),*]));
342 (@hashmap $($key:expr => $value:expr),*) => {
344 {
345 let _cap = defaulthashmap!(@count $($key),*);
346 let mut _map = ::std::collections::HashMap::with_capacity(_cap);
347 $(
348 _map.insert($key, $value);
349 )*
350 _map
351 }
352 };
353
354 ($($key:expr => $value:expr,)+) => { defaulthashmap!($($key => $value),+) };
355 ($($key:expr => $value:expr),*) => {
356 {
357 let _map = defaulthashmap!(@hashmap $($key => $value),*);
358 $crate::DefaultHashMap::from(_map)
359 }
360 };
361
362 ($default:expr$(, $key:expr => $value:expr)+ ,) => { defaulthashmap!($default, $($key => $value),+) };
363 ($default:expr$(, $key:expr => $value:expr)*) => {
364 {
365 let _map = defaulthashmap!(@hashmap $($key => $value),*);
366 $crate::DefaultHashMap::from_map_with_default(_map, $default)
367 }
368 };
369}
370
371#[cfg(test)]
372mod tests {
373 use super::DefaultHashMap;
374 use std::collections::HashMap;
375
376 #[test]
377 fn macro_test() {
378 let macro_map: DefaultHashMap<i32, i32> = defaulthashmap! {};
380 let normal_map = DefaultHashMap::<i32, i32>::default();
381 assert_eq!(macro_map, normal_map);
382
383 let macro_map: DefaultHashMap<_, _> = defaulthashmap! {
385 1 => 2,
386 2 => 3,
387 };
388 let mut normal_map = DefaultHashMap::<_, _>::default();
389 normal_map[1] = 2;
390 normal_map[2] = 3;
391 assert_eq!(macro_map, normal_map);
392
393 let macro_map: DefaultHashMap<i32, i32> = defaulthashmap! {5};
395 let normal_map = DefaultHashMap::<i32, i32>::with_default(5);
396 assert_eq!(macro_map, normal_map);
397
398 let macro_map: DefaultHashMap<_, _> = defaulthashmap! {
400 5,
401 1 => 2,
402 2 => 3,
403 };
404 let mut normal_map = DefaultHashMap::<_, _>::with_default(5);
405 normal_map[1] = 2;
406 normal_map[2] = 3;
407 assert_eq!(macro_map, normal_map);
408 }
409
410 #[test]
411 fn add() {
412 let mut map: DefaultHashMap<i32, i32> = DefaultHashMap::default();
413 *map.get_mut(0) += 1;
414 map[1] += 4;
415 map[2] = map[0] + map.get(&1);
416 assert_eq!(*map.get(0), 1);
417 assert_eq!(*map.get(&0), 1);
418 assert_eq!(map[0], 1);
419 assert_eq!(map[&0], 1);
420 assert_eq!(*map.get(&1), 4);
421 assert_eq!(*map.get(&2), 5);
422 assert_eq!(*map.get(999), 0);
423 assert_eq!(*map.get(&999), 0);
424 assert_eq!(map[999], 0);
425 assert_eq!(map[&999], 0);
426 }
427
428 #[test]
429 fn counter() {
430 let nums = [1, 4, 3, 3, 4, 2, 4];
431 let mut counts: DefaultHashMap<i32, i32> = DefaultHashMap::default();
432 for num in nums.iter() {
433 counts[*num] += 1;
434 }
435
436 assert_eq!(1, counts[1]);
437 assert_eq!(1, counts[2]);
438 assert_eq!(2, counts[3]);
439 assert_eq!(3, counts[4]);
440 assert_eq!(0, counts[5]);
441 }
442
443 #[test]
444 fn change_default() {
445 let mut numbers: DefaultHashMap<i32, String> =
446 DefaultHashMap::with_default("Mexico".to_string());
447
448 assert_eq!("Mexico", numbers.get_mut(1));
449 assert_eq!("Mexico", numbers.get_mut(2));
450 assert_eq!("Mexico", numbers[3]);
451
452 numbers.set_default("Cybernetics".to_string());
453 assert_eq!("Mexico", numbers[1]);
454 assert_eq!("Mexico", numbers[2]);
455 assert_eq!("Cybernetics", numbers[3]);
456 assert_eq!("Cybernetics", numbers[4]);
457 assert_eq!("Cybernetics", numbers[5]);
458 }
459
460 #[test]
461 fn synonyms() {
462 let synonym_tuples = [
463 ("nice", "sweet"),
464 ("sweet", "candy"),
465 ("nice", "entertaining"),
466 ("nice", "good"),
467 ("entertaining", "absorbing"),
468 ];
469
470 let mut synonym_map: DefaultHashMap<&str, Vec<&str>> = DefaultHashMap::default();
471
472 for &(l, r) in synonym_tuples.iter() {
473 synonym_map[l].push(r);
474 synonym_map[r].push(l);
475 }
476
477 println!("{:#?}", synonym_map);
478 assert_eq!(synonym_map["good"], vec!["nice"]);
479 assert_eq!(synonym_map["nice"], vec!["sweet", "entertaining", "good"]);
480 assert_eq!(synonym_map["evil"], Vec::<&str>::new());
481 }
482
483 #[derive(Clone)]
484 struct Clonable;
485
486 #[derive(Default, Clone)]
487 struct DefaultableValue;
488
489 #[derive(Hash, Eq, PartialEq)]
490 struct Hashable(i32);
491
492 #[test]
493 fn minimal_derives() {
494 let _: DefaultHashMap<Hashable, Clonable> = DefaultHashMap::with_default(Clonable);
495 let _: DefaultHashMap<Hashable, DefaultableValue> = DefaultHashMap::default();
496 }
497
498 #[test]
499 fn from() {
500 let normal: HashMap<i32, i32> = vec![(0, 1), (2, 3)].into_iter().collect();
501 let mut default: DefaultHashMap<_, _> = normal.into();
502 default.get_mut(4);
503 assert_eq!(default[0], 1);
504 assert_eq!(default[2], 3);
505 assert_eq!(default[1], 0);
506 assert_eq!(default[4], 0);
507 let expected: HashMap<i32, i32> = vec![(0, 1), (2, 3), (4, 0)].into_iter().collect();
508 assert_eq!(expected, default.into());
509 }
510
511 #[test]
512 fn with_fn() {
513 let i: i32 = 20;
514 let mut map = DefaultHashMap::with_fn(move || i);
515 map[0] += 1;
516 assert_eq!(21, map[0]);
517 assert_eq!(20, map[1]);
518 }
519
520 #[test]
521 fn from_map_with_fn() {
522 let i: i32 = 20;
523 let normal: HashMap<i32, i32> = vec![(0, 1), (2, 3)].into_iter().collect();
524 let mut map = DefaultHashMap::from_map_with_fn(normal, move || i);
525 map[0] += 1;
526 assert_eq!(map[0], 2);
527 assert_eq!(map[1], 20);
528 assert_eq!(map[2], 3);
529 }
530
531 #[cfg(feature = "with-serde")]
532 mod serde_tests {
533 use super::*;
534
535 #[test]
536 fn deserialize_static() {
537 let s = "{
538 \"map\" :
539 { \"foo\": 3,
540 \"bar\": 5
541 },
542 \"default\":15
543 }";
544 let h: Result<DefaultHashMap<&str, i32>, _> = serde_json::from_str(&s);
545 let h = h.unwrap();
546 assert_eq!(h["foo"] * h["bar"], h["foobar"])
547 }
548
549 #[test]
550 fn serialize_and_back() {
551 let h1: DefaultHashMap<i32, u64> = defaulthashmap!(1 => 10, 2 => 20, 3 => 30);
552 let s = serde_json::to_string(&h1).unwrap();
553 let h2: DefaultHashMap<i32, u64> = serde_json::from_str(&s).unwrap();
554 assert_eq!(h2, h2);
555 assert_eq!(h2[3], 30);
556 }
557
558 #[test]
559 fn serialize_default() {
560 let h1: DefaultHashMap<&str, u64> = DefaultHashMap::with_default(42);
561 let s = serde_json::to_string(&h1).unwrap();
562 let h2: DefaultHashMap<&str, u64> = serde_json::from_str(&s).unwrap();
563 assert_eq!(h2["answer"], 42);
564 }
565
566 #[test]
567 fn std_hashmap() {
568 let h1: DefaultHashMap<i32, i32> = defaulthashmap!(1=> 10, 2=> 20);
569 let stdhm: std::collections::HashMap<i32, i32> = h1.clone().into();
570 let s = serde_json::to_string(&stdhm).unwrap();
571 let h2: DefaultHashMap<i32, i32> = DefaultHashMap::from_map_with_default(
572 serde_json::from_str(&s).unwrap(),
573 i32::default(),
574 );
575 assert_eq!(h1, h2);
576 }
577 }
578}