use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use crate::cyclotomic::IsRing;
use crate::geom::snake::Snake;
use crate::rat_enum::canonical::CanonicalOps;
use crate::rat_enum::dfs::{SeedGather, collect_seeds, hashset_recorder, rat_enum_step};
use crate::rat_enum::prune::Prunes;
use crate::rat_enum::stats::DfsStats;
pub fn splitting_depth(n_threads: usize, branching: usize) -> usize {
if n_threads <= 1 || branching <= 1 {
return 0;
}
let target = (10 * n_threads) as f64;
let depth = (target.ln() / (branching as f64).ln()).ceil() as usize;
depth.max(1)
}
#[allow(clippy::too_many_arguments)]
pub fn rat_enum_parallel<ZZ: IsRing>(
max_steps: usize,
step: i8,
n_threads: usize,
ops: CanonicalOps,
label: &str,
prefix: &str,
paranoid: bool,
prunes: &Prunes,
) -> (Vec<Vec<i8>>, DfsStats) {
let hm1 = (ZZ::hturn() as usize).saturating_sub(1);
let branching = 2 * (hm1 / step.max(1) as usize) + 1;
let split_depth = splitting_depth(n_threads, branching);
println!("-------- {label} started --------");
if paranoid {
println!("paranoid: per-step fresh-snake cross-check enabled");
}
println!("parallel: n_threads={n_threads} branching={branching} split_depth={split_depth}");
let mut closed_main: HashSet<Vec<i8>> = HashSet::new();
let mut seeds: Vec<Vec<i8>> = Vec::new();
let mut seed_stats = DfsStats::default();
{
let mut snake: Snake<ZZ> = Snake::new();
let mut record_closed = hashset_recorder(&mut closed_main);
let mut gather = SeedGather {
seeds: &mut seeds,
record_closed: &mut record_closed,
stats: &mut seed_stats,
};
collect_seeds::<ZZ>(
&mut snake,
max_steps,
step,
split_depth,
&mut gather,
ops,
paranoid,
prunes,
);
}
println!("parallel: {} seed states collected", seeds.len());
let (merged, worker_stats) = parallel_drain_seeds::<ZZ>(
&seeds,
closed_main,
seed_stats,
max_steps,
step,
n_threads,
ops,
paranoid,
prunes,
);
println!(
"-------- {label} completed --------\n{prefix}{} rats found",
merged.len()
);
let mut result: Vec<Vec<i8>> = merged.into_iter().collect();
result.sort_by_key(|x| x.len());
(result, worker_stats)
}
#[allow(clippy::too_many_arguments)]
pub fn parallel_drain_seeds<ZZ: IsRing>(
seeds: &[Vec<i8>],
closed_main: HashSet<Vec<i8>>,
seed_stats: DfsStats,
max_steps: usize,
step: i8,
n_threads: usize,
ops: CanonicalOps,
paranoid: bool,
prunes: &Prunes,
) -> (HashSet<Vec<i8>>, DfsStats) {
let next_idx = AtomicUsize::new(0);
let next_ref = &next_idx;
thread::scope(|s| {
let mut handles = Vec::with_capacity(n_threads);
for _ in 0..n_threads {
handles.push(s.spawn(move || -> (HashSet<Vec<i8>>, DfsStats) {
let mut local: HashSet<Vec<i8>> = HashSet::new();
let mut stats = DfsStats::default();
loop {
let i = next_ref.fetch_add(1, Ordering::Relaxed);
if i >= seeds.len() {
break;
}
let mut snake: Snake<ZZ> = Snake::from_slice_unsafe(&seeds[i]);
let mut record = hashset_recorder(&mut local);
rat_enum_step::<ZZ>(
&mut snake,
max_steps,
step,
&mut record,
&mut stats,
ops,
paranoid,
prunes,
);
}
(local, stats)
}));
}
let mut merged = closed_main;
let mut total_stats = seed_stats;
for h in handles {
let (local, wstats) = h.join().expect("worker panic");
merged.extend(local);
total_stats.merge(&wstats);
}
(merged, total_stats)
})
}