use itertools::Itertools;
pub fn compute_bucket_len(tws: [u32; 64]) -> usize {
let total_diffs: u32 = tws.iter().map(|&tw| (tw & 1) ^ (tw >> 1)).sum();
assert!(total_diffs < 32);
1 << total_diffs
}
pub fn search128(feature: u128, tws: [u32; 64], radius: u32) -> impl Iterator<Item = usize> {
const NPAIRS: u32 = 64;
let mask = (1u128 << NPAIRS) - 1;
let substrings = [feature & mask, feature >> NPAIRS];
let low_indices = search64(
1,
substrings[0],
[
tws[0], tws[1], tws[2], tws[3], tws[4], tws[5], tws[6], tws[7], tws[8], tws[9],
tws[10], tws[11], tws[12], tws[13], tws[14], tws[15], tws[16], tws[17], tws[18],
tws[19], tws[20], tws[21], tws[22], tws[23], tws[24], tws[25], tws[26], tws[27],
tws[28], tws[29], tws[30], tws[31],
],
radius,
);
low_indices.flat_map(move |(low_index, low_sod, low_bucket_size, _)| {
let high_indices = search64(
1,
substrings[1],
[
tws[32], tws[33], tws[34], tws[35], tws[36], tws[37], tws[38], tws[39], tws[40],
tws[41], tws[42], tws[43], tws[44], tws[45], tws[46], tws[47], tws[48], tws[49],
tws[50], tws[51], tws[52], tws[53], tws[54], tws[55], tws[56], tws[57], tws[58],
tws[59], tws[60], tws[61], tws[62], tws[63],
],
radius - low_sod,
);
high_indices.map(move |(high_index, _, _, _)| high_index * low_bucket_size + low_index)
})
}
pub fn search64(
bits: u32,
feature: u128,
tws: [u32; 32],
radius: u32,
) -> impl Iterator<Item = (usize, u32, usize, [u32; 64])> {
const NPAIRS: u32 = 32;
let mask = (1u128 << (bits * NPAIRS)) - 1;
let substrings = [feature & mask, feature >> (NPAIRS * bits)];
let low_indices = search32(
bits,
substrings[0],
[
tws[0], tws[1], tws[2], tws[3], tws[4], tws[5], tws[6], tws[7], tws[8], tws[9],
tws[10], tws[11], tws[12], tws[13], tws[14], tws[15],
],
radius,
);
low_indices.flat_map(move |(low_index, low_sod, low_bucket_size, low_tws)| {
let high_indices = search32(
bits,
substrings[1],
[
tws[16], tws[17], tws[18], tws[19], tws[20], tws[21], tws[22], tws[23], tws[24],
tws[25], tws[26], tws[27], tws[28], tws[29], tws[30], tws[31],
],
radius - low_sod,
);
high_indices.map(move |(high_index, high_sod, high_bucket_size, high_tws)| {
(
high_index * low_bucket_size + low_index,
low_sod + high_sod,
low_bucket_size * high_bucket_size,
[
low_tws[0],
low_tws[1],
low_tws[2],
low_tws[3],
low_tws[4],
low_tws[5],
low_tws[6],
low_tws[7],
low_tws[8],
low_tws[9],
low_tws[10],
low_tws[11],
low_tws[12],
low_tws[13],
low_tws[14],
low_tws[15],
low_tws[16],
low_tws[17],
low_tws[18],
low_tws[19],
low_tws[20],
low_tws[21],
low_tws[22],
low_tws[23],
low_tws[24],
low_tws[25],
low_tws[26],
low_tws[27],
low_tws[28],
low_tws[29],
low_tws[30],
low_tws[31],
high_tws[0],
high_tws[1],
high_tws[2],
high_tws[3],
high_tws[4],
high_tws[5],
high_tws[6],
high_tws[7],
high_tws[8],
high_tws[9],
high_tws[10],
high_tws[11],
high_tws[12],
high_tws[13],
high_tws[14],
high_tws[15],
high_tws[16],
high_tws[17],
high_tws[18],
high_tws[19],
high_tws[20],
high_tws[21],
high_tws[22],
high_tws[23],
high_tws[24],
high_tws[25],
high_tws[26],
high_tws[27],
high_tws[28],
high_tws[29],
high_tws[30],
high_tws[31],
],
)
})
})
}
pub fn search32(
bits: u32,
feature: u128,
tws: [u32; 16],
radius: u32,
) -> impl Iterator<Item = (usize, u32, usize, [u32; 32])> {
const NPAIRS: u32 = 16;
let mask = (1u128 << (bits * NPAIRS)) - 1;
let substrings = [feature & mask, feature >> (NPAIRS * bits)];
let low_indices = search16(
bits,
substrings[0],
[
tws[0], tws[1], tws[2], tws[3], tws[4], tws[5], tws[6], tws[7],
],
radius,
);
low_indices.flat_map(move |(low_index, low_sod, low_bucket_size, low_tws)| {
let high_indices = search16(
bits,
substrings[1],
[
tws[8], tws[9], tws[10], tws[11], tws[12], tws[13], tws[14], tws[15],
],
radius - low_sod,
);
high_indices.map(move |(high_index, high_sod, high_bucket_size, high_tws)| {
(
high_index * low_bucket_size + low_index,
low_sod + high_sod,
low_bucket_size * high_bucket_size,
[
low_tws[0],
low_tws[1],
low_tws[2],
low_tws[3],
low_tws[4],
low_tws[5],
low_tws[6],
low_tws[7],
low_tws[8],
low_tws[9],
low_tws[10],
low_tws[11],
low_tws[12],
low_tws[13],
low_tws[14],
low_tws[15],
high_tws[0],
high_tws[1],
high_tws[2],
high_tws[3],
high_tws[4],
high_tws[5],
high_tws[6],
high_tws[7],
high_tws[8],
high_tws[9],
high_tws[10],
high_tws[11],
high_tws[12],
high_tws[13],
high_tws[14],
high_tws[15],
],
)
})
})
}
pub fn search16(
bits: u32,
feature: u128,
tws: [u32; 8],
radius: u32,
) -> impl Iterator<Item = (usize, u32, usize, [u32; 16])> {
const NPAIRS: u32 = 8;
let mask = (1u128 << (bits * NPAIRS)) - 1;
let substrings = [feature & mask, feature >> (NPAIRS * bits)];
let low_indices = search8(
bits,
substrings[0],
[tws[0], tws[1], tws[2], tws[3]],
radius,
);
low_indices.flat_map(move |(low_index, low_sod, low_bucket_size, low_tws)| {
let high_indices = search8(
bits,
substrings[1],
[tws[4], tws[5], tws[6], tws[7]],
radius - low_sod,
);
high_indices.map(move |(high_index, high_sod, high_bucket_size, high_tws)| {
(
high_index * low_bucket_size + low_index,
low_sod + high_sod,
low_bucket_size * high_bucket_size,
[
low_tws[0],
low_tws[1],
low_tws[2],
low_tws[3],
low_tws[4],
low_tws[5],
low_tws[6],
low_tws[7],
high_tws[0],
high_tws[1],
high_tws[2],
high_tws[3],
high_tws[4],
high_tws[5],
high_tws[6],
high_tws[7],
],
)
})
})
}
pub fn search8(
bits: u32,
feature: u128,
tws: [u32; 4],
radius: u32,
) -> impl Iterator<Item = (usize, u32, usize, [u32; 8])> {
const NPAIRS: u32 = 4;
let mask = (1u128 << (bits * NPAIRS)) - 1;
let substrings = [feature & mask, feature >> (NPAIRS * bits)];
let low_indices = search4(bits, substrings[0], [tws[0], tws[1]], radius);
low_indices.flat_map(move |(low_index, low_sod, low_bucket_size, low_tws)| {
let high_indices = search4(bits, substrings[1], [tws[2], tws[3]], radius - low_sod);
high_indices.map(move |(high_index, high_sod, high_bucket_size, high_tws)| {
(
high_index * low_bucket_size + low_index,
low_sod + high_sod,
low_bucket_size * high_bucket_size,
[
low_tws[0],
low_tws[1],
low_tws[2],
low_tws[3],
high_tws[0],
high_tws[1],
high_tws[2],
high_tws[3],
],
)
})
})
}
pub fn search4(
bits: u32,
feature: u128,
tws: [u32; 2],
radius: u32,
) -> impl Iterator<Item = (usize, u32, usize, [u32; 4])> {
const NPAIRS: u32 = 2;
let mask = (1u128 << (bits * NPAIRS)) - 1;
let substrings = [feature & mask, feature >> (NPAIRS * bits)];
let low_indices = search2(bits, substrings[0], tws[0], radius);
low_indices.flat_map(move |(low_index, low_sod, low_bucket_size, low_tws)| {
let high_indices = search2(bits, substrings[1], tws[1], radius - low_sod);
high_indices.map(move |(high_index, high_sod, high_bucket_size, high_tws)| {
(
high_index * low_bucket_size + low_index,
low_sod + high_sod,
low_bucket_size * high_bucket_size,
[low_tws[0], low_tws[1], high_tws[0], high_tws[1]],
)
})
})
}
pub fn search2(
bits: u32,
feature: u128,
tw: u32,
radius: u32,
) -> impl Iterator<Item = (usize, u32, usize, [u32; 2])> {
let sw = feature.count_ones();
let sl = ((feature >> bits) & ((1u128 << bits) - 1)).count_ones();
let max = std::cmp::min(tw, bits);
let min = tw - max;
let (indices, bucket_size) = search(bits, sl, sw, tw, radius);
indices.map(move |(index, sod)| {
(
index as usize,
sod,
bucket_size as usize,
[tw - (index + min), index + min],
)
})
}
pub fn search(
bits: u32,
sl: u32,
sw: u32,
tw: u32,
radius: u32,
) -> (impl Iterator<Item = (u32, u32)>, u32) {
let max = std::cmp::min(tw, bits);
let min = tw - max;
let sl = sl as i32;
let sw = sw as i32;
let tw = tw as i32;
let radius = radius as i32;
let c = 2 * sl - sw + tw;
let filter = move |&tl: &i32| tl >= min as i32 && tl <= max as i32;
let map = move |tl: i32| {
(
tl as u32 - min,
((tl - sl).abs() + ((tw - tl) - (sw - sl)).abs()) as u32,
)
};
if ((radius + c) / 2 - sl).abs() + (tw - (radius + c) / 2 - sw + sl).abs() <= radius {
let start = (-radius + c + 1) / 2;
let inflection1 = sl;
let inflection2 = sl - sw + tw;
let min_inflection = std::cmp::min(inflection1, inflection2);
let max_inflection = std::cmp::max(inflection1, inflection2);
let end = (radius + c) / 2;
let down = start..min_inflection;
let flat = min_inflection..=max_inflection;
let up = max_inflection + 1..=end;
(
flat.chain(down.interleave(up)).filter(filter).map(map),
max - min + 1,
)
} else {
let down = 0..0;
let flat = 0..=-1;
let up = 0..=-1;
(
flat.chain(down.interleave(up)).filter(filter).map(map),
max - min + 1,
)
}
}
#[cfg(test)]
mod test {
use super::*;
fn search_sort(bits: u32, sl: u32, sw: u32, tw: u32, radius: u32) -> (Vec<(u32, u32)>, u32) {
let indices_sorter = |a: &(u32, u32), b: &(u32, u32)| a.0.cmp(&b.0);
let (indices, size) = search(bits, sl, sw, tw, radius);
let mut indices = indices.collect::<Vec<_>>();
indices.sort_unstable_by(indices_sorter);
(indices, size)
}
#[test]
fn test_search_base() {
let (indices, size) = search_sort(64, 3, 5, 4, 1);
assert_eq!(&indices, &[(2, 1), (3, 1)]);
assert_eq!(size, 5);
let (indices, size) = search_sort(64, 58, 66, 40, 1);
assert_eq!(&indices, &[]);
assert_eq!(size, 41);
let (indices, size) = search_sort(64, 58, 66, 66, 1);
assert_eq!(&indices, &[(56, 0)]);
assert_eq!(size, 63);
let (indices, size) = search_sort(64, 58, 66, 66, 5);
assert_eq!(&indices, &[(54, 4), (55, 2), (56, 0), (57, 2), (58, 4)]);
assert_eq!(size, 63);
let (indices, size) = search_sort(64, 58, 72, 68, 10);
assert_eq!(
&indices,
&[
(47, 10),
(48, 8),
(49, 6),
(50, 4),
(51, 4),
(52, 4),
(53, 4),
(54, 4),
(55, 6),
(56, 8),
(57, 10)
]
);
assert_eq!(size, 61);
let (indices, size) = search_sort(64, 58, 72, 75, 10);
assert_eq!(
&indices,
&[
(44, 9),
(45, 7),
(46, 5),
(47, 3),
(48, 3),
(49, 3),
(50, 3),
(51, 5),
(52, 7),
(53, 9)
]
);
assert_eq!(size, 54);
let (indices, size) = search_sort(64, 58, 72, 76, 10);
assert_eq!(
&indices,
&[
(43, 10),
(44, 8),
(45, 6),
(46, 4),
(47, 4),
(48, 4),
(49, 4),
(50, 4),
(51, 6),
(52, 8)
]
);
assert_eq!(size, 53);
let (indices, size) = search_sort(64, 58, 72, 82, 10);
assert_eq!(
&indices,
&[
(40, 10),
(41, 10),
(42, 10),
(43, 10),
(44, 10),
(45, 10),
(46, 10)
]
);
assert_eq!(size, 47);
let (indices, size) = search_sort(64, 58, 72, 83, 10);
assert_eq!(&indices, &[]);
assert_eq!(size, 46);
let (indices, size) = search_sort(64, 0, 2, 2, 0);
assert_eq!(&indices, &[(0, 0)]);
assert_eq!(size, 3);
let (indices, size) = search_sort(1, 1, 2, 2, 1);
assert_eq!(&indices, &[(0, 0)]);
assert_eq!(size, 1);
}
fn search_sort2(
bits: u32,
feature: u128,
tw: u32,
radius: u32,
) -> Vec<(usize, u32, usize, [u32; 2])> {
let indices_sorter =
|a: &(usize, u32, usize, [u32; 2]), b: &(usize, u32, usize, [u32; 2])| a.0.cmp(&b.0);
let mut indices = search2(bits, feature, tw, radius).collect::<Vec<_>>();
indices.sort_unstable_by(indices_sorter);
indices
}
#[test]
fn test_search2() {
let indices = search_sort2(2, 0b101, 2, 1);
assert_eq!(&indices, &[(1, 0, 3, [1, 1])]);
let indices = search_sort2(2, 0b101, 2, 2);
assert_eq!(
&indices,
&[(0, 2, 3, [2, 0]), (1, 0, 3, [1, 1]), (2, 2, 3, [0, 2])]
);
let indices = search_sort2(1, 0b0, 1, 1);
assert_eq!(&indices, &[(0, 1, 2, [1, 0]), (1, 1, 2, [0, 1])]);
let indices = search_sort2(1, 0b11, 1, 1);
assert_eq!(&indices, &[(0, 1, 2, [1, 0]), (1, 1, 2, [0, 1])]);
let indices = search_sort2(1, 0b11, 2, 1);
assert_eq!(&indices, &[(0, 0, 1, [1, 1])]);
}
fn search_sort128(feature: u128, tws: [u32; 64], radius: u32) -> Vec<usize> {
let mut indices = search128(feature, tws, radius).collect::<Vec<_>>();
indices.sort_unstable();
indices
}
#[test]
fn test_search128() {
let mut tws = [0; 64];
tws[0] = 1;
tws[1] = 1;
let indices = search_sort128(0b101, tws, 1);
assert_eq!(&indices, &[0]);
}
}