Skip to main content

scry_index/
map.rs

1//! The primary public types: [`LearnedMap`], [`Guard`], and [`MapRef`].
2#![allow(unsafe_code)]
3
4use std::sync::atomic::{AtomicUsize, Ordering};
5
6use crossbeam_epoch::{self as epoch, Atomic, Owned};
7
8use crate::build;
9use crate::config::Config;
10use crate::error::Result;
11use crate::insert::{self, InsertResult};
12use std::ops::RangeBounds;
13
14use crate::iter::{self, Iter, Range};
15use crate::key::Key;
16use crate::lookup;
17use crate::model::LinearModel;
18use crate::node::Node;
19use crate::remove;
20
21/// Number of entries at which the first automatic root rebuild triggers.
22/// 64 entries gives FMCD enough data points for a well-fitted model,
23/// reducing the number of rebuilds during the critical ramp-up phase
24/// (first ~1000 inserts). The slightly longer initial bad-model phase
25/// is offset by fewer total rebuilds and better model quality.
26const INITIAL_ROOT_REBUILD_THRESHOLD: usize = 64;
27
28/// Growth factor between successive root rebuild thresholds.
29/// Schedule: 64, 128, 256, 512, 1024, 2048, ...
30/// Total amortized rebuild cost: sum of geometric series ≈ 2N = O(N).
31/// Using 2x (not 4x) keeps the root model fresh. This matters because keys
32/// outside the model's fitted range all clamp to the last slot.
33const ROOT_REBUILD_GROWTH_FACTOR: usize = 2;
34
35/// An epoch guard that keeps the current thread pinned.
36///
37/// While a guard exists, any memory retired during this epoch will not be
38/// reclaimed. Guards should be short-lived to avoid delaying reclamation.
39///
40/// Obtain a guard via [`LearnedMap::guard`] or use [`LearnedMap::pin`] for
41/// the convenience [`MapRef`] wrapper.
42pub struct Guard {
43    inner: epoch::Guard,
44}
45
46impl Guard {
47    fn new(inner: epoch::Guard) -> Self {
48        Self { inner }
49    }
50}
51
52/// A convenience handle that bundles a map reference with an epoch guard.
53///
54/// All operations on `MapRef` are forwarded to the underlying [`LearnedMap`]
55/// using the guard owned by this handle. This avoids passing a guard to
56/// every method call.
57///
58/// # Example
59///
60/// ```
61/// use scry_index::LearnedMap;
62///
63/// let map = LearnedMap::new();
64/// let m = map.pin();
65/// m.insert(1u64, "hello");
66/// assert_eq!(m.get(&1), Some(&"hello"));
67/// ```
68pub struct MapRef<'a, K: Key, V> {
69    map: &'a LearnedMap<K, V>,
70    guard: Guard,
71}
72
73impl<K: Key, V: Clone + Send + Sync> MapRef<'_, K, V> {
74    /// Look up a key, returning a reference to the value if found.
75    pub fn get(&self, key: &K) -> Option<&V> {
76        self.map.get(key, &self.guard)
77    }
78
79    /// Insert a key-value pair. Returns `true` if the key was newly inserted.
80    pub fn insert(&self, key: K, value: V) -> bool {
81        self.map.insert(key, value, &self.guard)
82    }
83
84    /// Remove a key. Returns `true` if the key was present and removed.
85    pub fn remove(&self, key: &K) -> bool {
86        self.map.remove(key, &self.guard)
87    }
88
89    /// Atomically get an existing value or insert a new one.
90    ///
91    /// See [`LearnedMap::get_or_insert`] for details.
92    pub fn get_or_insert(&self, key: K, value: V) -> &V {
93        self.map.get_or_insert(key, value, &self.guard)
94    }
95
96    /// Atomically get an existing value or insert a computed one.
97    ///
98    /// See [`LearnedMap::get_or_insert_with`] for details.
99    pub fn get_or_insert_with(&self, key: K, f: impl FnOnce() -> V) -> &V {
100        self.map.get_or_insert_with(key, f, &self.guard)
101    }
102
103    /// Check whether the map contains a key.
104    pub fn contains_key(&self, key: &K) -> bool {
105        self.map.contains_key(key, &self.guard)
106    }
107
108    /// Return the approximate number of key-value pairs in the map.
109    ///
110    /// See [`LearnedMap::len`] for details on relaxed-atomic staleness
111    /// under concurrency.
112    pub fn len(&self) -> usize {
113        self.map.len()
114    }
115
116    /// Return `true` if the map contains no entries.
117    ///
118    /// Subject to the same relaxed-atomic staleness as [`len`](Self::len).
119    pub fn is_empty(&self) -> bool {
120        self.map.is_empty()
121    }
122
123    /// Iterate over all key-value pairs in sorted order.
124    #[allow(clippy::iter_without_into_iter)]
125    pub fn iter(&self) -> Iter<'_, K, V> {
126        self.map.iter(&self.guard)
127    }
128
129    /// Collect all key-value pairs in sorted order (cloned).
130    pub fn iter_sorted(&self) -> Vec<(K, V)> {
131        self.map.iter_sorted(&self.guard)
132    }
133
134    /// Return an iterator over key-value pairs within the given range.
135    pub fn range<R: RangeBounds<K>>(&self, range: R) -> Range<'_, K, V> {
136        self.map.range(range, &self.guard)
137    }
138
139    /// Return the first (minimum) key-value pair.
140    pub fn first_key_value(&self) -> Option<(&K, &V)> {
141        self.map.first_key_value(&self.guard)
142    }
143
144    /// Return the last (maximum) key-value pair.
145    pub fn last_key_value(&self) -> Option<(&K, &V)> {
146        self.map.last_key_value(&self.guard)
147    }
148
149    /// Count the number of entries within the given range.
150    pub fn range_count<R: RangeBounds<K>>(&self, range: R) -> usize {
151        self.map.range_count(range, &self.guard)
152    }
153
154    /// Estimate the total heap memory allocated by this map, in bytes.
155    ///
156    /// See [`LearnedMap::allocated_bytes`] for details.
157    pub fn allocated_bytes(&self) -> usize {
158        self.map.allocated_bytes(&self.guard)
159    }
160
161    /// Return the maximum depth of the tree.
162    pub fn max_depth(&self) -> usize {
163        self.map.max_depth(&self.guard)
164    }
165
166    /// Rebuild the tree from scratch using bulk load.
167    pub fn rebuild(&self) {
168        self.map.rebuild(&self.guard);
169    }
170
171    /// Remove all entries from the map and return them as a sorted `Vec`.
172    pub fn drain(&self) -> Vec<(K, V)> {
173        self.map.drain(&self.guard)
174    }
175
176    /// Remove all entries from the map.
177    pub fn clear(&self) {
178        self.map.clear(&self.guard);
179    }
180}
181
182/// A sorted key-value map backed by a learned index.
183///
184/// Uses piecewise linear models to predict key positions, achieving O(1)
185/// expected lookup time for keys matching the data distribution.
186///
187/// # When to use
188///
189/// Best suited for read-heavy, concurrent workloads with sorted keys
190/// (time-series queries, lookup tables, analytics indexes). Point lookups
191/// are faster than `BTreeMap`; writes are competitive for sequential and
192/// append-only patterns. For random-key insert-heavy workloads, prefer
193/// [`bulk_load`](Self::bulk_load) over one-by-one insertion.
194///
195/// # Concurrency
196///
197/// All operations take `&self` and are safe to call from multiple threads.
198/// Reads are lock-free (atomic loads under an epoch guard). Writes use
199/// compare-and-swap retry loops on individual slots. No global lock.
200///
201/// # Example
202///
203/// ```
204/// use scry_index::LearnedMap;
205///
206/// let map = LearnedMap::new();
207/// let guard = map.guard();
208///
209/// map.insert(42u64, "hello", &guard);
210/// map.insert(17, "world", &guard);
211///
212/// assert_eq!(map.get(&42, &guard), Some(&"hello"));
213/// assert_eq!(map.get(&99, &guard), None);
214/// assert_eq!(map.len(), 2);
215/// ```
216pub struct LearnedMap<K: Key, V> {
217    root: Atomic<Node<K, V>>,
218    len: AtomicUsize,
219    config: Config,
220    /// Entry count at which the next automatic root rebuild triggers.
221    next_root_rebuild: AtomicUsize,
222}
223
224impl<K: Key, V> std::fmt::Debug for LearnedMap<K, V> {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        f.debug_struct("LearnedMap")
227            .field("len", &self.len.load(Ordering::Relaxed))
228            .finish_non_exhaustive()
229    }
230}
231
232impl<K: Key, V: Clone + Send + Sync> LearnedMap<K, V> {
233    /// Create a new empty learned map with default configuration.
234    pub fn new() -> Self {
235        Self::with_config(Config::default())
236    }
237
238    /// Create a new empty learned map with the given configuration.
239    pub fn with_config(config: Config) -> Self {
240        let root = Node::with_capacity(LinearModel::new(1.0, 0.0), 64);
241        let root_atomic = Atomic::new(root);
242        Self {
243            root: root_atomic,
244            len: AtomicUsize::new(0),
245            next_root_rebuild: AtomicUsize::new(INITIAL_ROOT_REBUILD_THRESHOLD),
246            config,
247        }
248    }
249
250    /// Create a learned map from sorted key-value pairs.
251    ///
252    /// Faster than inserting one-by-one because it builds the tree
253    /// structure in one pass using FMCD model fitting.
254    ///
255    /// # Errors
256    ///
257    /// Returns an error if `pairs` is empty or not sorted by key.
258    pub fn bulk_load(pairs: &[(K, V)]) -> Result<Self> {
259        Self::bulk_load_with_config(pairs, Config::default())
260    }
261
262    /// Create a learned map from sorted key-value pairs with configuration.
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if `pairs` is empty or not sorted by key.
267    pub fn bulk_load_with_config(pairs: &[(K, V)], config: Config) -> Result<Self> {
268        // Build with headroom so appends beyond the loaded range don't all
269        // clamp to the last slot. This is the most common pattern: bulk_load
270        // historical data, then stream new entries.
271        let build_config = Config {
272            range_headroom: config.range_headroom.max(1.0),
273            ..config
274        };
275        let root = build::bulk_load(pairs, &build_config)?;
276        let root_atomic = Atomic::new(root);
277        let next_threshold = pairs.len().saturating_mul(ROOT_REBUILD_GROWTH_FACTOR);
278        Ok(Self {
279            len: AtomicUsize::new(pairs.len()),
280            root: root_atomic,
281            next_root_rebuild: AtomicUsize::new(next_threshold),
282            config,
283        })
284    }
285
286    /// Create a learned map from sorted key-value pairs, deduplicating keys.
287    ///
288    /// When duplicate keys are present, the **last** value for each key is kept,
289    /// matching the semantics of `BTreeMap::from_iter`. The input must be sorted
290    /// by key in ascending order but may contain duplicates.
291    ///
292    /// This is equivalent to deduplicating and then calling [`bulk_load`](Self::bulk_load).
293    ///
294    /// # Errors
295    ///
296    /// Returns an error if `pairs` is empty (after dedup) or not sorted by key.
297    pub fn bulk_load_dedup(pairs: &[(K, V)]) -> Result<Self> {
298        Self::bulk_load_dedup_with_config(pairs, Config::default())
299    }
300
301    /// Create a learned map from sorted key-value pairs with dedup and config.
302    ///
303    /// See [`bulk_load_dedup`](Self::bulk_load_dedup) for semantics.
304    ///
305    /// # Errors
306    ///
307    /// Returns an error if `pairs` is empty (after dedup) or not sorted by key.
308    pub fn bulk_load_dedup_with_config(pairs: &[(K, V)], config: Config) -> Result<Self> {
309        if pairs.is_empty() {
310            return Err(crate::error::Error::EmptyData);
311        }
312
313        // Verify sorted (non-strictly: duplicates allowed).
314        for window in pairs.windows(2) {
315            if window[0].0 > window[1].0 {
316                return Err(crate::error::Error::NotSorted);
317            }
318        }
319
320        // Dedup: keep last value per key (scan right-to-left, take first unseen).
321        let mut deduped = Vec::with_capacity(pairs.len());
322        for window in pairs.windows(2) {
323            if window[0].0 != window[1].0 {
324                deduped.push(window[0].clone());
325            }
326        }
327        // Always include the last element.
328        if let Some(last) = pairs.last() {
329            deduped.push(last.clone());
330        }
331
332        if deduped.is_empty() {
333            return Err(crate::error::Error::EmptyData);
334        }
335
336        Self::bulk_load_with_config(&deduped, config)
337    }
338
339    /// Acquire an epoch guard for use with operations on this map.
340    ///
341    /// The guard pins the current thread to an epoch, preventing any
342    /// concurrently retired memory from being reclaimed while the guard
343    /// is held. Keep guards short-lived.
344    pub fn guard(&self) -> Guard {
345        Guard::new(epoch::pin())
346    }
347
348    /// Pin the current epoch and return a [`MapRef`] convenience handle.
349    ///
350    /// This is equivalent to `guard()` + passing the guard to every method,
351    /// but more ergonomic for sequences of operations.
352    pub fn pin(&self) -> MapRef<'_, K, V> {
353        MapRef {
354            map: self,
355            guard: self.guard(),
356        }
357    }
358
359    /// Look up a key, returning a reference to the value if found.
360    ///
361    /// The returned reference is valid for the lifetime of the guard.
362    pub fn get<'g>(&self, key: &K, guard: &'g Guard) -> Option<&'g V> {
363        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
364        // SAFETY: root is always non-null (set during construction, never nulled).
365        let root = unsafe { root_shared.deref() };
366        lookup::get(root, key, &guard.inner)
367    }
368
369    /// Insert a key-value pair. Returns `true` if the key was newly inserted,
370    /// `false` if an existing key's value was updated.
371    ///
372    /// When `auto_rebuild` is enabled, the insert path tracks descent depth
373    /// and triggers a localized subtree rebuild if the depth exceeds the
374    /// configured threshold. No global lock is required.
375    ///
376    /// If a concurrent root rebuild replaces the tree, the insert detects the
377    /// change and retries against the new root. A `was_new` flag tracks
378    /// whether the key was newly inserted across retries so `len` is
379    /// incremented exactly once.
380    #[allow(clippy::needless_pass_by_value)]
381    pub fn insert(&self, key: K, value: V, guard: &Guard) -> bool {
382        let mut was_new = false;
383        let backoff = crossbeam_utils::Backoff::new();
384        loop {
385            let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
386            // If root is frozen (tagged), a global rebuild is in progress.
387            if root_shared.tag() != 0 {
388                backoff.snooze();
389                continue;
390            }
391            // SAFETY: root is always non-null.
392            let root = unsafe { root_shared.deref() };
393            let result = insert::insert(root, key.clone(), &value, &self.config, &guard.inner);
394            // Validate: root wasn't replaced or frozen by a concurrent rebuild.
395            if self.root.load(Ordering::Acquire, &guard.inner) != root_shared {
396                if result == InsertResult::Inserted {
397                    was_new = true;
398                }
399                continue;
400            }
401            let is_new = result == InsertResult::Inserted || was_new;
402            if is_new {
403                let new_len = self.len.fetch_add(1, Ordering::Relaxed).wrapping_add(1);
404                self.maybe_rebuild_root(new_len, guard);
405            }
406            return is_new;
407        }
408    }
409
410    /// Check if the root should be rebuilt and, if so, claim the rebuild via CAS.
411    ///
412    /// Only one thread wins the CAS on `next_root_rebuild`. The winner performs
413    /// the rebuild inline; all other threads proceed without blocking.
414    fn maybe_rebuild_root(&self, current_len: usize, guard: &Guard) {
415        if !self.config.auto_rebuild {
416            return;
417        }
418
419        let threshold = self.next_root_rebuild.load(Ordering::Relaxed);
420        if current_len < threshold {
421            return;
422        }
423
424        // Try to claim the rebuild by CAS-advancing the threshold.
425        let next_threshold = threshold.saturating_mul(ROOT_REBUILD_GROWTH_FACTOR);
426        if self
427            .next_root_rebuild
428            .compare_exchange(
429                threshold,
430                next_threshold,
431                Ordering::AcqRel,
432                Ordering::Relaxed,
433            )
434            .is_ok()
435        {
436            self.rebuild(guard);
437        }
438    }
439
440    /// Remove a key. Returns `true` if the key was present and removed.
441    ///
442    /// If a concurrent root rebuild replaces the tree, the remove detects the
443    /// change and retries against the new root. A `was_removed` flag tracks
444    /// whether the key was successfully removed across retries so `len` is
445    /// decremented exactly once.
446    pub fn remove(&self, key: &K, guard: &Guard) -> bool {
447        let mut was_removed = false;
448        let backoff = crossbeam_utils::Backoff::new();
449        loop {
450            let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
451            // If root is frozen (tagged), a global rebuild is in progress.
452            if root_shared.tag() != 0 {
453                backoff.snooze();
454                continue;
455            }
456            // SAFETY: root is always non-null.
457            let root = unsafe { root_shared.deref() };
458            let removed = remove::remove(root, key, &self.config, &guard.inner);
459            // Validate: root wasn't replaced or frozen by a concurrent rebuild.
460            if self.root.load(Ordering::Acquire, &guard.inner) != root_shared {
461                if removed {
462                    was_removed = true;
463                }
464                continue;
465            }
466            let did_remove = removed || was_removed;
467            if did_remove {
468                self.len.fetch_sub(1, Ordering::Relaxed);
469            }
470            return did_remove;
471        }
472    }
473
474    /// Atomically get an existing value or insert a new one.
475    ///
476    /// If the key already exists, returns a reference to the existing value
477    /// without modifying it. If the key is absent, inserts the key-value pair
478    /// and returns a reference to the newly inserted value.
479    ///
480    /// This is atomic with respect to concurrent operations. There is no
481    /// TOCTOU race between checking and inserting.
482    #[allow(clippy::needless_pass_by_value)]
483    pub fn get_or_insert<'g>(&self, key: K, value: V, guard: &'g Guard) -> &'g V {
484        let mut was_new = false;
485        let backoff = crossbeam_utils::Backoff::new();
486        loop {
487            let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
488            if root_shared.tag() != 0 {
489                backoff.snooze();
490                continue;
491            }
492            // SAFETY: root is always non-null.
493            let root = unsafe { root_shared.deref() };
494            let (val, result) =
495                insert::get_or_insert(root, key.clone(), &value, &self.config, &guard.inner);
496            // Validate: root wasn't replaced or frozen by a concurrent rebuild.
497            if self.root.load(Ordering::Acquire, &guard.inner) != root_shared {
498                if result == InsertResult::Inserted {
499                    was_new = true;
500                }
501                continue;
502            }
503            let is_new = result == InsertResult::Inserted || was_new;
504            if is_new {
505                let new_len = self.len.fetch_add(1, Ordering::Relaxed).wrapping_add(1);
506                self.maybe_rebuild_root(new_len, guard);
507            }
508            return val;
509        }
510    }
511
512    /// Atomically get an existing value or insert a computed one.
513    ///
514    /// Like [`get_or_insert`](Self::get_or_insert), but the value is lazily
515    /// computed by `f` only if the key is absent during the initial lookup.
516    /// Under concurrent inserts, `f` may be called even if another thread
517    /// inserts the same key first; in that case the computed value is
518    /// discarded and the existing value is returned.
519    pub fn get_or_insert_with<'g>(&self, key: K, f: impl FnOnce() -> V, guard: &'g Guard) -> &'g V {
520        // Fast path: key already exists.
521        if let Some(val) = self.get(&key, guard) {
522            return val;
523        }
524        // Slow path: compute value and atomically insert-or-get.
525        let value = f();
526        self.get_or_insert(key, value, guard)
527    }
528
529    /// Check whether the map contains a key.
530    pub fn contains_key(&self, key: &K, guard: &Guard) -> bool {
531        self.get(key, guard).is_some()
532    }
533
534    /// Return the approximate number of key-value pairs in the map.
535    ///
536    /// Uses a relaxed atomic load internally. Under concurrent inserts or
537    /// removes the returned value may be slightly stale. It is **not**
538    /// linearizable with respect to other operations. For an exact count,
539    /// call [`iter`](Self::iter) and count the entries, which gives a
540    /// consistent snapshot under the epoch guard.
541    pub fn len(&self) -> usize {
542        self.len.load(Ordering::Relaxed)
543    }
544
545    /// Return `true` if the map contains no entries.
546    ///
547    /// Subject to the same relaxed-atomic staleness as [`len`](Self::len).
548    pub fn is_empty(&self) -> bool {
549        self.len() == 0
550    }
551
552    /// Iterate over all key-value pairs in sorted order.
553    ///
554    /// The returned references are valid for the lifetime of the guard.
555    pub fn iter<'g>(&self, guard: &'g Guard) -> Iter<'g, K, V> {
556        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
557        // SAFETY: root is always non-null.
558        let root = unsafe { root_shared.deref() };
559        Iter::with_hint(root, &guard.inner, self.len())
560    }
561
562    /// Collect all key-value pairs in sorted order (cloned).
563    ///
564    /// Performs a full traversal and clones all entries.
565    pub fn iter_sorted(&self, guard: &Guard) -> Vec<(K, V)> {
566        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
567        // SAFETY: root is always non-null.
568        let root = unsafe { root_shared.deref() };
569        iter::sorted_pairs(root, &guard.inner)
570    }
571
572    /// Return an iterator over key-value pairs within the given range.
573    ///
574    /// The iterator yields entries in ascending key order. Uses model-guided
575    /// seek for efficient initialization when a start bound is provided.
576    ///
577    /// Accepts any range syntax: `a..b`, `a..=b`, `a..`, `..b`, `..=b`, `..`.
578    pub fn range<'g, R: RangeBounds<K>>(&self, range: R, guard: &'g Guard) -> Range<'g, K, V> {
579        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
580        // SAFETY: root is always non-null.
581        let root = unsafe { root_shared.deref() };
582        Range::new(root, range, &guard.inner)
583    }
584
585    /// Return the first (minimum) key-value pair in the map.
586    pub fn first_key_value<'g>(&self, guard: &'g Guard) -> Option<(&'g K, &'g V)> {
587        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
588        // SAFETY: root is always non-null.
589        let root = unsafe { root_shared.deref() };
590        iter::first_entry(root, &guard.inner)
591    }
592
593    /// Return the last (maximum) key-value pair in the map.
594    pub fn last_key_value<'g>(&self, guard: &'g Guard) -> Option<(&'g K, &'g V)> {
595        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
596        // SAFETY: root is always non-null.
597        let root = unsafe { root_shared.deref() };
598        iter::last_entry(root, &guard.inner)
599    }
600
601    /// Count the number of entries within the given range.
602    pub fn range_count<R: RangeBounds<K>>(&self, range: R, guard: &Guard) -> usize {
603        self.range(range, guard).count()
604    }
605
606    /// Estimate the total heap memory allocated by this map, in bytes.
607    ///
608    /// Walks the entire tree and sums up node structs, slot arrays, and
609    /// per-entry allocations. This is an approximation: it does not account
610    /// for allocator overhead, alignment padding, or epoch-deferred garbage.
611    pub fn allocated_bytes(&self, guard: &Guard) -> usize {
612        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
613        // SAFETY: root is always non-null.
614        let root = unsafe { root_shared.deref() };
615        root.allocated_bytes(&guard.inner)
616    }
617
618    /// Return the maximum depth of the tree.
619    ///
620    /// A well-fit model keeps depth low (typically 1-3 after bulk load).
621    pub fn max_depth(&self, guard: &Guard) -> usize {
622        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
623        // SAFETY: root is always non-null.
624        let root = unsafe { root_shared.deref() };
625        root.max_depth(&guard.inner)
626    }
627
628    /// Rebuild the tree from scratch using bulk load.
629    ///
630    /// Collects all key-value pairs, sorts them, and rebuilds with optimal
631    /// FMCD model fitting. This compacts the tree and restores O(1) lookups
632    /// after many incremental inserts.
633    ///
634    /// The root is *frozen* (tagged) during the rebuild so concurrent inserts
635    /// spin-wait and then retry against the new root. In-flight inserts that
636    /// loaded the old root before the freeze detect the change via post-insert
637    /// validation in [`LearnedMap::insert`] and retry automatically.
638    pub fn rebuild(&self, guard: &Guard) {
639        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
640        if root_shared.is_null() || root_shared.tag() != 0 {
641            return;
642        }
643        // SAFETY: root is not null.
644        let root = unsafe { root_shared.deref() };
645
646        // Freeze the root: tag it so concurrent inserts spin-wait instead
647        // of operating on the old tree during the rebuild.
648        let frozen = root_shared.with_tag(1);
649        if self
650            .root
651            .compare_exchange(
652                root_shared,
653                frozen,
654                Ordering::AcqRel,
655                Ordering::Acquire,
656                &guard.inner,
657            )
658            .is_err()
659        {
660            return;
661        }
662
663        let pairs = iter::sorted_pairs(root, &guard.inner);
664        if pairs.is_empty() {
665            // Unfreeze.
666            let _ = self.root.compare_exchange(
667                frozen,
668                root_shared,
669                Ordering::AcqRel,
670                Ordering::Relaxed,
671                &guard.inner,
672            );
673            return;
674        }
675
676        // Root rebuilds use range headroom so the model covers beyond the
677        // current key range. Without this, all keys inserted after the rebuild
678        // that exceed the training range clamp to the last slot.
679        let rebuild_config = Config {
680            range_headroom: 1.0,
681            ..self.config.clone()
682        };
683        // DEFENSIVE: sorted_pairs always produces sorted, non-empty data when
684        // the pairs.is_empty() check above passes. bulk_load only fails on
685        // EmptyData or NotSorted, neither of which can occur here.
686        let Ok(new_root) = build::bulk_load(&pairs, &rebuild_config) else {
687            // Unfreeze.
688            let _ = self.root.compare_exchange(
689                frozen,
690                root_shared,
691                Ordering::AcqRel,
692                Ordering::Relaxed,
693                &guard.inner,
694            );
695            return;
696        };
697        let new_owned = Owned::new(new_root);
698        if self
699            .root
700            .compare_exchange(
701                frozen,
702                new_owned,
703                Ordering::AcqRel,
704                Ordering::Acquire,
705                &guard.inner,
706            )
707            .is_ok()
708        {
709            // SAFETY: CAS succeeded; old root is unreachable to new readers.
710            // In-flight inserts detect the change via map::insert validation.
711            unsafe {
712                guard.inner.defer_destroy(root_shared);
713            }
714            // Don't reset len. It's maintained solely by map::insert
715            // (fetch_add) and map::remove (fetch_sub). Resetting via store
716            // would race with concurrent fetch_adds.
717            let count = pairs.len();
718            // Reset the rebuild schedule relative to the new tree's size.
719            self.next_root_rebuild.store(
720                count.saturating_mul(ROOT_REBUILD_GROWTH_FACTOR),
721                Ordering::Relaxed,
722            );
723        } else {
724            // DEFENSIVE: CAS on a frozen pointer we own should always succeed
725            // (no other code path modifies a frozen root). Unfreeze as safety net.
726            let _ = self.root.compare_exchange(
727                frozen,
728                root_shared,
729                Ordering::AcqRel,
730                Ordering::Relaxed,
731                &guard.inner,
732            );
733        }
734    }
735
736    /// Remove all entries from the map and return them as a sorted `Vec`.
737    ///
738    /// Uses the same freeze protocol as [`rebuild`](Self::rebuild) to coordinate
739    /// with concurrent inserts. After draining, the map has a fresh root identical
740    /// to one created by [`new`](Self::new).
741    ///
742    /// Returns an empty `Vec` if the map is already empty or if the freeze CAS
743    /// fails (another thread is rebuilding concurrently).
744    pub fn drain(&self, guard: &Guard) -> Vec<(K, V)> {
745        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
746        if root_shared.is_null() || root_shared.tag() != 0 {
747            return Vec::new();
748        }
749
750        // Freeze the root so concurrent inserts spin-wait.
751        let frozen = root_shared.with_tag(1);
752        if self
753            .root
754            .compare_exchange(
755                root_shared,
756                frozen,
757                Ordering::AcqRel,
758                Ordering::Acquire,
759                &guard.inner,
760            )
761            .is_err()
762        {
763            return Vec::new();
764        }
765
766        // SAFETY: root is not null (checked above) and frozen by us.
767        let root = unsafe { root_shared.deref() };
768        let pairs = iter::sorted_pairs(root, &guard.inner);
769
770        let new_root = Node::with_capacity(LinearModel::new(1.0, 0.0), 64);
771        let new_owned = Owned::new(new_root);
772        if self
773            .root
774            .compare_exchange(
775                frozen,
776                new_owned,
777                Ordering::AcqRel,
778                Ordering::Acquire,
779                &guard.inner,
780            )
781            .is_ok()
782        {
783            // SAFETY: CAS succeeded; old root is unreachable to new readers.
784            unsafe {
785                guard.inner.defer_destroy(root_shared);
786            }
787            // Subtract the drained count instead of storing zero. A plain
788            // store(0) would race with concurrent fetch_adds from inserts
789            // landing in the new root between the CAS above and this point.
790            self.len.fetch_sub(pairs.len(), Ordering::Relaxed);
791            self.next_root_rebuild
792                .store(INITIAL_ROOT_REBUILD_THRESHOLD, Ordering::Relaxed);
793        } else {
794            // DEFENSIVE: CAS on a frozen pointer we own should always succeed.
795            let _ = self.root.compare_exchange(
796                frozen,
797                root_shared,
798                Ordering::AcqRel,
799                Ordering::Relaxed,
800                &guard.inner,
801            );
802        }
803
804        pairs
805    }
806
807    /// Remove all entries from the map, resetting it to an empty state.
808    ///
809    /// Uses the same freeze protocol as [`rebuild`](Self::rebuild) to coordinate
810    /// with concurrent inserts. After clearing, the map has a fresh root identical
811    /// to one created by [`new`](Self::new).
812    pub fn clear(&self, guard: &Guard) {
813        let root_shared = self.root.load(Ordering::Acquire, &guard.inner);
814        if root_shared.is_null() || root_shared.tag() != 0 {
815            return;
816        }
817
818        // Freeze the root so concurrent inserts spin-wait.
819        let frozen = root_shared.with_tag(1);
820        if self
821            .root
822            .compare_exchange(
823                root_shared,
824                frozen,
825                Ordering::AcqRel,
826                Ordering::Acquire,
827                &guard.inner,
828            )
829            .is_err()
830        {
831            return;
832        }
833
834        // Count live entries while the tree is frozen (no concurrent
835        // inserts/removes can proceed). Used for the fetch_sub below.
836        // SAFETY: root is non-null (checked above) and frozen by us.
837        let old_root = unsafe { root_shared.deref() };
838        let entry_count = Iter::new(old_root, &guard.inner).count();
839
840        let new_root = Node::with_capacity(LinearModel::new(1.0, 0.0), 64);
841        let new_owned = Owned::new(new_root);
842        if self
843            .root
844            .compare_exchange(
845                frozen,
846                new_owned,
847                Ordering::AcqRel,
848                Ordering::Acquire,
849                &guard.inner,
850            )
851            .is_ok()
852        {
853            // SAFETY: CAS succeeded; old root is unreachable to new readers.
854            unsafe {
855                guard.inner.defer_destroy(root_shared);
856            }
857            // Subtract the cleared count instead of storing zero. A plain
858            // store(0) would race with concurrent fetch_adds from inserts
859            // landing in the new root between the CAS above and this point.
860            self.len.fetch_sub(entry_count, Ordering::Relaxed);
861            self.next_root_rebuild
862                .store(INITIAL_ROOT_REBUILD_THRESHOLD, Ordering::Relaxed);
863        } else {
864            // DEFENSIVE: CAS on a frozen pointer we own should always succeed.
865            let _ = self.root.compare_exchange(
866                frozen,
867                root_shared,
868                Ordering::AcqRel,
869                Ordering::Relaxed,
870                &guard.inner,
871            );
872        }
873    }
874}
875
876#[cfg(feature = "serde")]
877impl<K, V> serde::Serialize for LearnedMap<K, V>
878where
879    K: Key + serde::Serialize,
880    V: Clone + Send + Sync + serde::Serialize,
881{
882    fn serialize<S: serde::Serializer>(
883        &self,
884        serializer: S,
885    ) -> std::result::Result<S::Ok, S::Error> {
886        use serde::ser::SerializeSeq;
887
888        let guard = self.guard();
889        let len = self.len();
890        let mut seq = serializer.serialize_seq(Some(len))?;
891        for (k, v) in self.iter(&guard) {
892            seq.serialize_element(&(k, v))?;
893        }
894        seq.end()
895    }
896}
897
898#[cfg(feature = "serde")]
899impl<'de, K, V> serde::Deserialize<'de> for LearnedMap<K, V>
900where
901    K: Key + serde::Deserialize<'de>,
902    V: Clone + Send + Sync + serde::Deserialize<'de>,
903{
904    fn deserialize<D: serde::Deserializer<'de>>(
905        deserializer: D,
906    ) -> std::result::Result<Self, D::Error> {
907        let pairs: Vec<(K, V)> = Vec::deserialize(deserializer)?;
908        if pairs.is_empty() {
909            return Ok(Self::new());
910        }
911        Self::bulk_load_dedup(&pairs).map_err(serde::de::Error::custom)
912    }
913}
914
915impl<K: Key, V: Clone + Send + Sync> Default for LearnedMap<K, V> {
916    fn default() -> Self {
917        Self::new()
918    }
919}
920
921impl<K: Key, V: Clone + Send + Sync> Extend<(K, V)> for LearnedMap<K, V> {
922    fn extend<I: IntoIterator<Item = (K, V)>>(&mut self, iter: I) {
923        let guard = self.guard();
924        for (k, v) in iter {
925            self.insert(k, v, &guard);
926        }
927    }
928}
929
930/// Note: `from_iter` inserts elements one at a time into an empty map.
931/// For pre-sorted data, [`LearnedMap::bulk_load`] is significantly faster
932/// because it fits an optimal model in a single pass rather than
933/// building incrementally with conflict resolution.
934impl<K: Key, V: Clone + Send + Sync> FromIterator<(K, V)> for LearnedMap<K, V> {
935    fn from_iter<I: IntoIterator<Item = (K, V)>>(iter: I) -> Self {
936        let map = Self::new();
937        let guard = map.guard();
938        for (k, v) in iter {
939            map.insert(k, v, &guard);
940        }
941        map
942    }
943}
944
945impl<K: Key, V> Drop for LearnedMap<K, V> {
946    fn drop(&mut self) {
947        // We must defer destruction of the root rather than freeing immediately.
948        // Our Guard type does not borrow the map, so a caller can hold a Guard
949        // (and references from get()) after the map is dropped. Using
950        // defer_destroy ensures the tree lives until all such guards are gone.
951        //
952        // SAFETY: We pin the current epoch and schedule the root for deferred
953        // destruction. Crossbeam guarantees the root won't be freed until all
954        // guards active at this epoch are dropped. Node::drop (which uses
955        // unprotected + into_owned to free children) is safe at that point
956        // because no guard can still reference the tree.
957        unsafe {
958            let guard = epoch::pin();
959            let shared = self.root.load(Ordering::Relaxed, &guard);
960            if !shared.is_null() {
961                guard.defer_destroy(shared);
962            }
963        }
964    }
965}
966
967// SAFETY: LearnedMap is Send+Sync when K and V are Send+Sync. All interior
968// mutation goes through atomic operations and epoch-based reclamation.
969unsafe impl<K: Key, V: Send + Sync> Send for LearnedMap<K, V> {}
970unsafe impl<K: Key, V: Send + Sync> Sync for LearnedMap<K, V> {}
971
972#[cfg(test)]
973mod tests {
974    use super::*;
975
976    #[test]
977    fn new_map_is_empty() {
978        let map = LearnedMap::<u64, ()>::new();
979        assert!(map.is_empty());
980        assert_eq!(map.len(), 0);
981    }
982
983    #[test]
984    fn insert_and_get() {
985        let map = LearnedMap::new();
986        let g = map.guard();
987        assert!(map.insert(42u64, "hello", &g));
988        assert_eq!(map.get(&42, &g), Some(&"hello"));
989        assert_eq!(map.len(), 1);
990    }
991
992    #[test]
993    fn insert_duplicate_updates() {
994        let map = LearnedMap::new();
995        let g = map.guard();
996        assert!(map.insert(1u64, "one", &g));
997        assert!(!map.insert(1, "ONE", &g));
998        assert_eq!(map.get(&1, &g), Some(&"ONE"));
999        assert_eq!(map.len(), 1);
1000    }
1001
1002    #[test]
1003    fn remove_existing() {
1004        let map = LearnedMap::new();
1005        let g = map.guard();
1006        map.insert(1u64, "a", &g);
1007        map.insert(2, "b", &g);
1008        assert!(map.remove(&1, &g));
1009        assert_eq!(map.len(), 1);
1010        assert!(!map.contains_key(&1, &g));
1011        assert!(map.contains_key(&2, &g));
1012    }
1013
1014    #[test]
1015    fn remove_missing() {
1016        let map = LearnedMap::new();
1017        let g = map.guard();
1018        map.insert(1u64, "a", &g);
1019        assert!(!map.remove(&99, &g));
1020        assert_eq!(map.len(), 1);
1021    }
1022
1023    #[test]
1024    fn bulk_load_basic() {
1025        let pairs: Vec<(u64, u64)> = (0..100).map(|i| (i, i * 10)).collect();
1026        let map = LearnedMap::bulk_load(&pairs).unwrap();
1027        let g = map.guard();
1028        assert_eq!(map.len(), 100);
1029        for (k, v) in &pairs {
1030            assert_eq!(map.get(k, &g), Some(v));
1031        }
1032    }
1033
1034    #[test]
1035    fn bulk_load_then_insert() {
1036        let pairs: Vec<(u64, u64)> = vec![(10, 1), (20, 2), (30, 3)];
1037        let map = LearnedMap::bulk_load(&pairs).unwrap();
1038        let g = map.guard();
1039        map.insert(15, 15, &g);
1040        map.insert(25, 25, &g);
1041        assert_eq!(map.len(), 5);
1042        assert_eq!(map.get(&15, &g), Some(&15));
1043        assert_eq!(map.get(&25, &g), Some(&25));
1044    }
1045
1046    #[test]
1047    fn bulk_load_dedup_keeps_last() {
1048        let pairs: Vec<(u64, &str)> = vec![(1, "a"), (1, "A"), (2, "b"), (3, "c"), (3, "C")];
1049        let map = LearnedMap::bulk_load_dedup(&pairs).unwrap();
1050        let g = map.guard();
1051        assert_eq!(map.len(), 3);
1052        assert_eq!(map.get(&1, &g), Some(&"A"));
1053        assert_eq!(map.get(&2, &g), Some(&"b"));
1054        assert_eq!(map.get(&3, &g), Some(&"C"));
1055    }
1056
1057    #[test]
1058    fn bulk_load_dedup_no_duplicates() {
1059        let pairs: Vec<(u64, u64)> = (0..50).map(|i| (i, i * 10)).collect();
1060        let map = LearnedMap::bulk_load_dedup(&pairs).unwrap();
1061        let g = map.guard();
1062        assert_eq!(map.len(), 50);
1063        for (k, v) in &pairs {
1064            assert_eq!(map.get(k, &g), Some(v));
1065        }
1066    }
1067
1068    #[test]
1069    fn bulk_load_dedup_all_same_key() {
1070        let pairs: Vec<(u64, u64)> = (0..10).map(|i| (42, i)).collect();
1071        let map = LearnedMap::bulk_load_dedup(&pairs).unwrap();
1072        let g = map.guard();
1073        assert_eq!(map.len(), 1);
1074        assert_eq!(map.get(&42, &g), Some(&9));
1075    }
1076
1077    #[test]
1078    fn bulk_load_dedup_empty() {
1079        let result = LearnedMap::<u64, u64>::bulk_load_dedup(&[]);
1080        assert!(result.is_err());
1081    }
1082
1083    #[test]
1084    fn bulk_load_dedup_not_sorted() {
1085        let pairs: Vec<(u64, u64)> = vec![(3, 0), (1, 0), (2, 0)];
1086        let result = LearnedMap::bulk_load_dedup(&pairs);
1087        assert!(result.is_err());
1088    }
1089
1090    #[test]
1091    fn from_iterator() {
1092        let map: LearnedMap<u64, &str> = vec![(1, "a"), (2, "b"), (3, "c")].into_iter().collect();
1093        let g = map.guard();
1094        assert_eq!(map.len(), 3);
1095        assert_eq!(map.get(&2, &g), Some(&"b"));
1096    }
1097
1098    #[test]
1099    fn extend_map() {
1100        let mut map = LearnedMap::new();
1101        {
1102            let g = map.guard();
1103            map.insert(1u64, 10, &g);
1104        }
1105        map.extend(vec![(2, 20), (3, 30)]);
1106        assert_eq!(map.len(), 3);
1107    }
1108
1109    #[test]
1110    fn iter_sorted_order() {
1111        let map = LearnedMap::new();
1112        let g = map.guard();
1113        map.insert(30u64, "c", &g);
1114        map.insert(10, "a", &g);
1115        map.insert(20, "b", &g);
1116
1117        let items: Vec<(u64, &str)> = map.iter_sorted(&g);
1118        assert_eq!(items, vec![(10, "a"), (20, "b"), (30, "c")]);
1119    }
1120
1121    #[test]
1122    fn max_depth_bounded() {
1123        let pairs: Vec<(u64, u64)> = (0..1000).map(|i| (i, i)).collect();
1124        let map = LearnedMap::bulk_load(&pairs).unwrap();
1125        let g = map.guard();
1126        assert!(
1127            map.max_depth(&g) <= 5,
1128            "depth {} is too high for 1000 sequential keys",
1129            map.max_depth(&g)
1130        );
1131    }
1132
1133    #[test]
1134    fn stress_insert_lookup_remove() {
1135        let map = LearnedMap::new();
1136        let g = map.guard();
1137        let n = 500u64;
1138
1139        for i in 0..n {
1140            map.insert(i * 3, i, &g);
1141        }
1142        assert_eq!(map.len(), n as usize);
1143
1144        for i in 0..n {
1145            assert_eq!(map.get(&(i * 3), &g), Some(&i), "key {} missing", i * 3);
1146        }
1147
1148        for i in (0..n).filter(|i| i % 2 == 0) {
1149            map.remove(&(i * 3), &g);
1150        }
1151        assert_eq!(map.len(), (n / 2) as usize);
1152
1153        for i in (0..n).filter(|i| i % 2 != 0) {
1154            assert_eq!(map.get(&(i * 3), &g), Some(&i));
1155        }
1156    }
1157
1158    #[test]
1159    fn manual_rebuild() {
1160        let map = LearnedMap::new();
1161        let g = map.guard();
1162        for i in (0..100u64).rev() {
1163            map.insert(i, i * 10, &g);
1164        }
1165        let depth_before = map.max_depth(&g);
1166        map.rebuild(&g);
1167        let depth_after = map.max_depth(&g);
1168        assert!(
1169            depth_after <= depth_before,
1170            "rebuild didn't help: {depth_before} -> {depth_after}"
1171        );
1172        // All keys still present (need fresh guard after rebuild)
1173        let g2 = map.guard();
1174        for i in 0..100u64 {
1175            assert_eq!(map.get(&i, &g2), Some(&(i * 10)));
1176        }
1177    }
1178
1179    #[test]
1180    fn rebuild_empty_is_noop() {
1181        let map = LearnedMap::<u64, u64>::new();
1182        let g = map.guard();
1183        map.rebuild(&g);
1184        assert!(map.is_empty());
1185    }
1186
1187    #[test]
1188    fn large_incremental_insert() {
1189        let map = LearnedMap::new();
1190        let g = map.guard();
1191        for i in 0..1000u64 {
1192            map.insert(i, i, &g);
1193        }
1194        assert_eq!(map.len(), 1000);
1195        for i in 0..1000u64 {
1196            assert_eq!(map.get(&i, &g), Some(&i));
1197        }
1198    }
1199
1200    #[test]
1201    fn pin_convenience() {
1202        let map = LearnedMap::new();
1203        let m = map.pin();
1204        m.insert(1u64, "one");
1205        m.insert(2, "two");
1206        assert_eq!(m.get(&1), Some(&"one"));
1207        assert_eq!(m.get(&2), Some(&"two"));
1208        assert_eq!(m.len(), 2);
1209        assert!(!m.is_empty());
1210    }
1211
1212    #[test]
1213    fn map_ref_remove() {
1214        let map = LearnedMap::new();
1215        let m = map.pin();
1216        m.insert(10u64, 100);
1217        m.insert(20, 200);
1218        assert!(m.remove(&10));
1219        assert!(!m.remove(&10));
1220        assert_eq!(m.len(), 1);
1221        assert!(m.contains_key(&20));
1222    }
1223
1224    #[test]
1225    fn map_ref_iter_sorted() {
1226        let map = LearnedMap::new();
1227        let m = map.pin();
1228        m.insert(3u64, "c");
1229        m.insert(1, "a");
1230        m.insert(2, "b");
1231        let items = m.iter_sorted();
1232        assert_eq!(items, vec![(1, "a"), (2, "b"), (3, "c")]);
1233    }
1234
1235    #[test]
1236    fn auto_root_rebuild_from_empty() {
1237        let map = LearnedMap::new();
1238        let g = map.guard();
1239        for i in 0..200u64 {
1240            map.insert(i, i, &g);
1241        }
1242        let g2 = map.guard();
1243        let depth = map.max_depth(&g2);
1244        assert!(
1245            depth <= 12,
1246            "depth {depth} too high after auto root rebuild"
1247        );
1248        for i in 0..200u64 {
1249            assert_eq!(map.get(&i, &g2), Some(&i), "key {i} missing");
1250        }
1251    }
1252
1253    #[test]
1254    fn auto_root_rebuild_disabled() {
1255        let map = LearnedMap::with_config(Config::new().auto_rebuild(false));
1256        let g = map.guard();
1257        for i in 0..200u64 {
1258            map.insert(i, i, &g);
1259        }
1260        let depth = map.max_depth(&g);
1261        assert!(depth > 5, "depth {depth} too low without auto rebuild");
1262    }
1263
1264    #[test]
1265    fn bulk_load_no_early_rebuild() {
1266        let pairs: Vec<(u64, u64)> = (0..100).map(|i| (i, i)).collect();
1267        let map = LearnedMap::bulk_load(&pairs).unwrap();
1268        let g = map.guard();
1269        let depth = map.max_depth(&g);
1270        assert!(depth <= 3, "bulk-loaded tree depth {depth} too high");
1271        assert_eq!(map.len(), 100);
1272    }
1273
1274    #[test]
1275    fn manual_rebuild_resets_threshold() {
1276        let map = LearnedMap::new();
1277        let g = map.guard();
1278        for i in 0..50u64 {
1279            map.insert(i, i, &g);
1280        }
1281        map.rebuild(&g);
1282        let g2 = map.guard();
1283        for i in 50..150u64 {
1284            map.insert(i, i, &g2);
1285        }
1286        assert_eq!(map.len(), 150);
1287        for i in 0..150u64 {
1288            assert_eq!(map.get(&i, &g2), Some(&i));
1289        }
1290    }
1291
1292    #[test]
1293    fn clear_empties_map() {
1294        let map = LearnedMap::new();
1295        let g = map.guard();
1296        for i in 0..100u64 {
1297            map.insert(i, i, &g);
1298        }
1299        assert_eq!(map.len(), 100);
1300        map.clear(&g);
1301        let g2 = map.guard();
1302        assert_eq!(map.len(), 0);
1303        assert!(map.is_empty());
1304        for i in 0..100u64 {
1305            assert_eq!(map.get(&i, &g2), None);
1306        }
1307    }
1308
1309    #[test]
1310    fn clear_then_reinsert() {
1311        let map = LearnedMap::new();
1312        let g = map.guard();
1313        for i in 0..50u64 {
1314            map.insert(i, i * 10, &g);
1315        }
1316        map.clear(&g);
1317        let g2 = map.guard();
1318        for i in 0..30u64 {
1319            map.insert(i + 100, i, &g2);
1320        }
1321        assert_eq!(map.len(), 30);
1322        assert_eq!(map.get(&100, &g2), Some(&0));
1323        assert_eq!(map.get(&0, &g2), None);
1324    }
1325
1326    #[test]
1327    fn clear_empty_is_noop() {
1328        let map = LearnedMap::<u64, u64>::new();
1329        let g = map.guard();
1330        map.clear(&g);
1331        assert!(map.is_empty());
1332    }
1333
1334    #[test]
1335    fn map_ref_clear() {
1336        let map = LearnedMap::new();
1337        let m = map.pin();
1338        m.insert(1u64, "a");
1339        m.insert(2, "b");
1340        assert_eq!(m.len(), 2);
1341        m.clear();
1342        assert!(m.is_empty());
1343        assert_eq!(m.get(&1), None);
1344    }
1345
1346    #[test]
1347    fn drain_returns_sorted_entries() {
1348        let map = LearnedMap::new();
1349        let g = map.guard();
1350        for i in (0..50u64).rev() {
1351            map.insert(i, i * 10, &g);
1352        }
1353        assert_eq!(map.len(), 50);
1354        let drained = map.drain(&g);
1355        assert_eq!(drained.len(), 50);
1356        // Verify sorted order
1357        for w in drained.windows(2) {
1358            assert!(w[0].0 < w[1].0);
1359        }
1360        // Verify contents
1361        for (i, (k, v)) in drained.iter().enumerate() {
1362            assert_eq!(*k, i as u64);
1363            assert_eq!(*v, (i as u64) * 10);
1364        }
1365        // Map should be empty after drain
1366        let g2 = map.guard();
1367        assert!(map.is_empty());
1368        assert_eq!(map.get(&0, &g2), None);
1369    }
1370
1371    #[test]
1372    fn drain_empty_returns_empty() {
1373        let map = LearnedMap::<u64, u64>::new();
1374        let g = map.guard();
1375        let drained = map.drain(&g);
1376        assert!(drained.is_empty());
1377        assert!(map.is_empty());
1378    }
1379
1380    #[test]
1381    fn drain_then_reinsert() {
1382        let map = LearnedMap::new();
1383        let g = map.guard();
1384        for i in 0..30u64 {
1385            map.insert(i, i, &g);
1386        }
1387        let drained = map.drain(&g);
1388        assert_eq!(drained.len(), 30);
1389        let g2 = map.guard();
1390        for i in 100..110u64 {
1391            map.insert(i, i, &g2);
1392        }
1393        assert_eq!(map.len(), 10);
1394        assert_eq!(map.get(&100, &g2), Some(&100));
1395        assert_eq!(map.get(&0, &g2), None);
1396    }
1397
1398    #[test]
1399    fn map_ref_drain() {
1400        let map = LearnedMap::new();
1401        let m = map.pin();
1402        m.insert(3u64, "c");
1403        m.insert(1, "a");
1404        m.insert(2, "b");
1405        let drained = m.drain();
1406        assert_eq!(drained, vec![(1, "a"), (2, "b"), (3, "c")]);
1407        assert!(m.is_empty());
1408    }
1409
1410    #[test]
1411    fn allocated_bytes_empty() {
1412        let map = LearnedMap::<u64, u64>::new();
1413        let g = map.guard();
1414        let bytes = map.allocated_bytes(&g);
1415        // An empty map still has the root node + slot array.
1416        assert!(bytes > 0, "empty map should have non-zero allocation");
1417    }
1418
1419    #[test]
1420    fn allocated_bytes_grows_with_entries() {
1421        let map = LearnedMap::new();
1422        let g = map.guard();
1423        let empty_bytes = map.allocated_bytes(&g);
1424
1425        for i in 0..100u64 {
1426            map.insert(i, i, &g);
1427        }
1428        let g2 = map.guard();
1429        let full_bytes = map.allocated_bytes(&g2);
1430        assert!(
1431            full_bytes > empty_bytes,
1432            "100 entries should use more memory than empty: {full_bytes} vs {empty_bytes}"
1433        );
1434    }
1435
1436    #[test]
1437    fn allocated_bytes_bulk_load() {
1438        let pairs: Vec<(u64, u64)> = (0..500).map(|i| (i, i)).collect();
1439        let map = LearnedMap::bulk_load(&pairs).unwrap();
1440        let g = map.guard();
1441        let bytes = map.allocated_bytes(&g);
1442        // Sanity: at minimum each entry occupies size_of key + value.
1443        let min_data_bytes = 500 * std::mem::size_of::<u64>() * 2;
1444        assert!(
1445            bytes > min_data_bytes,
1446            "allocated_bytes {bytes} is less than minimum data size {min_data_bytes}"
1447        );
1448    }
1449
1450    #[test]
1451    fn map_ref_allocated_bytes() {
1452        let map = LearnedMap::new();
1453        let m = map.pin();
1454        m.insert(1u64, 1u64);
1455        m.insert(2, 2);
1456        let bytes = m.allocated_bytes();
1457        assert!(bytes > 0);
1458    }
1459}