use itertools::Itertools;
use swar::*;
pub fn search_radius128(
bits: u32,
sp: Bits2<u128>,
sc: Bits1<u128>,
tp: Bits2<u128>,
radius: u32,
) -> impl Iterator<Item = (Bits1<u128>, u32)> {
let (lsp, rsp) = sp.halve();
let (lsc, rsc) = sc.halve();
let (ltp, rtp) = tp.halve();
Box::new(
search_radius64(bits, lsp, lsc, ltp, radius).flat_map(move |(ltc, lsod)| {
search_radius64(bits, rsp, rsc, rtp, radius - lsod)
.map(move |(rtc, rsod)| (Bits1::union(ltc, rtc), lsod + rsod))
}),
)
}
pub fn search_radius64(
bits: u32,
sp: Bits4<u128>,
sc: Bits2<u128>,
tp: Bits4<u128>,
radius: u32,
) -> impl Iterator<Item = (Bits2<u128>, u32)> {
let (lsp, rsp) = sp.halve();
let (lsc, rsc) = sc.halve();
let (ltp, rtp) = tp.halve();
Box::new(
search_radius32(bits, lsp, lsc, ltp, radius).flat_map(move |(ltc, lsod)| {
search_radius32(bits, rsp, rsc, rtp, radius - lsod)
.map(move |(rtc, rsod)| (Bits2::union(ltc, rtc), lsod + rsod))
}),
)
}
pub fn search_radius32(
bits: u32,
sp: Bits8<u128>,
sc: Bits4<u128>,
tp: Bits8<u128>,
radius: u32,
) -> impl Iterator<Item = (Bits4<u128>, u32)> {
let (lsp, rsp) = sp.halve();
let (lsc, rsc) = sc.halve();
let (ltp, rtp) = tp.halve();
Box::new(
search_radius16(bits, lsp, lsc, ltp, radius).flat_map(move |(ltc, lsod)| {
search_radius16(bits, rsp, rsc, rtp, radius - lsod)
.map(move |(rtc, rsod)| (Bits4::union(ltc, rtc), lsod + rsod))
}),
)
}
pub fn search_radius16(
bits: u32,
sp: Bits16<u128>,
sc: Bits8<u128>,
tp: Bits16<u128>,
radius: u32,
) -> impl Iterator<Item = (Bits8<u128>, u32)> {
let (lsp, rsp) = sp.halve();
let (lsc, rsc) = sc.halve();
let (ltp, rtp) = tp.halve();
search_radius8(bits, lsp, lsc, ltp, radius).flat_map(move |(ltc, lsod)| {
search_radius8(bits, rsp, rsc, rtp, radius - lsod)
.map(move |(rtc, rsod)| (Bits8::union(ltc, rtc), lsod + rsod))
})
}
pub fn search_radius8(
bits: u32,
sp: Bits32<u128>,
sc: Bits16<u128>,
tp: Bits32<u128>,
radius: u32,
) -> impl Iterator<Item = (Bits16<u128>, u32)> {
let (lsp, rsp) = sp.halve();
let (lsc, rsc) = sc.halve();
let (ltp, rtp) = tp.halve();
Box::new(
search_radius4(bits, lsp, lsc, ltp, radius).flat_map(move |(ltc, lsod)| {
search_radius4(bits, rsp, rsc, rtp, radius - lsod)
.map(move |(rtc, rsod)| (Bits16::union(ltc, rtc), lsod + rsod))
}),
)
}
pub fn search_radius4(
bits: u32,
sp: Bits64<u128>,
sc: Bits32<u128>,
tp: Bits64<u128>,
radius: u32,
) -> impl Iterator<Item = (Bits32<u128>, u32)> {
let (lsp, rsp) = sp.halve();
let (lsc, rsc) = sc.halve();
let (ltp, rtp) = tp.halve();
search_radius2(bits, lsp, lsc, ltp, radius).flat_map(move |(ltc, lsod)| {
search_radius2(bits, rsp, rsc, rtp, radius - lsod)
.map(move |(rtc, rsod)| (Bits32::union(ltc, rtc), lsod + rsod))
})
}
pub fn search_radius2(
bits: u32,
sp: Bits128<u128>,
sc: Bits64<u128>,
tp: Bits128<u128>,
radius: u32,
) -> impl Iterator<Item = (Bits64<u128>, u32)> {
let sw = sp.count_ones();
let sl = (sc >> 64).count_ones();
let tw = tp.count_ones();
search_radius(bits, sl, sw, tw, radius)
.map(|([tl, tr], sod)| (Bits64(((1 << tl) - 1) << 64 | ((1 << tr) - 1)), sod))
}
pub fn search_radius(
bits: u32,
sl: u32,
sw: u32,
tw: u32,
radius: u32,
) -> impl Iterator<Item = ([u32; 2], u32)> {
let max = std::cmp::min(tw, bits);
let min = tw - max;
let sl = sl as i32;
let sw = sw as i32;
let tw = tw as i32;
let radius = radius as i32;
let c = 2 * sl - sw + tw;
let filter = move |&tl: &i32| tl >= min as i32 && tl <= max as i32;
let map = move |tl: i32| {
(
[tl as u32, (tw - tl) as u32],
((tl - sl).abs() + ((tw - tl) - (sw - sl)).abs()) as u32,
)
};
let bottom_distance = (tw - sw).abs();
if bottom_distance <= radius {
let start = (-radius + c + 1) / 2;
let inflection1 = sl;
let inflection2 = sl - sw + tw;
let min_inflection = std::cmp::min(inflection1, inflection2);
let max_inflection = std::cmp::max(inflection1, inflection2);
let end = (radius + c) / 2;
let down = start..min_inflection;
let flat = min_inflection..=max_inflection;
let up = max_inflection + 1..=end;
flat.chain(down.interleave(up)).filter(filter).map(map)
} else {
let down = 0..0;
let flat = 0..=-1;
let up = 0..=-1;
flat.chain(down.interleave(up)).filter(filter).map(map)
}
}