digit_bin_index/lib.rs
1//! A `DigitBinIndex` is a tree-based data structure that organizes a large
2//! collection of weighted items to enable highly efficient weighted random
3//! selection and removal.
4//!
5//! It is a specialized tool, purpose-built for scenarios with millions of
6//! items where probabilities are approximate and high performance is critical,
7//! particularly for simulations involving sequential sampling like Wallenius'
8//! noncentral hypergeometric distribution.
9
10use wyrand::WyRand;
11use rand::{distr::{Distribution, Uniform}, Rng, SeedableRng};
12use roaring::{RoaringBitmap, RoaringTreemap};
13
14// The default precision to use if none is specified in the constructor.
15const DEFAULT_PRECISION: u8 = 3;
16const MAX_PRECISION: usize = 9;
17
18/// Trait for types that can be used as leaf bins in a `DigitBinIndex`.
19///
20/// Implement this trait for any container you want to use for storing IDs in the leaf nodes.
21/// Provided implementations: [`Vec<u32>`], [`RoaringBitmap`].
22pub trait DigitBin: Clone + Default {
23 fn insert(&mut self, id: u64);
24 fn remove(&mut self, id: u64) -> bool;
25 fn len(&self) -> usize;
26 fn is_empty(&self) -> bool;
27 fn get_random(&self, rng: &mut impl rand::Rng) -> Option<u64>;
28 fn get_random_and_remove(&mut self, rng: &mut impl rand::Rng) -> Option<u64>;
29}
30
31impl DigitBin for Vec<u32> {
32 fn insert(&mut self, id: u64) { self.push(id as u32); }
33 fn remove(&mut self, id: u64) -> bool {
34 if let Some(pos) = self.iter().position(|&x| x == id as u32) {
35 self.swap_remove(pos);
36 true
37 } else {
38 false
39 }
40 }
41 fn len(&self) -> usize { self.len() }
42 fn is_empty(&self) -> bool { self.is_empty() }
43 fn get_random(&self, rng: &mut impl rand::Rng) -> Option<u64> {
44 if self.is_empty() { None } else { Some(self[rng.random_range(0..self.len())] as u64) }
45 }
46 fn get_random_and_remove(&mut self, rng: &mut impl rand::Rng) -> Option<u64> {
47 if self.is_empty() { None } else {
48 let pos = rng.random_range(0..self.len());
49 Some(self.swap_remove(pos) as u64)
50 }
51 }
52}
53
54impl DigitBin for RoaringBitmap {
55 fn insert(&mut self, id: u64) { self.insert(id as u32); }
56 fn remove(&mut self, id: u64) -> bool { self.remove(id as u32) }
57 fn len(&self) -> usize { self.len() as usize }
58 fn is_empty(&self) -> bool { self.is_empty() }
59 fn get_random(&self, rng: &mut impl rand::Rng) -> Option<u64> {
60 if self.is_empty() { None } else {
61 let idx = rng.random_range(0..self.len() as u32);
62 self.select(idx).map(|v| v as u64)
63 }
64 }
65 fn get_random_and_remove(&mut self, rng: &mut impl rand::Rng) -> Option<u64> {
66 if self.is_empty() { None } else {
67 let idx = rng.random_range(0..self.len() as u32);
68 let selected = self.select(idx);
69 self.remove(selected.unwrap());
70 selected.map(|v| v as u64)
71 }
72 }
73}
74
75impl DigitBin for RoaringTreemap {
76 fn insert(&mut self, id: u64) { self.insert(id); }
77 fn remove(&mut self, id: u64) -> bool { self.remove(id) }
78 fn len(&self) -> usize { self.len() as usize }
79 fn is_empty(&self) -> bool { self.is_empty() }
80 fn get_random(&self, rng: &mut impl rand::Rng) -> Option<u64> {
81 if self.is_empty() { None } else {
82 let idx = rng.random_range(0..self.len() as u64);
83 self.select(idx)
84 }
85 }
86 fn get_random_and_remove(&mut self, rng: &mut impl rand::Rng) -> Option<u64> {
87 if self.is_empty() { None } else {
88 let idx = rng.random_range(0..self.len());
89 let selected = self.select(idx);
90 self.remove(selected.unwrap());
91 selected
92 }
93 }
94}
95
96// Helper to create an array of Option<T>
97fn new_children_array<B: DigitBin>() -> Box<[Option<Node<B>>; 10]> {
98 // This is a standard way to initialize an array of non-Copy types.
99 let data: [Option<Node<B>>; 10] = Default::default();
100 Box::new(data)
101}
102
103/// The content of a node, which is either more nodes or a leaf with individuals.
104#[derive(Debug, Clone)]
105pub enum NodeContent<B: DigitBin> {
106 /// An internal node that contains children for the next digit (0-9).
107 DigitIndex(Box<[Option<Node<B>>; 10]>),
108 /// A leaf node that contains a bin of IDs for individuals in this bin.
109 Bin(B),
110}
111
112/// A node within the DigitBinIndex tree.
113#[derive(Debug, Clone)]
114pub struct Node<B: DigitBin> {
115 /// The content of this node, either more nodes or a list of individual IDs.
116 pub content: NodeContent<B>,
117 /// The total sum of scaled values stored under this node.
118 pub accumulated_value: u64,
119 /// The total count of individuals stored under this node.
120 pub content_count: u64,
121}
122
123impl<B: DigitBin> Node<B> {
124 /// Creates a new, empty internal node.
125 fn new_internal() -> Self {
126 Self {
127 content: NodeContent::DigitIndex(new_children_array()),
128 accumulated_value: 0u64,
129 content_count: 0,
130 }
131 }
132}
133
134/// A data structure that organizes weighted items into bins based on their
135/// decimal digits to enable fast weighted random selection and updates.
136///
137/// This structure is a specialized radix tree optimized for sequential sampling
138/// (like in Wallenius' distribution). It makes a deliberate engineering trade-off:
139/// it sacrifices a small, controllable amount of precision by binning items,
140/// but in return, it achieves O(P) performance for its core operations, where P
141/// is the configured precision. This is significantly faster than the O(log N)
142/// performance of general-purpose structures like a Fenwick Tree for its
143/// ideal use case.
144///
145/// # Examples
146///
147/// ```
148/// use digit_bin_index::DigitBinIndex;
149/// let mut index = DigitBinIndex::with_precision_and_capacity(3, 100);
150/// ```
151#[derive(Debug, Clone)]
152pub enum DigitBinIndex {
153 Small(DigitBinIndexGeneric<Vec<u32>>),
154 Medium(DigitBinIndexGeneric<RoaringBitmap>),
155 Large(DigitBinIndexGeneric<RoaringTreemap>),
156}
157
158impl DigitBinIndex {
159 /// Creates a new DigitBinIndex with the given precision and expected capacity.
160 /// Uses Vec<u32> for small bins, RoaringBitmap for large bins.
161 ///
162 /// # Arguments
163 ///
164 /// * `precision` - The number of decimal places for binning (1 to 9).
165 /// * `capacity` - The expected number of items to be stored in the index.
166 ///
167 /// # Returns
168 ///
169 /// A new `DigitBinIndex` instance with the appropriate bin type.
170 ///
171 /// # Panics
172 ///
173 /// Panics if `precision` is 0 or greater than 9.
174 ///
175 /// # Examples
176 ///
177 /// ```
178 /// use digit_bin_index::DigitBinIndex;
179 ///
180 /// let index = DigitBinIndex::with_precision_and_capacity(3, 100);
181 /// // Uses Vec<u32> because capacity is small
182 /// ```
183 pub fn with_precision_and_capacity(precision: u8, capacity: u64) -> Self {
184 let max_bins = 10u64.pow(precision as u32);
185 if capacity / max_bins > 1_000_000_000 {
186 // Heuristic: Use RoaringTreemap if average bin size (capacity / 10^precision) exceeds threshold
187 DigitBinIndex::Large(DigitBinIndexGeneric::<RoaringTreemap>::with_precision(precision))
188 }
189 else if capacity / max_bins > 1_000 {
190 // Heuristic: Use RoaringBitmap if average bin size (capacity / 10^precision) exceeds threshold
191 DigitBinIndex::Medium(DigitBinIndexGeneric::<RoaringBitmap>::with_precision(precision))
192 } else {
193 // Heuristic: Use Vec<u32> for small average bin sizes
194 DigitBinIndex::Small(DigitBinIndexGeneric::<Vec<u32>>::with_precision(precision))
195 }
196 }
197
198 /// Creates a new DigitBinIndex with Vec<u32> bins and the specified precision.
199 ///
200 /// Optimized for small to medium-sized problems (average <= 1,000 items per bin).
201 /// Provides the fastest O(1) select_and_remove performance but truncates u64 IDs to u32.
202 ///
203 /// # Arguments
204 ///
205 /// * `precision` - The number of decimal places for binning (1 to 9).
206 ///
207 /// # Returns
208 ///
209 /// A new `DigitBinIndex` instance with Vec<u32> bins.
210 ///
211 /// # Panics
212 ///
213 /// Panics if `precision` is 0 or greater than 9.
214 ///
215 /// # Examples
216 ///
217 /// ```
218 /// use digit_bin_index::DigitBinIndex;
219 ///
220 /// let index = DigitBinIndex::small(3);
221 /// assert_eq!(index.precision(), 3);
222 /// ```
223 pub fn small(precision: u8) -> Self {
224 DigitBinIndex::Small(DigitBinIndexGeneric::<Vec<u32>>::with_precision(precision))
225 }
226
227 /// Creates a new DigitBinIndex with RoaringBitmap bins and the specified precision.
228 ///
229 /// Optimized for large-scale problems (average > 1,000 items per bin) where IDs fit within u32.
230 /// Provides excellent memory compression and fast set operations but truncates u64 IDs to u32.
231 ///
232 /// # Arguments
233 ///
234 /// * `precision` - The number of decimal places for binning (1 to 9).
235 ///
236 /// # Returns
237 ///
238 /// A new `DigitBinIndex` instance with RoaringBitmap bins.
239 ///
240 /// # Panics
241 ///
242 /// Panics if `precision` is 0 or greater than 9.
243 ///
244 /// # Examples
245 ///
246 /// ```
247 /// use digit_bin_index::DigitBinIndex;
248 ///
249 /// let index = DigitBinIndex::medium(3);
250 /// assert_eq!(index.precision(), 3);
251 /// ```
252 pub fn medium(precision: u8) -> Self {
253 DigitBinIndex::Medium(DigitBinIndexGeneric::<RoaringBitmap>::with_precision(precision))
254 }
255
256 /// Creates a new DigitBinIndex with RoaringTreemap bins and the specified precision.
257 ///
258 /// Optimized for massive-scale problems requiring full u64 ID support (average > 1,000,000,000 items per bin).
259 /// Supports the full 64-bit ID space, ideal for extremely large datasets.
260 ///
261 /// # Arguments
262 ///
263 /// * `precision` - The number of decimal places for binning (1 to 9).
264 ///
265 /// # Returns
266 ///
267 /// A new `DigitBinIndex` instance with RoaringTreemap bins.
268 ///
269 /// # Panics
270 ///
271 /// Panics if `precision` is 0 or greater than 9.
272 ///
273 /// # Examples
274 ///
275 /// ```
276 /// use digit_bin_index::DigitBinIndex;
277 ///
278 /// let index = DigitBinIndex::large(3);
279 /// assert_eq!(index.precision(), 3);
280 /// ```
281 pub fn large(precision: u8) -> Self {
282 DigitBinIndex::Large(DigitBinIndexGeneric::<RoaringTreemap>::with_precision(precision))
283 }
284
285 /// Creates a new `DigitBinIndex` instance with the default precision.
286 ///
287 /// The default precision is set to 3 decimal places, which provides a good balance
288 /// between accuracy and performance for most use cases. For custom precision, use
289 /// [`with_precision`](Self::with_precision).
290 ///
291 /// # Returns
292 ///
293 /// A new `DigitBinIndex` instance.
294 ///
295 /// # Examples
296 ///
297 /// ```
298 /// use digit_bin_index::DigitBinIndex;
299 ///
300 /// let index = DigitBinIndex::new();
301 /// assert_eq!(index.precision(), 3);
302 /// ```
303 pub fn new() -> Self {
304 DigitBinIndex::Small(DigitBinIndexGeneric::<Vec<u32>>::new())
305 }
306
307 /// Creates a new `DigitBinIndex` instance with the specified precision.
308 ///
309 /// The precision determines the number of decimal places used for binning weights.
310 /// Higher precision improves sampling accuracy but increases memory usage and tree depth.
311 /// Precision must be between 1 and 9 (inclusive).
312 ///
313 /// # Arguments
314 ///
315 /// * `precision` - The number of decimal places for binning (1 to 9).
316 ///
317 /// # Returns
318 ///
319 /// A new `DigitBinIndex` instance with the given precision.
320 ///
321 /// # Panics
322 ///
323 /// Panics if `precision` is 0 or greater than 9.
324 ///
325 /// # Examples
326 ///
327 /// ```
328 /// use digit_bin_index::DigitBinIndex;
329 ///
330 /// let index = DigitBinIndex::with_precision(4);
331 /// assert_eq!(index.precision(), 4);
332 /// ```
333 pub fn with_precision(precision: u8) -> Self {
334 DigitBinIndex::Small(DigitBinIndexGeneric::<Vec<u32>>::with_precision(precision))
335 }
336
337 /// Adds an item with the given ID and weight to the index.
338 ///
339 /// The weight is rescaled to the index's precision and binned accordingly.
340 /// If the weight is non-positive or becomes zero after scaling, the item is not added.
341 ///
342 /// # Arguments
343 ///
344 /// * `individual_id` - The unique ID of the item to add (u32).
345 /// * `weight` - The positive weight (probability) of the item.
346 ///
347 /// # Returns
348 ///
349 /// `true` if the item was successfully added, `false` otherwise (e.g., invalid weight).
350 ///
351 /// # Examples
352 ///
353 /// ```
354 /// use digit_bin_index::DigitBinIndex;
355 ///
356 /// let mut index = DigitBinIndex::new();
357 /// let added = index.add(1, 0.5);
358 /// assert_eq!(index.count(), 1);
359 /// ```
360 pub fn add(&mut self, id: u64, weight: f64) {
361 match self {
362 DigitBinIndex::Small(index) => index.add(id, weight),
363 DigitBinIndex::Medium(index) => index.add(id, weight),
364 DigitBinIndex::Large(index) => index.add(id, weight),
365 }
366 }
367
368 /// Adds multiple items to the index in a highly optimized batch operation.
369 ///
370 /// This method is significantly faster than calling `add` in a loop for large
371 /// collections of items. It works by pre-processing the input, grouping items
372 /// by their shared weight, and then propagating each group through the tree in
373 /// a single pass. This minimizes cache misses and reduces function call overhead.
374 ///
375 /// Weights are rescaled to the index's precision and binned accordingly.
376 /// Items with non-positive weights or weights that become zero after scaling
377 /// will be ignored.
378 ///
379 /// # Arguments
380 ///
381 /// * `items` - A slice of `(id, weight)` tuples to add to the index.
382 ///
383 /// # Examples
384 ///
385 /// ```
386 /// use digit_bin_index::DigitBinIndex;
387 ///
388 /// let mut index = DigitBinIndex::new();
389 /// let items_to_add = vec![(1, 0.1), (2, 0.2), (3, 0.1)];
390 /// index.add_many(&items_to_add);
391 ///
392 /// assert_eq!(index.count(), 3);
393 /// // The total weight should be 0.1 + 0.2 + 0.1 = 0.4
394 /// assert!((index.total_weight() - 0.4).abs() < f64::EPSILON);
395 /// ```
396 pub fn add_many(&mut self, items: &[(u64, f64)]) {
397 match self {
398 DigitBinIndex::Small(index) => index.add_many(items),
399 DigitBinIndex::Medium(index) => index.add_many(items),
400 DigitBinIndex::Large(index) => index.add_many(items),
401 }
402 }
403
404 /// Removes an item with the given ID and weight from the index.
405 ///
406 /// The weight must match the one used during addition (after rescaling).
407 /// If the item is not found in the corresponding bin, no removal occurs.
408 ///
409 /// # Arguments
410 ///
411 /// * `individual_id` - The ID of the item to remove.
412 /// * `weight` - The weight of the item (must match the added weight).
413 ///
414 /// # Examples
415 ///
416 /// ```
417 /// use digit_bin_index::DigitBinIndex;
418 ///
419 /// let mut index = DigitBinIndex::new();
420 /// index.add(1, 0.5);
421 /// index.remove(1, 0.5);
422 /// assert_eq!(index.count(), 0);
423 /// ```
424 pub fn remove(&mut self, id: u64, weight: f64) -> bool {
425 match self {
426 DigitBinIndex::Small(index) => index.remove(id, weight),
427 DigitBinIndex::Medium(index) => index.remove(id, weight),
428 DigitBinIndex::Large(index) => index.remove(id, weight),
429 }
430 }
431
432 /// Removes multiple items from the index in a highly optimized batch operation.
433 ///
434 /// This method is significantly faster than calling `remove` in a loop. It
435 /// groups the items to be removed by their weight path and traverses the tree
436 /// only once per group, performing aggregated updates on the way up.
437 ///
438 /// The `(id, weight)` pairs must match items that are currently in the index.
439 /// If a given pair is not found, it is silently ignored.
440 ///
441 /// # Arguments
442 ///
443 /// * `items` - A slice of `(id, weight)` tuples to remove from the index.
444 ///
445 /// # Examples
446 ///
447 /// ```
448 /// use digit_bin_index::DigitBinIndex;
449 ///
450 /// let mut index = DigitBinIndex::new();
451 /// let items_to_add = vec![(1, 0.1), (2, 0.2), (3, 0.1), (4, 0.3)];
452 /// index.add_many(&items_to_add);
453 /// assert_eq!(index.count(), 4);
454 ///
455 /// let items_to_remove = vec![(2, 0.2), (3, 0.1)];
456 /// index.remove_many(&items_to_remove);
457 ///
458 /// assert_eq!(index.count(), 2); // Items 1 and 4 should remain
459 /// // The total weight should be 0.1 + 0.3 = 0.4
460 /// assert!((index.total_weight() - 0.4).abs() < f64::EPSILON);
461 /// ```
462 pub fn remove_many(&mut self, items: &[(u64, f64)]) -> bool {
463 match self {
464 DigitBinIndex::Small(index) => index.remove_many(items),
465 DigitBinIndex::Medium(index) => index.remove_many(items),
466 DigitBinIndex::Large(index) => index.remove_many(items),
467 }
468 }
469
470 /// Selects a single item randomly based on weights without removal.
471 ///
472 /// Performs weighted random selection. Returns `None` if the index is empty.
473 ///
474 /// # Returns
475 ///
476 /// An `Option` containing the selected item's ID and its (rescaled) weight.
477 ///
478 /// # Examples
479 ///
480 /// ```
481 /// use digit_bin_index::DigitBinIndex;
482 ///
483 /// let mut index = DigitBinIndex::new();
484 /// index.add(1, 0.5);
485 /// if let Some((id, weight)) = index.select() {
486 /// assert_eq!(id, 1);
487 /// assert_eq!(weight, 0.5);
488 /// }
489 /// ```
490 pub fn select(&mut self) -> Option<(u64, f64)> {
491 match self {
492 DigitBinIndex::Small(index) => index.select(),
493 DigitBinIndex::Medium(index) => index.select(),
494 DigitBinIndex::Large(index) => index.select(),
495 }
496 }
497
498 /// Selects a single item randomly and removes it from the index.
499 ///
500 /// Combines selection and removal in one operation. Returns `None` if empty.
501 ///
502 /// # Returns
503 ///
504 /// An `Option` containing the selected item's ID and weight.
505 ///
506 /// # Examples
507 ///
508 /// ```
509 /// use digit_bin_index::DigitBinIndex;
510 ///
511 /// let mut index = DigitBinIndex::new();
512 /// index.add(1, 0.5);
513 /// if let Some((id, _)) = index.select_and_remove() {
514 /// assert_eq!(id, 1);
515 /// }
516 /// assert_eq!(index.count(), 0);
517 /// ```
518 pub fn select_and_remove(&mut self) -> Option<(u64, f64)> {
519 match self {
520 DigitBinIndex::Small(index) => index.select_and_remove(),
521 DigitBinIndex::Medium(index) => index.select_and_remove(),
522 DigitBinIndex::Large(index) => index.select_and_remove(),
523 }
524 }
525
526 /// Selects multiple unique items randomly based on weights without removal.
527 ///
528 /// Uses rejection sampling to ensure uniqueness. Returns `None` if `num_to_draw`
529 /// exceeds the number of items in the index.
530 ///
531 /// # Arguments
532 ///
533 /// * `num_to_draw` - The number of unique items to select.
534 ///
535 /// # Returns
536 ///
537 /// An `Option` containing a vector of selected (ID, weight) pairs.
538 ///
539 /// # Examples
540 ///
541 /// ```
542 /// use digit_bin_index::DigitBinIndex;
543 ///
544 /// let mut index = DigitBinIndex::new();
545 /// index.add(1, 0.3);
546 /// index.add(2, 0.7);
547 /// if let Some(selected) = index.select_many(2) {
548 /// assert_eq!(selected.len(), 2);
549 /// }
550 /// ```
551 pub fn select_many(&mut self, num_to_draw: u64) -> Option<Vec<(u64, f64)>> {
552 match self {
553 DigitBinIndex::Small(index) => index.select_many(num_to_draw),
554 DigitBinIndex::Medium(index) => index.select_many(num_to_draw),
555 DigitBinIndex::Large(index) => index.select_many(num_to_draw),
556 }
557 }
558
559 /// Selects multiple unique items randomly and removes them from the index.
560 ///
561 /// Selects and removes in batch. Returns `None` if `num_to_draw` exceeds item count.
562 ///
563 /// # Arguments
564 ///
565 /// * `num_to_draw` - The number of unique items to select and remove.
566 ///
567 /// # Returns
568 ///
569 /// An `Option` containing a vector of selected (ID, weight) pairs.
570 ///
571 /// # Examples
572 ///
573 /// ```
574 /// use digit_bin_index::DigitBinIndex;
575 ///
576 /// let mut index = DigitBinIndex::new();
577 /// index.add(1, 0.3);
578 /// index.add(2, 0.7);
579 /// if let Some(selected) = index.select_many_and_remove(2) {
580 /// assert_eq!(selected.len(), 2);
581 /// }
582 /// assert_eq!(index.count(), 0);
583 /// ```
584 pub fn select_many_and_remove(&mut self, num_to_draw: u64) -> Option<Vec<(u64, f64)>> {
585 match self {
586 DigitBinIndex::Small(index) => index.select_many_and_remove(num_to_draw),
587 DigitBinIndex::Medium(index) => index.select_many_and_remove(num_to_draw),
588 DigitBinIndex::Large(index) => index.select_many_and_remove(num_to_draw),
589 }
590 }
591
592 /// Returns the total number of items currently in the index.
593 ///
594 /// # Returns
595 ///
596 /// The count of items as a `u32`.
597 ///
598 /// # Examples
599 ///
600 /// ```
601 /// use digit_bin_index::DigitBinIndex;
602 ///
603 /// let mut index = DigitBinIndex::new();
604 /// assert_eq!(index.count(), 0);
605 /// ```
606 pub fn count(&self) -> u64 {
607 match self {
608 DigitBinIndex::Small(index) => index.count(),
609 DigitBinIndex::Medium(index) => index.count(),
610 DigitBinIndex::Large(index) => index.count(),
611 }
612 }
613
614 /// Returns the sum of all weights in the index.
615 ///
616 /// This represents the total accumulated probability mass.
617 ///
618 /// # Returns
619 ///
620 /// The total weight as a `f64`.
621 ///
622 /// # Examples
623 ///
624 /// ```
625 /// use digit_bin_index::DigitBinIndex;
626 ///
627 /// let mut index = DigitBinIndex::new();
628 /// index.add(1, 0.5);
629 /// assert_eq!(index.total_weight(), 0.5);
630 /// ```
631 pub fn total_weight(&self) -> f64 {
632 match self {
633 DigitBinIndex::Small(index) => index.total_weight(),
634 DigitBinIndex::Medium(index) => index.total_weight(),
635 DigitBinIndex::Large(index) => index.total_weight(),
636 }
637 }
638
639 /// Prints detailed statistics about the index's structure, memory usage,
640 /// and data distribution.
641 pub fn print_stats(&self) {
642 println!("DigitBinIndex Statistics:");
643 println!("=========================");
644 match self {
645 DigitBinIndex::Small(idx) => {
646 println!("- Index Type: Small (Vec<u32>)");
647 idx.print_stats_generic();
648 },
649 DigitBinIndex::Medium(idx) => {
650 println!("- Index Type: Medium (RoaringBitmap)");
651 idx.print_stats_generic();
652 },
653 DigitBinIndex::Large(idx) => {
654 println!("- Index Type: Large (RoaringTreemap)");
655 idx.print_stats_generic();
656 },
657 }
658 }
659
660 /// Returns the precision (number of decimal places) used for binning.
661 pub fn precision(&self) -> u8 {
662 match self {
663 DigitBinIndex::Small(idx) => idx.precision,
664 DigitBinIndex::Medium(idx) => idx.precision,
665 DigitBinIndex::Large(idx) => idx.precision,
666 }
667 }
668}
669
670/// A data structure that organizes weighted items into bins based on their
671/// decimal digits to enable fast weighted random selection and updates.
672///
673/// This structure is a specialized radix tree optimized for sequential sampling
674/// (like in Wallenius' distribution). It makes a deliberate engineering trade-off:
675/// it sacrifices a small, controllable amount of precision by binning items,
676/// but in return, it achieves O(P) performance for its core operations, where P
677/// is the configured precision. This is significantly faster than the O(log N)
678/// performance of general-purpose structures like a Fenwick Tree for its
679/// ideal use case.
680///
681/// # Type Parameters
682///
683/// * `B` - The bin container type for leaf nodes. Must implement the [`DigitBin`] trait.
684/// Common choices are [`Vec<u32>`] for maximum speed with small bins, or [`RoaringBitmap`]
685/// for memory efficiency with large, sparse bins.
686///
687/// # Examples
688///
689/// ```
690/// use digit_bin_index::DigitBinIndexGeneric;
691/// // Use Vec<u32> for leaf bins
692/// let mut index = DigitBinIndexGeneric::<Vec<u32>>::new();
693/// // Or use RoaringBitmap for leaf bins
694/// // let mut index = DigitBinIndexGeneric::<roaring::RoaringBitmap>::new();
695/// ```
696#[derive(Debug, Clone)]
697pub struct DigitBinIndexGeneric<B: DigitBin> {
698 /// The root node of the tree.
699 pub root: Node<B>,
700 /// The precision (number of decimal places) used for binning.
701 pub precision: u8,
702 /// The scaling factor (10^precision) as f64 for conversions.
703 scale: f64,
704}
705
706impl<B: DigitBin> Default for DigitBinIndexGeneric<B> {
707 fn default() -> Self {
708 Self::new()
709 }
710}
711
712impl<B: DigitBin> DigitBinIndexGeneric<B> {
713 #[must_use]
714 pub fn new() -> Self {
715 Self::with_precision(DEFAULT_PRECISION)
716 }
717
718 #[must_use]
719 pub fn with_precision(precision: u8) -> Self {
720 assert!(precision > 0, "Precision must be at least 1.");
721 assert!(precision <= MAX_PRECISION as u8, "Precision cannot be larger than {}.", MAX_PRECISION);
722 Self {
723 root: Node::new_internal(),
724 precision,
725 scale: 10f64.powi(precision as i32),
726 }
727 }
728
729 /// Converts a f64 weight to an array of digits [0-9] for the given precision and the scaled u64 value.
730 /// Returns None if the weight is invalid (non-positive or zero after scaling).
731 fn weight_to_digits(&self, weight: f64, digits: &mut [u8; MAX_PRECISION]) -> Option<u64> {
732 if weight <= 0.0 || weight >= 1.0 {
733 return None;
734 }
735
736 let scaled = (weight * self.scale) as u64;
737 if scaled == 0 {
738 return None;
739 }
740
741 let mut temp = scaled;
742 for i in (0..self.precision as usize).rev() {
743 digits[i] = (temp % 10) as u8;
744 temp /= 10;
745 }
746 Some(scaled)
747 }
748
749 // --- Standard Functions ---
750
751 pub fn add(&mut self, individual_id: u64, weight: f64) {
752 let mut digits = [0u8; MAX_PRECISION];
753 if let Some(scaled) = self.weight_to_digits(weight, &mut digits) {
754 Self::add_recurse(&mut self.root, individual_id, scaled, &digits, 1, self.precision)
755 }
756 }
757
758 /// Recursive private method to handle adding individuals.
759 fn add_recurse(
760 node: &mut Node<B>,
761 individual_id: u64,
762 scaled: u64, // Scaled weight as u64
763 digits: &[u8; MAX_PRECISION],
764 current_depth: u8,
765 max_depth: u8,
766 ) {
767 node.content_count += 1;
768 node.accumulated_value += scaled;
769
770 if current_depth > max_depth {
771 if let NodeContent::DigitIndex(_) = &node.content {
772 node.content = NodeContent::Bin(B::default());
773 }
774 if let NodeContent::Bin(bin) = &mut node.content {
775 bin.insert(individual_id);
776 }
777 return;
778 }
779
780 let digit = digits[current_depth as usize - 1] as usize;
781 if let NodeContent::DigitIndex(children) = &mut node.content {
782 // Get the child, creating it if it doesn't exist.
783 let child_node = children[digit].get_or_insert_with(Node::new_internal);
784 Self::add_recurse(child_node, individual_id, scaled, digits, current_depth + 1, max_depth);
785 }
786 }
787
788 /// Adds multiple items to the index in a highly optimized batch operation.
789 ///
790 /// This method is significantly faster than calling `add` in a loop for large
791 /// collections of items. It works by first pre-processing the input items,
792 /// grouping them by their shared weight path (e.g., all items with weight 0.123...).
793 /// It then traverses the tree once per group, rather than once per item,
794 /// drastically reducing function call overhead and improving CPU cache performance
795 /// by performing aggregated updates at each node.
796 ///
797 /// # Arguments
798 ///
799 /// * `items` - A slice of `(individual_id, weight)` tuples to add to the index.
800 ///
801 pub fn add_many(&mut self, items: &[(u64, f64)]) {
802 if items.is_empty() {
803 return;
804 }
805
806 let mut digits = [0u8; MAX_PRECISION];
807 for &(id, weight) in items {
808 if let Some(scaled) = self.weight_to_digits(weight, &mut digits) {
809 Self::add_recurse(&mut self.root, id, scaled, &digits, 1, self.precision)
810 }
811 }
812 }
813
814 pub fn remove(&mut self, individual_id: u64, weight: f64) -> bool{
815 let mut digits = [0u8; MAX_PRECISION];
816 if let Some(scaled) = self.weight_to_digits(weight, &mut digits) {
817 return Self::remove_recurse(&mut self.root, individual_id, scaled, &digits, 1, self.precision);
818 }
819 false
820 }
821
822 /// Recursive private method to handle removing individuals.
823 fn remove_recurse(
824 node: &mut Node<B>,
825 individual_id: u64,
826 scaled: u64,
827 digits: &[u8; MAX_PRECISION],
828 current_depth: u8,
829 max_depth: u8,
830 ) -> bool {
831 if current_depth > max_depth {
832 if let NodeContent::Bin(bin) = &mut node.content {
833 let orig_len = bin.len();
834 bin.remove(individual_id);
835 if bin.len() < orig_len {
836 node.content_count -= 1;
837 node.accumulated_value -= scaled;
838 return true;
839 }
840 }
841 return false;
842 }
843
844 let digit = digits[current_depth as usize - 1] as usize;
845 if let NodeContent::DigitIndex(children) = &mut node.content {
846 // Check if the child at 'digit' exists and get a mutable reference to it.
847 if let Some(child_node) = children[digit].as_mut() {
848 // If it exists, recurse. If the recursion returns true (success)...
849 if Self::remove_recurse(child_node, individual_id, scaled, digits, current_depth + 1, max_depth) {
850 // ...then update this node's stats and propagate the success upwards.
851 node.content_count -= 1;
852 node.accumulated_value -= scaled;
853 return true;
854 }
855 }
856 }
857 false
858 }
859
860 /// Removes multiple items from the index in a highly optimized batch operation.
861 ///
862 /// This method is significantly faster than calling `remove` in a loop. It
863 /// groups the items to be removed by their weight path and traverses the tree
864 /// only once per group, performing aggregated updates on the way up.
865 ///
866 /// The `(id, weight)` pairs must match items that are currently in the index.
867 /// If a given pair is not found, it is silently ignored.
868 ///
869 /// # Arguments
870 ///
871 /// * `items` - A slice of `(id, weight)` tuples to remove from the index.
872 ///
873 pub fn remove_many(&mut self, items: &[(u64, f64)]) -> bool {
874 if items.is_empty() {
875 return false;
876 }
877
878 let mut digits = [0u8; MAX_PRECISION];
879 let mut success = true;
880 for &(id, weight) in items {
881 if let Some(scaled) = self.weight_to_digits(weight, &mut digits) {
882 success &= Self::remove_recurse(&mut self.root, id, scaled, &digits, 1, self.precision)
883 } else {
884 success &= false;
885 }
886 }
887 success
888 }
889
890 // --- Selection Functions ---
891
892 pub fn select(&mut self) -> Option<(u64, f64)> {
893 self.select_and_optionally_remove(false)
894 }
895
896 pub fn select_many(&mut self, num_to_draw: u64) -> Option<Vec<(u64, f64)>> {
897 self.select_many_and_optionally_remove(num_to_draw, false)
898 }
899
900 pub fn select_and_remove(&mut self) -> Option<(u64, f64)> {
901 self.select_and_optionally_remove(true)
902 }
903
904 // Wrapper function to handle both select and select_and_remove
905 pub fn select_and_optionally_remove(&mut self, with_removal: bool) -> Option<(u64, f64)> {
906 if self.root.content_count == 0 {
907 return None;
908 }
909 let mut rng = WyRand::from_os_rng();
910 let random_target = rng.random_range(0u64..self.root.accumulated_value);
911 Self::select_and_optionally_remove_recurse(&mut self.root, random_target, 1, self.precision, &mut rng, with_removal, self.scale)
912 }
913
914 // Helper function
915 fn select_and_optionally_remove_recurse(
916 node: &mut Node<B>,
917 target: u64,
918 current_depth: u8,
919 max_depth: u8,
920 rng: &mut WyRand,
921 with_removal: bool,
922 scale: f64,
923 ) -> Option<(u64, f64)> {
924 // Base case: Bin node
925 if current_depth > max_depth {
926 if let NodeContent::Bin(bin) = &mut node.content {
927 if bin.is_empty() {
928 return None;
929 }
930 let scaled_weight = node.accumulated_value / node.content_count as u64;
931 let weight = scaled_weight as f64 / scale;
932 let selected_id = if with_removal {
933 bin.get_random_and_remove(rng)?
934 } else {
935 bin.get_random(rng)?
936 };
937 if with_removal {
938 node.content_count -= 1;
939 node.accumulated_value -= scaled_weight;
940 }
941 return Some((selected_id, weight));
942 }
943 return None;
944 }
945
946 // Recursive case: DigitIndex node
947 if let NodeContent::DigitIndex(children) = &mut node.content {
948 let mut cum: u64 = 0;
949 // The iterator now gives us a mutable reference to the Option.
950 for child_option in children.iter_mut() {
951 // We pattern match to see if a child Node exists.
952 if let Some(child) = child_option.as_mut() {
953 // Now, 'child' is a &mut Node<B>, and we can proceed with the original logic.
954 if child.accumulated_value == 0 {
955 continue;
956 }
957 if target < cum + child.accumulated_value {
958 if let Some((selected_id, weight)) = Self::select_and_optionally_remove_recurse(
959 child,
960 target - cum,
961 current_depth + 1,
962 max_depth,
963 rng,
964 with_removal,
965 scale,
966 ) {
967 if with_removal {
968 node.content_count -= 1;
969 node.accumulated_value -= (weight * scale).round() as u64;
970 }
971 return Some((selected_id, weight));
972 }
973 // This path is taken if recursion fails, which implies an empty bin was selected.
974 return None;
975 }
976 cum += child.accumulated_value;
977 }
978 }
979 }
980 None
981 }
982
983 pub fn select_many_and_remove(&mut self, num_to_draw: u64) -> Option<Vec<(u64, f64)>> {
984 self.select_many_and_optionally_remove(num_to_draw, true)
985 }
986
987 // Wrapper function to handle both select_many and select_many_and_remove
988 pub fn select_many_and_optionally_remove(&mut self, num_to_draw: u64, with_removal: bool) -> Option<Vec<(u64, f64)>> {
989 if num_to_draw > self.count() || num_to_draw == 0 {
990 return if num_to_draw == 0 { Some(Vec::new()) } else { None };
991 }
992 let mut rng = WyRand::from_os_rng();
993 let mut selected: Vec<(u64, f64)> = Vec::with_capacity(num_to_draw as usize);
994 let total_accum = self.root.accumulated_value;
995 // Create a Uniform distribution for the range [0, total_accum)
996 let uniform = Uniform::new(0u64, total_accum).expect("Valid range for Uniform");
997 // Generate num_to_draw random numbers using sample_iter
998 let passed_targets: Vec<u64> = uniform
999 .sample_iter(&mut rng)
1000 .take(num_to_draw as usize)
1001 .collect();
1002 Self::select_many_and_optionally_remove_recurse(
1003 &mut self.root,
1004 total_accum,
1005 &mut selected,
1006 &mut rng,
1007 1,
1008 self.precision,
1009 with_removal,
1010 passed_targets,
1011 self.scale,
1012 );
1013 if selected.len() == num_to_draw as usize {
1014 Some(selected)
1015 } else {
1016 None // Should not happen if logic is correct
1017 }
1018 }
1019
1020 /// Recursive helper for batch selection and removal.
1021 /// - node: Current subtree root.
1022 /// - subtree_total: Accumulated value of this node (passed to avoid borrowing issues).
1023 /// - selected: Mutable vec to collect (id, weight) from leaves.
1024 /// - rng: Mutable RNG.
1025 /// - current_depth: Current digit level.
1026 /// - precision: The precision of the DigitBinIndex (passed explicitly).
1027 /// - with_removal: Whether to remove selected items.
1028 /// - passed_targets: Pre-computed relative targets from parent (in [0, subtree_total)).
1029 /// - scale: The scaling factor for weight conversions.
1030 fn select_many_and_optionally_remove_recurse(
1031 node: &mut Node<B>,
1032 subtree_total: u64,
1033 selected: &mut Vec<(u64, f64)>,
1034 rng: &mut WyRand,
1035 current_depth: u8,
1036 precision: u8,
1037 with_removal: bool,
1038 passed_targets: Vec<u64>,
1039 scale: f64,
1040 ) {
1041 let original_target_count = passed_targets.len() as u64;
1042 if original_target_count == 0 {
1043 return;
1044 }
1045
1046 // This base case (leaf node) logic does not change, as it doesn't interact
1047 // with the DigitIndex.
1048 if current_depth > precision {
1049 if let NodeContent::Bin(bin) = &mut node.content {
1050 let bin_scaled = if node.content_count > 0 {
1051 node.accumulated_value / node.content_count as u64
1052 } else {
1053 0u64
1054 };
1055 let bin_weight = bin_scaled as f64 / scale;
1056 let to_select = original_target_count.min(node.content_count);
1057 let mut picked = 0u64;
1058 while picked < to_select && !bin.is_empty() {
1059 let id = if with_removal {
1060 bin.get_random_and_remove(rng).unwrap()
1061 } else {
1062 bin.get_random(rng).unwrap()
1063 };
1064 selected.push((id, bin_weight));
1065 picked += 1;
1066 }
1067 if with_removal {
1068 node.content_count -= picked;
1069 node.accumulated_value -= bin_scaled * picked as u64;
1070 }
1071 }
1072 return;
1073 }
1074
1075 // --- START OF MODIFIED LOGIC ---
1076 if let NodeContent::DigitIndex(children) = &mut node.content {
1077 // CHANGE: Use fixed-size arrays of length 10 instead of dynamically sized Vecs.
1078 let mut child_assigned = [0u64; 10];
1079 // Note: `Default::default()` works for arrays where the element type is `Default`.
1080 let mut child_rel_targets: [Vec<u64>; 10] = Default::default();
1081 let mut assigned = 0u64;
1082
1083 // --- Main assignment loop ---
1084 for &target in &passed_targets {
1085 let mut cum: u64 = 0;
1086 let mut chosen_idx = None;
1087 // CHANGE: Iterate over the array of Options.
1088 for (i, child_option) in children.iter().enumerate() {
1089 // CHANGE: Only process existing children.
1090 if let Some(child) = child_option {
1091 if child.accumulated_value == 0 {
1092 continue;
1093 }
1094 if target < cum + child.accumulated_value {
1095 if child_assigned[i] + 1 <= child.content_count {
1096 chosen_idx = Some(i);
1097 }
1098 break;
1099 }
1100 cum += child.accumulated_value;
1101 }
1102 }
1103 if let Some(idx) = chosen_idx {
1104 child_assigned[idx] += 1;
1105 // We need to re-calculate `cum` up to the chosen index to get the relative target.
1106 let start_of_child_range: u64 = children[..idx].iter().filter_map(|c| c.as_ref()).map(|c| c.accumulated_value).sum();
1107 let rel_target = target - start_of_child_range;
1108 child_rel_targets[idx].push(rel_target);
1109 assigned += 1;
1110 }
1111 }
1112
1113 // --- Rejection sampling for any remaining targets ---
1114 let remaining = original_target_count - assigned;
1115 let mut additional_assigned = 0u64;
1116 while additional_assigned < remaining {
1117 let target = rng.random_range(0u64..subtree_total);
1118 let mut cum: u64 = 0;
1119 let mut chosen_idx = None;
1120 // CHANGE: Same iteration pattern as the loop above.
1121 for (i, child_option) in children.iter().enumerate() {
1122 if let Some(child) = child_option {
1123 if child.accumulated_value == 0 {
1124 continue;
1125 }
1126 if target < cum + child.accumulated_value {
1127 if child_assigned[i] + 1 <= child.content_count {
1128 chosen_idx = Some(i);
1129 }
1130 break;
1131 }
1132 cum += child.accumulated_value;
1133 }
1134 }
1135 if let Some(idx) = chosen_idx {
1136 child_assigned[idx] += 1;
1137 let start_of_child_range: u64 = children[..idx].iter().filter_map(|c| c.as_ref()).map(|c| c.accumulated_value).sum();
1138 let rel_target = target - start_of_child_range;
1139 child_rel_targets[idx].push(rel_target);
1140 additional_assigned += 1;
1141 }
1142 }
1143
1144 // CHANGE: Store accumulated values in a fixed-size array for the recursive calls.
1145 let child_accums: [u64; 10] = std::array::from_fn(|i| {
1146 children[i].as_ref().map_or(0, |c| c.accumulated_value)
1147 });
1148
1149 // --- Recurse into children ---
1150 // CHANGE: Iterate through mutable options.
1151 for (i, child_option) in children.iter_mut().enumerate() {
1152 let assign_count = child_assigned[i];
1153 if assign_count > 0 {
1154 // We must have a child here if it was assigned targets.
1155 if let Some(child) = child_option {
1156 let rel_targets = std::mem::take(&mut child_rel_targets[i]);
1157 Self::select_many_and_optionally_remove_recurse(
1158 child,
1159 child_accums[i],
1160 selected,
1161 rng,
1162 current_depth + 1,
1163 precision,
1164 with_removal,
1165 rel_targets,
1166 scale,
1167 );
1168 }
1169 }
1170 }
1171
1172 if with_removal {
1173 // --- Unwind: Update this node's stats ---
1174 // CHANGE: Sum up counts and values from the existing children in the array.
1175 node.content_count = children.iter().filter_map(|c| c.as_ref()).map(|c| c.content_count).sum();
1176 node.accumulated_value = children.iter().filter_map(|c| c.as_ref()).map(|c| c.accumulated_value).sum();
1177 }
1178 }
1179 }
1180
1181 pub fn count(&self) -> u64 {
1182 self.root.content_count
1183 }
1184
1185 pub fn total_weight(&self) -> f64 {
1186 self.root.accumulated_value as f64 / self.scale
1187 }
1188
1189 /// Prints detailed statistics about the tree: node count, bin stats, and weight stats.
1190 pub fn print_stats_generic(&self) {
1191 // This struct holds all the metrics we want to collect.
1192 struct Stats {
1193 node_count: usize,
1194 non_empty_node_count: usize,
1195 internal_node_count: usize, // NEW: For branching factor
1196 child_slots_used: usize, // NEW: For branching factor
1197 bin_count: usize,
1198 empty_bin_count: usize,
1199 total_bin_items: u64,
1200 min_weight: Option<f64>,
1201 max_weight: Option<f64>,
1202 // We collect all bin sizes to calculate standard deviation later.
1203 bin_sizes: Vec<usize>,
1204 // Memory estimates
1205 mem_nodes: usize,
1206 mem_bins: usize,
1207 }
1208
1209 fn traverse<B: DigitBin>(
1210 node: &Node<B>,
1211 stats: &mut Stats,
1212 scale: f64,
1213 ) {
1214 stats.node_count += 1;
1215 stats.mem_nodes += std::mem::size_of::<Node<B>>();
1216
1217 if node.content_count > 0 {
1218 stats.non_empty_node_count += 1;
1219 }
1220
1221 match &node.content {
1222 NodeContent::DigitIndex(children) => {
1223 // --- NEW: Calculate branching factor stats ---
1224 stats.internal_node_count += 1;
1225 let used_children = children.iter().filter(|c| c.is_some()).count();
1226 stats.child_slots_used += used_children;
1227 // --- END NEW ---
1228
1229 // Add memory for the heap-allocated array of 10 optional nodes.
1230 stats.mem_nodes += std::mem::size_of::<[Option<Node<B>>; 10]>();
1231
1232 // Iterate over the options in the array
1233 for child_option in children.iter() {
1234 // Only recurse into the children that actually exist (are Some)
1235 if let Some(child) = child_option {
1236 traverse(child, stats, scale);
1237 }
1238 }
1239 }
1240 NodeContent::Bin(bin) => {
1241 stats.bin_count += 1;
1242 let bin_size = bin.len();
1243 stats.bin_sizes.push(bin_size);
1244 stats.total_bin_items += bin_size as u64;
1245
1246 // Estimate memory for the bin's contents.
1247 // This is an approximation. For RoaringBitmap, `serialized_size()` would be more accurate.
1248 stats.mem_bins += bin_size * std::mem::size_of::<u32>();
1249
1250 if bin_size == 0 {
1251 stats.empty_bin_count += 1;
1252 } else {
1253 // All items in a bin share the same weight.
1254 let scaled_weight = node.accumulated_value / node.content_count;
1255 let weight = scaled_weight as f64 / scale;
1256 stats.min_weight = Some(stats.min_weight.map_or(weight, |min| min.min(weight)));
1257 stats.max_weight = Some(stats.max_weight.map_or(weight, |max| max.max(weight)));
1258 }
1259 }
1260 }
1261 }
1262
1263 let mut stats = Stats {
1264 node_count: 0,
1265 non_empty_node_count: 0,
1266 internal_node_count: 0, // NEW
1267 child_slots_used: 0, // NEW
1268 bin_count: 0,
1269 empty_bin_count: 0,
1270 total_bin_items: 0,
1271 min_weight: None,
1272 max_weight: None,
1273 bin_sizes: Vec::new(),
1274 mem_nodes: 0,
1275 mem_bins: 0,
1276 };
1277
1278 traverse(&self.root, &mut stats, self.scale);
1279
1280 // --- Calculations ---
1281 let fill_ratio = if stats.node_count > 0 {
1282 stats.non_empty_node_count as f64 / stats.node_count as f64 * 100.0
1283 } else { 0.0 };
1284
1285 // NEW: Calculate average branching factor
1286 let avg_branching_factor = if stats.internal_node_count > 0 {
1287 stats.child_slots_used as f64 / stats.internal_node_count as f64
1288 } else { 0.0 };
1289
1290 let avg_bin_size = if stats.bin_count > 0 {
1291 stats.total_bin_items as f64 / stats.bin_count as f64
1292 } else { 0.0 };
1293
1294 let std_dev_bin_size = if stats.bin_count > 1 {
1295 let variance = stats.bin_sizes.iter()
1296 .map(|&size| (size as f64 - avg_bin_size).powi(2)) // Corrected: removed the `*`
1297 .sum::<f64>() / (stats.bin_count - 1) as f64;
1298 variance.sqrt()
1299 } else { 0.0 };
1300
1301 // NEW: Calculate bin size quartiles
1302 let (q1_bin_size, median_bin_size, q3_bin_size) = if !stats.bin_sizes.is_empty() {
1303 let mut sorted_sizes = stats.bin_sizes.clone();
1304 sorted_sizes.sort_unstable();
1305 let q1 = sorted_sizes.get(sorted_sizes.len() / 4).cloned().unwrap_or(0);
1306 let median = sorted_sizes.get(sorted_sizes.len() / 2).cloned().unwrap_or(0);
1307 let q3 = sorted_sizes.get(sorted_sizes.len() * 3 / 4).cloned().unwrap_or(0);
1308 (q1, median, q3)
1309 } else {
1310 (0, 0, 0)
1311 };
1312
1313 let total_mem_mb = (stats.mem_nodes + stats.mem_bins) as f64 / (1024.0 * 1024.0);
1314 let nodes_mem_mb = stats.mem_nodes as f64 / (1024.0 * 1024.0);
1315 let bins_mem_mb = stats.mem_bins as f64 / (1024.0 * 1024.0);
1316
1317 // NEW: Calculate average weight
1318 let avg_weight = if self.count() > 0 {
1319 self.total_weight() / self.count() as f64
1320 } else { 0.0 };
1321
1322
1323 // --- Printing ---
1324 println!("\n[Tree Structure]");
1325 println!("- Total Nodes Created: {}", stats.node_count);
1326 println!("- Internal Nodes: {}", stats.internal_node_count); // NEW
1327 println!("- Avg Branching Factor: {:.2} / 10", avg_branching_factor); // NEW
1328 println!("- Tree Fill Ratio: {:.2}%", fill_ratio);
1329 println!("- Max Depth: {}", self.precision);
1330
1331 println!("\n[Memory (Estimated)]");
1332 println!("- Tree Structure: {:.2} MB", nodes_mem_mb);
1333 println!("- Leaf Bins: {:.2} MB", bins_mem_mb);
1334 println!("- Total Estimated: {:.2} MB", total_mem_mb);
1335
1336 println!("\n[Items & Bins]");
1337 println!("- Total Items: {}", stats.total_bin_items);
1338 println!("- Total Bins (Leaves): {}", stats.bin_count);
1339 println!("- Empty Bins: {}", stats.empty_bin_count);
1340 println!("- Avg Items per Bin: {:.2}", avg_bin_size);
1341 println!("- Std Dev of Bin Size: {:.2}", std_dev_bin_size);
1342 println!("- Bin Size (min/max): {} / {}", stats.bin_sizes.iter().min().map_or(0, |v| *v), stats.bin_sizes.iter().max().map_or(0, |v| *v));
1343 println!("- Bin Size (Q1/Med/Q3): {} / {} / {}", q1_bin_size, median_bin_size, q3_bin_size); // NEW
1344
1345 println!("\n[Weights]");
1346 println!("- Smallest Weight: {}", stats.min_weight.map_or("-".to_string(), |v| format!("{:.prec$}", v, prec = self.precision as usize)));
1347 println!("- Largest Weight: {}", stats.max_weight.map_or("-".to_string(), |v| format!("{:.prec$}", v, prec = self.precision as usize)));
1348 println!("- Average Weight: {:.prec$}", avg_weight, prec = self.precision as usize); // NEW
1349 }
1350}
1351
1352#[cfg(feature = "python-bindings")]
1353mod python {
1354 use super::*;
1355 use pyo3::prelude::*;
1356
1357 #[pyclass(name = "DigitBinIndex")]
1358 struct PyDigitBinIndex {
1359 index: DigitBinIndex,
1360 }
1361
1362 #[pymethods]
1363 impl PyDigitBinIndex {
1364 #[new]
1365 fn new() -> Self {
1366 PyDigitBinIndex {
1367 index: DigitBinIndex::new(),
1368 }
1369 }
1370
1371 /// Create a DigitBinIndex with a specific precision.
1372 #[staticmethod]
1373 fn with_precision(precision: u64) -> Self {
1374 PyDigitBinIndex {
1375 index: DigitBinIndex::with_precision(precision.try_into().unwrap()),
1376 }
1377 }
1378
1379 /// Create a DigitBinIndex with a specific precision and expected capacity.
1380 #[staticmethod]
1381 fn with_precision_and_capacity(precision: u8, capacity: u64) -> Self {
1382 PyDigitBinIndex {
1383 index: DigitBinIndex::with_precision_and_capacity(precision, capacity),
1384 }
1385 }
1386
1387 /// Create a DigitBinIndex with Vec<u32> bins and the specified precision.
1388 #[staticmethod]
1389 fn small(precision: u8) -> Self {
1390 PyDigitBinIndex {
1391 index: DigitBinIndex::small(precision),
1392 }
1393 }
1394
1395 /// Create a DigitBinIndex with RoaringBitmap bins and the specified precision.
1396 #[staticmethod]
1397 fn medium(precision: u8) -> Self {
1398 PyDigitBinIndex {
1399 index: DigitBinIndex::medium(precision),
1400 }
1401 }
1402
1403 /// Create a DigitBinIndex with RoaringTreemap bins and the specified precision.
1404 #[staticmethod]
1405 fn large(precision: u8) -> Self {
1406 PyDigitBinIndex {
1407 index: DigitBinIndex::large(precision),
1408 }
1409 }
1410
1411 fn add(&mut self, id: u64, weight: f64) {
1412 self.index.add(id, weight)
1413 }
1414
1415 fn add_many(&mut self, items: Vec<(u64, f64)>) {
1416 self.index.add_many(&items);
1417 }
1418
1419 fn remove(&mut self, id: u64, weight: f64) -> bool {
1420 self.index.remove(id, weight)
1421 }
1422
1423 fn remove_many(&mut self, items: Vec<(u64, f64)>) -> bool {
1424 self.index.remove_many(&items)
1425 }
1426
1427 fn select(&mut self) -> Option<(u64, f64)> {
1428 self.index.select()
1429 }
1430
1431 fn select_many(&mut self, n: u64) -> Option<Vec<(u64, f64)>> {
1432 self.index.select_many(n)
1433 }
1434
1435 fn select_and_remove(&mut self) -> Option<(u64, f64)> {
1436 self.index.select_and_remove()
1437 }
1438
1439 fn select_many_and_remove(&mut self, n: u64) -> Option<Vec<(u64, f64)>> {
1440 self.index.select_many_and_remove(n)
1441 }
1442
1443 fn total_weight(&self) -> f64 {
1444 self.index.total_weight()
1445 }
1446
1447 fn count(&self) -> u64 {
1448 self.index.count()
1449 }
1450
1451 fn print_stats(&self) {
1452 self.index.print_stats();
1453 }
1454 }
1455
1456 #[pymodule]
1457 fn digit_bin_index(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
1458 m.add_class::<PyDigitBinIndex>()?;
1459 Ok(())
1460 }
1461}
1462
1463#[cfg(test)]
1464mod tests {
1465 use super::*;
1466
1467 #[test]
1468 fn test_select_and_remove() {
1469 let mut index = DigitBinIndex::with_precision(3);
1470 index.add(1, 0.085);
1471 index.add(2, 0.205);
1472 index.add(3, 0.346);
1473 index.add(4, 0.364);
1474 index.print_stats();
1475 println!("Initial state: {} individuals, total weight = {}", index.count(), index.total_weight());
1476 if let Some((id, weight)) = index.select_and_remove() {
1477 println!("Selected ID: {} with weight: {}", id, weight);
1478 }
1479 assert!(
1480 index.count() == 3,
1481 "The count is now {} and not 3 as expected",
1482 index.count()
1483 );
1484 println!("Intermediate state: {} individuals, total weight = {}", index.count(), index.total_weight());
1485 if let Some(selection) = index.select_many_and_remove(2) {
1486 println!("Selection: {:?}", selection);
1487 }
1488 assert!(
1489 index.count() == 1,
1490 "The count is now {} and not 1 as expected",
1491 index.count()
1492 );
1493 println!("Final state: {} individuals, total weight = {}", index.count(), index.total_weight());
1494 }
1495
1496 #[test]
1497 fn test_wallenius_distribution_is_correct() {
1498 // --- Setup: Create a controlled population ---
1499 const ITEMS_PER_GROUP: u64 = 1000;
1500 const TOTAL_ITEMS: u64 = ITEMS_PER_GROUP * 2;
1501 const NUM_DRAWS: u64 = TOTAL_ITEMS / 2;
1502
1503 let low_risk_weight = 0.1f64; // 0.1
1504 let high_risk_weight = 0.2f64; // 0.2
1505
1506 // --- Execution: Run many simulations to average out randomness ---
1507 const NUM_SIMULATIONS: u32 = 100;
1508 let mut total_high_risk_selected = 0;
1509
1510 for _ in 0..NUM_SIMULATIONS {
1511 let mut index = DigitBinIndex::with_precision_and_capacity(3, TOTAL_ITEMS);
1512 for i in 0..ITEMS_PER_GROUP { index.add(i, low_risk_weight); }
1513 for i in ITEMS_PER_GROUP..TOTAL_ITEMS { index.add(i, high_risk_weight); }
1514
1515 let mut high_risk_in_this_run = 0;
1516 for _ in 0..NUM_DRAWS {
1517 if let Some((selected_id, _)) = index.select_and_remove() {
1518 if selected_id >= ITEMS_PER_GROUP {
1519 high_risk_in_this_run += 1;
1520 }
1521 }
1522 }
1523 total_high_risk_selected += high_risk_in_this_run;
1524 }
1525
1526 // --- Validation: Check the statistical properties of a Wallenius' draw ---
1527 let avg_high_risk = total_high_risk_selected as f64 / NUM_SIMULATIONS as f64;
1528
1529 // 1. The mean of a uniform draw (central hypergeometric) would be 500.
1530 let uniform_mean = NUM_DRAWS as f64 * 0.5;
1531
1532 // 2. The mean of a simultaneous draw (Fisher's NCG) is based on initial proportions.
1533 // This is the naive expectation we started with.
1534 let fishers_mean = NUM_DRAWS as f64 * (2.0 / 3.0); // ~666.67
1535
1536 // The mean of a Wallenius' draw is mathematically proven to lie strictly
1537 // between the uniform mean and the Fisher's mean.
1538 assert!(
1539 avg_high_risk > uniform_mean,
1540 "Test failed: Result {:.2} was not biased towards higher weights (uniform mean is {:.2})",
1541 avg_high_risk, uniform_mean
1542 );
1543
1544 assert!(
1545 avg_high_risk < fishers_mean,
1546 "Test failed: Result {:.2} showed too much bias. It should be less than the Fisher's mean of {:.2} due to the Wallenius effect.",
1547 avg_high_risk, fishers_mean
1548 );
1549
1550 println!(
1551 "Distribution test passed: Got an average of {:.2} high-risk selections.",
1552 avg_high_risk
1553 );
1554 println!(
1555 "This correctly lies between the uniform mean ({:.2}) and the Fisher's mean ({:.2}), confirming the Wallenius' distribution behavior.",
1556 uniform_mean, fishers_mean
1557 );
1558 }
1559 #[test]
1560 fn test_fisher_distribution_is_correct() {
1561 const ITEMS_PER_GROUP: u64 = 1000;
1562 const TOTAL_ITEMS: u64 = ITEMS_PER_GROUP * 2;
1563 const NUM_DRAWS: u64 = TOTAL_ITEMS / 2;
1564
1565 let low_risk_weight = 0.1f64; // 0.1
1566 let high_risk_weight = 0.2f64; // 0.2
1567
1568 const NUM_SIMULATIONS: u32 = 100;
1569 let mut total_high_risk_selected = 0;
1570
1571 for _ in 0..NUM_SIMULATIONS {
1572 let mut index = DigitBinIndex::with_precision_and_capacity(3, TOTAL_ITEMS);
1573 for i in 0..ITEMS_PER_GROUP { index.add(i, low_risk_weight); }
1574 for i in ITEMS_PER_GROUP..TOTAL_ITEMS { index.add(i, high_risk_weight); }
1575
1576 // Call the new method
1577 if let Some(selected_ids) = index.select_many_and_remove(NUM_DRAWS) {
1578 let high_risk_in_this_run = selected_ids.iter().filter(|&&(id, _)| id >= ITEMS_PER_GROUP).count();
1579 total_high_risk_selected += high_risk_in_this_run as u32;
1580 }
1581 }
1582
1583 let avg_high_risk = total_high_risk_selected as f64 / NUM_SIMULATIONS as f64;
1584 let fishers_mean = NUM_DRAWS as f64 * (2.0 / 3.0);
1585 let tolerance = fishers_mean * 0.02;
1586
1587 // The mean of a Fisher's draw should be very close to the naive expectation.
1588 assert!(
1589 (avg_high_risk - fishers_mean).abs() < tolerance,
1590 "Fisher's test failed: Result {:.2} was not close to the expected mean of {:.2}",
1591 avg_high_risk, fishers_mean
1592 );
1593
1594 println!(
1595 "Fisher's test passed: Got avg {:.2} high-risk selections (expected ~{:.2}).",
1596 avg_high_risk, fishers_mean
1597 );
1598 }
1599}
1600
1601#[cfg(test)]
1602#[test]
1603fn test_weight_to_digits() {
1604 // Create an instance (using Vec<u32> as the bin type for simplicity)
1605 let index = DigitBinIndexGeneric::<Vec<u32>>::with_precision(3);
1606
1607 // Test valid weight
1608 let mut digits = [0u8; MAX_PRECISION];
1609 if let Some(scaled) = index.weight_to_digits(0.123, &mut digits) {
1610 assert_eq!(scaled, 123);
1611 assert_eq!(digits[0..3], [1, 2, 3]);
1612 assert_eq!(digits[3..], [0; 6]); // Remaining digits should be zero-padded
1613 } else {
1614 panic!("Expected Some for valid weight");
1615 }
1616
1617 // Test invalid weights
1618 assert!(index.weight_to_digits(0.0, &mut digits).is_none());
1619 assert!(index.weight_to_digits(-0.1, &mut digits).is_none());
1620 assert!(index.weight_to_digits(0.0000001, &mut digits).is_none()); // Rounds to zero after scaling
1621
1622 // Test overflow (though unlikely for weights <1, but for completeness)
1623 assert!(index.weight_to_digits(2.0, &mut digits).is_none()); // Should trigger temp != 0 check
1624}
1625
1626#[cfg(test)]
1627#[test]
1628fn test_add_many() {
1629 const CAPACITY: u64 = 1_000_000u64;
1630 let mut index_one_at_a_time = DigitBinIndex::with_precision_and_capacity(3, CAPACITY);
1631 let mut index_all_at_once = DigitBinIndex::with_precision_and_capacity(3, CAPACITY);
1632 let mut population = Vec::with_capacity(CAPACITY as usize);
1633 let mut rng = WyRand::from_os_rng();
1634 for i in 0..CAPACITY {
1635 let weight: f64 = rng.random_range(0.001..=0.999);
1636 population.push((i, weight));
1637 index_one_at_a_time.add(i, weight);
1638 }
1639 index_all_at_once.add_many(&population);
1640 index_one_at_a_time.print_stats();
1641 index_all_at_once.print_stats();
1642}
1643#[test]
1644fn test_add() {
1645 let mut index = DigitBinIndex::new();
1646 index.add(1, 0.5);
1647 index.print_stats();
1648}
1649