assembly_theory/memoize.rs
1//! Memoize assembly states to avoid redundant recursive search.
2
3use std::sync::Arc;
4
5use bit_set::BitSet;
6use clap::ValueEnum;
7use dashmap::DashMap;
8
9use crate::{
10 canonize::{canonize, CanonizeMode, Labeling},
11 molecule::Molecule,
12};
13
14/// Strategy for memoizing assembly states in the search phase.
15#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
16pub enum MemoizeMode {
17 /// Do not use memoization.
18 None,
19 /// Cache states by fragments and store their assembly index upper bounds.
20 FragsIndex,
21 /// Like `FragsIndex`, but cache states by fragments' canonical labelings,
22 /// allowing isomorphic assembly states to hash to the same value.
23 CanonIndex,
24}
25
26/// Key type for the memoization cache.
27#[derive(Clone, PartialEq, Eq, Hash)]
28enum CacheKey {
29 /// Use fragments as keys, as in [`MemoizeMode::FragsIndex`].
30 Frags(Vec<BitSet>),
31 /// Use fragments' canonical labelings as keys, as in
32 /// [`MemoizeMode::CanonIndex`].
33 Canon(Vec<Labeling>),
34}
35
36/// Struct for the memoization cache.
37#[derive(Clone)]
38pub struct Cache {
39 /// Memoization mode.
40 memoize_mode: MemoizeMode,
41 /// Canonization mode; only used with [`MemoizeMode::CanonIndex`].
42 canonize_mode: CanonizeMode,
43 /// A parallel-aware cache mapping keys (either fragments or canonical
44 /// labelings, depending on the memoization mode) to their assembly index
45 /// upper bounds and match removal order.
46 cache: Arc<DashMap<CacheKey, (usize, Vec<usize>)>>,
47 /// A parallel-aware map from fragments to their canonical labelings; only
48 /// used with [`MemoizeMode::CanonIndex`].
49 fragment_labels: Arc<DashMap<BitSet, Labeling>>,
50}
51
52impl Cache {
53 /// Construct a new [`Cache`] with the specified modes.
54 pub fn new(memoize_mode: MemoizeMode, canonize_mode: CanonizeMode) -> Self {
55 Self {
56 memoize_mode,
57 canonize_mode,
58 cache: Arc::new(DashMap::<CacheKey, (usize, Vec<usize>)>::new()),
59 fragment_labels: Arc::new(DashMap::<BitSet, Labeling>::new()),
60 }
61 }
62
63 /// Create a [`CacheKey`] for the given assembly state.
64 ///
65 /// If using [`MemoizeMode::FragsIndex`], keys are the lexicographically
66 /// sorted fragment [`BitSet`]s. If using [`MemoizeMode::CanonIndex`], keys
67 /// are lexicographically sorted fragment canonical labelings created using
68 /// the specified [`CanonizeMode`]. These labelings are stored for reuse.
69 fn key(&self, mol: &Molecule, state: &[BitSet]) -> Option<CacheKey> {
70 match self.memoize_mode {
71 MemoizeMode::None => None,
72 MemoizeMode::FragsIndex => {
73 let mut fragments = state.to_vec();
74 fragments.sort_by_key(|a| a.iter().next());
75 Some(CacheKey::Frags(fragments))
76 }
77 MemoizeMode::CanonIndex => {
78 let mut labelings: Vec<Labeling> = state
79 .iter()
80 .map(|fragment| {
81 self.fragment_labels
82 .entry(fragment.clone())
83 .or_insert(canonize(mol, fragment, self.canonize_mode))
84 .value()
85 .clone()
86 })
87 .collect();
88 labelings.sort();
89 Some(CacheKey::Canon(labelings))
90 }
91 }
92 }
93
94 /// Return `true` iff memoization is enabled and this assembly state is
95 /// preempted by the cached assembly state.
96 /// See https://github.com/DaymudeLab/assembly-theory/pull/95 for details.
97 pub fn memoize_state(
98 &self,
99 mol: &Molecule,
100 state: &[BitSet],
101 state_index: usize,
102 removal_order: &Vec<usize>,
103 ) -> bool {
104 // If memoization is enabled, get this assembly state's cache key.
105 if let Some(cache_key) = self.key(mol, state) {
106 // Do all of the following atomically: Access the cache entry. If
107 // the cached entry has a worse index upper bound or later removal
108 // order than this state, or if it does not exist, then cache this
109 // state's values and return `false`. Otherwise, the cached entry
110 // preempts this assembly state, so return `true`.
111 let (cached_index, cached_order) = self
112 .cache
113 .entry(cache_key)
114 .and_modify(|val| {
115 if val.0 > state_index || val.1 > *removal_order {
116 val.0 = state_index;
117 val.1 = removal_order.clone();
118 }
119 })
120 .or_insert((state_index, removal_order.clone()))
121 .value()
122 .clone();
123 if cached_index <= state_index && cached_order < *removal_order {
124 return true;
125 }
126 }
127 false
128 }
129}