1#![allow(unsafe_code)]
3
4use std::sync::atomic::{AtomicUsize, Ordering};
5
6use crossbeam_epoch::{self as epoch, Atomic, Owned};
7
8use crate::build;
9use crate::config::Config;
10use crate::error::Result;
11use crate::insert::{self, InsertResult};
12use std::ops::RangeBounds;
13
14use crate::iter::{self, Iter, Range};
15use crate::key::Key;
16use crate::lookup;
17use crate::model::LinearModel;
18use crate::node::Node;
19use crate::remove;
20
21const INITIAL_ROOT_REBUILD_THRESHOLD: usize = 64;
27
28const ROOT_REBUILD_GROWTH_FACTOR: usize = 2;
34
35pub struct Guard {
43 inner: epoch::Guard,
44}
45
46impl Guard {
47 fn new(inner: epoch::Guard) -> Self {
48 Self { inner }
49 }
50}
51
52pub struct MapRef<'a, K: Key, V> {
69 map: &'a LearnedMap<K, V>,
70 guard: Guard,
71}
72
73impl<K: Key, V: Clone + Send + Sync> MapRef<'_, K, V> {
74 pub fn get(&self, key: &K) -> Option<&V> {
76 self.map.get(key, &self.guard)
77 }
78
79 pub fn insert(&self, key: K, value: V) -> bool {
81 self.map.insert(key, value, &self.guard)
82 }
83
84 pub fn remove(&self, key: &K) -> bool {
86 self.map.remove(key, &self.guard)
87 }
88
89 pub fn get_or_insert(&self, key: K, value: V) -> &V {
93 self.map.get_or_insert(key, value, &self.guard)
94 }
95
96 pub fn get_or_insert_with(&self, key: K, f: impl FnOnce() -> V) -> &V {
100 self.map.get_or_insert_with(key, f, &self.guard)
101 }
102
103 pub fn contains_key(&self, key: &K) -> bool {
105 self.map.contains_key(key, &self.guard)
106 }
107
108 pub fn len(&self) -> usize {
113 self.map.len()
114 }
115
116 pub fn is_empty(&self) -> bool {
120 self.map.is_empty()
121 }
122
123 #[allow(clippy::iter_without_into_iter)]
125 pub fn iter(&self) -> Iter<'_, K, V> {
126 self.map.iter(&self.guard)
127 }
128
129 pub fn iter_sorted(&self) -> Vec<(K, V)> {
131 self.map.iter_sorted(&self.guard)
132 }
133
134 pub fn range<R: RangeBounds<K>>(&self, range: R) -> Range<'_, K, V> {
136 self.map.range(range, &self.guard)
137 }
138
139 pub fn first_key_value(&self) -> Option<(&K, &V)> {
141 self.map.first_key_value(&self.guard)
142 }
143
144 pub fn last_key_value(&self) -> Option<(&K, &V)> {
146 self.map.last_key_value(&self.guard)
147 }
148
149 pub fn range_count<R: RangeBounds<K>>(&self, range: R) -> usize {
151 self.map.range_count(range, &self.guard)
152 }
153
154 pub fn allocated_bytes(&self) -> usize {
158 self.map.allocated_bytes(&self.guard)
159 }
160
161 pub fn max_depth(&self) -> usize {
163 self.map.max_depth(&self.guard)
164 }
165
166 pub fn rebuild(&self) {
168 self.map.rebuild(&self.guard);
169 }
170
171 pub fn drain(&self) -> Vec<(K, V)> {
173 self.map.drain(&self.guard)
174 }
175
176 pub fn clear(&self) {
178 self.map.clear(&self.guard);
179 }
180}
181
182pub struct LearnedMap<K: Key, V> {
217 root: Atomic<Node<K, V>>,
218 len: AtomicUsize,
219 config: Config,
220 next_root_rebuild: AtomicUsize,
222}
223
224impl<K: Key, V> std::fmt::Debug for LearnedMap<K, V> {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 f.debug_struct("LearnedMap")
227 .field("len", &self.len.load(Ordering::Relaxed))
228 .finish_non_exhaustive()
229 }
230}
231
232impl<K: Key, V: Clone + Send + Sync> LearnedMap<K, V> {
233 pub fn new() -> Self {
235 Self::with_config(Config::default())
236 }
237
238 pub fn with_config(config: Config) -> Self {
240 let root = Node::with_capacity(LinearModel::new(1.0, 0.0), 64);
241 let root_atomic = Atomic::new(root);
242 Self {
243 root: root_atomic,
244 len: AtomicUsize::new(0),
245 next_root_rebuild: AtomicUsize::new(INITIAL_ROOT_REBUILD_THRESHOLD),
246 config,
247 }
248 }
249
250 pub fn bulk_load(pairs: &[(K, V)]) -> Result<Self> {
259 Self::bulk_load_with_config(pairs, Config::default())
260 }
261
262 pub fn bulk_load_with_config(pairs: &[(K, V)], config: Config) -> Result<Self> {
268 let build_config = Config {
272 range_headroom: config.range_headroom.max(1.0),
273 ..config
274 };
275 let root = build::bulk_load(pairs, &build_config)?;
276 let root_atomic = Atomic::new(root);
277 let next_threshold = pairs.len().saturating_mul(ROOT_REBUILD_GROWTH_FACTOR);
278 Ok(Self {
279 len: AtomicUsize::new(pairs.len()),
280 root: root_atomic,
281 next_root_rebuild: AtomicUsize::new(next_threshold),
282 config,
283 })
284 }
285
286 pub fn bulk_load_dedup(pairs: &[(K, V)]) -> Result<Self> {
298 Self::bulk_load_dedup_with_config(pairs, Config::default())
299 }
300
301 pub fn bulk_load_dedup_with_config(pairs: &[(K, V)], config: Config) -> Result<Self> {
309 if pairs.is_empty() {
310 return Err(crate::error::Error::EmptyData);
311 }
312
313 for window in pairs.windows(2) {
315 if window[0].0 > window[1].0 {
316 return Err(crate::error::Error::NotSorted);
317 }
318 }
319
320 let mut deduped = Vec::with_capacity(pairs.len());
322 for window in pairs.windows(2) {
323 if window[0].0 != window[1].0 {
324 deduped.push(window[0].clone());
325 }
326 }
327 if let Some(last) = pairs.last() {
329 deduped.push(last.clone());
330 }
331
332 if deduped.is_empty() {
333 return Err(crate::error::Error::EmptyData);
334 }
335
336 Self::bulk_load_with_config(&deduped, config)
337 }
338
339 pub fn guard(&self) -> Guard {
345 Guard::new(epoch::pin())
346 }
347
348 pub fn pin(&self) -> MapRef<'_, K, V> {
353 MapRef {
354 map: self,
355 guard: self.guard(),
356 }
357 }
358
359 pub fn get<'g>(&self, key: &K, guard: &'g Guard) -> Option<&'g V> {
363 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
364 let root = unsafe { root_shared.deref() };
366 lookup::get(root, key, &guard.inner)
367 }
368
369 #[allow(clippy::needless_pass_by_value)]
381 pub fn insert(&self, key: K, value: V, guard: &Guard) -> bool {
382 let mut was_new = false;
383 let backoff = crossbeam_utils::Backoff::new();
384 loop {
385 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
386 if root_shared.tag() != 0 {
388 backoff.snooze();
389 continue;
390 }
391 let root = unsafe { root_shared.deref() };
393 let result = insert::insert(root, key.clone(), &value, &self.config, &guard.inner);
394 if self.root.load(Ordering::Acquire, &guard.inner) != root_shared {
396 if result == InsertResult::Inserted {
397 was_new = true;
398 }
399 continue;
400 }
401 let is_new = result == InsertResult::Inserted || was_new;
402 if is_new {
403 let new_len = self.len.fetch_add(1, Ordering::Relaxed).wrapping_add(1);
404 self.maybe_rebuild_root(new_len, guard);
405 }
406 return is_new;
407 }
408 }
409
410 fn maybe_rebuild_root(&self, current_len: usize, guard: &Guard) {
415 if !self.config.auto_rebuild {
416 return;
417 }
418
419 let threshold = self.next_root_rebuild.load(Ordering::Relaxed);
420 if current_len < threshold {
421 return;
422 }
423
424 let next_threshold = threshold.saturating_mul(ROOT_REBUILD_GROWTH_FACTOR);
426 if self
427 .next_root_rebuild
428 .compare_exchange(
429 threshold,
430 next_threshold,
431 Ordering::AcqRel,
432 Ordering::Relaxed,
433 )
434 .is_ok()
435 {
436 self.rebuild(guard);
437 }
438 }
439
440 pub fn remove(&self, key: &K, guard: &Guard) -> bool {
447 let mut was_removed = false;
448 let backoff = crossbeam_utils::Backoff::new();
449 loop {
450 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
451 if root_shared.tag() != 0 {
453 backoff.snooze();
454 continue;
455 }
456 let root = unsafe { root_shared.deref() };
458 let removed = remove::remove(root, key, &self.config, &guard.inner);
459 if self.root.load(Ordering::Acquire, &guard.inner) != root_shared {
461 if removed {
462 was_removed = true;
463 }
464 continue;
465 }
466 let did_remove = removed || was_removed;
467 if did_remove {
468 self.len.fetch_sub(1, Ordering::Relaxed);
469 }
470 return did_remove;
471 }
472 }
473
474 #[allow(clippy::needless_pass_by_value)]
483 pub fn get_or_insert<'g>(&self, key: K, value: V, guard: &'g Guard) -> &'g V {
484 let mut was_new = false;
485 let backoff = crossbeam_utils::Backoff::new();
486 loop {
487 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
488 if root_shared.tag() != 0 {
489 backoff.snooze();
490 continue;
491 }
492 let root = unsafe { root_shared.deref() };
494 let (val, result) =
495 insert::get_or_insert(root, key.clone(), &value, &self.config, &guard.inner);
496 if self.root.load(Ordering::Acquire, &guard.inner) != root_shared {
498 if result == InsertResult::Inserted {
499 was_new = true;
500 }
501 continue;
502 }
503 let is_new = result == InsertResult::Inserted || was_new;
504 if is_new {
505 let new_len = self.len.fetch_add(1, Ordering::Relaxed).wrapping_add(1);
506 self.maybe_rebuild_root(new_len, guard);
507 }
508 return val;
509 }
510 }
511
512 pub fn get_or_insert_with<'g>(&self, key: K, f: impl FnOnce() -> V, guard: &'g Guard) -> &'g V {
520 if let Some(val) = self.get(&key, guard) {
522 return val;
523 }
524 let value = f();
526 self.get_or_insert(key, value, guard)
527 }
528
529 pub fn contains_key(&self, key: &K, guard: &Guard) -> bool {
531 self.get(key, guard).is_some()
532 }
533
534 pub fn len(&self) -> usize {
542 self.len.load(Ordering::Relaxed)
543 }
544
545 pub fn is_empty(&self) -> bool {
549 self.len() == 0
550 }
551
552 pub fn iter<'g>(&self, guard: &'g Guard) -> Iter<'g, K, V> {
556 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
557 let root = unsafe { root_shared.deref() };
559 Iter::with_hint(root, &guard.inner, self.len())
560 }
561
562 pub fn iter_sorted(&self, guard: &Guard) -> Vec<(K, V)> {
566 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
567 let root = unsafe { root_shared.deref() };
569 iter::sorted_pairs(root, &guard.inner)
570 }
571
572 pub fn range<'g, R: RangeBounds<K>>(&self, range: R, guard: &'g Guard) -> Range<'g, K, V> {
579 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
580 let root = unsafe { root_shared.deref() };
582 Range::new(root, range, &guard.inner)
583 }
584
585 pub fn first_key_value<'g>(&self, guard: &'g Guard) -> Option<(&'g K, &'g V)> {
587 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
588 let root = unsafe { root_shared.deref() };
590 iter::first_entry(root, &guard.inner)
591 }
592
593 pub fn last_key_value<'g>(&self, guard: &'g Guard) -> Option<(&'g K, &'g V)> {
595 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
596 let root = unsafe { root_shared.deref() };
598 iter::last_entry(root, &guard.inner)
599 }
600
601 pub fn range_count<R: RangeBounds<K>>(&self, range: R, guard: &Guard) -> usize {
603 self.range(range, guard).count()
604 }
605
606 pub fn allocated_bytes(&self, guard: &Guard) -> usize {
612 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
613 let root = unsafe { root_shared.deref() };
615 root.allocated_bytes(&guard.inner)
616 }
617
618 pub fn max_depth(&self, guard: &Guard) -> usize {
622 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
623 let root = unsafe { root_shared.deref() };
625 root.max_depth(&guard.inner)
626 }
627
628 pub fn rebuild(&self, guard: &Guard) {
639 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
640 if root_shared.is_null() || root_shared.tag() != 0 {
641 return;
642 }
643 let root = unsafe { root_shared.deref() };
645
646 let frozen = root_shared.with_tag(1);
649 if self
650 .root
651 .compare_exchange(
652 root_shared,
653 frozen,
654 Ordering::AcqRel,
655 Ordering::Acquire,
656 &guard.inner,
657 )
658 .is_err()
659 {
660 return;
661 }
662
663 let pairs = iter::sorted_pairs(root, &guard.inner);
664 if pairs.is_empty() {
665 let _ = self.root.compare_exchange(
667 frozen,
668 root_shared,
669 Ordering::AcqRel,
670 Ordering::Relaxed,
671 &guard.inner,
672 );
673 return;
674 }
675
676 let rebuild_config = Config {
680 range_headroom: 1.0,
681 ..self.config.clone()
682 };
683 let Ok(new_root) = build::bulk_load(&pairs, &rebuild_config) else {
687 let _ = self.root.compare_exchange(
689 frozen,
690 root_shared,
691 Ordering::AcqRel,
692 Ordering::Relaxed,
693 &guard.inner,
694 );
695 return;
696 };
697 let new_owned = Owned::new(new_root);
698 if self
699 .root
700 .compare_exchange(
701 frozen,
702 new_owned,
703 Ordering::AcqRel,
704 Ordering::Acquire,
705 &guard.inner,
706 )
707 .is_ok()
708 {
709 unsafe {
712 guard.inner.defer_destroy(root_shared);
713 }
714 let count = pairs.len();
718 self.next_root_rebuild.store(
720 count.saturating_mul(ROOT_REBUILD_GROWTH_FACTOR),
721 Ordering::Relaxed,
722 );
723 } else {
724 let _ = self.root.compare_exchange(
727 frozen,
728 root_shared,
729 Ordering::AcqRel,
730 Ordering::Relaxed,
731 &guard.inner,
732 );
733 }
734 }
735
736 pub fn drain(&self, guard: &Guard) -> Vec<(K, V)> {
745 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
746 if root_shared.is_null() || root_shared.tag() != 0 {
747 return Vec::new();
748 }
749
750 let frozen = root_shared.with_tag(1);
752 if self
753 .root
754 .compare_exchange(
755 root_shared,
756 frozen,
757 Ordering::AcqRel,
758 Ordering::Acquire,
759 &guard.inner,
760 )
761 .is_err()
762 {
763 return Vec::new();
764 }
765
766 let root = unsafe { root_shared.deref() };
768 let pairs = iter::sorted_pairs(root, &guard.inner);
769
770 let new_root = Node::with_capacity(LinearModel::new(1.0, 0.0), 64);
771 let new_owned = Owned::new(new_root);
772 if self
773 .root
774 .compare_exchange(
775 frozen,
776 new_owned,
777 Ordering::AcqRel,
778 Ordering::Acquire,
779 &guard.inner,
780 )
781 .is_ok()
782 {
783 unsafe {
785 guard.inner.defer_destroy(root_shared);
786 }
787 self.len.fetch_sub(pairs.len(), Ordering::Relaxed);
791 self.next_root_rebuild
792 .store(INITIAL_ROOT_REBUILD_THRESHOLD, Ordering::Relaxed);
793 } else {
794 let _ = self.root.compare_exchange(
796 frozen,
797 root_shared,
798 Ordering::AcqRel,
799 Ordering::Relaxed,
800 &guard.inner,
801 );
802 }
803
804 pairs
805 }
806
807 pub fn clear(&self, guard: &Guard) {
813 let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
814 if root_shared.is_null() || root_shared.tag() != 0 {
815 return;
816 }
817
818 let frozen = root_shared.with_tag(1);
820 if self
821 .root
822 .compare_exchange(
823 root_shared,
824 frozen,
825 Ordering::AcqRel,
826 Ordering::Acquire,
827 &guard.inner,
828 )
829 .is_err()
830 {
831 return;
832 }
833
834 let old_root = unsafe { root_shared.deref() };
838 let entry_count = Iter::new(old_root, &guard.inner).count();
839
840 let new_root = Node::with_capacity(LinearModel::new(1.0, 0.0), 64);
841 let new_owned = Owned::new(new_root);
842 if self
843 .root
844 .compare_exchange(
845 frozen,
846 new_owned,
847 Ordering::AcqRel,
848 Ordering::Acquire,
849 &guard.inner,
850 )
851 .is_ok()
852 {
853 unsafe {
855 guard.inner.defer_destroy(root_shared);
856 }
857 self.len.fetch_sub(entry_count, Ordering::Relaxed);
861 self.next_root_rebuild
862 .store(INITIAL_ROOT_REBUILD_THRESHOLD, Ordering::Relaxed);
863 } else {
864 let _ = self.root.compare_exchange(
866 frozen,
867 root_shared,
868 Ordering::AcqRel,
869 Ordering::Relaxed,
870 &guard.inner,
871 );
872 }
873 }
874}
875
876#[cfg(feature = "serde")]
877impl<K, V> serde::Serialize for LearnedMap<K, V>
878where
879 K: Key + serde::Serialize,
880 V: Clone + Send + Sync + serde::Serialize,
881{
882 fn serialize<S: serde::Serializer>(
883 &self,
884 serializer: S,
885 ) -> std::result::Result<S::Ok, S::Error> {
886 use serde::ser::SerializeSeq;
887
888 let guard = self.guard();
889 let len = self.len();
890 let mut seq = serializer.serialize_seq(Some(len))?;
891 for (k, v) in self.iter(&guard) {
892 seq.serialize_element(&(k, v))?;
893 }
894 seq.end()
895 }
896}
897
898#[cfg(feature = "serde")]
899impl<'de, K, V> serde::Deserialize<'de> for LearnedMap<K, V>
900where
901 K: Key + serde::Deserialize<'de>,
902 V: Clone + Send + Sync + serde::Deserialize<'de>,
903{
904 fn deserialize<D: serde::Deserializer<'de>>(
905 deserializer: D,
906 ) -> std::result::Result<Self, D::Error> {
907 let pairs: Vec<(K, V)> = Vec::deserialize(deserializer)?;
908 if pairs.is_empty() {
909 return Ok(Self::new());
910 }
911 Self::bulk_load_dedup(&pairs).map_err(serde::de::Error::custom)
912 }
913}
914
915impl<K: Key, V: Clone + Send + Sync> Default for LearnedMap<K, V> {
916 fn default() -> Self {
917 Self::new()
918 }
919}
920
921impl<K: Key, V: Clone + Send + Sync> Extend<(K, V)> for LearnedMap<K, V> {
922 fn extend<I: IntoIterator<Item = (K, V)>>(&mut self, iter: I) {
923 let guard = self.guard();
924 for (k, v) in iter {
925 self.insert(k, v, &guard);
926 }
927 }
928}
929
930impl<K: Key, V: Clone + Send + Sync> FromIterator<(K, V)> for LearnedMap<K, V> {
935 fn from_iter<I: IntoIterator<Item = (K, V)>>(iter: I) -> Self {
936 let map = Self::new();
937 let guard = map.guard();
938 for (k, v) in iter {
939 map.insert(k, v, &guard);
940 }
941 map
942 }
943}
944
945impl<K: Key, V> Drop for LearnedMap<K, V> {
946 fn drop(&mut self) {
947 unsafe {
958 let guard = epoch::pin();
959 let shared = self.root.load(Ordering::Relaxed, &guard);
960 if !shared.is_null() {
961 guard.defer_destroy(shared);
962 }
963 }
964 }
965}
966
967unsafe impl<K: Key, V: Send + Sync> Send for LearnedMap<K, V> {}
970unsafe impl<K: Key, V: Send + Sync> Sync for LearnedMap<K, V> {}
971
972#[cfg(test)]
973mod tests {
974 use super::*;
975
976 #[test]
977 fn new_map_is_empty() {
978 let map = LearnedMap::<u64, ()>::new();
979 assert!(map.is_empty());
980 assert_eq!(map.len(), 0);
981 }
982
983 #[test]
984 fn insert_and_get() {
985 let map = LearnedMap::new();
986 let g = map.guard();
987 assert!(map.insert(42u64, "hello", &g));
988 assert_eq!(map.get(&42, &g), Some(&"hello"));
989 assert_eq!(map.len(), 1);
990 }
991
992 #[test]
993 fn insert_duplicate_updates() {
994 let map = LearnedMap::new();
995 let g = map.guard();
996 assert!(map.insert(1u64, "one", &g));
997 assert!(!map.insert(1, "ONE", &g));
998 assert_eq!(map.get(&1, &g), Some(&"ONE"));
999 assert_eq!(map.len(), 1);
1000 }
1001
1002 #[test]
1003 fn remove_existing() {
1004 let map = LearnedMap::new();
1005 let g = map.guard();
1006 map.insert(1u64, "a", &g);
1007 map.insert(2, "b", &g);
1008 assert!(map.remove(&1, &g));
1009 assert_eq!(map.len(), 1);
1010 assert!(!map.contains_key(&1, &g));
1011 assert!(map.contains_key(&2, &g));
1012 }
1013
1014 #[test]
1015 fn remove_missing() {
1016 let map = LearnedMap::new();
1017 let g = map.guard();
1018 map.insert(1u64, "a", &g);
1019 assert!(!map.remove(&99, &g));
1020 assert_eq!(map.len(), 1);
1021 }
1022
1023 #[test]
1024 fn bulk_load_basic() {
1025 let pairs: Vec<(u64, u64)> = (0..100).map(|i| (i, i * 10)).collect();
1026 let map = LearnedMap::bulk_load(&pairs).unwrap();
1027 let g = map.guard();
1028 assert_eq!(map.len(), 100);
1029 for (k, v) in &pairs {
1030 assert_eq!(map.get(k, &g), Some(v));
1031 }
1032 }
1033
1034 #[test]
1035 fn bulk_load_then_insert() {
1036 let pairs: Vec<(u64, u64)> = vec![(10, 1), (20, 2), (30, 3)];
1037 let map = LearnedMap::bulk_load(&pairs).unwrap();
1038 let g = map.guard();
1039 map.insert(15, 15, &g);
1040 map.insert(25, 25, &g);
1041 assert_eq!(map.len(), 5);
1042 assert_eq!(map.get(&15, &g), Some(&15));
1043 assert_eq!(map.get(&25, &g), Some(&25));
1044 }
1045
1046 #[test]
1047 fn bulk_load_dedup_keeps_last() {
1048 let pairs: Vec<(u64, &str)> = vec![(1, "a"), (1, "A"), (2, "b"), (3, "c"), (3, "C")];
1049 let map = LearnedMap::bulk_load_dedup(&pairs).unwrap();
1050 let g = map.guard();
1051 assert_eq!(map.len(), 3);
1052 assert_eq!(map.get(&1, &g), Some(&"A"));
1053 assert_eq!(map.get(&2, &g), Some(&"b"));
1054 assert_eq!(map.get(&3, &g), Some(&"C"));
1055 }
1056
1057 #[test]
1058 fn bulk_load_dedup_no_duplicates() {
1059 let pairs: Vec<(u64, u64)> = (0..50).map(|i| (i, i * 10)).collect();
1060 let map = LearnedMap::bulk_load_dedup(&pairs).unwrap();
1061 let g = map.guard();
1062 assert_eq!(map.len(), 50);
1063 for (k, v) in &pairs {
1064 assert_eq!(map.get(k, &g), Some(v));
1065 }
1066 }
1067
1068 #[test]
1069 fn bulk_load_dedup_all_same_key() {
1070 let pairs: Vec<(u64, u64)> = (0..10).map(|i| (42, i)).collect();
1071 let map = LearnedMap::bulk_load_dedup(&pairs).unwrap();
1072 let g = map.guard();
1073 assert_eq!(map.len(), 1);
1074 assert_eq!(map.get(&42, &g), Some(&9));
1075 }
1076
1077 #[test]
1078 fn bulk_load_dedup_empty() {
1079 let result = LearnedMap::<u64, u64>::bulk_load_dedup(&[]);
1080 assert!(result.is_err());
1081 }
1082
1083 #[test]
1084 fn bulk_load_dedup_not_sorted() {
1085 let pairs: Vec<(u64, u64)> = vec![(3, 0), (1, 0), (2, 0)];
1086 let result = LearnedMap::bulk_load_dedup(&pairs);
1087 assert!(result.is_err());
1088 }
1089
1090 #[test]
1091 fn from_iterator() {
1092 let map: LearnedMap<u64, &str> = vec![(1, "a"), (2, "b"), (3, "c")].into_iter().collect();
1093 let g = map.guard();
1094 assert_eq!(map.len(), 3);
1095 assert_eq!(map.get(&2, &g), Some(&"b"));
1096 }
1097
1098 #[test]
1099 fn extend_map() {
1100 let mut map = LearnedMap::new();
1101 {
1102 let g = map.guard();
1103 map.insert(1u64, 10, &g);
1104 }
1105 map.extend(vec![(2, 20), (3, 30)]);
1106 assert_eq!(map.len(), 3);
1107 }
1108
1109 #[test]
1110 fn iter_sorted_order() {
1111 let map = LearnedMap::new();
1112 let g = map.guard();
1113 map.insert(30u64, "c", &g);
1114 map.insert(10, "a", &g);
1115 map.insert(20, "b", &g);
1116
1117 let items: Vec<(u64, &str)> = map.iter_sorted(&g);
1118 assert_eq!(items, vec![(10, "a"), (20, "b"), (30, "c")]);
1119 }
1120
1121 #[test]
1122 fn max_depth_bounded() {
1123 let pairs: Vec<(u64, u64)> = (0..1000).map(|i| (i, i)).collect();
1124 let map = LearnedMap::bulk_load(&pairs).unwrap();
1125 let g = map.guard();
1126 assert!(
1127 map.max_depth(&g) <= 5,
1128 "depth {} is too high for 1000 sequential keys",
1129 map.max_depth(&g)
1130 );
1131 }
1132
1133 #[test]
1134 fn stress_insert_lookup_remove() {
1135 let map = LearnedMap::new();
1136 let g = map.guard();
1137 let n = 500u64;
1138
1139 for i in 0..n {
1140 map.insert(i * 3, i, &g);
1141 }
1142 assert_eq!(map.len(), n as usize);
1143
1144 for i in 0..n {
1145 assert_eq!(map.get(&(i * 3), &g), Some(&i), "key {} missing", i * 3);
1146 }
1147
1148 for i in (0..n).filter(|i| i % 2 == 0) {
1149 map.remove(&(i * 3), &g);
1150 }
1151 assert_eq!(map.len(), (n / 2) as usize);
1152
1153 for i in (0..n).filter(|i| i % 2 != 0) {
1154 assert_eq!(map.get(&(i * 3), &g), Some(&i));
1155 }
1156 }
1157
1158 #[test]
1159 fn manual_rebuild() {
1160 let map = LearnedMap::new();
1161 let g = map.guard();
1162 for i in (0..100u64).rev() {
1163 map.insert(i, i * 10, &g);
1164 }
1165 let depth_before = map.max_depth(&g);
1166 map.rebuild(&g);
1167 let depth_after = map.max_depth(&g);
1168 assert!(
1169 depth_after <= depth_before,
1170 "rebuild didn't help: {depth_before} -> {depth_after}"
1171 );
1172 let g2 = map.guard();
1174 for i in 0..100u64 {
1175 assert_eq!(map.get(&i, &g2), Some(&(i * 10)));
1176 }
1177 }
1178
1179 #[test]
1180 fn rebuild_empty_is_noop() {
1181 let map = LearnedMap::<u64, u64>::new();
1182 let g = map.guard();
1183 map.rebuild(&g);
1184 assert!(map.is_empty());
1185 }
1186
1187 #[test]
1188 fn large_incremental_insert() {
1189 let map = LearnedMap::new();
1190 let g = map.guard();
1191 for i in 0..1000u64 {
1192 map.insert(i, i, &g);
1193 }
1194 assert_eq!(map.len(), 1000);
1195 for i in 0..1000u64 {
1196 assert_eq!(map.get(&i, &g), Some(&i));
1197 }
1198 }
1199
1200 #[test]
1201 fn pin_convenience() {
1202 let map = LearnedMap::new();
1203 let m = map.pin();
1204 m.insert(1u64, "one");
1205 m.insert(2, "two");
1206 assert_eq!(m.get(&1), Some(&"one"));
1207 assert_eq!(m.get(&2), Some(&"two"));
1208 assert_eq!(m.len(), 2);
1209 assert!(!m.is_empty());
1210 }
1211
1212 #[test]
1213 fn map_ref_remove() {
1214 let map = LearnedMap::new();
1215 let m = map.pin();
1216 m.insert(10u64, 100);
1217 m.insert(20, 200);
1218 assert!(m.remove(&10));
1219 assert!(!m.remove(&10));
1220 assert_eq!(m.len(), 1);
1221 assert!(m.contains_key(&20));
1222 }
1223
1224 #[test]
1225 fn map_ref_iter_sorted() {
1226 let map = LearnedMap::new();
1227 let m = map.pin();
1228 m.insert(3u64, "c");
1229 m.insert(1, "a");
1230 m.insert(2, "b");
1231 let items = m.iter_sorted();
1232 assert_eq!(items, vec![(1, "a"), (2, "b"), (3, "c")]);
1233 }
1234
1235 #[test]
1236 fn auto_root_rebuild_from_empty() {
1237 let map = LearnedMap::new();
1238 let g = map.guard();
1239 for i in 0..200u64 {
1240 map.insert(i, i, &g);
1241 }
1242 let g2 = map.guard();
1243 let depth = map.max_depth(&g2);
1244 assert!(
1245 depth <= 12,
1246 "depth {depth} too high after auto root rebuild"
1247 );
1248 for i in 0..200u64 {
1249 assert_eq!(map.get(&i, &g2), Some(&i), "key {i} missing");
1250 }
1251 }
1252
1253 #[test]
1254 fn auto_root_rebuild_disabled() {
1255 let map = LearnedMap::with_config(Config::new().auto_rebuild(false));
1256 let g = map.guard();
1257 for i in 0..200u64 {
1258 map.insert(i, i, &g);
1259 }
1260 let depth = map.max_depth(&g);
1261 assert!(depth > 5, "depth {depth} too low without auto rebuild");
1262 }
1263
1264 #[test]
1265 fn bulk_load_no_early_rebuild() {
1266 let pairs: Vec<(u64, u64)> = (0..100).map(|i| (i, i)).collect();
1267 let map = LearnedMap::bulk_load(&pairs).unwrap();
1268 let g = map.guard();
1269 let depth = map.max_depth(&g);
1270 assert!(depth <= 3, "bulk-loaded tree depth {depth} too high");
1271 assert_eq!(map.len(), 100);
1272 }
1273
1274 #[test]
1275 fn manual_rebuild_resets_threshold() {
1276 let map = LearnedMap::new();
1277 let g = map.guard();
1278 for i in 0..50u64 {
1279 map.insert(i, i, &g);
1280 }
1281 map.rebuild(&g);
1282 let g2 = map.guard();
1283 for i in 50..150u64 {
1284 map.insert(i, i, &g2);
1285 }
1286 assert_eq!(map.len(), 150);
1287 for i in 0..150u64 {
1288 assert_eq!(map.get(&i, &g2), Some(&i));
1289 }
1290 }
1291
1292 #[test]
1293 fn clear_empties_map() {
1294 let map = LearnedMap::new();
1295 let g = map.guard();
1296 for i in 0..100u64 {
1297 map.insert(i, i, &g);
1298 }
1299 assert_eq!(map.len(), 100);
1300 map.clear(&g);
1301 let g2 = map.guard();
1302 assert_eq!(map.len(), 0);
1303 assert!(map.is_empty());
1304 for i in 0..100u64 {
1305 assert_eq!(map.get(&i, &g2), None);
1306 }
1307 }
1308
1309 #[test]
1310 fn clear_then_reinsert() {
1311 let map = LearnedMap::new();
1312 let g = map.guard();
1313 for i in 0..50u64 {
1314 map.insert(i, i * 10, &g);
1315 }
1316 map.clear(&g);
1317 let g2 = map.guard();
1318 for i in 0..30u64 {
1319 map.insert(i + 100, i, &g2);
1320 }
1321 assert_eq!(map.len(), 30);
1322 assert_eq!(map.get(&100, &g2), Some(&0));
1323 assert_eq!(map.get(&0, &g2), None);
1324 }
1325
1326 #[test]
1327 fn clear_empty_is_noop() {
1328 let map = LearnedMap::<u64, u64>::new();
1329 let g = map.guard();
1330 map.clear(&g);
1331 assert!(map.is_empty());
1332 }
1333
1334 #[test]
1335 fn map_ref_clear() {
1336 let map = LearnedMap::new();
1337 let m = map.pin();
1338 m.insert(1u64, "a");
1339 m.insert(2, "b");
1340 assert_eq!(m.len(), 2);
1341 m.clear();
1342 assert!(m.is_empty());
1343 assert_eq!(m.get(&1), None);
1344 }
1345
1346 #[test]
1347 fn drain_returns_sorted_entries() {
1348 let map = LearnedMap::new();
1349 let g = map.guard();
1350 for i in (0..50u64).rev() {
1351 map.insert(i, i * 10, &g);
1352 }
1353 assert_eq!(map.len(), 50);
1354 let drained = map.drain(&g);
1355 assert_eq!(drained.len(), 50);
1356 for w in drained.windows(2) {
1358 assert!(w[0].0 < w[1].0);
1359 }
1360 for (i, (k, v)) in drained.iter().enumerate() {
1362 assert_eq!(*k, i as u64);
1363 assert_eq!(*v, (i as u64) * 10);
1364 }
1365 let g2 = map.guard();
1367 assert!(map.is_empty());
1368 assert_eq!(map.get(&0, &g2), None);
1369 }
1370
1371 #[test]
1372 fn drain_empty_returns_empty() {
1373 let map = LearnedMap::<u64, u64>::new();
1374 let g = map.guard();
1375 let drained = map.drain(&g);
1376 assert!(drained.is_empty());
1377 assert!(map.is_empty());
1378 }
1379
1380 #[test]
1381 fn drain_then_reinsert() {
1382 let map = LearnedMap::new();
1383 let g = map.guard();
1384 for i in 0..30u64 {
1385 map.insert(i, i, &g);
1386 }
1387 let drained = map.drain(&g);
1388 assert_eq!(drained.len(), 30);
1389 let g2 = map.guard();
1390 for i in 100..110u64 {
1391 map.insert(i, i, &g2);
1392 }
1393 assert_eq!(map.len(), 10);
1394 assert_eq!(map.get(&100, &g2), Some(&100));
1395 assert_eq!(map.get(&0, &g2), None);
1396 }
1397
1398 #[test]
1399 fn map_ref_drain() {
1400 let map = LearnedMap::new();
1401 let m = map.pin();
1402 m.insert(3u64, "c");
1403 m.insert(1, "a");
1404 m.insert(2, "b");
1405 let drained = m.drain();
1406 assert_eq!(drained, vec![(1, "a"), (2, "b"), (3, "c")]);
1407 assert!(m.is_empty());
1408 }
1409
1410 #[test]
1411 fn allocated_bytes_empty() {
1412 let map = LearnedMap::<u64, u64>::new();
1413 let g = map.guard();
1414 let bytes = map.allocated_bytes(&g);
1415 assert!(bytes > 0, "empty map should have non-zero allocation");
1417 }
1418
1419 #[test]
1420 fn allocated_bytes_grows_with_entries() {
1421 let map = LearnedMap::new();
1422 let g = map.guard();
1423 let empty_bytes = map.allocated_bytes(&g);
1424
1425 for i in 0..100u64 {
1426 map.insert(i, i, &g);
1427 }
1428 let g2 = map.guard();
1429 let full_bytes = map.allocated_bytes(&g2);
1430 assert!(
1431 full_bytes > empty_bytes,
1432 "100 entries should use more memory than empty: {full_bytes} vs {empty_bytes}"
1433 );
1434 }
1435
1436 #[test]
1437 fn allocated_bytes_bulk_load() {
1438 let pairs: Vec<(u64, u64)> = (0..500).map(|i| (i, i)).collect();
1439 let map = LearnedMap::bulk_load(&pairs).unwrap();
1440 let g = map.guard();
1441 let bytes = map.allocated_bytes(&g);
1442 let min_data_bytes = 500 * std::mem::size_of::<u64>() * 2;
1444 assert!(
1445 bytes > min_data_bytes,
1446 "allocated_bytes {bytes} is less than minimum data size {min_data_bytes}"
1447 );
1448 }
1449
1450 #[test]
1451 fn map_ref_allocated_bytes() {
1452 let map = LearnedMap::new();
1453 let m = map.pin();
1454 m.insert(1u64, 1u64);
1455 m.insert(2, 2);
1456 let bytes = m.allocated_bytes();
1457 assert!(bytes > 0);
1458 }
1459}