use crate::indices::*;
use crate::search::*;
use crate::{FeatureHeap, NodeQueue};
use hashbrown::HashMap;
use log::trace;
use swar::*;
const TAU: usize = 1 << 16;
const TABLE_TAUS: [usize; 7] = [
START_TAU_MAG,
START_TAU_MAG << 1,
START_TAU_MAG << 2,
START_TAU_MAG << 2,
START_TAU_MAG << 3,
START_TAU_MAG << 5,
START_TAU_MAG << 6,
];
const START_TAU_MAG: usize = 128;
const INITIAL_CAPACITY: usize = 16;
pub(crate) type InternalMap = HashMap<u128, u32, std::hash::BuildHasherDefault<ahash::AHasher>>;
#[derive(Debug)]
enum Internal {
Vec(Vec<u128>),
Map(InternalMap),
}
impl Default for Internal {
fn default() -> Self {
Internal::Vec(Vec::with_capacity(INITIAL_CAPACITY))
}
}
pub struct Hwt {
internals: Vec<Internal>,
count: usize,
}
impl Hwt {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn allocate_internal(&mut self) -> u32 {
let internal = self.internals.len() as u32;
assert!(internal < std::u32::MAX);
self.internals.push(Internal::default());
internal
}
fn convert(&mut self, internal: usize, level: usize) {
let mut old_vec = Internal::Vec(Vec::new());
std::mem::swap(&mut self.internals[internal], &mut old_vec);
self.internals[internal] = match old_vec {
Internal::Vec(v) => {
let mut map = InternalMap::default();
for feature in v.into_iter() {
let index = indices128(feature)[level];
let new_internal =
*map.entry(index).or_insert_with(|| self.allocate_internal());
if let Internal::Vec(ref mut v) = self.internals[new_internal as usize] {
v.push(feature);
} else {
unreachable!(
"cannot have InternalStore::Map in subtable when just created"
);
}
}
Internal::Map(map.into_iter().collect())
}
_ => panic!("tried to convert an InternalStore::Map"),
}
}
pub fn insert(&mut self, feature: u128) {
self.count += 1;
let indices = indices128(feature);
let mut bucket = 0;
let mut create_internal = None;
for (i, &tc) in indices.iter().enumerate() {
match &mut self.internals[bucket] {
Internal::Vec(ref mut v) => {
v.push(feature);
if v.len() > TAU {
self.convert(bucket, i);
}
return;
}
Internal::Map(ref mut map) => {
match map.get(&tc) {
Some(&internal) => {
bucket = internal as usize;
}
None => {
create_internal = Some(tc);
break;
}
}
}
}
}
if let Some(vacant_node) = create_internal {
let new_internal = self.allocate_internal();
if let Internal::Vec(ref mut v) = self.internals[new_internal as usize] {
v.push(feature);
} else {
unreachable!("cannot have InternalStore::Map in subtable when just created");
}
if let Internal::Map(ref mut map) = &mut self.internals[bucket] {
map.insert(vacant_node, new_internal);
} else {
unreachable!("shouldn't ever get vec after finding vacant map node");
}
} else {
match self.internals[bucket] {
Internal::Vec(ref mut v) => v.push(feature),
_ => panic!("Can't have InternalStore::Map at bottom of tree"),
}
}
}
pub fn contains(&mut self, feature: u128) -> bool {
let indices = indices128(feature);
let mut bucket = 0;
for &index in &indices {
match &self.internals[bucket] {
Internal::Vec(vec) => return vec.iter().cloned().any(|n| n == feature),
Internal::Map(map) => {
if let Some((_, &internal)) = map.iter().find(|&(&tc, _)| tc == index) {
bucket = internal as usize;
} else {
return false;
}
}
}
}
false
}
#[allow(clippy::cognitive_complexity)]
pub fn nearest<'a>(
&self,
feature: u128,
max_weight: u32,
max_error: u32,
node_queue: &mut NodeQueue,
feature_heap: &mut FeatureHeap,
dest: &'a mut [u128],
) -> &'a mut [u128] {
trace!(
"nearest feature({:032X}) weight({})",
feature,
feature.count_ones()
);
let indices = indices128(feature);
node_queue.clear();
feature_heap.reset(dest.len(), feature);
match &self.internals[0] {
Internal::Vec(v) => {
trace!("nearest sole leaf node len({})", v.len());
feature_heap.add(v.as_slice());
return feature_heap.fill_slice(dest);
}
Internal::Map(m) => {
trace!("nearest emptying root len({})", m.len());
for (distance, node) in m
.iter()
.map(|(&tc, &node)| {
let distance = (tc ^ indices[0]).count_ones();
(distance, node)
})
.filter(|&(distance, _)| distance <= max_weight)
{
match unsafe {
std::mem::transmute::<_, &'static Internal>(&self.internals[node as usize])
} {
Internal::Vec(v) => {
feature_heap.add(v.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
node_queue.add_one((distance, &m, 0));
}
}
}
}
}
for distance in 0..=max_weight {
trace!("searching distance({})", distance);
feature_heap.search_distance(std::cmp::min(128, distance + max_error));
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
while node_queue.distance() == Some(distance) {
if let Some((_, internal, level)) = node_queue.pop() {
if level == 7 {
unreachable!("hwt: it is impossible to have an internal node at layer 7");
}
trace!(
"nearest node distance({}) len({}) level({})",
distance,
internal.len(),
level
);
if internal.len() < TABLE_TAUS[level as usize] {
trace!("nearest brute force");
for (child_distance, child) in internal.iter().map(|(&tc, &child)| {
let child_distance = (tc ^ indices[(level + 1) as usize]).count_ones();
(child_distance, child)
}) {
match unsafe {
std::mem::transmute::<_, &'static Internal>(
&self.internals[child as usize],
)
} {
Internal::Vec(leaves) => {
feature_heap.add(leaves.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
node_queue.add_one((child_distance, &m, level + 1));
}
}
}
} else {
trace!("nearest precision search");
match level {
0 => {
let tp = Bits64(*internal.iter().next().unwrap().0).pack_ones();
for Bits64(tc) in search_exact2(
64,
Bits128(indices[level as usize]),
Bits64(indices[level as usize + 1]),
tp,
distance,
) {
if let Some(&child) = internal.get(&tc) {
match unsafe {
std::mem::transmute::<_, &'static Internal>(
&self.internals[child as usize],
)
} {
Internal::Vec(leaves) => {
trace!("nearest leaves len({})", leaves.len());
feature_heap.add(leaves.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
trace!("nearest map len({})", m.len());
let child_distance =
(tc ^ indices[level as usize + 1]).count_ones();
node_queue.add_one((child_distance, &m, level + 1));
}
}
}
}
}
1 => {
let tp = Bits32(*internal.iter().next().unwrap().0).pack_ones();
for Bits32(tc) in search_exact4(
32,
Bits64(indices[level as usize]),
Bits32(indices[level as usize + 1]),
tp,
distance,
) {
if let Some(&child) = internal.get(&tc) {
match unsafe {
std::mem::transmute::<_, &'static Internal>(
&self.internals[child as usize],
)
} {
Internal::Vec(leaves) => {
trace!("nearest leaves len({})", leaves.len());
feature_heap.add(leaves.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
trace!("nearest map len({})", m.len());
let child_distance =
(tc ^ indices[level as usize + 1]).count_ones();
node_queue.add_one((child_distance, &m, level + 1));
}
}
}
}
}
2 => {
let tp = Bits16(*internal.iter().next().unwrap().0).pack_ones();
for Bits16(tc) in search_exact8(
16,
Bits32(indices[level as usize]),
Bits16(indices[level as usize + 1]),
tp,
distance,
) {
if let Some(&child) = internal.get(&tc) {
match unsafe {
std::mem::transmute::<_, &'static Internal>(
&self.internals[child as usize],
)
} {
Internal::Vec(leaves) => {
trace!("nearest leaves len({})", leaves.len());
feature_heap.add(leaves.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
trace!("nearest map len({})", m.len());
let child_distance =
(tc ^ indices[level as usize + 1]).count_ones();
node_queue.add_one((child_distance, &m, level + 1));
}
}
}
}
}
3 => {
let tp = Bits8(*internal.iter().next().unwrap().0).pack_ones();
for Bits8(tc) in search_exact16(
8,
Bits16(indices[level as usize]),
Bits8(indices[level as usize + 1]),
tp,
distance,
) {
if let Some(&child) = internal.get(&tc) {
match unsafe {
std::mem::transmute::<_, &'static Internal>(
&self.internals[child as usize],
)
} {
Internal::Vec(leaves) => {
trace!("nearest leaves len({})", leaves.len());
feature_heap.add(leaves.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
trace!("nearest map len({})", m.len());
let child_distance =
(tc ^ indices[level as usize + 1]).count_ones();
node_queue.add_one((child_distance, &m, level + 1));
}
}
}
}
}
4 => {
let tp = Bits4(*internal.iter().next().unwrap().0).pack_ones();
for Bits4(tc) in search_exact32(
4,
Bits8(indices[level as usize]),
Bits4(indices[level as usize + 1]),
tp,
distance,
) {
if let Some(&child) = internal.get(&tc) {
match unsafe {
std::mem::transmute::<_, &'static Internal>(
&self.internals[child as usize],
)
} {
Internal::Vec(leaves) => {
trace!("nearest leaves len({})", leaves.len());
feature_heap.add(leaves.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
trace!("nearest map len({})", m.len());
let child_distance =
(tc ^ indices[level as usize + 1]).count_ones();
node_queue.add_one((child_distance, &m, level + 1));
}
}
}
}
}
5 => {
let tp = Bits2(*internal.iter().next().unwrap().0).pack_ones();
for Bits2(tc) in search_exact64(
2,
Bits4(indices[level as usize]),
Bits2(indices[level as usize + 1]),
tp,
distance,
) {
if let Some(&child) = internal.get(&tc) {
match unsafe {
std::mem::transmute::<_, &'static Internal>(
&self.internals[child as usize],
)
} {
Internal::Vec(leaves) => {
trace!("nearest leaves len({})", leaves.len());
feature_heap.add(leaves.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
trace!("nearest map len({})", m.len());
let child_distance =
(tc ^ indices[level as usize + 1]).count_ones();
node_queue.add_one((child_distance, &m, level + 1));
}
}
}
}
}
6 => {
let tp = Bits1(*internal.iter().next().unwrap().0).pack_ones();
for Bits1(tc) in search_exact128(
1,
Bits2(indices[level as usize]),
Bits1(indices[level as usize + 1]),
tp,
distance,
) {
if let Some(&child) = internal.get(&tc) {
match unsafe {
std::mem::transmute::<_, &'static Internal>(
&self.internals[child as usize],
)
} {
Internal::Vec(leaves) => {
trace!("nearest leaves len({})", leaves.len());
feature_heap.add(leaves.as_slice());
if feature_heap.done() {
return feature_heap.fill_slice(dest);
}
}
Internal::Map(m) => {
trace!("nearest map len({})", m.len());
let child_distance =
(tc ^ indices[level as usize + 1]).count_ones();
node_queue.add_one((child_distance, &m, level + 1));
}
}
}
}
}
_ => unreachable!(),
}
if distance != 128 {
node_queue.add_one((distance + 1, internal, level));
}
}
}
}
}
feature_heap.fill_slice(dest)
}
pub fn search_radius<'a>(
&'a self,
radius: u32,
feature: u128,
) -> impl Iterator<Item = u128> + 'a {
let index = indices128(feature)[1];
self.bucket_scan_radius(radius, feature, 0, Self::radius2, move |tc| {
(tc ^ index).count_ones() <= radius
})
}
fn radius2<'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
) -> impl Iterator<Item = u128> + 'a {
let index = indices128(feature)[2];
self.bucket_scan_radius(radius, feature, bucket, Self::radius4, move |tc| {
(tc ^ index).count_ones() <= radius
})
}
fn radius4<'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
) -> impl Iterator<Item = u128> + 'a {
let index = indices128(feature)[3];
self.bucket_scan_radius(radius, feature, bucket, Self::radius8, move |tc| {
(tc ^ index).count_ones() <= radius
})
}
fn radius8<'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
) -> impl Iterator<Item = u128> + 'a {
let index = indices128(feature)[4];
self.bucket_scan_radius(radius, feature, bucket, Self::radius16, move |tc| {
(tc ^ index).count_ones() <= radius
})
}
fn radius16<'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
) -> impl Iterator<Item = u128> + 'a {
let index = indices128(feature)[5];
self.bucket_scan_radius(radius, feature, bucket, Self::radius32, move |tc| {
(tc ^ index).count_ones() <= radius
})
}
fn radius32<'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
) -> impl Iterator<Item = u128> + 'a {
let index = indices128(feature)[6];
self.bucket_scan_radius(radius, feature, bucket, Self::radius64, move |tc| {
(tc ^ index).count_ones() <= radius
})
}
fn radius64<'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
) -> impl Iterator<Item = u128> + 'a {
let index = indices128(feature)[7];
self.bucket_scan_radius(radius, feature, bucket, Self::radius128, move |tc| {
(tc ^ index).count_ones() <= radius
})
}
fn radius128<'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
) -> impl Iterator<Item = u128> + 'a {
self.bucket_scan_radius(
radius,
feature,
bucket,
|_, _, _, bucket| -> Box<dyn Iterator<Item = u128> + 'a> {
panic!(
"hwt::Hwt::neighbors128(): it is an error to find an internal node this far down in the tree (bucket: {})", bucket,
)
},
move |tc| panic!("hwt::Hwt::neighbors128(): it is an error to find an internal node this far down in the tree (tc: {})", tc)
)
}
#[allow(clippy::too_many_arguments)]
fn bucket_scan_radius<'a, I: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
subtable: fn(&'a Self, u32, u128, usize) -> I,
filter: impl Fn(u128) -> bool + 'a,
) -> Box<dyn Iterator<Item = u128> + 'a>
where
I: Iterator<Item = u128>,
{
trace!(
"bucket_scan_radius feature({:032X}) radius({}) bucket({})",
feature,
radius,
bucket,
);
let lookup_distance = move |leaf: u128| (leaf ^ feature).count_ones();
match &self.internals[bucket] {
Internal::Vec(v) => Box::new(
v.iter()
.cloned()
.filter(move |&leaf| lookup_distance(leaf) <= radius),
),
Internal::Map(m) => Box::new(
m.iter()
.filter(move |&(&key, _)| filter(key))
.flat_map(move |(_, &node)| subtable(self, radius, feature, node as usize)),
),
}
}
}
impl Default for Hwt {
fn default() -> Self {
Self {
internals: vec![Internal::default()],
count: 0,
}
}
}