use alloc::boxed::Box;
use core::iter;
#[cfg(test)]
use crate::{
allocator::Global,
raw::{InsertPrefixError, InsertResult, OpaqueNodePtr},
AsBytes, TreeMap,
};
pub fn swap<A, B>((a, b): (A, B)) -> (B, A) {
(b, a)
}
pub fn generate_keys_skewed(max_len: usize) -> impl Iterator<Item = Box<[u8]>> {
assert!(max_len > 0, "the fixed key length must be greater than 0");
iter::successors(Some(vec![u8::MAX; 1].into_boxed_slice()), move |prev| {
if prev.len() < max_len {
let mut key = vec![u8::MIN; prev.len()];
key.push(u8::MAX);
Some(key.into_boxed_slice())
} else {
None
}
})
}
pub fn generate_key_fixed_length<const KEY_LENGTH: usize>(
level_widths: [u8; KEY_LENGTH],
) -> impl Iterator<Item = [u8; KEY_LENGTH]> {
struct FixedLengthKeys<const KEY_LENGTH: usize> {
level_widths: [u8; KEY_LENGTH],
next_value: Option<[u8; KEY_LENGTH]>,
}
impl<const KEY_LENGTH: usize> FixedLengthKeys<KEY_LENGTH> {
pub fn new(level_widths: [u8; KEY_LENGTH]) -> Self {
assert!(
KEY_LENGTH > 0,
"the fixed key length must be greater than 0"
);
assert!(
level_widths.iter().all(|value_stops| value_stops > &0),
"the number of distinct values for each key digit must be greater than 0"
);
FixedLengthKeys {
level_widths,
next_value: Some([u8::MIN; KEY_LENGTH]),
}
}
}
impl<const KEY_LENGTH: usize> Iterator for FixedLengthKeys<KEY_LENGTH> {
type Item = [u8; KEY_LENGTH];
fn next(&mut self) -> Option<Self::Item> {
let next_value = self.next_value.take()?;
if next_value
.iter()
.zip(self.level_widths)
.all(|(digit, max_digit)| *digit == max_digit)
{
return Some(next_value);
}
let mut new_next_value = next_value;
for idx in (0..new_next_value.len()).rev() {
if new_next_value[idx] == self.level_widths[idx] {
new_next_value[idx] = u8::MIN;
} else {
new_next_value[idx] = new_next_value[idx].saturating_add(1);
break;
}
}
self.next_value = Some(new_next_value);
Some(next_value)
}
}
FixedLengthKeys::new(level_widths)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PrefixExpansion {
pub base_index: usize,
pub expanded_length: usize,
}
pub fn generate_key_with_prefix<const KEY_LENGTH: usize>(
level_widths: [u8; KEY_LENGTH],
prefix_expansions: impl AsRef<[PrefixExpansion]>,
) -> impl Iterator<Item = Box<[u8]>> {
fn apply_expansions_to_key(
old_key: &[u8],
new_key_template: &[u8],
sorted_expansions: &[PrefixExpansion],
) -> Box<[u8]> {
let mut new_key: Box<[u8]> = new_key_template.into();
let mut new_key_index = 0usize;
let mut old_key_index = 0usize;
for expansion in sorted_expansions {
let before_len = expansion.base_index - old_key_index;
new_key[new_key_index..(new_key_index + before_len)]
.copy_from_slice(&old_key[old_key_index..expansion.base_index]);
new_key[(new_key_index + before_len)
..(new_key_index + before_len + expansion.expanded_length)]
.fill(old_key[expansion.base_index]);
old_key_index = expansion.base_index + 1;
new_key_index += before_len + expansion.expanded_length;
}
new_key[new_key_index..].copy_from_slice(&old_key[old_key_index..]);
new_key
}
let expansions = prefix_expansions.as_ref();
assert!(
expansions
.iter()
.all(|expand| { expand.base_index < KEY_LENGTH }),
"the prefix expansion index must be less than `base_key_len`."
);
assert!(
expansions
.iter()
.all(|expand| { expand.expanded_length > 0 }),
"the prefix expansion length must be greater than 0."
);
#[cfg(feature = "std")]
{
let mut uniq_indices = std::collections::HashSet::new();
assert!(
expansions
.iter()
.all(|expand| uniq_indices.insert(expand.base_index)),
"the prefix expansion index must be unique"
);
}
let mut sorted_expansions = expansions.to_vec();
sorted_expansions.sort_by_key(|a| a.base_index);
let full_key_len = expansions
.iter()
.map(|expand| expand.expanded_length - 1)
.sum::<usize>()
+ KEY_LENGTH;
let full_key_template = vec![u8::MIN; full_key_len].into_boxed_slice();
generate_key_fixed_length(level_widths)
.map(move |key| apply_expansions_to_key(&key, &full_key_template, &sorted_expansions))
}
#[cfg(test)]
pub(crate) unsafe fn insert_unchecked<'a, K, V, const PREFIX_LEN: usize>(
root: OpaqueNodePtr<K, V, PREFIX_LEN>,
key: K,
value: V,
) -> Result<InsertResult<'a, K, V, PREFIX_LEN>, InsertPrefixError>
where
K: AsBytes + 'a,
{
use crate::raw::search_for_insert_point;
let insert_point = unsafe { search_for_insert_point(root, key.as_bytes())? };
Ok(unsafe { insert_point.apply(key, value, &Global) })
}
#[cfg(test)]
pub(crate) fn setup_tree_from_entries<K, V, const PREFIX_LEN: usize>(
entries_it: impl Iterator<Item = (K, V)>,
) -> OpaqueNodePtr<K, V, PREFIX_LEN>
where
K: AsBytes,
{
let mut tree = TreeMap::with_prefix_len();
for (key, value) in entries_it {
let _ = tree.try_insert(key, value).unwrap();
}
TreeMap::into_raw(tree).unwrap()
}
#[cfg(all(test, not(miri)))]
mod tests {
use super::*;
use crate::TreeMap;
#[test]
fn key_generator_returns_expected_number_of_entries() {
#[track_caller]
fn check<K: AsBytes>(it: impl IntoIterator<Item = K>, expected_num_entries: usize) {
let mut num_entries = 0;
let it = it.into_iter().inspect(|_| num_entries += 1);
let mut tree = TreeMap::new();
for (key, value) in it.enumerate().map(|(a, b)| (b, a)) {
tree.try_insert(key, value).unwrap();
}
assert_eq!(num_entries, tree.len());
assert_eq!(expected_num_entries, num_entries);
}
check(generate_key_fixed_length([3, 2, 1]), 4 * 3 * 2);
check(generate_key_fixed_length([15, 2]), 16 * 3);
check(generate_key_fixed_length([255]), 256);
check(generate_key_fixed_length([127]), 128);
check(generate_key_fixed_length([7; 5]), 8 * 8 * 8 * 8 * 8);
let no_op_expansion = [PrefixExpansion {
base_index: 0,
expanded_length: 1,
}];
check(
generate_key_with_prefix([3, 2, 1], no_op_expansion),
4 * 3 * 2,
);
check(generate_key_with_prefix([15, 2], no_op_expansion), 16 * 3);
check(generate_key_with_prefix([255], no_op_expansion), 256);
check(generate_key_with_prefix([127], no_op_expansion), 128);
check(
generate_key_with_prefix(
[3, 2, 1],
[
PrefixExpansion {
base_index: 0,
expanded_length: 1,
},
PrefixExpansion {
base_index: 1,
expanded_length: 1,
},
PrefixExpansion {
base_index: 2,
expanded_length: 1,
},
],
),
4 * 3 * 2,
);
check(
generate_key_with_prefix(
[3, 2, 1],
[
PrefixExpansion {
base_index: 0,
expanded_length: 3,
},
PrefixExpansion {
base_index: 1,
expanded_length: 256,
},
PrefixExpansion {
base_index: 2,
expanded_length: 127,
},
],
),
4 * 3 * 2,
);
}
}