assembly_theory/
memoize.rs

1//! Memoize assembly states to avoid redundant recursive search.
2
3use std::sync::Arc;
4
5use bit_set::BitSet;
6use clap::ValueEnum;
7use dashmap::DashMap;
8
9use crate::{
10    canonize::{canonize, CanonizeMode, Labeling},
11    molecule::Molecule,
12};
13
14/// Strategy for memoizing assembly states in the search phase.
15#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
16pub enum MemoizeMode {
17    /// Do not use memoization.
18    None,
19    /// Cache states by fragments and store their assembly index upper bounds.
20    FragsIndex,
21    /// Like `FragsIndex`, but cache states by fragments' canonical labelings,
22    /// allowing isomorphic assembly states to hash to the same value.
23    CanonIndex,
24}
25
26/// Key type for the memoization cache.
27#[derive(Clone, PartialEq, Eq, Hash)]
28enum CacheKey {
29    /// Use fragments as keys, as in [`MemoizeMode::FragsIndex`].
30    Frags(Vec<BitSet>),
31    /// Use fragments' canonical labelings as keys, as in
32    /// [`MemoizeMode::CanonIndex`].
33    Canon(Vec<Labeling>),
34}
35
36/// Struct for the memoization cache.
37#[derive(Clone)]
38pub struct Cache {
39    /// Memoization mode.
40    memoize_mode: MemoizeMode,
41    /// Canonization mode; only used with [`MemoizeMode::CanonIndex`].
42    canonize_mode: CanonizeMode,
43    /// A parallel-aware cache mapping keys (either fragments or canonical
44    /// labelings, depending on the memoization mode) to their assembly index
45    /// upper bounds and match removal order.
46    cache: Arc<DashMap<CacheKey, (usize, Vec<usize>)>>,
47    /// A parallel-aware map from fragments to their canonical labelings; only
48    /// used with [`MemoizeMode::CanonIndex`].
49    fragment_labels: Arc<DashMap<BitSet, Labeling>>,
50}
51
52impl Cache {
53    /// Construct a new [`Cache`] with the specified modes.
54    pub fn new(memoize_mode: MemoizeMode, canonize_mode: CanonizeMode) -> Self {
55        Self {
56            memoize_mode,
57            canonize_mode,
58            cache: Arc::new(DashMap::<CacheKey, (usize, Vec<usize>)>::new()),
59            fragment_labels: Arc::new(DashMap::<BitSet, Labeling>::new()),
60        }
61    }
62
63    /// Create a [`CacheKey`] for the given assembly state.
64    ///
65    /// If using [`MemoizeMode::FragsIndex`], keys are the lexicographically
66    /// sorted fragment [`BitSet`]s. If using [`MemoizeMode::CanonIndex`], keys
67    /// are lexicographically sorted fragment canonical labelings created using
68    /// the specified [`CanonizeMode`]. These labelings are stored for reuse.
69    fn key(&self, mol: &Molecule, state: &[BitSet]) -> Option<CacheKey> {
70        match self.memoize_mode {
71            MemoizeMode::None => None,
72            MemoizeMode::FragsIndex => {
73                let mut fragments = state.to_vec();
74                fragments.sort_by_key(|a| a.iter().next());
75                Some(CacheKey::Frags(fragments))
76            }
77            MemoizeMode::CanonIndex => {
78                let mut labelings: Vec<Labeling> = state
79                    .iter()
80                    .map(|fragment| {
81                        self.fragment_labels
82                            .entry(fragment.clone())
83                            .or_insert(canonize(mol, fragment, self.canonize_mode))
84                            .value()
85                            .clone()
86                    })
87                    .collect();
88                labelings.sort();
89                Some(CacheKey::Canon(labelings))
90            }
91        }
92    }
93
94    /// Return `true` iff memoization is enabled and this assembly state is
95    /// preempted by the cached assembly state.
96    /// See https://github.com/DaymudeLab/assembly-theory/pull/95 for details.
97    pub fn memoize_state(
98        &self,
99        mol: &Molecule,
100        state: &[BitSet],
101        state_index: usize,
102        removal_order: &Vec<usize>,
103    ) -> bool {
104        // If memoization is enabled, get this assembly state's cache key.
105        if let Some(cache_key) = self.key(mol, state) {
106            // Do all of the following atomically: Access the cache entry. If
107            // the cached entry has a worse index upper bound or later removal
108            // order than this state, or if it does not exist, then cache this
109            // state's values and return `false`. Otherwise, the cached entry
110            // preempts this assembly state, so return `true`.
111            let (cached_index, cached_order) = self
112                .cache
113                .entry(cache_key)
114                .and_modify(|val| {
115                    if val.0 > state_index || val.1 > *removal_order {
116                        val.0 = state_index;
117                        val.1 = removal_order.clone();
118                    }
119                })
120                .or_insert((state_index, removal_order.clone()))
121                .value()
122                .clone();
123            if cached_index <= state_index && cached_order < *removal_order {
124                return true;
125            }
126        }
127        false
128    }
129}