use crate::indices::*;
use crate::search::*;
use hashbrown::{hash_map::Entry, HashMap};
use std::cmp::{max, min};
use swar::*;
const TAU: usize = 1024;
const INITIAL_CAPACITY: usize = 16;
#[derive(Debug)]
enum Internal {
Vec(Vec<u32>),
Map(HashMap<u128, u32, std::hash::BuildHasherDefault<ahash::AHasher>>),
}
impl Default for Internal {
fn default() -> Self {
Internal::Vec(Vec::with_capacity(INITIAL_CAPACITY))
}
}
pub struct Hwt {
internals: Vec<Internal>,
count: usize,
}
impl Hwt {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn allocate_internal(&mut self) -> u32 {
let internal = self.internals.len() as u32;
assert!(internal < std::u32::MAX);
self.internals.push(Internal::default());
internal
}
fn convert<F>(&mut self, internal: usize, level: usize, mut lookup: F)
where
F: FnMut(u32) -> u128,
{
let mut old_vec = Internal::Vec(Vec::new());
std::mem::swap(&mut self.internals[internal], &mut old_vec);
self.internals[internal] = match old_vec {
Internal::Vec(v) => {
let mut map = HashMap::default();
for leaf in v.into_iter() {
let leaf_feature = lookup(leaf);
let index = indices128(leaf_feature)[level];
let new_internal =
*map.entry(index).or_insert_with(|| self.allocate_internal());
if let Internal::Vec(ref mut v) = self.internals[new_internal as usize] {
v.push(leaf);
} else {
unreachable!(
"cannot have InternalStore::Map in subtable when just created"
);
}
}
Internal::Map(map)
}
_ => panic!("tried to convert an InternalStore::Map"),
}
}
pub fn insert<F>(&mut self, feature: u128, item: u32, mut lookup: F)
where
F: FnMut(u32) -> u128,
{
self.count += 1;
let indices = indices128(feature);
let mut bucket = 0;
let mut create_internal = None;
for (i, &node) in indices.iter().enumerate() {
match &mut self.internals[bucket] {
Internal::Vec(ref mut v) => {
v.push(item);
if v.len() > TAU {
self.convert(bucket, i, &mut lookup);
}
return;
}
Internal::Map(ref mut map) => {
match map.entry(node) {
Entry::Occupied(o) => {
let internal = *o.get();
bucket = internal as usize;
}
Entry::Vacant(_) => {
create_internal = Some(node);
break;
}
}
}
}
}
if let Some(vacant_node) = create_internal {
let new_internal = self.allocate_internal();
if let Internal::Vec(ref mut v) = self.internals[new_internal as usize] {
v.push(item);
} else {
unreachable!("cannot have InternalStore::Map in subtable when just created");
}
if let Internal::Map(ref mut map) = &mut self.internals[bucket] {
map.insert(vacant_node, new_internal);
} else {
unreachable!("shouldn't ever get vec after finding vacant map node");
}
} else {
match self.internals[bucket] {
Internal::Vec(ref mut v) => v.push(item),
_ => panic!("Can't have InternalStore::Map at bottom of tree"),
}
}
}
pub fn get<F>(&mut self, feature: u128, mut lookup: F) -> Option<u32>
where
F: FnMut(u32) -> u128,
{
let indices = indices128(feature);
let mut bucket = 0;
for index in &indices {
match &self.internals[bucket] {
Internal::Vec(vec) => return vec.iter().cloned().find(|&n| lookup(n) == feature),
Internal::Map(map) => {
if let Some(&occupied_node) = map.get(index) {
bucket = occupied_node as usize;
} else {
return None;
}
}
}
}
None
}
pub fn nearest<'a, F: 'a>(
&'a self,
feature: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
(0..=128)
.map(move |r| self.search_exact(r, feature, lookup))
.flatten()
}
pub fn search_exact<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
let sw = indices[0] as i32;
let start = max(0, sw - radius as i32) as u128;
let end = min(128, sw + radius as i32) as u128;
self.bucket_scan_exact(
radius,
feature,
0,
lookup,
start..=end,
Self::exact2,
move |tc| Bits64(tc).hwd(Bits64(indices[1])).sum_weight() as u32 == radius,
)
}
fn exact2<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_exact(
radius,
feature,
bucket,
lookup,
search_exact2(Bits128(indices[0]), Bits64(indices[1]), Bits128(tp), radius)
.map(|tc| tc.0),
Self::exact4,
move |tc| Bits32(tc).hwd(Bits32(indices[2])).sum_weight() as u32 == radius,
)
}
fn exact4<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_exact(
radius,
feature,
bucket,
lookup,
search_exact4(Bits64(indices[1]), Bits32(indices[2]), Bits64(tp), radius)
.map(|tc| tc.0),
Self::exact8,
move |tc| Bits16(tc).hwd(Bits16(indices[3])).sum_weight() as u32 == radius,
)
}
fn exact8<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_exact(
radius,
feature,
bucket,
lookup,
search_exact8(Bits32(indices[2]), Bits16(indices[3]), Bits32(tp), radius)
.map(|tc| tc.0),
Self::exact16,
move |tc| Bits8(tc).hwd(Bits8(indices[4])).sum_weight() as u32 == radius,
)
}
fn exact16<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_exact(
radius,
feature,
bucket,
lookup,
search_exact16(Bits16(indices[3]), Bits8(indices[4]), Bits16(tp), radius)
.map(|tc| tc.0),
Self::exact32,
move |tc| Bits4(tc).hwd(Bits4(indices[5])).sum_weight() as u32 == radius,
)
}
fn exact32<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_exact(
radius,
feature,
bucket,
lookup,
search_exact32(Bits8(indices[4]), Bits4(indices[5]), Bits8(tp), radius).map(|tc| tc.0),
Self::exact64,
move |tc| Bits2(tc).hwd(Bits2(indices[6])).sum_weight() as u32 == radius,
)
}
fn exact64<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_exact(
radius,
feature,
bucket,
lookup,
search_exact64(Bits4(indices[5]), Bits2(indices[6]), Bits4(tp), radius).map(|tc| tc.0),
Self::exact128,
move |tc| Bits1(tc).hwd(Bits1(indices[7])).sum_weight() as u32 == radius,
)
}
fn exact128<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_exact(
radius,
feature,
bucket,
lookup,
search_exact128(Bits2(indices[6]), Bits1(indices[7]), Bits2(tp), radius).map(|tc| tc.0),
|_, _, _, bucket, _, _| -> Box<dyn Iterator<Item = u32> + 'a> {
panic!(
"hwt::Hwt::neighbors128(): it is an error to find an internal node this far down in the tree (bucket: {})", bucket,
)
},
move |tc| panic!("hwt::Hwt::neighbors128(): it is an error to find an internal node this far down in the tree (tc: {})", tc)
)
}
#[allow(clippy::too_many_arguments)]
fn bucket_scan_exact<'a, F: 'a, I: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
lookup: &'a F,
indices: impl Iterator<Item = u128> + 'a,
subtable: impl Fn(&'a Self, u32, u128, usize, u128, &'a F) -> I + 'a,
filter: impl Fn(u128) -> bool + 'a,
) -> Box<dyn Iterator<Item = u32> + 'a>
where
F: Fn(u32) -> u128,
I: Iterator<Item = u32>,
{
match &self.internals[bucket] {
Internal::Vec(v) => Box::new(
v.iter()
.cloned()
.filter(move |&leaf| (lookup(leaf) ^ feature).count_ones() == radius),
) as Box<dyn Iterator<Item = u32> + 'a>,
Internal::Map(m) => {
if m.len() < TAU {
Box::new(
m.iter()
.filter(move |&(&key, _)| filter(key))
.map(|(_, &node)| node),
) as Box<dyn Iterator<Item = u32> + 'a>
} else {
Box::new(indices.flat_map(move |tc| {
if let Some(&occupied_node) = m.get(&tc) {
let subbucket = occupied_node as usize;
either::Right(subtable(self, radius, feature, subbucket, tc, lookup))
} else {
either::Left(None.into_iter())
}
})) as Box<dyn Iterator<Item = u32> + 'a>
}
}
}
}
pub fn search_radius<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
let sw = indices[0] as i32;
let start = max(0, sw - radius as i32) as u128;
let end = min(128, sw + radius as i32) as u128;
self.bucket_scan_radius(
radius,
feature,
0,
lookup,
start..=end,
Self::radius2,
move |tc| Bits64(tc).hwd(Bits64(indices[1])).sum_weight() as u32 <= radius,
)
}
fn radius2<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_radius(
radius,
feature,
bucket,
lookup,
search_radius2(Bits128(indices[0]), Bits64(indices[1]), Bits128(tp), radius)
.map(|(tc, _sod)| tc.0),
Self::radius4,
move |tc| Bits32(tc).hwd(Bits32(indices[2])).sum_weight() as u32 <= radius,
)
}
fn radius4<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_radius(
radius,
feature,
bucket,
lookup,
search_radius4(Bits64(indices[1]), Bits32(indices[2]), Bits64(tp), radius)
.map(|(tc, _sod)| tc.0),
Self::radius8,
move |tc| Bits16(tc).hwd(Bits16(indices[3])).sum_weight() as u32 <= radius,
)
}
fn radius8<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_radius(
radius,
feature,
bucket,
lookup,
search_radius8(Bits32(indices[2]), Bits16(indices[3]), Bits32(tp), radius)
.map(|(tc, _sod)| tc.0),
Self::radius16,
move |tc| Bits8(tc).hwd(Bits8(indices[4])).sum_weight() as u32 <= radius,
)
}
fn radius16<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_radius(
radius,
feature,
bucket,
lookup,
search_radius16(Bits16(indices[3]), Bits8(indices[4]), Bits16(tp), radius)
.map(|(tc, _sod)| tc.0),
Self::radius32,
move |tc| Bits4(tc).hwd(Bits4(indices[5])).sum_weight() as u32 <= radius,
)
}
fn radius32<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_radius(
radius,
feature,
bucket,
lookup,
search_radius32(Bits8(indices[4]), Bits4(indices[5]), Bits8(tp), radius)
.map(|(tc, _sod)| tc.0),
Self::radius64,
move |tc| Bits2(tc).hwd(Bits2(indices[6])).sum_weight() as u32 <= radius,
)
}
fn radius64<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_radius(
radius,
feature,
bucket,
lookup,
search_radius64(Bits4(indices[5]), Bits2(indices[6]), Bits4(tp), radius)
.map(|(tc, _sod)| tc.0),
Self::radius128,
move |tc| Bits1(tc).hwd(Bits1(indices[7])).sum_weight() as u32 <= radius,
)
}
fn radius128<'a, F: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
tp: u128,
lookup: &'a F,
) -> impl Iterator<Item = u32> + 'a
where
F: Fn(u32) -> u128,
{
let indices = indices128(feature);
self.bucket_scan_radius(
radius,
feature,
bucket,
lookup,
search_radius128(Bits2(indices[6]), Bits1(indices[7]), Bits2(tp), radius).map(|(tc, _sod)| tc.0),
|_, _, _, bucket, _, _| -> Box<dyn Iterator<Item = u32> + 'a> {
panic!(
"hwt::Hwt::neighbors128(): it is an error to find an internal node this far down in the tree (bucket: {})", bucket,
)
},
move |tc| panic!("hwt::Hwt::neighbors128(): it is an error to find an internal node this far down in the tree (tc: {})", tc)
)
}
#[allow(clippy::too_many_arguments)]
fn bucket_scan_radius<'a, F: 'a, I: 'a>(
&'a self,
radius: u32,
feature: u128,
bucket: usize,
lookup: &'a F,
indices: impl Iterator<Item = u128> + 'a,
subtable: impl Fn(&'a Self, u32, u128, usize, u128, &'a F) -> I + 'a,
filter: impl Fn(u128) -> bool + 'a,
) -> Box<dyn Iterator<Item = u32> + 'a>
where
F: Fn(u32) -> u128,
I: Iterator<Item = u32>,
{
match &self.internals[bucket] {
Internal::Vec(v) => Box::new(
v.iter()
.cloned()
.filter(move |&leaf| (lookup(leaf) ^ feature).count_ones() <= radius),
) as Box<dyn Iterator<Item = u32> + 'a>,
Internal::Map(m) => {
if m.len() < TAU {
Box::new(
m.iter()
.filter(move |&(&key, _)| filter(key))
.map(|(_, &node)| node),
) as Box<dyn Iterator<Item = u32> + 'a>
} else {
Box::new(indices.flat_map(move |tc| {
if let Some(&occupied_node) = m.get(&tc) {
let subbucket = occupied_node as usize;
either::Right(subtable(self, radius, feature, subbucket, tc, lookup))
} else {
either::Left(None.into_iter())
}
})) as Box<dyn Iterator<Item = u32> + 'a>
}
}
}
}
}
impl Default for Hwt {
fn default() -> Self {
Self {
internals: vec![Internal::default()],
count: 0,
}
}
}