pub const MODULAR_CELL_BUDGET: u64 = 1 << 16;
pub struct ModularPrune {
pub phi: usize,
pub moduli: Vec<i64>,
pub tables: Vec<Vec<rustc_hash::FxHashSet<u64>>>,
}
impl ModularPrune {
pub fn build(
units: &[Vec<i64>],
phi: usize,
max_steps: usize,
moduli_override: Option<&[i64]>,
) -> Self {
let default_candidates: [i64; 4] = [2, 3, 4, 6];
let candidates: &[i64] = moduli_override.unwrap_or(&default_candidates);
let mut moduli: Vec<i64> = Vec::new();
for &m in candidates {
if m < 2 {
continue;
}
let cells = (m as u64).checked_pow(phi as u32).unwrap_or(u64::MAX);
if cells <= MODULAR_CELL_BUDGET {
moduli.push(m);
}
}
let mut tables: Vec<Vec<rustc_hash::FxHashSet<u64>>> = Vec::with_capacity(moduli.len());
for &m in &moduli {
tables.push(Self::build_one_modulus(units, phi, m, max_steps));
}
ModularPrune {
phi,
moduli,
tables,
}
}
fn build_one_modulus(
units: &[Vec<i64>],
phi: usize,
m: i64,
max_steps: usize,
) -> Vec<rustc_hash::FxHashSet<u64>> {
let total = (m as u64).pow(phi as u32);
let units_mod: Vec<Vec<i64>> = units
.iter()
.map(|u| u.iter().map(|&c| c.rem_euclid(m)).collect())
.collect();
let mut layers: Vec<rustc_hash::FxHashSet<u64>> = Vec::with_capacity(max_steps + 1);
let mut cumulative: rustc_hash::FxHashSet<u64> = rustc_hash::FxHashSet::default();
cumulative.insert(0u64);
layers.push(cumulative.clone());
let mut current_vec: rustc_hash::FxHashSet<Vec<i64>> = rustc_hash::FxHashSet::default();
current_vec.insert(vec![0i64; phi]);
let mut saturated = cumulative.len() as u64 == total;
for _r in 1..=max_steps {
if saturated {
layers.push(layers.last().unwrap().clone());
continue;
}
let mut next: rustc_hash::FxHashSet<Vec<i64>> = rustc_hash::FxHashSet::default();
next.reserve(current_vec.len() * units_mod.len());
for v in ¤t_vec {
for u in &units_mod {
let sum: Vec<i64> = v
.iter()
.zip(u.iter())
.map(|(a, b)| (a + b).rem_euclid(m))
.collect();
next.insert(sum);
}
}
for v in &next {
cumulative.insert(pack_coeffs(v, m));
}
layers.push(cumulative.clone());
current_vec = next;
if cumulative.len() as u64 == total {
saturated = true;
}
}
layers
}
#[inline]
pub fn allows_closure(&self, disp: &[i64], remaining: usize) -> bool {
for (i, &m) in self.moduli.iter().enumerate() {
let key = pack_coeffs(disp, m);
let table = match self.tables[i].get(remaining) {
Some(t) => t,
None => continue, };
if !table.contains(&key) {
return false;
}
}
true
}
pub fn cell_counts(&self) -> Vec<u64> {
self.moduli
.iter()
.map(|&m| (m as u64).pow(self.phi as u32))
.collect()
}
}
#[inline]
fn pack_coeffs(coeffs: &[i64], m: i64) -> u64 {
let mut key = 0u64;
let mut mult = 1u64;
let m_u = m as u64;
for &c in coeffs {
let v = c.rem_euclid(m) as u64;
key += v * mult;
mult = mult.wrapping_mul(m_u);
}
key
}