use super::*;
use num_traits::{bounds::Bounded, AsPrimitive, FromPrimitive, Unsigned};
fn find_half_point<K, V>(bucket: &ArrayHash<twox_hash::XxHash64, Box<[K]>, V>, index: usize, start: usize, end: usize) -> K
where K: AsPrimitive<usize> + core::hash::Hash + FromPrimitive + Bounded + PartialEq + PartialOrd + Unsigned,
Box<[K]>: Clone + core::cmp::PartialEq,
V: Clone {
let mut count = vec![0; end - start + 1];
let mut split_portion = 0;
for (k, _) in bucket.iter() {
if k.len() > index {
count[k[index].as_() - start] += 1;
split_portion += 1;
}
}
split_portion /= 2;
let mut split_point = 0;
let mut first_portion = 0;
while first_portion < split_portion {
first_portion += count[split_point];
split_point += 1;
}
split_point += start;
K::from_usize(split_point).unwrap()
}
#[derive(Clone, Debug)]
struct ChildsContainer<K, T, V>
where K: AsPrimitive<usize> + core::hash::Hash + PartialEq + PartialOrd,
Box<[K]>: Clone + core::cmp::PartialEq,
T: Clone + TrieNode<K, V>,
V: Clone {
childs: Vec<NodeType<K, T, V>>,
_internal_pointer: Box<[*mut NodeType<K, T, V>]>
}
impl<K, T, V> ChildsContainer<K, T, V>
where K: AsPrimitive<usize> + Bounded + core::fmt::Debug + core::hash::Hash + FromPrimitive + PartialEq + PartialOrd + Unsigned,
Box<[K]>: Clone + core::cmp::PartialEq,
T: Clone + TrieNode<K, V>,
V: Clone + Default {
pub fn new(size: usize) -> Self {
let mut childs = Vec::with_capacity(size);
let start = K::min_value().as_();
let end = K::max_value().as_();
let bucket = NodeType::Hybrid((ArrayHashBuilder::default().build(), K::from_usize(start).unwrap()..=K::from_usize(end).unwrap()));
childs.push(bucket);
let ptrs = (0..size).map(|_| (&mut childs[0]) as *mut NodeType<K, T, V>).collect();
ChildsContainer {
childs,
_internal_pointer: ptrs
}
}
fn maybe_split(&mut self, key: K) -> bool {
let mut pure_trie = None;
let mut split_result = None;
match &mut self[key] {
NodeType::Pure(bucket) => {
if bucket.len() >= super::BURST_THRESHOLD {
pure_trie = Some(key);
} else {
return false
}
},
NodeType::Hybrid((bucket, range)) => {
if bucket.len() >= super::BURST_THRESHOLD {
let start = range.start().as_();
let end = range.end().as_();
let split_point = find_half_point(bucket, 0, start, end + 1);
let new_bucket = bucket.split_by(|(k, _)| {
k[0] >= split_point
});
*range = K::from_usize(start).unwrap()..=K::from_usize(split_point.as_() - 1).unwrap();
split_result = Some((new_bucket, [start, split_point.as_(), end]));
} else {
return false
}
},
_ => ()
};
if let Some(key) = pure_trie {
let k = key.as_();
unsafe {
let old_child = std::mem::take(&mut *self._internal_pointer[k]);
if let NodeType::Pure(bucket) = old_child {
let trie = NodeType::Trie(T::new_split(bucket));
std::mem::replace(&mut *self._internal_pointer[k], trie);
}
}
return true
} else if let Some((new_bucket, [start, split_point, end])) = split_result {
if start == split_point {
let old = std::mem::take(&mut self[key]);
if let NodeType::Hybrid((table, _)) = old {
std::mem::replace(&mut self[key], NodeType::Pure(table));
}
}
if split_point < end {
self.childs.push(NodeType::Hybrid((new_bucket, K::from_usize(split_point).unwrap()..=K::from_usize(end).unwrap())));
} else {
self.childs.push(NodeType::Pure(new_bucket));
}
let last = self.childs.len() - 1;
for ptr in self._internal_pointer[split_point..].iter_mut() {
*ptr = (&mut self.childs[last]) as *mut NodeType<K, T, V>;
}
return true
}
false
}
}
impl<K, T, V> core::ops::Index<K> for ChildsContainer<K, T, V>
where K: AsPrimitive<usize> + core::hash::Hash + PartialEq + PartialOrd + Unsigned,
Box<[K]>: Clone + core::cmp::PartialEq,
T: Clone + TrieNode<K, V>,
V: Clone {
type Output=NodeType<K, T, V>;
fn index(&self, idx: K) -> &Self::Output {
unsafe {
&*self._internal_pointer[idx.as_()]
}
}
}
impl<K, T, V> core::ops::IndexMut<K> for ChildsContainer<K, T, V>
where K: AsPrimitive<usize> + core::hash::Hash + PartialEq + PartialOrd + Unsigned,
Box<[K]>: Clone + core::cmp::PartialEq,
T: Clone + TrieNode<K, V>,
V: Clone {
fn index_mut(&mut self, idx: K) -> &mut Self::Output {
unsafe {
&mut *self._internal_pointer[idx.as_()]
}
}
}
#[derive(Clone, Debug)]
pub struct DenseVecTrieNode<K, V>
where K: AsPrimitive<usize> + Bounded + Copy + core::fmt::Debug + core::hash::Hash + FromPrimitive + PartialEq + PartialOrd + Sized + Unsigned,
Box<[K]>: Clone + PartialEq,
V: Clone + Default {
childs: ChildsContainer<K, Self, V>,
value: Option<V>,
}
impl<K, V> TrieNode<K, V> for DenseVecTrieNode<K, V>
where K: AsPrimitive<usize> + Bounded + Copy + core::fmt::Debug + core::hash::Hash + FromPrimitive + PartialEq + PartialOrd + Sized + Unsigned,
Box<[K]>: Clone + PartialEq,
V: Clone + Default {
fn new_split(mut bucket: ArrayHash<twox_hash::XxHash64, Box<[K]>, V>) -> Self {
let start = K::min_value().as_();
let end = K::max_value().as_();
let mut old_size = bucket.len();
let split_point = find_half_point(&bucket, 1, start, end);
let mut node_value = None;
let mut left = ArrayHashBuilder::default().build();
let mut right = ArrayHashBuilder::default().build();
for (key, value) in bucket.drain() {
if key.len() == 1 {
node_value = Some(value);
old_size -= 1;
continue;
}
if key[1] >= split_point {
assert!(right.put(key[1..].into(), value).is_none());
} else {
assert!(left.put(key[1..].into(), value).is_none());
}
}
assert_eq!(old_size, left.len() + right.len());
let mut childs = vec![
NodeType::Hybrid((left, K::from_usize(start).unwrap()..=K::from_usize(split_point.as_() - 1).unwrap())),
NodeType::Hybrid((right, split_point..=K::from_usize(end).unwrap())) ];
let split_point_usize = split_point.as_();
let ptr = (start..=end).map(|key| {
if key >= split_point_usize {
(&mut childs[1]) as *mut NodeType<K, DenseVecTrieNode<K, V>, V>
} else {
(&mut childs[0]) as *mut NodeType<K, DenseVecTrieNode<K, V>, V>
}
}).collect::<Vec<*mut NodeType<K, DenseVecTrieNode<K, V>, V>>>().into_boxed_slice();
DenseVecTrieNode {
childs: ChildsContainer {
childs,
_internal_pointer: ptr
},
value: node_value,
}
}
fn child<'a>(&'a self, key: &K) -> &'a NodeType<K, Self, V> {
&self.childs[*key]
}
fn value(&self) -> Option<&V> {
self.value.as_ref()
}
fn value_mut(&mut self) -> &mut Option<V> {
&mut self.value
}
fn put(&mut self, key: &[K], value: V) -> Option<V> {
if key.len() == 0 {return None}
let mut offset = 0;
let mut parent = self;
let mut nodes = &mut parent.childs;
nodes.maybe_split(key[0]);
let mut node;
loop {
node = &mut nodes[key[offset]];
match node {
NodeType::None => {
let mut bucket = ArrayHashBuilder::default().build();
bucket.put(key[offset..].into(), value);
*node = NodeType::Pure(bucket);
return None
}
NodeType::Trie(t) => {
parent = t;
offset += 1;
if offset >= key.len() {
return parent.value.replace(value)
}
nodes = &mut parent.childs;
nodes.maybe_split(key[offset]);
},
NodeType::Pure(childs) => {
let old = childs.put(key[offset..].into(), value);
if let Some((_, v)) = old {
return Some(v)
} else {
return None
}
},
NodeType::Hybrid(bucket) => {
let old = bucket.0.put(key[offset..].into(), value);
if let Some((_, v)) = old {
return Some(v)
} else {
return None
}
}
}
}
}
fn try_put<'a>(&'a mut self, key: &[K], value: V) -> Option<&'a V> where K: 'a {
if key.len() == 0 {return None}
let mut offset = 0;
let mut parent = self;
let mut nodes = &mut parent.childs;
let mut node;
loop {
node = &mut nodes[key[offset]];
if offset < key.len() {
match node {
NodeType::None => {
let mut bucket = ArrayHashBuilder::default().build();
bucket.put(key[offset..].into(), value);
*node = NodeType::Pure(bucket);
return None
},
NodeType::Trie(t) => {
parent = t;
nodes = &mut parent.childs;
},
NodeType::Pure(childs) => {
return childs.try_put(key[offset..].into(), value)
},
NodeType::Hybrid(bucket) => {
return bucket.0.try_put(key[offset..].into(), value)
}
}
} else {
match node {
NodeType::Trie(ref mut t) => {
if t.value.is_none() {
t.value.replace(value);
return None
} else {
return t.value.as_ref()
}
},
_ => {
return None
}
}
}
offset += 1;
}
}
}
impl<K, V> DenseVecTrieNode<K, V>
where K: Copy + AsPrimitive<usize> + Bounded + core::fmt::Debug + core::hash::Hash + FromPrimitive + PartialEq + PartialOrd + Sized + Unsigned,
Box<[K]>: Clone + PartialEq,
V: Clone + Default {
pub fn new() -> DenseVecTrieNode<K, V> {
DenseVecTrieNode {
value: None,
childs: ChildsContainer::new(2usize.pow(std::mem::size_of::<K>() as u32 * 8u32))
}
}
}
#[cfg(test)]
mod tests;