use std::{
sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc,
},
time::Duration,
};
use bit_set::BitSet;
use clap::ValueEnum;
use rayon::iter::{ParallelBridge, ParallelIterator};
use tokio::{runtime::Runtime, sync::oneshot, time::timeout as tktimeout};
use crate::{
bounds::{state_bounds, Bound},
canonize::CanonizeMode,
kernels::KernelMode,
matches::Matches,
memoize::{Cache, MemoizeMode},
molecule::Molecule,
state::State,
utils::connected_components_under_edges,
};
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
pub enum ParallelMode {
None,
DepthOne,
Always,
}
pub fn depth(mol: &Molecule) -> u32 {
let mut ix = u32::MAX;
for (left, right) in mol.partitions().unwrap() {
let l = if left.is_basic_unit() {
0
} else {
depth(&left)
};
let r = if right.is_basic_unit() {
0
} else {
depth(&right)
};
ix = ix.min(l.max(r) + 1)
}
ix
}
fn fragments(mol: &Molecule, state: &[BitSet], h1: &BitSet, h2: &BitSet) -> Option<Vec<BitSet>> {
let f1 = state.iter().enumerate().find(|(_, c)| h1.is_subset(c));
let f2 = state.iter().enumerate().find(|(_, c)| h2.is_subset(c));
let (Some((i1, f1)), Some((i2, f2))) = (f1, f2) else {
return None;
};
let mut fragments = state.to_owned();
if i1 == i2 {
let mut union = h1.clone();
union.union_with(h2);
let mut difference = f1.clone();
difference.difference_with(&union);
let c = connected_components_under_edges(mol.graph(), &difference);
fragments.extend(c);
fragments.swap_remove(i1);
} else {
let mut diff1 = f1.clone();
diff1.difference_with(h1);
let c1 = connected_components_under_edges(mol.graph(), &diff1);
fragments.extend(c1);
let mut diff2 = f2.clone();
diff2.difference_with(h2);
let c2 = connected_components_under_edges(mol.graph(), &diff2);
fragments.extend(c2);
fragments.swap_remove(i1.max(i2));
fragments.swap_remove(i1.min(i2));
}
fragments.retain(|i| i.len() > 1);
fragments.push(h1.clone());
Some(fragments)
}
pub fn recurse_index_search(
mol: &Molecule,
matches: &Matches,
state: &State,
best_index: Arc<AtomicUsize>,
bounds: &[Bound],
cache: &mut Cache,
parallel_mode: ParallelMode,
) -> (usize, usize) {
if state_bounds(mol, state, best_index.load(Relaxed), bounds) || cache.memoize_state(mol, state)
{
return (state.index(), 1);
}
let (intermediate_frags, matches_to_remove): (Vec<BitSet>, Vec<usize>) =
matches.matches_to_remove(mol, state, best_index.load(Relaxed), bounds);
let best_child_index = AtomicUsize::from(state.index());
let states_searched = AtomicUsize::from(1);
let recurse_on_match = |i: usize, match_ix: usize| {
let (h1, h2) = matches.match_fragments(match_ix);
if let Some(fragments) = fragments(mol, &intermediate_frags, h1, h2) {
let new_parallel = if parallel_mode == ParallelMode::DepthOne {
ParallelMode::None
} else {
parallel_mode
};
let (child_index, child_states_searched) = recurse_index_search(
mol,
matches,
&state.update(fragments, i, match_ix, h1.len()),
best_index.clone(),
bounds,
&mut cache.clone(),
new_parallel,
);
best_child_index.fetch_min(child_index, Relaxed);
best_index.fetch_min(best_child_index.load(Relaxed), Relaxed);
states_searched.fetch_add(child_states_searched, Relaxed);
}
};
if parallel_mode == ParallelMode::None {
matches_to_remove
.iter()
.enumerate()
.for_each(|(i, match_ix)| recurse_on_match(i, *match_ix));
} else {
matches_to_remove
.iter()
.enumerate()
.par_bridge()
.for_each(|(i, match_ix)| recurse_on_match(i, *match_ix));
}
(
best_child_index.load(Relaxed),
states_searched.load(Relaxed),
)
}
pub fn index_search(
mol: &Molecule,
timeout: Option<u64>,
canonize_mode: CanonizeMode,
parallel_mode: ParallelMode,
memoize_mode: MemoizeMode,
kernel_mode: KernelMode,
bounds: &[Bound],
) -> (u32, u32, Option<usize>) {
if kernel_mode != KernelMode::None {
panic!("The chosen --kernel mode is not implemented yet!")
}
let state = State::new(mol);
let mut cache = Cache::new(memoize_mode, canonize_mode);
let matches = Matches::new(mol, canonize_mode);
let best_index = Arc::new(AtomicUsize::from(mol.graph().edge_count() - 1));
if let Some(timeout) = timeout {
let best_index_copy = best_index.clone();
let mol = mol.clone();
let bounds = bounds.to_vec();
let num_matches = matches.len();
let rt = Runtime::new().unwrap();
let result = rt.block_on(async {
let (send, recv) = oneshot::channel();
rayon::spawn(move || {
let _ = send.send(recurse_index_search(
&mol,
&matches,
&state,
best_index_copy,
&bounds,
&mut cache,
parallel_mode,
));
});
tktimeout(Duration::from_millis(timeout), recv).await
});
let (index, states_searched) = match result {
Ok(Ok((index, states_searched))) => (index, Some(states_searched)),
Err(_) => (best_index.load(Relaxed), None),
_ => panic!("An unexpected error occurred in async index_search"),
};
(index as u32, num_matches as u32, states_searched)
} else {
let (index, states_searched) = recurse_index_search(
mol,
&matches,
&state,
best_index,
bounds,
&mut cache,
parallel_mode,
);
(index as u32, matches.len() as u32, Some(states_searched))
}
}
pub fn index(mol: &Molecule) -> u32 {
index_search(
mol,
None,
CanonizeMode::TreeNauty,
ParallelMode::DepthOne,
MemoizeMode::CanonIndex,
KernelMode::None,
&[Bound::Int, Bound::MatchableEdges],
)
.0
}